From 36a947de63f615e4e975d80c7c279aecbf461722 Mon Sep 17 00:00:00 2001 From: Robin Webbers Date: Wed, 24 Dec 2025 15:39:25 +0100 Subject: [PATCH 1/2] Initial version of theory of array support! It has some caveats still. :) This requires a version of SBV that is not yet on stackage (at the time of writing) and ruins backwards compatibility. There is no optimisation of expressions involving theory of array expressions as of yet. --- examples/grisette-examples.cabal | 4 +- grisette.cabal | 4 +- src/Grisette/Internal/Backend/Solving.hs | 15 + .../Internal/Impl/Core/Data/Class/SymEq.hs | 13 +- .../Internal/Impl/Core/Data/Class/SymOrd.hs | 6 +- src/Grisette/Internal/SymPrim/Array.hs | 73 +++ src/Grisette/Internal/SymPrim/GeneralFun.hs | 12 + .../Prim/Internal/Instances/PEvalOrdTerm.hs | 44 +- .../SymPrim/Prim/Internal/Serialize.hs | 296 ++++++++----- .../Internal/SymPrim/Prim/Internal/Term.hs | 419 ++++++++++++++++-- src/Grisette/Internal/SymPrim/Prim/Pattern.hs | 11 +- src/Grisette/Internal/SymPrim/SymArray.hs | 159 +++++++ src/Grisette/SymPrim.hs | 6 + 13 files changed, 876 insertions(+), 186 deletions(-) create mode 100644 src/Grisette/Internal/SymPrim/Array.hs create mode 100644 src/Grisette/Internal/SymPrim/SymArray.hs diff --git a/examples/grisette-examples.cabal b/examples/grisette-examples.cabal index bad6f13f1..886f6a2df 100644 --- a/examples/grisette-examples.cabal +++ b/examples/grisette-examples.cabal @@ -1,11 +1,11 @@ cabal-version: 1.12 --- This file has been generated from package.yaml by hpack version 0.38.0. +-- This file has been generated from package.yaml by hpack version 0.38.1. -- -- see: https://github.com/sol/hpack name: grisette-examples -version: 0.13.0.0 +version: 0.13.0.1 synopsis: Examples for Grisette description: More examples are available in the [tutorials](https://github.com/lsrcz/grisette/tree/main/tutorials) of diff --git a/grisette.cabal b/grisette.cabal index 5ce70c779..615047051 100644 --- a/grisette.cabal +++ b/grisette.cabal @@ -1,6 +1,6 @@ cabal-version: 1.12 --- This file has been generated from package.yaml by hpack version 0.37.0. +-- This file has been generated from package.yaml by hpack version 0.38.1. -- -- see: https://github.com/sol/hpack @@ -116,6 +116,7 @@ library Grisette.Internal.Core.Data.UnionBase Grisette.Internal.SymPrim.AlgReal Grisette.Internal.SymPrim.AllSyms + Grisette.Internal.SymPrim.Array Grisette.Internal.SymPrim.BV Grisette.Internal.SymPrim.FP Grisette.Internal.SymPrim.FunInstanceGen @@ -147,6 +148,7 @@ library Grisette.Internal.SymPrim.Quantifier Grisette.Internal.SymPrim.SomeBV Grisette.Internal.SymPrim.SymAlgReal + Grisette.Internal.SymPrim.SymArray Grisette.Internal.SymPrim.SymBool Grisette.Internal.SymPrim.SymBV Grisette.Internal.SymPrim.SymFP diff --git a/src/Grisette/Internal/Backend/Solving.hs b/src/Grisette/Internal/Backend/Solving.hs index 8916a55d9..2d1db9b48 100644 --- a/src/Grisette/Internal/Backend/Solving.hs +++ b/src/Grisette/Internal/Backend/Solving.hs @@ -264,6 +264,9 @@ import Grisette.Internal.SymPrim.Prim.Term pattern SymTerm, pattern ToFPTerm, pattern XorBitsTerm, + pattern SelectTerm, + pattern StoreTerm, + pattern ConstArrayTerm, ) import Grisette.Internal.SymPrim.SymBool (SymBool (SymBool)) @@ -796,6 +799,18 @@ lowerSinglePrimCached t' m' = do mode <- goCached qs mode arg <- goCached qs arg return $ \qst -> sbvToFPTerm @b (mode qst) (arg qst) + goCachedIntermediate qs (SelectTerm (arr :: Term arr) key) = withPrim @arr $ do + arr' <- goCached qs arr + key' <- goCached qs key + pure $ \qst -> SBV.readArray (arr' qst) (key' qst) + goCachedIntermediate qs (StoreTerm arr key val) = withPrim @a $ do + arr' <- goCached qs arr + key' <- goCached qs key + val' <- goCached qs val + pure $ \qst -> SBV.writeArray (arr' qst) (key' qst) (val' qst) + goCachedIntermediate qs (ConstArrayTerm _ val) = withPrim @a $ do + val' <- goCached qs val + pure $ \qst -> SBV.constArray $ val' qst goCachedIntermediate _ ConTerm {} = error "Should not happen" goCachedIntermediate _ SymTerm {} = error "Should not happen" goCachedIntermediate _ ForallTerm {} = error "Should not happen" diff --git a/src/Grisette/Internal/Internal/Impl/Core/Data/Class/SymEq.hs b/src/Grisette/Internal/Internal/Impl/Core/Data/Class/SymEq.hs index b8fe9c838..44f2be6f5 100644 --- a/src/Grisette/Internal/Internal/Impl/Core/Data/Class/SymEq.hs +++ b/src/Grisette/Internal/Internal/Impl/Core/Data/Class/SymEq.hs @@ -80,8 +80,9 @@ import Grisette.Internal.SymPrim.FP ) import Grisette.Internal.SymPrim.Prim.Term ( SupportedPrim (pevalDistinctTerm), + LinkedRep (underlyingTerm, wrapTerm), + SupportedNonFuncPrim, pevalEqTerm, - underlyingTerm, ) import Grisette.Internal.SymPrim.SymAlgReal (SymAlgReal (SymAlgReal)) import Grisette.Internal.SymPrim.SymBV @@ -95,6 +96,7 @@ import Grisette.Internal.SymPrim.SymFP ) import Grisette.Internal.SymPrim.SymInteger (SymInteger (SymInteger)) import Grisette.Internal.TH.Derivation.Derive (derive) +import Grisette.Internal.SymPrim.SymArray (SymArray) #define CONCRETE_SEQ(type) \ instance SymEq type where \ @@ -185,6 +187,15 @@ instance (ValidFP eb sb) => SymEq (SymFP eb sb) where (SymFP l) .== (SymFP r) = SymBool $ pevalEqTerm l r {-# INLINE (.==) #-} +instance + ( SupportedNonFuncPrim ck, + SupportedNonFuncPrim cv, + LinkedRep ck sk, + LinkedRep cv sv + ) => + SymEq (SymArray sk sv) where + lhs .== rhs = wrapTerm $ pevalEqTerm (underlyingTerm lhs) (underlyingTerm rhs) + derive [ ''(), ''AssertionError, diff --git a/src/Grisette/Internal/Internal/Impl/Core/Data/Class/SymOrd.hs b/src/Grisette/Internal/Internal/Impl/Core/Data/Class/SymOrd.hs index 9a6cb8500..2c8f95973 100644 --- a/src/Grisette/Internal/Internal/Impl/Core/Data/Class/SymOrd.hs +++ b/src/Grisette/Internal/Internal/Impl/Core/Data/Class/SymOrd.hs @@ -100,10 +100,7 @@ import Grisette.Internal.SymPrim.SymBV SymWordN (SymWordN), ) import Grisette.Internal.SymPrim.SymBool (SymBool (SymBool)) -import Grisette.Internal.SymPrim.SymFP - ( SymFP (SymFP), - SymFPRoundingMode (SymFPRoundingMode), - ) +import Grisette.Internal.SymPrim.SymFP (SymFP (SymFP)) import Grisette.Internal.SymPrim.SymInteger (SymInteger (SymInteger)) import Grisette.Internal.TH.Derivation.Derive (derive) @@ -254,7 +251,6 @@ instance SymOrd SymBool where #if 1 SORD_SIMPLE(SymInteger) SORD_SIMPLE(SymAlgReal) -SORD_SIMPLE(SymFPRoundingMode) SORD_BV(SymIntN) SORD_BV(SymWordN) #endif diff --git a/src/Grisette/Internal/SymPrim/Array.hs b/src/Grisette/Internal/SymPrim/Array.hs new file mode 100644 index 000000000..ae6797f6f --- /dev/null +++ b/src/Grisette/Internal/SymPrim/Array.hs @@ -0,0 +1,73 @@ +{-# LANGUAGE DeriveAnyClass #-} +{-# LANGUAGE DeriveLift #-} +{-# LANGUAGE DeriveGeneric #-} +{-# LANGUAGE ExplicitForAll #-} +{-# LANGUAGE ImportQualifiedPost #-} + +-- | +-- Module : Grisette.Internal.SymPrim.Array +-- Copyright : (c) Sirui Lu 2021-2023 +-- License : BSD-3-Clause (see the LICENSE file) +-- +-- Maintainer : siruilu@cs.washington.edu +-- Stability : Experimental +-- Portability : GHC only +module Grisette.Internal.SymPrim.Array + ( Array (..) + , const + , select + , store + ) where + +import Control.DeepSeq (NFData) +import Data.Binary qualified as Binary +import Data.Bytes.Serial (Serial (serialize, deserialize)) +import Data.Hashable (Hashable) +import Data.HashMap.Strict qualified as HM +import Data.Serialize qualified as Cereal +import GHC.Generics (Generic) +import Language.Haskell.TH.Syntax (Lift) +import Prelude (Show, Eq, Ord) + +-- TODO: The equality of this array model is incorrect. The easy solution is +-- to disallow it entirely. Alternatively, I already have a version with a +-- working equality check. It works by canonicalising the array. +-- +-- Canonicalisation will not happen for keys with an infinite domain and +-- realistically also not for keys with a sufficiently large domain. In fact, +-- we avoid tracking information for canonicalisation in these cases altogether! +-- The main gripe with this is that at that point is that insertions do require +-- keys for which we know both their cardinality and can enumerate their domain. +-- The latter we could restrict to only enumerable domains given a finite +-- cardinality with type-level shenanigans, but still. +-- +-- Yet another alternative would be to simply accept that we cannot conclude +-- inequality if one of the arrays would require canonicalisation? Then we don't +-- need the additional typeclass constraints. This way, we could still perform +-- normalisation of most terms. +data Array k v = Array (HM.HashMap k v) v + deriving (Show, Eq, Ord, Generic, Lift, Hashable, NFData) + +instance (Hashable k, Serial k, Serial v) => Serial (Array k v) + +instance (Hashable k, Serial k, Serial v) => Cereal.Serialize (Array k v) where + put = serialize + get = deserialize + +instance (Hashable k, Serial k, Serial v) => Binary.Binary (Array k v) where + put = serialize + get = deserialize + +-- TODO: Perhaps it is nice to make this a typeclass and give it names that do +-- not require qualified imports? I don't necessarily mind the qualified import, +-- but we'll see what the library author thinks. +const :: forall k v. v -> Array k v +const = Array HM.empty + +select :: forall k v. Hashable k => Array k v -> k -> v +select (Array entries root) key = HM.lookupDefault root key entries + +store :: forall k v. Hashable k => Array k v -> k -> v -> Array k v +store (Array entries root) key value = do + let entries' = HM.insert key value entries + Array entries' root diff --git a/src/Grisette/Internal/SymPrim/GeneralFun.hs b/src/Grisette/Internal/SymPrim/GeneralFun.hs index d42782657..011b19fa8 100644 --- a/src/Grisette/Internal/SymPrim/GeneralFun.hs +++ b/src/Grisette/Internal/SymPrim/GeneralFun.hs @@ -100,6 +100,9 @@ import Grisette.Internal.SymPrim.Prim.Internal.Term PEvalOrdTerm (pevalLeOrdTerm, pevalLtOrdTerm), PEvalRotateTerm (pevalRotateRightTerm), PEvalShiftTerm (pevalShiftLeftTerm, pevalShiftRightTerm), + pevalSelectTerm, + pevalStoreTerm, + pevalConstArrayTerm, SBVRep (SBVType), SomeTypedAnySymbol, SomeTypedConstantSymbol, @@ -189,6 +192,9 @@ import Grisette.Internal.SymPrim.Prim.Internal.Term pattern SymTerm, pattern ToFPTerm, pattern XorBitsTerm, + pattern SelectTerm, + pattern StoreTerm, + pattern ConstArrayTerm, ) import Grisette.Internal.SymPrim.Prim.Pattern (pattern SubTerms) import Grisette.Internal.SymPrim.Prim.SomeTerm (SomeTerm (SomeTerm), someTerm) @@ -542,6 +548,12 @@ generalSubstSomeTerm subst initialBoundedSymbols = go initialMemo _ (SomeTerm (ToFPTerm mode (arg :: Term a) (_ :: p eb) (_ :: q sb))) = goBinary memo (pevalToFPTerm @a @eb @sb) mode arg + goSome memo _ (SomeTerm (SelectTerm arr key)) = + goBinary memo pevalSelectTerm arr key + goSome memo _ (SomeTerm (StoreTerm arr key val)) = + goTernary memo pevalStoreTerm arr key val + goSome memo _ (SomeTerm (ConstArrayTerm pkey val)) = + goUnary memo (pevalConstArrayTerm pkey) val goUnary memo f a = SomeTerm $ f (go memo a) goBinary memo f a b = SomeTerm $ f (go memo a) (go memo b) goTernary memo f a b c = diff --git a/src/Grisette/Internal/SymPrim/Prim/Internal/Instances/PEvalOrdTerm.hs b/src/Grisette/Internal/SymPrim/Prim/Internal/Instances/PEvalOrdTerm.hs index 9712c5f62..7414ffe63 100644 --- a/src/Grisette/Internal/SymPrim/Prim/Internal/Instances/PEvalOrdTerm.hs +++ b/src/Grisette/Internal/SymPrim/Prim/Internal/Instances/PEvalOrdTerm.hs @@ -30,9 +30,7 @@ import Grisette.Internal.SymPrim.AlgReal (AlgReal) import Grisette.Internal.SymPrim.BV (IntN, WordN) import Grisette.Internal.SymPrim.FP ( FP, - FPRoundingMode, ValidFP, - allFPRoundingMode, ) import Grisette.Internal.SymPrim.Prim.Internal.Instances.PEvalNumTerm () import Grisette.Internal.SymPrim.Prim.Internal.Term @@ -41,10 +39,9 @@ import Grisette.Internal.SymPrim.Prim.Internal.Term ( pevalLeOrdTerm, pevalLtOrdTerm, sbvLeOrdTerm, - sbvLtOrdTerm, withSbvOrdTermConstraint ), - SupportedPrim (conSBVTerm, withPrim), + SupportedPrim (withPrim), Term, conTerm, leOrdTerm, @@ -133,45 +130,6 @@ instance (ValidFP eb sb) => PEvalOrdTerm (FP eb sb) where (SBV.sNot (SBV.fpIsNaN x) SBV..&& SBV.sNot (SBV.fpIsNaN y)) SBV..&& (x SBV..<= y) --- Use this table to avoid accidental breakage introduced by sbv. -fpRoundingModeLtTable :: [(SBV.SRoundingMode, SBV.SRoundingMode)] -fpRoundingModeLtTable = - [ ( conSBVTerm @FPRoundingMode a, - conSBVTerm @FPRoundingMode b - ) - | a <- allFPRoundingMode, - b <- allFPRoundingMode, - a < b - ] - -fpRoundingModeLeTable :: [(SBV.SRoundingMode, SBV.SRoundingMode)] -fpRoundingModeLeTable = - [ ( conSBVTerm @FPRoundingMode a, - conSBVTerm @FPRoundingMode b - ) - | a <- allFPRoundingMode, - b <- allFPRoundingMode, - a <= b - ] - -sbvTableLookup :: - [(SBV.SRoundingMode, SBV.SRoundingMode)] -> - SBV.SRoundingMode -> - SBV.SRoundingMode -> - SBV.SBV Bool -sbvTableLookup tbl lhs rhs = - foldl - (\acc (a, b) -> acc SBV..|| ((lhs SBV..== a) SBV..&& (rhs SBV..== b))) - SBV.sFalse - tbl - -instance PEvalOrdTerm FPRoundingMode where - pevalLtOrdTerm = pevalGeneralLtOrdTerm - pevalLeOrdTerm = pevalGeneralLeOrdTerm - withSbvOrdTermConstraint r = withPrim @FPRoundingMode r - sbvLtOrdTerm = sbvTableLookup fpRoundingModeLtTable - sbvLeOrdTerm = sbvTableLookup fpRoundingModeLeTable - instance PEvalOrdTerm AlgReal where pevalLtOrdTerm = pevalGeneralLtOrdTerm pevalLeOrdTerm = pevalGeneralLeOrdTerm diff --git a/src/Grisette/Internal/SymPrim/Prim/Internal/Serialize.hs b/src/Grisette/Internal/SymPrim/Prim/Internal/Serialize.hs index cfa76b17b..ae81aee42 100644 --- a/src/Grisette/Internal/SymPrim/Prim/Internal/Serialize.hs +++ b/src/Grisette/Internal/SymPrim/Prim/Internal/Serialize.hs @@ -26,6 +26,7 @@ module Grisette.Internal.SymPrim.Prim.Internal.Serialize () where import Control.Monad (replicateM, unless, when) +import Control.Monad.Identity (Identity (runIdentity)) import Control.Monad.State (StateT, evalStateT) import qualified Control.Monad.State as State import qualified Data.Binary as Binary @@ -38,14 +39,17 @@ import qualified Data.HashSet as HS import Data.Hashable (Hashable (hashWithSalt)) import Data.List (intercalate) import Data.List.NonEmpty (NonEmpty ((:|))) +import Data.Maybe (isJust, fromMaybe) import Data.Proxy (Proxy (Proxy)) import qualified Data.Serialize as Cereal +import Data.Typeable (heqT) import Data.Word (Word8) import GHC.Generics (Generic) import GHC.Natural (Natural) import GHC.Stack (HasCallStack) import GHC.TypeNats (KnownNat, natVal, type (+), type (<=)) import Grisette.Internal.SymPrim.AlgReal (AlgReal) +import Grisette.Internal.SymPrim.Array (Array) import Grisette.Internal.SymPrim.BV (IntN, WordN) import Grisette.Internal.SymPrim.FP ( FP, @@ -144,6 +148,9 @@ import Grisette.Internal.SymPrim.Prim.Internal.Term termId, toFPTerm, xorBitsTerm, + selectTerm, + storeTerm, + constArrayTerm, pattern AbsNumTerm, pattern AddNumTerm, pattern AndBitsTerm, @@ -193,6 +200,9 @@ import Grisette.Internal.SymPrim.Prim.Internal.Term pattern SymTerm, pattern ToFPTerm, pattern XorBitsTerm, + pattern SelectTerm, + pattern StoreTerm, + pattern ConstArrayTerm, ) import Grisette.Internal.SymPrim.Prim.SomeTerm ( SomeTerm (SomeTerm), @@ -213,11 +223,9 @@ import Grisette.Internal.Utils.Parameterized unsafeLeqProof, ) import Type.Reflection - ( SomeTypeRep (SomeTypeRep), - TypeRep, + ( TypeRep, Typeable, eqTypeRep, - someTypeRep, typeRep, pattern App, pattern Con, @@ -233,6 +241,7 @@ data KnownNonFuncType where FPType :: (ValidFP eb sb) => Proxy eb -> Proxy sb -> KnownNonFuncType FPRoundingModeType :: KnownNonFuncType AlgRealType :: KnownNonFuncType + ArrayType :: KnownNonFuncType -> KnownNonFuncType -> KnownNonFuncType instance Eq KnownNonFuncType where BoolType == BoolType = True @@ -255,6 +264,8 @@ instance Hashable KnownNonFuncType where s `hashWithSalt` (4 :: Int) `hashWithSalt` natVal p `hashWithSalt` natVal q hashWithSalt s FPRoundingModeType = s `hashWithSalt` (5 :: Int) hashWithSalt s AlgRealType = s `hashWithSalt` (6 :: Int) + hashWithSalt s (ArrayType k v) = + s `hashWithSalt` k `hashWithSalt` v `hashWithSalt` (7 :: Int) data KnownNonFuncTypeWitness where KnownNonFuncTypeWitness :: @@ -280,6 +291,10 @@ witnessKnownNonFuncType (FPType (Proxy :: Proxy eb) (Proxy :: Proxy sb)) = witnessKnownNonFuncType FPRoundingModeType = KnownNonFuncTypeWitness (Proxy @FPRoundingMode) witnessKnownNonFuncType AlgRealType = KnownNonFuncTypeWitness (Proxy @AlgReal) +witnessKnownNonFuncType (ArrayType k v) = runIdentity $ do + KnownNonFuncTypeWitness (_ :: Proxy k) <- pure $ witnessKnownNonFuncType k + KnownNonFuncTypeWitness (_ :: Proxy v) <- pure $ witnessKnownNonFuncType v + pure $ KnownNonFuncTypeWitness @(Array k v) Proxy data KnownType where NonFuncType :: KnownNonFuncType -> KnownType @@ -496,91 +511,90 @@ instance Show KnownNonFuncType where <> show (natVal (Proxy @sb)) show FPRoundingModeType = "FPRoundingMode" show AlgRealType = "AlgReal" + show (ArrayType key val) = "Array (" ++ show key ++ ") (" ++ show val ++ ")" instance Show KnownType where show (NonFuncType t) = show t show (TabularFunType ts) = intercalate " =-> " $ show <$> ts show (GeneralFunType ts) = intercalate " --> " $ show <$> ts -knownNonFuncType :: - forall a p. (SupportedNonFuncPrim a) => p a -> KnownNonFuncType -knownNonFuncType _ = - case tr of - _ | SomeTypeRep tr == someTypeRep (Proxy @Bool) -> BoolType - _ | SomeTypeRep tr == someTypeRep (Proxy @Integer) -> IntegerType - _ - | SomeTypeRep tr == someTypeRep (Proxy @FPRoundingMode) -> - FPRoundingModeType - _ | SomeTypeRep tr == someTypeRep (Proxy @AlgReal) -> AlgRealType - App (ta@(Con _) :: TypeRep w) (_ :: TypeRep n) -> - case ( eqTypeRep ta (typeRep @WordN), - eqTypeRep ta (typeRep @IntN) - ) of - (Just HRefl, _) -> withPrim @a $ WordNType (Proxy @n) - (_, Just HRefl) -> withPrim @a $ IntNType (Proxy @n) - _ -> err - App (App (tf :: TypeRep f) (_ :: TypeRep a0)) (_ :: TypeRep a1) -> - case eqTypeRep tf (typeRep @FP) of - Just HRefl -> withPrim @a $ FPType (Proxy @a0) (Proxy @a1) - _ -> err - _ -> err +knownNonFuncTypeMaybe :: + forall a p. SupportedPrim a => p a -> Maybe KnownNonFuncType +knownNonFuncTypeMaybe _ = withPrim @a $ case tr of + _ | isTy @Bool Proxy -> pure BoolType + | isTy @Integer Proxy -> pure IntegerType + | isTy @FPRoundingMode Proxy -> pure FPRoundingModeType + | isTy @AlgReal Proxy -> pure AlgRealType + App (ta@(Con _) :: TypeRep w) (_ :: TypeRep n) + | Just HRefl <- eqTypeRep ta $ typeRep @WordN -> pure $ WordNType @n Proxy + | Just HRefl <- eqTypeRep ta $ typeRep @IntN -> pure $ IntNType @n Proxy + App (App (tf :: TypeRep f) (_ :: TypeRep eb)) (_ :: TypeRep es) + | Just HRefl <- eqTypeRep tf $ typeRep @FP -> do + pure $ FPType (Proxy @eb) (Proxy @es) + App (App arrR (_ :: TypeRep k)) (_ :: TypeRep v) + | Just HRefl <- eqTypeRep arrR $ typeRep @Array -> do + keyTy <- knownNonFuncTypeMaybe @k Proxy + valTy <- knownNonFuncTypeMaybe @v Proxy + pure $ ArrayType keyTy valTy + _ -> Nothing where - tr = primTypeRep @a - err = error $ "knownNonFuncType: unsupported type: " <> show tr - -knownType :: - forall a p. (SupportedPrim a) => p a -> KnownType -knownType _ = - case tr of - _ | SomeTypeRep tr == someTypeRep (Proxy @Bool) -> NonFuncType BoolType - _ - | SomeTypeRep tr == someTypeRep (Proxy @Integer) -> - NonFuncType IntegerType - _ - | SomeTypeRep tr == someTypeRep (Proxy @FPRoundingMode) -> - NonFuncType FPRoundingModeType - _ - | SomeTypeRep tr == someTypeRep (Proxy @AlgReal) -> - NonFuncType AlgRealType - App (ta@(Con _) :: TypeRep w) (_ :: TypeRep n) -> - case ( eqTypeRep ta (typeRep @WordN), - eqTypeRep ta (typeRep @IntN) - ) of - (Just HRefl, _) -> withPrim @a $ NonFuncType $ WordNType (Proxy @n) - (_, Just HRefl) -> withPrim @a $ NonFuncType $ IntNType (Proxy @n) - _ -> err - App (App (tf :: TypeRep f) (_ :: TypeRep a0)) (_ :: TypeRep a1) -> - case ( eqTypeRep tf (typeRep @FP), - eqTypeRep tf (typeRep @(=->)), - eqTypeRep tf (typeRep @(-->)) - ) of - (Just HRefl, _, _) -> - withPrim @a $ NonFuncType $ FPType (Proxy @a0) (Proxy @a1) - (_, Just HRefl, _) -> - withPrim @a $ - let arg = knownType (Proxy @a0) - ret = knownType (Proxy @a1) - in case arg of - NonFuncType n -> case ret of - NonFuncType m -> TabularFunType [n, m] - TabularFunType ns -> TabularFunType (n : ns) - _ -> err - _ -> err - (_, _, Just HRefl) -> - withPrim @a $ - let arg = knownType (Proxy @a0) - ret = knownType (Proxy @a1) - in case arg of - NonFuncType n -> case ret of - NonFuncType m -> GeneralFunType [n, m] - GeneralFunType ns -> GeneralFunType (n : ns) - _ -> err - _ -> err - _ -> err - _ -> err + tr = typeRep @a + + isTy :: forall b. Typeable b => Proxy b -> Bool + isTy _ = isJust . eqTypeRep tr $ typeRep @b + +knownNonFuncType :: + forall a p. SupportedPrim a => p a -> KnownNonFuncType +knownNonFuncType proxy = do + let err = error $ "knownNonFuncType: unsupported type: " <> show (typeRep @a) + fromMaybe err $ knownNonFuncTypeMaybe proxy + +knownTypeMaybe :: forall a p. SupportedPrim a => p a -> Maybe KnownType +knownTypeMaybe proxy = withPrim @a $ case tr of + _ | Just result <- knownNonFuncTypeMaybe proxy -> pure $ NonFuncType result + App (App (funR :: TypeRep f) (_ :: TypeRep arg)) (_ :: TypeRep res) + | Just HRefl <- eqTypeRep funR $ typeRep @(=->) -> do + -- Gather the argument type. + arg <- knownTypeMaybe @arg Proxy + n <- case arg of + NonFuncType n -> pure n + _ -> Nothing + + -- Gather the result type. + ret <- knownTypeMaybe @res Proxy + ns <- case ret of + NonFuncType m -> pure [m] + TabularFunType ns -> pure ns + _ -> Nothing + + -- Create the tabular function type. + pure $ TabularFunType (n : ns) + + | Just HRefl <- eqTypeRep funR $ typeRep @(-->) -> do + -- Gather the argument type. + arg <- knownTypeMaybe @arg Proxy + n <- case arg of + NonFuncType n -> pure n + _ -> Nothing + + -- Gather the result type. + ret <- knownTypeMaybe @res Proxy + ns <- case ret of + NonFuncType m -> pure [m] + GeneralFunType ns -> pure ns + _ -> Nothing + + -- Create the general function type. + pure $ GeneralFunType (n : ns) + + _ -> Nothing where - tr = primTypeRep @a - err = error $ "knownType: unsupported type: " <> show tr + tr = typeRep @a + +knownType :: forall a p. SupportedPrim a => p a -> KnownType +knownType proxy = do + let err = error $ "knownType: unsupported type: " <> show (typeRep @a) + fromMaybe err $ knownTypeMaybe proxy -- Bool: 0 -- Integer: 1 @@ -589,6 +603,7 @@ knownType _ = -- FP: 4 -- FPRoundingMode: 5 -- AlgReal: 6 +-- Array: 7 serializeKnownNonFuncType :: (MonadPut m) => KnownNonFuncType -> m () serializeKnownNonFuncType BoolType = putWord8 0 serializeKnownNonFuncType IntegerType = putWord8 1 @@ -600,6 +615,10 @@ serializeKnownNonFuncType (FPType (Proxy :: Proxy eb) (Proxy :: Proxy sb)) = putWord8 4 >> serialize (natVal (Proxy @eb)) >> serialize (natVal (Proxy @sb)) serializeKnownNonFuncType FPRoundingModeType = putWord8 5 serializeKnownNonFuncType AlgRealType = putWord8 6 +serializeKnownNonFuncType (ArrayType key val) = do + putWord8 7 + serializeKnownNonFuncType key + serializeKnownNonFuncType val serializeKnownType :: (MonadPut m) => KnownType -> m () serializeKnownType (NonFuncType t) = putWord8 0 >> serializeKnownNonFuncType t @@ -639,6 +658,10 @@ deserializeKnownNonFuncType = do withUnsafeValidFP @eb @sb $ return $ FPType (Proxy @eb) (Proxy @sb) 5 -> return FPRoundingModeType 6 -> return AlgRealType + 7 -> do + keyT <- deserializeKnownNonFuncType + valT <- deserializeKnownNonFuncType + pure $ ArrayType keyT valT _ -> fail "deserializeKnownNonFuncType: Unknown type tag" deserializeKnownType :: (MonadGet m) => m KnownType @@ -877,6 +900,15 @@ fromFPOrTermTag = 46 toFPTermTag :: Word8 toFPTermTag = 47 +selectTermTag :: Word8 +selectTermTag = 48 + +storeTermTag :: Word8 +storeTermTag = 49 + +constArrayTermTag :: Word8 +constArrayTermTag = 50 + terminalTag :: Word8 terminalTag = 255 @@ -931,33 +963,18 @@ asNumTypeTerm (SomeTerm (t1 :: Term a)) f = err = error $ "asNumTypeTerm: unsupported type: " <> show ta asOrdTypeTerm :: - (HasCallStack) => SomeTerm -> (forall n. (PEvalOrdTerm n) => Term n -> r) -> r -asOrdTypeTerm (SomeTerm (t1 :: Term a)) f = - case ( eqTypeRep ta (typeRep @Integer), - eqTypeRep ta (typeRep @AlgReal), - eqTypeRep ta (typeRep @FPRoundingMode) - ) of - (Just HRefl, _, _) -> f t1 - (_, Just HRefl, _) -> f t1 - (_, _, Just HRefl) -> f t1 - _ -> - case ta of - App (ta@(Con _) :: TypeRep w) (_ :: TypeRep n) -> - case ( eqTypeRep ta (typeRep @WordN), - eqTypeRep ta (typeRep @IntN) - ) of - (Just HRefl, _) -> withPrim @a $ f t1 - (_, Just HRefl) -> withPrim @a $ f t1 - _ -> err - App (App (tf :: TypeRep f) (_ :: TypeRep a0)) (_ :: TypeRep a1) -> - case eqTypeRep tf (typeRep @FP) of - Just HRefl -> - withPrim @a $ withPrim @a $ f t1 - _ -> err - _ -> err + HasCallStack => SomeTerm -> (forall n. PEvalOrdTerm n => Term n -> r) -> r +asOrdTypeTerm (SomeTerm (t1 :: Term a)) f = case ta of + _ | Just HRefl <- eqTypeRep ta $ typeRep @Integer -> f t1 + _ | Just HRefl <- eqTypeRep ta $ typeRep @AlgReal -> f t1 + App bvR _nR + | Just HRefl <- eqTypeRep bvR $ typeRep @WordN -> withPrim @a $ f t1 + | Just HRefl <- eqTypeRep bvR $ typeRep @IntN -> withPrim @a $ f t1 + App (App fpR _ebR) _esR + | Just HRefl <- eqTypeRep fpR $ typeRep @FP -> withPrim @a $ f t1 + _ -> error $ "asNumTypeTerm: unsupported type: " <> show ta where ta = primTypeRep @a - err = error $ "asOrdTypeTerm: unsupported type: " <> show ta asBitsTypeTerm :: (HasCallStack) => @@ -1527,6 +1544,67 @@ statefulDeserializeSomeTerm = do ktTmId ) else error "statefulDeserializeSomeTerm: invalid FP type" + | tag == selectTermTag -> do + -- Deserialize the array and key. + SomeTerm (arr :: Term a) <- deserializeTerm + SomeTerm (key :: Term k) <- deserializeTerm + + -- Deserialize the resulting value type and get the required + -- dictionaries for this type. + -- TODO: How do I know I get the correct known type here? + valType <- deserializeKnownType + KnownTypeWitness (_ :: Proxy v) <- pure $ witnessKnownType valType + + -- Ensure that the key is indeed a valid key for the array and that + -- the resulting value matches the value type such that we can provide + -- the dictionary. Using this, we construct the final term. + let term = case typeRep @a of + App (App aRep kRep) vRep + | Just HRefl <- eqTypeRep aRep $ typeRep @Array + , Just HRefl <- eqTypeRep kRep $ typeRep @k + , Just HRefl <- eqTypeRep vRep $ typeRep @v -> do + someTerm $ selectTerm @k @v arr key + _ -> error "statefulDeserializeSomeTerm: invalid Array type" + + pure $ Just (term, ktTmId) + | tag == storeTermTag -> do + -- Deserialize the array, key and value. + SomeTerm (arr :: Term a) <- deserializeTerm + SomeTerm (key :: Term k) <- deserializeTerm + SomeTerm (val :: Term v) <- deserializeTerm + + -- Ensure the types match up such that we can construct the term. + let term = case typeRep @a of + App (App aRep kRep) vRep + | Just HRefl <- eqTypeRep aRep $ typeRep @Array + , Just HRefl <- eqTypeRep kRep $ typeRep @k + , Just HRefl <- eqTypeRep vRep $ typeRep @v -> do + someTerm $ storeTerm @k @v arr key val + _ -> error "statefulDeserializeSomeTerm: invalid Array type" + + pure $ Just (term, ktTmId) + | tag == constArrayTermTag -> do + -- Get the value term and non-function primitive dictionary. + SomeTerm (val :: Term v) <- deserializeTerm + let valType = knownNonFuncType @v Proxy + KnownNonFuncTypeWitness (_ :: p v') <- do + pure $ witnessKnownNonFuncType valType + + -- Gather the key type and its non-function primitive dictionary + -- TODO: How do I know I get the correct known type for the key? + keyType <- deserializeKnownNonFuncType + KnownNonFuncTypeWitness (_ :: p k) <- do + pure $ witnessKnownNonFuncType keyType + + -- Really, this should never fail but I guess we can check instead of + -- coercing unsafely. + HRefl <- case heqT @v @v' of + Just refl -> pure refl + Nothing -> error "statefulDeserializeSomeTerm: non-injective type translation" + + let term = someTerm $ constArrayTerm @k Proxy val + + pure $ Just (term, ktTmId) | otherwise -> error $ "statefulDeserializeSomeTerm: unknown tag: " <> show tag case r of @@ -1811,6 +1889,14 @@ serializeSingleSomeTerm (SomeTerm (tm :: Term t)) = do serialize $ natVal sb serialize $ knownTypeTermId rd serialize $ knownTypeTermId t + SelectTerm arr key -> do + serializeBinary ktTmId selectTermTag arr key + serializeKnownType $ knownType @t Proxy + StoreTerm arr key val -> do + serializeTernary ktTmId storeTermTag arr key val + ConstArrayTerm pkey val -> withPrim @t $ do + serializeUnary ktTmId constArrayTermTag val + serializeKnownType $ knownType pkey State.put $ HS.insert ktTmId st where serializeQuantified :: diff --git a/src/Grisette/Internal/SymPrim/Prim/Internal/Term.hs b/src/Grisette/Internal/SymPrim/Prim/Internal/Term.hs index 112a48f4c..8f4c1fbf0 100644 --- a/src/Grisette/Internal/SymPrim/Prim/Internal/Term.hs +++ b/src/Grisette/Internal/SymPrim/Prim/Internal/Term.hs @@ -69,6 +69,9 @@ module Grisette.Internal.SymPrim.Prim.Internal.Term PEvalFloatingTerm (..), PEvalFromIntegralTerm (..), PEvalIEEEFPConvertibleTerm (..), + pevalSelectTerm, + pevalStoreTerm, + pevalConstArrayTerm, -- * Typed symbols SymbolKind (..), @@ -165,6 +168,9 @@ module Grisette.Internal.SymPrim.Prim.Internal.Term fromIntegralTerm, fromFPOrTerm, toFPTerm, + selectTerm, + storeTerm, + constArrayTerm, -- * Patterns pattern SupportedTerm, @@ -220,6 +226,9 @@ module Grisette.Internal.SymPrim.Prim.Internal.Term pattern FromIntegralTerm, pattern FromFPOrTerm, pattern ToFPTerm, + pattern SelectTerm, + pattern StoreTerm, + pattern ConstArrayTerm, -- * Support for boolean type trueTerm, @@ -329,6 +338,7 @@ import qualified Control.Monad.Writer.Lazy as Lazy import qualified Control.Monad.Writer.Strict as Strict import Data.Atomics (atomicModifyIORefCAS_) import qualified Data.Binary as Binary +import Data.Bifunctor (Bifunctor(bimap)) import Data.Bits ( Bits (complement, isSigned, xor, zeroBits, (.&.), (.|.)), FiniteBits (countLeadingZeros), @@ -377,6 +387,7 @@ import Grisette.Internal.Core.Data.Symbol Symbol (IndexedSymbol, SimpleSymbol), ) import Grisette.Internal.SymPrim.AlgReal (AlgReal, fromSBVAlgReal, toSBVAlgReal) +import Grisette.Internal.SymPrim.Array (Array (Array)) import Grisette.Internal.SymPrim.BV (IntN, WordN) import Grisette.Internal.SymPrim.FP ( FP (FP), @@ -495,7 +506,8 @@ class Eq a, Show a, Hashable a, - Typeable a + Typeable a, + SBVType a ~ SBV.SBV (NonFuncSBVBaseType a) ) => NonFuncSBVRep a where @@ -509,7 +521,7 @@ type NonFuncPrimConstraint a = SBV.Mergeable (SBVType a), SBV.SMTDefinable (SBVType a), SBV.Mergeable (SBVType a), - SBVType a ~ SBV.SBV (NonFuncSBVBaseType a), + SBVT.SatModel (NonFuncSBVBaseType a), PrimConstraint a ) @@ -520,6 +532,7 @@ class (NonFuncSBVRep a) => SupportedNonFuncPrim a where symNonFuncSBVTerm :: (SBVFreshMonad m) => String -> m (SBV.SBV (NonFuncSBVBaseType a)) withNonFuncPrim :: ((NonFuncPrimConstraint a) => r) -> r + sbvToCon :: NonFuncSBVBaseType a -> a -- | Partition the list of CVs for models for functions. partitionCVArg :: @@ -644,6 +657,13 @@ class (SBVT.EqSymbolic (SBVType t)) => NonEmpty (SBVType t) -> SBV.SBV Bool sbvDistinct = SBV.distinct . toList parseSMTModelResult :: Int -> ([([SBVD.CV], SBVD.CV)], SBVD.CV) -> t + default parseSMTModelResult :: + SupportedNonFuncPrim t => + Int -> + ([([SBVD.CV], SBVD.CV)], SBVD.CV) -> + t + parseSMTModelResult _ = withNonFuncPrim @t $ do + parseScalarSMTModelResult sbvToCon castTypedSymbol :: (IsSymbolKind knd') => TypedSymbol knd t -> Maybe (TypedSymbol knd' t) funcDummyConstraint :: SBVType t -> SBV.SBV Bool @@ -1734,6 +1754,25 @@ data Term t where Proxy eb -> Proxy sb -> Term (FP eb sb) + SelectTerm' :: + SupportedPrim (Array k v) => + {-# UNPACK #-} !CachedInfo -> + !(Term (Array k v)) -> + !(Term k) -> + Term v + StoreTerm' :: + SupportedPrim (Array k v) => + {-# UNPACK #-} !CachedInfo -> + !(Term (Array k v)) -> + !(Term k) -> + !(Term v) -> + Term (Array k v) + ConstArrayTerm' :: + SupportedPrim (Array k v) => + {-# UNPACK #-} !CachedInfo -> + Proxy k -> + !(Term v) -> + Term (Array k v) data SupportedPrimEvidence t where SupportedPrimEvidence :: (SupportedPrim t) => SupportedPrimEvidence t @@ -2699,6 +2738,67 @@ pattern ToFPTerm rm t eb sb <- (ToFPTerm' _ rm t@SupportedTerm eb sb) {-# INLINE ToFPTerm #-} #endif +-- | Pattern synonym for 'SelectTerm''. Note that using this pattern to +-- construct a 'Term' will do term simplification. +pattern SelectTerm :: + forall ret. + () => + forall k v. + ( SupportedPrim (Array k v), + ret ~ v + ) => + Term (Array k v) -> + Term k -> + Term ret +pattern SelectTerm arr key <- SelectTerm' _ arr key + where + SelectTerm arr key = pevalSelectTerm arr key + +#if MIN_VERSION_base(4, 16, 4) +{-# INLINE SelectTerm #-} +#endif + +-- | Pattern synonym for 'StoreTerm''. Note that using this pattern to +-- construct a 'Term' will do term simplification. +pattern StoreTerm :: + forall ret. + () => + forall k v. + ( SupportedPrim (Array k v), + ret ~ Array k v + ) => + Term (Array k v) -> + Term k -> + Term v -> + Term ret +pattern StoreTerm arr key val <- StoreTerm' _ arr key val + where + StoreTerm arr key = pevalStoreTerm arr key + +#if MIN_VERSION_base(4, 16, 4) +{-# INLINE StoreTerm #-} +#endif + +-- | Pattern synonym for 'StoreTerm''. Note that using this pattern to +-- construct a 'Term' will do term simplification. +pattern ConstArrayTerm :: + forall ret. + () => + forall k v. + ( SupportedPrim (Array k v), + ret ~ Array k v + ) => + Proxy k -> + Term v -> + Term ret +pattern ConstArrayTerm pkey val <- ConstArrayTerm' _ pkey val + where + ConstArrayTerm pkey val = pevalConstArrayTerm pkey val + +#if MIN_VERSION_base(4, 16, 4) +{-# INLINE ConstArrayTerm #-} +#endif + #if MIN_VERSION_base(4, 16, 4) {-# COMPLETE ConTerm, @@ -2748,7 +2848,10 @@ pattern ToFPTerm rm t eb sb <- (ToFPTerm' _ rm t@SupportedTerm eb sb) FPFMATerm, FromIntegralTerm, FromFPOrTerm, - ToFPTerm + ToFPTerm, + SelectTerm, + StoreTerm, + ConstArrayTerm #-} #endif @@ -2802,6 +2905,9 @@ termInfo (FPFMATerm' i _ _ _ _) = i termInfo (FromIntegralTerm' i _) = i termInfo (FromFPOrTerm' i _ _ _) = i termInfo (ToFPTerm' i _ _ _ _) = i +termInfo (SelectTerm' i _ _) = i +termInfo (StoreTerm' i _ _ _) = i +termInfo (ConstArrayTerm' i _ _) = i -- | Get the thread ID for a term. {-# INLINE termThreadId #-} @@ -2923,6 +3029,10 @@ introSupportedPrimConstraint0 FPFMATerm' {} x = x introSupportedPrimConstraint0 FromIntegralTerm' {} x = x introSupportedPrimConstraint0 FromFPOrTerm' {} x = x introSupportedPrimConstraint0 ToFPTerm' {} x = x +introSupportedPrimConstraint0 (SelectTerm' _ (_ :: Term arr) _) x = do + withPrim @arr x +introSupportedPrimConstraint0 StoreTerm' {} x = x +introSupportedPrimConstraint0 ConstArrayTerm' {} x = x -- | Introduce the 'SupportedPrim' constraint from a term. introSupportedPrimConstraint :: @@ -2984,6 +3094,9 @@ pformatTerm (FPFMATerm mode arg1 arg2 arg3) = pformatTerm (FromIntegralTerm arg) = "(from_integral " ++ pformatTerm arg ++ ")" pformatTerm (FromFPOrTerm d r arg) = "(from_fp_or " ++ pformatTerm d ++ " " ++ pformatTerm r ++ " " ++ pformatTerm arg ++ ")" pformatTerm (ToFPTerm r arg _ _) = "(to_fp " ++ pformatTerm r ++ " " ++ pformatTerm arg ++ ")" +pformatTerm (SelectTerm arr key) = "(select " ++ pformatTerm arr ++ " " ++ pformatTerm key ++ ")" +pformatTerm (StoreTerm arr key val) = "(store " ++ pformatTerm arr ++ " " ++ pformatTerm key ++ " " ++ pformatTerm val ++ ")" +pformatTerm (ConstArrayTerm _ val) = "(const_array " ++ pformatTerm val ++ ")" -- {-# INLINE pformatTerm #-} @@ -3054,6 +3167,11 @@ instance Lift (Term t) where liftTyped (FromFPOrTerm t1 t2 t3) = [||fromFPOrTerm t1 t2 t3||] liftTyped (ToFPTerm t1 t2 _ _) = [||toFPTerm t1 t2||] + liftTyped (SelectTerm t1 t2) = [||selectTerm t1 t2||] + liftTyped (StoreTerm t1 t2 t3) = [||storeTerm t1 t2 t3||] + liftTyped (ConstArrayTerm (_ :: p k) t2) = do + let pkey = [||Proxy||] :: CODE (Proxy k) + [||constArrayTerm $$pkey t2||] instance Show (Term ty) where show t@(ConTerm v) = @@ -3530,6 +3648,38 @@ instance Show (Term ty) where ++ ", arg=" ++ show arg ++ "}" + show t@(SelectTerm arr key) = + "SelectTerm{tid=" + ++ show (termThreadId t) + ++ ", id=" + ++ show (termId t) + ++ ", array=" + ++ show arr + ++ ", key=" + ++ show key + ++ "}" + show t@(StoreTerm arr key val) = + "StoreTerm{tid=" + ++ show (termThreadId t) + ++ ", id=" + ++ show (termId t) + ++ ", array=" + ++ show arr + ++ ", key=" + ++ show key + ++ ", val=" + ++ show val + ++ "}" + show t@(ConstArrayTerm (_ :: p k) val) = + "ConstArrayTerm{tid=" + ++ show (termThreadId t) + ++ ", id=" + ++ show (termId t) + ++ ", key=" + ++ withPrim @ty (show $ typeRep @k) + ++ ", val=" + ++ show val + ++ "}" -- {-# INLINE show #-} @@ -3751,6 +3901,22 @@ data UTerm t where Proxy eb -> Proxy sb -> UTerm (FP eb sb) + USelectTerm :: + SupportedPrim (Array k v) => + !(Term (Array k v)) -> + !(Term k) -> + UTerm v + UStoreTerm :: + SupportedPrim (Array k v) => + !(Term (Array k v)) -> + !(Term k) -> + !(Term v) -> + UTerm (Array k v) + UConstArrayTerm :: + SupportedPrim (Array k v) => + Proxy k -> + !(Term v) -> + UTerm (Array k v) -- | Compare two t'TypedSymbol's for equality. eqHeteroSymbol :: forall ta a tb b. TypedSymbol ta a -> TypedSymbol tb b -> Bool @@ -4017,6 +4183,20 @@ preHashToFPTermDescription h1 h2 = fromIntegral (50 `hashWithSalt` h1 `hashWithSalt` h2) {-# INLINE preHashToFPTermDescription #-} +preHashSelectDescription :: HashId -> HashId -> Digest +preHashSelectDescription h1 h2 = + fromIntegral (51 `hashWithSalt` h1 `hashWithSalt` h2) +{-# INLINE preHashSelectDescription #-} + +preHashStoreDescription :: HashId -> HashId -> HashId -> Digest +preHashStoreDescription h1 h2 h3 = + fromIntegral (52 `hashWithSalt` h1 `hashWithSalt` h2 `hashWithSalt` h3) +{-# INLINE preHashStoreDescription #-} + +preHashConstArrayDescription :: HashId -> Digest +preHashConstArrayDescription h1 = fromIntegral (53 `hashWithSalt` h1) +{-# INLINE preHashConstArrayDescription #-} + instance Interned (Term t) where type Uninterned (Term t) = UTerm t data Description (Term t) where @@ -4265,6 +4445,22 @@ instance Interned (Term t) where {-# UNPACK #-} !HashId -> {-# UNPACK #-} !TypeHashId -> Description (Term (FP eb sb)) + DSelectTerm :: + {-# UNPACK #-} !Digest -> + {-# UNPACK #-} !HashId -> + {-# UNPACK #-} !HashId -> + Description (Term v) + DStoreTerm :: + {-# UNPACK #-} !Digest -> + {-# UNPACK #-} !HashId -> + {-# UNPACK #-} !HashId -> + {-# UNPACK #-} !HashId -> + Description (Term v) + DConstArrayTerm :: + {-# UNPACK #-} !Digest -> + {-# UNPACK #-} !Fingerprint -> + {-# UNPACK #-} !HashId -> + Description (Term v) describe (UConTerm v) = DConTerm sameCon (preHashConDescription v) v describe ((USymTerm name) :: UTerm t) = @@ -4576,6 +4772,22 @@ instance Interned (Term t) where (preHashToFPTermDescription modeHashId argHashId) modeHashId argHashId + describe (USelectTerm arr key) = do + let arrHashId = termHashId arr + let keyHashId = termHashId key + let digest = preHashSelectDescription arrHashId keyHashId + DSelectTerm digest arrHashId keyHashId + describe (UStoreTerm arr key val) = do + let arrHashId = termHashId arr + let keyHashId = termHashId key + let valHashId = termHashId val + let digest = preHashStoreDescription arrHashId keyHashId valHashId + DStoreTerm digest arrHashId keyHashId valHashId + describe (UConstArrayTerm pkey val) = withPrim @t $ do + let keyFingerprint = typeRepFingerprint $ someTypeRep pkey + let valHashId = termHashId val + let digest = preHashConstArrayDescription valHashId + DConstArrayTerm digest keyFingerprint valHashId -- {-# INLINE describe #-} @@ -4640,6 +4852,9 @@ instance Interned (Term t) where go (UFromFPOrTerm d mode arg) = FromFPOrTerm' info d mode arg go (UToFPTerm mode (arg :: Term a) _ _) = goPhantomToFP info getPhantomDict mode arg + go (USelectTerm arr key) = SelectTerm' info arr key + go (UStoreTerm arr key val) = StoreTerm' info arr key val + go (UConstArrayTerm proxy val) = ConstArrayTerm' info proxy val {-# INLINE go #-} -- {-# INLINE identify #-} @@ -4694,6 +4909,9 @@ instance Interned (Term t) where descriptionDigest (DFromIntegralTerm h _) = h descriptionDigest (DFromFPOrTerm h _ _ _) = h descriptionDigest (DToFPTerm h _ _) = h + descriptionDigest (DSelectTerm h _ _) = h + descriptionDigest (DStoreTerm h _ _ _) = h + descriptionDigest (DConstArrayTerm h _ _) = h -- {-# INLINE descriptionDigest #-} {-# NOINLINE goPhantomCon #-} @@ -5087,6 +5305,18 @@ fullReconstructTerm (FromFPOrTerm d r arg) = fullReconstructTerm3 curThreadFromFPOrTerm d r arg fullReconstructTerm (ToFPTerm r arg _ _) = fullReconstructTerm2 curThreadToFPTerm r arg +fullReconstructTerm (SelectTerm (arr :: Term arr) key) = withPrim @arr $ do + arr' <- fullReconstructTerm arr + key' <- fullReconstructTerm key + intern $ USelectTerm arr' key' +fullReconstructTerm (StoreTerm arr key val) = do + arr' <- fullReconstructTerm arr + key' <- fullReconstructTerm key + val' <- fullReconstructTerm val + intern $ UStoreTerm arr' key' val' +fullReconstructTerm (ConstArrayTerm pkey val) = do + val' <- fullReconstructTerm val + intern $ UConstArrayTerm pkey val' toCurThreadImpl :: forall t. WeakThreadId -> Term t -> IO (Term t) toCurThreadImpl tid t | termThreadId t == tid = return t @@ -5505,6 +5735,36 @@ curThreadToFPTerm :: curThreadToFPTerm r f = intern $ UToFPTerm r f (Proxy @eb) (Proxy @sb) {-# INLINE curThreadToFPTerm #-} +-- | Construct and internalizing a 'SelectTerm'. +curThreadSelectTerm :: + forall k v. + SupportedPrim (Array k v) => + Term (Array k v) -> + Term k -> + IO (Term v) +curThreadSelectTerm arr key = withPrim @(Array k v) $ do + intern $ USelectTerm arr key +{-# INLINE curThreadSelectTerm #-} + +curThreadStoreTerm :: + forall k v. + SupportedPrim (Array k v) => + Term (Array k v) -> + Term k -> + Term v -> + IO (Term (Array k v)) +curThreadStoreTerm arr key val = intern $ UStoreTerm arr key val +{-# INLINE curThreadStoreTerm #-} + +curThreadConstArrayTerm :: + forall k v. + SupportedPrim (Array k v) => + Proxy k -> + Term v -> + IO (Term (Array k v)) +curThreadConstArrayTerm pkey val = intern $ UConstArrayTerm pkey val +{-# INLINE curThreadConstArrayTerm #-} + inCurThread1 :: forall a b. (Term a -> IO (Term b)) -> @@ -6045,6 +6305,37 @@ toFPTerm :: toFPTerm = unsafeInCurThread2 curThreadToFPTerm {-# NOINLINE toFPTerm #-} +-- | Construct and internalizing a 'SelectTerm'. +selectTerm :: + forall k v. + SupportedPrim (Array k v) => + Term (Array k v) -> + Term k -> + Term v +selectTerm = unsafeInCurThread2 curThreadSelectTerm +{-# NOINLINE selectTerm #-} + +-- | Construct and internalizing a 'StoreTerm'. +storeTerm :: + forall k v. + SupportedPrim (Array k v) => + Term (Array k v) -> + Term k -> + Term v -> + Term (Array k v) +storeTerm = unsafeInCurThread3 curThreadStoreTerm +{-# NOINLINE storeTerm #-} + +-- | Construct and internalizing a 'ConstArrayTerm'. +constArrayTerm :: + forall k v. + SupportedPrim (Array k v) => + Proxy k -> + Term v -> + Term (Array k v) +constArrayTerm pkey = unsafeInCurThread1 $ curThreadConstArrayTerm pkey +{-# NOINLINE constArrayTerm #-} + -- Support for boolean type defaultValueForBool :: Bool defaultValueForBool = False @@ -6692,7 +6983,6 @@ instance SupportedPrim Bool where symSBVName symbol _ = show symbol symSBVTerm = sbvFresh withPrim r = r - parseSMTModelResult _ = parseScalarSMTModelResult id castTypedSymbol :: forall knd knd'. (IsSymbolKind knd') => @@ -6711,6 +7001,7 @@ instance SupportedNonFuncPrim Bool where conNonFuncSBVTerm = conSBVTerm symNonFuncSBVTerm = symSBVTerm @Bool withNonFuncPrim r = r + sbvToCon = id data PhantomDict a where PhantomDict :: (SupportedPrim a) => PhantomDict a @@ -6807,7 +7098,6 @@ instance SupportedPrim Integer where conSBVTerm n = fromInteger n symSBVName symbol _ = show symbol symSBVTerm name = sbvFresh name - parseSMTModelResult _ = parseScalarSMTModelResult id castTypedSymbol :: forall knd knd'. (IsSymbolKind knd') => @@ -6826,6 +7116,7 @@ instance SupportedNonFuncPrim Integer where conNonFuncSBVTerm = conSBVTerm symNonFuncSBVTerm = symSBVTerm @Integer withNonFuncPrim r = r + sbvToCon = id pevalITEBVTerm :: forall bv n. @@ -6955,9 +7246,6 @@ instance (KnownNat w, 1 <= w) => SupportedPrim (IntN w) where symSBVTerm name = bvIsNonZeroFromGEq1 (Proxy @w) $ sbvFresh name withPrim r = bvIsNonZeroFromGEq1 (Proxy @w) r {-# INLINE withPrim #-} - parseSMTModelResult _ cv = - withPrim @(IntN w) $ - parseScalarSMTModelResult (\(x :: SBV.IntN w) -> fromIntegral x) cv castTypedSymbol :: forall knd knd'. (IsSymbolKind knd') => @@ -6988,6 +7276,7 @@ instance (KnownNat w, 1 <= w) => SupportedNonFuncPrim (IntN w) where conNonFuncSBVTerm = conSBVTerm symNonFuncSBVTerm = symSBVTerm @(IntN w) withNonFuncPrim r = bvIsNonZeroFromGEq1 (Proxy @w) r + sbvToCon = withPrim @(IntN w) fromIntegral -- Unsigned BV instance (KnownNat w, 1 <= w) => SupportedPrimConstraint (WordN w) where @@ -7014,9 +7303,6 @@ instance (KnownNat w, 1 <= w) => SupportedPrim (WordN w) where symSBVTerm name = bvIsNonZeroFromGEq1 (Proxy @w) $ sbvFresh name withPrim r = bvIsNonZeroFromGEq1 (Proxy @w) r {-# INLINE withPrim #-} - parseSMTModelResult _ cv = - withPrim @(WordN w) $ - parseScalarSMTModelResult (\(x :: SBV.WordN w) -> fromIntegral x) cv castTypedSymbol :: forall knd knd'. (IsSymbolKind knd') => @@ -7035,6 +7321,7 @@ instance (KnownNat w, 1 <= w) => SupportedNonFuncPrim (WordN w) where conNonFuncSBVTerm = conSBVTerm symNonFuncSBVTerm = symSBVTerm @(WordN w) withNonFuncPrim r = bvIsNonZeroFromGEq1 (Proxy @w) r + sbvToCon = withPrim @(WordN w) fromIntegral -- FP instance (ValidFP eb sb) => SupportedPrimConstraint (FP eb sb) where @@ -7067,9 +7354,6 @@ instance (ValidFP eb sb) => SupportedPrim (FP eb sb) where conSBVTerm (FP fp) = SBV.literal fp symSBVName symbol _ = show symbol symSBVTerm name = sbvFresh name - parseSMTModelResult _ cv = - withPrim @(FP eb sb) $ - parseScalarSMTModelResult (\(x :: SBV.FloatingPoint eb sb) -> coerce x) cv funcDummyConstraint _ = SBV.sTrue -- Workaround for sbv#702. @@ -7101,6 +7385,7 @@ instance (ValidFP eb sb) => SupportedNonFuncPrim (FP eb sb) where conNonFuncSBVTerm = conSBVTerm symNonFuncSBVTerm = symSBVTerm @(FP eb sb) withNonFuncPrim r = r + sbvToCon = coerce -- FPRoundingMode instance SupportedPrimConstraint FPRoundingMode @@ -7122,17 +7407,6 @@ instance SupportedPrim FPRoundingMode where conSBVTerm RTZ = SBV.sRTZ symSBVName symbol _ = show symbol symSBVTerm name = sbvFresh name - parseSMTModelResult _ cv = - withPrim @(FPRoundingMode) $ - parseScalarSMTModelResult - ( \(x :: SBV.RoundingMode) -> case x of - SBV.RoundNearestTiesToEven -> RNE - SBV.RoundNearestTiesToAway -> RNA - SBV.RoundTowardPositive -> RTP - SBV.RoundTowardNegative -> RTN - SBV.RoundTowardZero -> RTZ - ) - cv castTypedSymbol :: forall knd knd'. (IsSymbolKind knd') => @@ -7151,6 +7425,12 @@ instance SupportedNonFuncPrim FPRoundingMode where conNonFuncSBVTerm = conSBVTerm symNonFuncSBVTerm = symSBVTerm @FPRoundingMode withNonFuncPrim r = r + sbvToCon mode = case mode of + SBV.RoundNearestTiesToEven -> RNE + SBV.RoundNearestTiesToAway -> RNA + SBV.RoundTowardPositive -> RTP + SBV.RoundTowardNegative -> RTN + SBV.RoundTowardZero -> RTZ -- AlgReal @@ -7169,9 +7449,6 @@ instance SupportedPrim AlgReal where conSBVTerm = SBV.literal . toSBVAlgReal symSBVName symbol _ = show symbol symSBVTerm name = sbvFresh name - parseSMTModelResult _ cv = - withPrim @AlgReal $ - parseScalarSMTModelResult fromSBVAlgReal cv castTypedSymbol :: forall knd knd'. (IsSymbolKind knd') => @@ -7190,6 +7467,92 @@ instance SupportedNonFuncPrim AlgReal where conNonFuncSBVTerm = conSBVTerm symNonFuncSBVTerm = symSBVTerm @AlgReal withNonFuncPrim r = r + sbvToCon = fromSBVAlgReal + +-- Array + +pevalSelectTerm :: + forall k v. + SupportedPrim (Array k v) => + Term (Array k v) -> + Term k -> + Term v +pevalSelectTerm = selectTerm -- TODO: perform optimisation + +pevalStoreTerm :: + forall k v. + SupportedPrim (Array k v) => + Term (Array k v) -> + Term k -> + Term v -> + Term (Array k v) +pevalStoreTerm = storeTerm -- TODO: perform optimisation + +pevalConstArrayTerm :: + forall k v. + SupportedPrim (Array k v) => + Proxy k -> + Term v -> + Term (Array k v) +pevalConstArrayTerm = constArrayTerm -- TODO: perform optimisation + +instance SupportedPrimConstraint (Array k v) where + type PrimConstraint (Array k v) = + ( SupportedNonFuncPrim k + , SupportedNonFuncPrim v + , SBVT.SymVal (NonFuncSBVBaseType k) + , SBVT.SymVal (NonFuncSBVBaseType v) + ) + +instance SBVRep (Array k v) where + type SBVType (Array k v) = SBV.SArray (NonFuncSBVBaseType k) (NonFuncSBVBaseType v) + +instance + ( SupportedNonFuncPrim k + , SupportedNonFuncPrim v + ) => SupportedPrim (Array k v) where + defaultValue = Array mempty defaultValue + pevalITETerm = pevalITEBasicTerm + pevalEqTerm = pevalDefaultEqTerm + pevalDistinctTerm = pevalGeneralDistinct + conSBVTerm (Array entries def) = withNonFuncPrim @(Array k v) $ do + let root = SBV.constArray $ conSBVTerm def + let foldlWithKeyBy acc xs f = HM.foldlWithKey' f acc xs + foldlWithKeyBy root entries $ \acc key val -> do + SBV.writeArray acc (conSBVTerm key) (conSBVTerm val) + symSBVName x _ = show x + symSBVTerm = withNonFuncPrim @(Array k v) $ sbvFresh + withPrim = withNonFuncPrim @(Array k v) + sbvEq = withPrim @(Array k v) (SBV..==) + sbvDistinct = withPrim @(Array k v) $ SBV.distinct . toList + castTypedSymbol :: + forall knd' knd. + IsSymbolKind knd' => + TypedSymbol knd (Array k v) -> + Maybe (TypedSymbol knd' (Array k v)) + castTypedSymbol = pure . case decideSymbolKind @knd' of + Left HRefl -> TypedSymbol . unTypedSymbol + Right HRefl -> TypedSymbol . unTypedSymbol + funcDummyConstraint _ = SBV.sTrue + +instance + ( SupportedNonFuncPrim k, Ord k, Typeable k, Hashable k, Show k + , SupportedNonFuncPrim v, Ord v, Typeable v, Hashable v, Show v + ) => NonFuncSBVRep (Array k v) where + type NonFuncSBVBaseType (Array k v) = SBV.ArrayModel (NonFuncSBVBaseType k) (NonFuncSBVBaseType v) + +instance + ( SupportedNonFuncPrim k + , SupportedNonFuncPrim v + ) => SupportedNonFuncPrim (Array k v) where + conNonFuncSBVTerm = conSBVTerm + symNonFuncSBVTerm = withNonFuncPrim @(Array k v) sbvFresh + withNonFuncPrim = withNonFuncPrim @k $ withNonFuncPrim @v $ id + sbvToCon (SBV.ArrayModel tbl def) = do + -- NOTE: We reverse the list as later elements should take precedence. + let tbl' = HM.fromList . reverse . fmap (bimap sbvToCon sbvToCon) $ tbl + let def' = sbvToCon def + Array tbl' def' -- Bitwise diff --git a/src/Grisette/Internal/SymPrim/Prim/Pattern.hs b/src/Grisette/Internal/SymPrim/Prim/Pattern.hs index cc0c2f6e1..e5cbd6e2c 100644 --- a/src/Grisette/Internal/SymPrim/Prim/Pattern.hs +++ b/src/Grisette/Internal/SymPrim/Prim/Pattern.hs @@ -1,6 +1,8 @@ {-# LANGUAGE CPP #-} {-# LANGUAGE GADTs #-} {-# LANGUAGE PatternSynonyms #-} +{-# LANGUAGE ScopedTypeVariables #-} +{-# LANGUAGE TypeApplications #-} {-# LANGUAGE ViewPatterns #-} -- | @@ -19,6 +21,7 @@ where import Data.Foldable (Foldable (toList)) import Grisette.Internal.SymPrim.Prim.Internal.Term ( Term, + SupportedPrim (withPrim), pattern AbsNumTerm, pattern AddNumTerm, pattern AndBitsTerm, @@ -67,10 +70,13 @@ import Grisette.Internal.SymPrim.Prim.Internal.Term pattern SymTerm, pattern ToFPTerm, pattern XorBitsTerm, + pattern SelectTerm, + pattern StoreTerm, + pattern ConstArrayTerm, ) import Grisette.Internal.SymPrim.Prim.SomeTerm (SomeTerm (SomeTerm)) -subTermsViewPattern :: Term a -> Maybe [SomeTerm] +subTermsViewPattern :: forall a. Term a -> Maybe [SomeTerm] subTermsViewPattern (ConTerm _) = return [] subTermsViewPattern (SymTerm _) = return [] subTermsViewPattern (ForallTerm _ t) = return [SomeTerm t] @@ -121,6 +127,9 @@ subTermsViewPattern (FPFMATerm rd t1 t2 t3) = subTermsViewPattern (FromIntegralTerm t) = return [SomeTerm t] subTermsViewPattern (FromFPOrTerm t1 rd t2) = return [SomeTerm t1, SomeTerm rd, SomeTerm t2] subTermsViewPattern (ToFPTerm rd t1 _ _) = return [SomeTerm rd, SomeTerm t1] +subTermsViewPattern (SelectTerm (t1 :: Term arr) t2) = withPrim @arr $ return [SomeTerm t1, SomeTerm t2] +subTermsViewPattern (StoreTerm t1 t2 t3) = withPrim @a $ return [SomeTerm t1, SomeTerm t2, SomeTerm t3] +subTermsViewPattern (ConstArrayTerm _ t1) = withPrim @a $ return [SomeTerm t1] -- | Extract all the subterms of a term. pattern SubTerms :: [SomeTerm] -> Term a diff --git a/src/Grisette/Internal/SymPrim/SymArray.hs b/src/Grisette/Internal/SymPrim/SymArray.hs new file mode 100644 index 000000000..3cee6ad3e --- /dev/null +++ b/src/Grisette/Internal/SymPrim/SymArray.hs @@ -0,0 +1,159 @@ +{-# LANGUAGE DeriveGeneric #-} +{-# LANGUAGE DeriveLift #-} +{-# LANGUAGE GeneralizedNewtypeDeriving #-} +{-# LANGUAGE ImportQualifiedPost #-} +{-# LANGUAGE MultiParamTypeClasses #-} +{-# LANGUAGE PatternSynonyms #-} +{-# LANGUAGE ScopedTypeVariables #-} +{-# LANGUAGE TypeApplications #-} +{-# LANGUAGE FlexibleContexts #-} +{-# LANGUAGE TypeFamilies #-} +{-# LANGUAGE FlexibleInstances #-} +{-# LANGUAGE UndecidableInstances #-} + +-- | +-- Module : Grisette.Internal.SymPrim.Array +-- Copyright : (c) Sirui Lu 2021-2023 +-- License : BSD-3-Clause (see the LICENSE file) +-- +-- Maintainer : siruilu@cs.washington.edu +-- Stability : Experimental +-- Portability : GHC only +module Grisette.Internal.SymPrim.SymArray + ( SymArray (..) + , const + , select + , store + ) where + +import Control.DeepSeq (NFData) +import Data.Binary qualified as Binary +import Data.Bytes.Serial (Serial (deserialize, serialize)) +import Data.Data (Proxy(Proxy)) +import Data.Serialize qualified as Cereal +import Data.String (IsString (fromString)) +import Grisette.Internal.SymPrim.Array (Array) +import Grisette.Internal.SymPrim.Prim.Internal.Term + ( Term + , SupportedNonFuncPrim + , ConRep (ConType) + , SymRep (SymType) + , LinkedRep (underlyingTerm, wrapTerm) + , conTerm + , typedConstantSymbol + , symTerm + , pattern ConTerm + , pattern SelectTerm + , pattern StoreTerm + , pattern ConstArrayTerm + ) +import Grisette.Internal.SymPrim.Prim.Internal.Serialize () +import Grisette.Internal.Core.Data.Class.Solvable (Solvable (con, sym, conView), ssym) +import GHC.Generics (Generic) +import Language.Haskell.TH.Syntax (Lift) +import Prelude (Maybe (Just, Nothing), (<$>), ($), (.)) + +newtype SymArray k v = SymArray { underlyingArrayTerm :: Term (Array (ConType k) (ConType v)) } + deriving (Lift, NFData, Generic) + +instance ConRep (SymArray k v) where + type ConType (SymArray k v) = Array (ConType k) (ConType v) + +instance (SupportedNonFuncPrim k, SupportedNonFuncPrim v) => SymRep (Array k v) where + type SymType (Array k v) = SymArray (SymType k) (SymType v) + +instance + ( SupportedNonFuncPrim ck, + SupportedNonFuncPrim cv, + LinkedRep ck sk, + LinkedRep cv sv + ) => + LinkedRep (Array ck cv) (SymArray sk sv) where + underlyingTerm = underlyingArrayTerm + wrapTerm = SymArray + +instance + ( SupportedNonFuncPrim ck, + SupportedNonFuncPrim cv, + LinkedRep ck sk, + LinkedRep cv sv + ) => + Solvable (Array ck cv) (SymArray sk sv) where + con = wrapTerm . conTerm + sym = wrapTerm . symTerm . typedConstantSymbol + conView v = case underlyingTerm v of + ConTerm t -> Just t + _ -> Nothing + +instance + ( SupportedNonFuncPrim ck, + SupportedNonFuncPrim cv, + LinkedRep ck sk, + LinkedRep cv sv + ) => + IsString (SymArray sk sv) where + fromString = ssym . fromString + +instance + ( SupportedNonFuncPrim ck, + SupportedNonFuncPrim cv, + LinkedRep ck sk, + LinkedRep cv sv + ) => + Serial (SymArray sk sv) where + serialize = serialize . underlyingArrayTerm + deserialize = SymArray <$> deserialize + +instance + ( SupportedNonFuncPrim ck, + SupportedNonFuncPrim cv, + LinkedRep ck sk, + LinkedRep cv sv + ) => + Cereal.Serialize (SymArray sk sv) where + put = serialize + get = deserialize + +instance + ( SupportedNonFuncPrim ck, + SupportedNonFuncPrim cv, + LinkedRep ck sk, + LinkedRep cv sv + ) => + Binary.Binary (SymArray sk sv) where + put = serialize + get = deserialize + +const + :: forall k v + . SupportedNonFuncPrim (ConType k) + => SupportedNonFuncPrim (ConType v) + => LinkedRep (ConType k) k + => LinkedRep (ConType v) v + => v + -> SymArray k v +const val = wrapTerm $ ConstArrayTerm Proxy (underlyingTerm val) + +select + :: forall k v + . SupportedNonFuncPrim (ConType k) + => SupportedNonFuncPrim (ConType v) + => LinkedRep (ConType k) k + => LinkedRep (ConType v) v + => SymArray k v + -> k + -> v +select arr key = wrapTerm $ SelectTerm (underlyingTerm arr) (underlyingTerm key) + +store + :: forall k v + . SupportedNonFuncPrim (ConType k) + => SupportedNonFuncPrim (ConType v) + => LinkedRep (ConType k) k + => LinkedRep (ConType v) v + => SymArray k v + -> k + -> v + -> SymArray k v +store arr key val = do + wrapTerm $ StoreTerm (underlyingTerm arr) (underlyingTerm key) (underlyingTerm val) diff --git a/src/Grisette/SymPrim.hs b/src/Grisette/SymPrim.hs index 6ca8637d5..97684f6a1 100644 --- a/src/Grisette/SymPrim.hs +++ b/src/Grisette/SymPrim.hs @@ -290,6 +290,9 @@ module Grisette.SymPrim pattern FromIntegralTerm, pattern FromFPOrTerm, pattern ToFPTerm, + pattern SelectTerm, + pattern StoreTerm, + pattern ConstArrayTerm, pattern SubTerms, ) where @@ -416,6 +419,9 @@ import Grisette.Internal.SymPrim.Prim.Term pattern SupportedTypedSymbol, pattern SymTerm, pattern ToFPTerm, + pattern SelectTerm, + pattern StoreTerm, + pattern ConstArrayTerm, pattern XorBitsTerm, ) import Grisette.Internal.SymPrim.Prim.TermUtils From ae4d837886efb2e7838f89271f343d6fa8130388 Mon Sep 17 00:00:00 2001 From: Robin Webbers Date: Wed, 11 Mar 2026 09:26:41 +0100 Subject: [PATCH 2/2] Added typeclass common typeclass instances for symbolic arrays. --- grisette.cabal | 6 +++--- package.yaml | 2 +- src/Grisette/Internal/Core/Data/Class/ITEOp.hs | 12 ++++++++++++ .../Internal/Impl/Core/Data/Class/EvalSym.hs | 14 ++++++++++++++ .../Internal/Impl/Core/Data/Class/Mergeable.hs | 15 +++++++++++++++ .../Impl/Core/Data/Class/SimpleMergeable.hs | 14 ++++++++++++-- src/Grisette/Internal/SymPrim/SymArray.hs | 10 +++++++--- stack-9.10.yaml | 6 ++---- stack.yaml.lock | 9 ++++++++- 9 files changed, 74 insertions(+), 14 deletions(-) diff --git a/grisette.cabal b/grisette.cabal index 615047051..fe74e6404 100644 --- a/grisette.cabal +++ b/grisette.cabal @@ -331,7 +331,7 @@ library , mtl >=2.2.2 && <2.4 , parallel >=3.2.2 && <3.3 , prettyprinter >=1.5.0 && <1.8 - , sbv >=8.17 && <13 + , sbv >=13.4 , stm ==2.5.* , template-haskell >=2.16 && <2.24 , text >=1.2.4.1 && <2.2 @@ -380,7 +380,7 @@ test-suite doctest , mtl >=2.2.2 && <2.4 , parallel >=3.2.2 && <3.3 , prettyprinter >=1.5.0 && <1.8 - , sbv >=8.17 && <13 + , sbv >=13.4 , stm ==2.5.* , template-haskell >=2.16 && <2.24 , text >=1.2.4.1 && <2.2 @@ -501,7 +501,7 @@ test-suite spec , mtl >=2.2.2 && <2.4 , parallel >=3.2.2 && <3.3 , prettyprinter >=1.5.0 && <1.8 - , sbv >=8.17 && <13 + , sbv >=13.4 , stm ==2.5.* , template-haskell >=2.16 && <2.24 , test-framework >=0.8.2 && <0.9 diff --git a/package.yaml b/package.yaml index f3f65e630..ea9d6e98d 100644 --- a/package.yaml +++ b/package.yaml @@ -36,7 +36,7 @@ dependencies: - th-compat >= 0.1.2 && < 0.2 - th-abstraction >= 0.4 && < 0.8 - array >= 0.5.4 && < 0.6 - - sbv >= 8.17 && < 13 + - sbv >= 13.4 - parallel >= 3.2.2 && < 3.3 - text >= 1.2.4.1 && < 2.2 - QuickCheck >= 2.14 && < 2.17 diff --git a/src/Grisette/Internal/Core/Data/Class/ITEOp.hs b/src/Grisette/Internal/Core/Data/Class/ITEOp.hs index 51842ebc7..48cb37019 100644 --- a/src/Grisette/Internal/Core/Data/Class/ITEOp.hs +++ b/src/Grisette/Internal/Core/Data/Class/ITEOp.hs @@ -33,10 +33,13 @@ import Grisette.Internal.SymPrim.GeneralFun import Grisette.Internal.SymPrim.Prim.SomeTerm (SomeTerm (SomeTerm)) import Grisette.Internal.SymPrim.Prim.Term ( SupportedPrim (pevalITETerm), + SupportedNonFuncPrim, + LinkedRep, TypedConstantSymbol, symTerm, ) import Grisette.Internal.SymPrim.SymAlgReal (SymAlgReal (SymAlgReal)) +import Grisette.Internal.SymPrim.SymArray (SymArray (SymArray)) import Grisette.Internal.SymPrim.SymBV ( SymIntN (SymIntN), SymWordN (SymWordN), @@ -93,6 +96,15 @@ ITEOP_FUN((=->), (=~>), SymTabularFun) ITEOP_FUN((-->), (-~>), SymGeneralFun) #endif +instance + ( SupportedNonFuncPrim ck, + SupportedNonFuncPrim cv, + LinkedRep ck sk, + LinkedRep cv sv + ) => + ITEOp (SymArray sk sv) where + symIte (SymBool c) (SymArray t) (SymArray f) = SymArray $ pevalITETerm c t f + instance ITEOp (a --> b) where symIte (SymBool c) diff --git a/src/Grisette/Internal/Internal/Impl/Core/Data/Class/EvalSym.hs b/src/Grisette/Internal/Internal/Impl/Core/Data/Class/EvalSym.hs index 8d92bb4ee..0c9eab183 100644 --- a/src/Grisette/Internal/Internal/Impl/Core/Data/Class/EvalSym.hs +++ b/src/Grisette/Internal/Internal/Impl/Core/Data/Class/EvalSym.hs @@ -75,6 +75,7 @@ import Grisette.Internal.Internal.Decl.Core.Data.Class.EvalSym evalSym1, ) import Grisette.Internal.SymPrim.AlgReal (AlgReal) +import Grisette.Internal.SymPrim.Array (Array) import Grisette.Internal.SymPrim.BV (IntN, WordN) import Grisette.Internal.SymPrim.FP ( FP, @@ -86,9 +87,12 @@ import Grisette.Internal.SymPrim.GeneralFun (type (-->) (GeneralFun)) import Grisette.Internal.SymPrim.Prim.Model (evalTerm) import Grisette.Internal.SymPrim.Prim.Term ( SymRep (SymType), + SupportedNonFuncPrim, + LinkedRep, someTypedSymbol, ) import Grisette.Internal.SymPrim.SymAlgReal (SymAlgReal (SymAlgReal)) +import Grisette.Internal.SymPrim.SymArray (SymArray (SymArray)) import Grisette.Internal.SymPrim.SymBV ( SymIntN (SymIntN), SymWordN (SymWordN), @@ -137,6 +141,7 @@ CONCRETE_EVALUATESYM(Ordering) CONCRETE_EVALUATESYM_BV(IntN) CONCRETE_EVALUATESYM_BV(WordN) CONCRETE_EVALUATESYM(AlgReal) +CONCRETE_EVALUATESYM((Array k v)) #endif instance EvalSym (Proxy a) where @@ -186,6 +191,15 @@ instance (ValidFP eb sb) => EvalSym (SymFP eb sb) where evalSym fillDefault model (SymFP t) = SymFP $ evalTerm fillDefault model HS.empty t +instance + ( SupportedNonFuncPrim ck, + SupportedNonFuncPrim cv, + LinkedRep ck sk, + LinkedRep cv sv + ) => + EvalSym (SymArray sk sv) where + evalSym fill model (SymArray t) = SymArray $ evalTerm fill model HS.empty t + derive [ ''(), ''AssertionError, diff --git a/src/Grisette/Internal/Internal/Impl/Core/Data/Class/Mergeable.hs b/src/Grisette/Internal/Internal/Impl/Core/Data/Class/Mergeable.hs index 63c26b5c0..6937aef4c 100644 --- a/src/Grisette/Internal/Internal/Impl/Core/Data/Class/Mergeable.hs +++ b/src/Grisette/Internal/Internal/Impl/Core/Data/Class/Mergeable.hs @@ -90,6 +90,7 @@ import Grisette.Internal.Internal.Decl.Core.Data.Class.Mergeable wrapStrategy, ) import Grisette.Internal.SymPrim.AlgReal (AlgReal, AlgRealPoly, RealPoint) +import Grisette.Internal.SymPrim.Array (Array) import Grisette.Internal.SymPrim.BV ( IntN, WordN, @@ -103,6 +104,7 @@ import Grisette.Internal.SymPrim.FP ) import Grisette.Internal.SymPrim.GeneralFun (type (-->)) import Grisette.Internal.SymPrim.SymAlgReal (SymAlgReal) +import Grisette.Internal.SymPrim.SymArray (SymArray) import Grisette.Internal.SymPrim.SymBV (SymIntN, SymWordN) import Grisette.Internal.SymPrim.SymFP (SymFP, SymFPRoundingMode) import Grisette.Internal.SymPrim.SymGeneralFun (type (-~>)) @@ -111,6 +113,7 @@ import Grisette.Internal.SymPrim.SymTabularFun (type (=~>)) import Grisette.Internal.SymPrim.TabularFun (type (=->)) import Grisette.Internal.TH.Derivation.Derive (derive) import Unsafe.Coerce (unsafeCoerce) +import Grisette.Internal.SymPrim.Prim.Internal.Term (SupportedNonFuncPrim, LinkedRep) #define CONCRETE_ORD_MERGEABLE(type) \ instance Mergeable type where \ @@ -175,6 +178,9 @@ instance Mergeable (a =-> b) where instance Mergeable (a --> b) where rootStrategy = SimpleStrategy symIte +instance Mergeable (Array k v) where + rootStrategy = NoStrategy + #define MERGEABLE_SIMPLE(symtype) \ instance Mergeable symtype where \ rootStrategy = SimpleStrategy symIte @@ -197,6 +203,15 @@ MERGEABLE_FUN((=->), (=~>), SymTabularFun) MERGEABLE_FUN((-->), (-~>), SymGeneralFun) #endif +instance + ( SupportedNonFuncPrim ck, + SupportedNonFuncPrim cv, + LinkedRep ck sk, + LinkedRep cv sv + ) => + Mergeable (SymArray sk sv) where + rootStrategy = SimpleStrategy $ symIte + instance (ValidFP eb sb) => Mergeable (SymFP eb sb) where rootStrategy = SimpleStrategy symIte diff --git a/src/Grisette/Internal/Internal/Impl/Core/Data/Class/SimpleMergeable.hs b/src/Grisette/Internal/Internal/Impl/Core/Data/Class/SimpleMergeable.hs index 2cc3b47dc..d3d8247b1 100644 --- a/src/Grisette/Internal/Internal/Impl/Core/Data/Class/SimpleMergeable.hs +++ b/src/Grisette/Internal/Internal/Impl/Core/Data/Class/SimpleMergeable.hs @@ -78,10 +78,11 @@ import Grisette.Internal.Internal.Decl.Core.Data.Class.SimpleMergeable import Grisette.Internal.Internal.Impl.Core.Data.Class.TryMerge () import Grisette.Internal.SymPrim.FP (ValidFP) import Grisette.Internal.SymPrim.GeneralFun (freshArgSymbol, substTerm, type (-->) (GeneralFun)) -import Grisette.Internal.SymPrim.Prim.Internal.Term (SupportedPrim (pevalITETerm), symTerm) +import Grisette.Internal.SymPrim.Prim.Internal.Term (SupportedPrim (pevalITETerm), LinkedRep, symTerm) import Grisette.Internal.SymPrim.Prim.SomeTerm (SomeTerm (SomeTerm)) -import Grisette.Internal.SymPrim.Prim.Term (TypedConstantSymbol) +import Grisette.Internal.SymPrim.Prim.Term (TypedConstantSymbol, SupportedNonFuncPrim) import Grisette.Internal.SymPrim.SymAlgReal (SymAlgReal (SymAlgReal)) +import Grisette.Internal.SymPrim.SymArray (SymArray (SymArray)) import Grisette.Internal.SymPrim.SymBV ( SymIntN (SymIntN), SymWordN (SymWordN), @@ -559,6 +560,15 @@ SIMPLE_MERGEABLE_FUN((=->), (=~>), SymTabularFun) SIMPLE_MERGEABLE_FUN((-->), (-~>), SymGeneralFun) #endif +instance + ( SupportedNonFuncPrim ck, + SupportedNonFuncPrim cv, + LinkedRep ck sk, + LinkedRep cv sv + ) => + SimpleMergeable (SymArray sk sv) where + mrgIte (SymBool c) (SymArray t) (SymArray f) = SymArray $ pevalITETerm c t f + instance SimpleMergeable (a --> b) where mrgIte (SymBool c) diff --git a/src/Grisette/Internal/SymPrim/SymArray.hs b/src/Grisette/Internal/SymPrim/SymArray.hs index 3cee6ad3e..c5391142b 100644 --- a/src/Grisette/Internal/SymPrim/SymArray.hs +++ b/src/Grisette/Internal/SymPrim/SymArray.hs @@ -42,6 +42,7 @@ import Grisette.Internal.SymPrim.Prim.Internal.Term , conTerm , typedConstantSymbol , symTerm + , pformatTerm , pattern ConTerm , pattern SelectTerm , pattern StoreTerm @@ -51,7 +52,7 @@ import Grisette.Internal.SymPrim.Prim.Internal.Serialize () import Grisette.Internal.Core.Data.Class.Solvable (Solvable (con, sym, conView), ssym) import GHC.Generics (Generic) import Language.Haskell.TH.Syntax (Lift) -import Prelude (Maybe (Just, Nothing), (<$>), ($), (.)) +import Prelude (Show (show), Maybe (Just, Nothing), (<$>), ($), (.)) newtype SymArray k v = SymArray { underlyingArrayTerm :: Term (Array (ConType k) (ConType v)) } deriving (Lift, NFData, Generic) @@ -94,6 +95,9 @@ instance IsString (SymArray sk sv) where fromString = ssym . fromString +instance Show (SymArray sk sv) where + show = pformatTerm . underlyingArrayTerm + instance ( SupportedNonFuncPrim ck, SupportedNonFuncPrim cv, @@ -101,8 +105,8 @@ instance LinkedRep cv sv ) => Serial (SymArray sk sv) where - serialize = serialize . underlyingArrayTerm - deserialize = SymArray <$> deserialize + serialize = serialize . underlyingTerm + deserialize = wrapTerm <$> deserialize instance ( SupportedNonFuncPrim ck, diff --git a/stack-9.10.yaml b/stack-9.10.yaml index bb503c62e..e64a01b13 100644 --- a/stack-9.10.yaml +++ b/stack-9.10.yaml @@ -33,13 +33,11 @@ packages: # These entries can reference officially published versions as well as # forks / in-progress versions pinned to a git hash. For example: # -# extra-deps: -# - acme-missiles-0.3 -# - git: https://github.com/commercialhaskell/stack.git -# commit: e7b331f14bcffb8367cd58fbfc8b40ec7642100a # # Override default flag values for local packages and extra-deps # flags: {} +extra-deps: +- sbv-13.5 # Extra package databases containing global packages # extra-package-dbs: [] diff --git a/stack.yaml.lock b/stack.yaml.lock index a1b57cf9c..9d1442db9 100644 --- a/stack.yaml.lock +++ b/stack.yaml.lock @@ -3,7 +3,14 @@ # For more information, please see the documentation at: # https://docs.haskellstack.org/en/stable/topics/lock_files -packages: [] +packages: +- completed: + hackage: sbv-13.5@sha256:00aea86ad09dcefb5d5286dd68bfb894709d2874052bcdadf80af3add9596af7,26690 + pantry-tree: + sha256: 65d23f68088ce549e8aa8ae77a35f2ec7c6491facbc4aeae8c2317101d9ddcda + size: 93933 + original: + hackage: sbv-13.5 snapshots: - completed: sha256: 7a26eba54b469fc72b1e37b881dfec480a2c1cb0636136f96aec7d81be6c762f