diff options
author | Dougal <d.maclaurin@gmail.com> | 2023-12-02 20:12:07 -0500 |
---|---|---|
committer | Dougal <d.maclaurin@gmail.com> | 2023-12-02 20:12:07 -0500 |
commit | 75c41841a321f0984ebae6ee9326f8377e5262ca (patch) | |
tree | b78c414006224843a809abd3f216bf6ce373fd45 | |
parent | 6f62fb4cb19fa9327eaeee4a592afcabcb7f45f2 (diff) |
Move Pretty instances to where the data types are defined.
This avoids circular import issues and orphan instances.
Also move top-level data types out of Types.Core to make the file size more reasonable.
40 files changed, 2077 insertions, 2006 deletions
@@ -94,6 +94,7 @@ library , Types.Primitives , Types.OpNames , Types.Source + , Types.Top , QueryType , QueryTypePure , Util @@ -124,7 +125,6 @@ library , prettyprinter , text -- Portable system utilities - , ansi-terminal , directory , filepath , haskeline @@ -234,6 +234,7 @@ executable dex main-is: dex.hs build-depends: dex , ansi-wl-pprint + , ansi-terminal , base , bytestring , containers @@ -21,8 +21,9 @@ import Data.List import qualified Data.Text as T import qualified Data.Text.Encoding as T import qualified Data.Map.Strict as M +import qualified System.Console.ANSI as ANSI +import System.Console.ANSI hiding (Color) -import PPrint (printOutput) import TopLevel import Err import Name @@ -35,6 +36,7 @@ import Core import Types.Core import Types.Imp import Types.Source +import Types.Top import MonadUtil data DocFmt = ResultOnly @@ -193,6 +195,24 @@ stdOutLogger (Outputs outs) = do isatty <- queryTerminal stdOutput forM_ outs \out -> putStr $ printOutput isatty out +printOutput :: Bool -> Output -> String +printOutput isatty out = case out of + Error _ -> addColor isatty Red $ addPrefix ">" $ pprint out + _ -> addPrefix (addColor isatty Cyan ">") $ pprint $ out + +addPrefix :: String -> String -> String +addPrefix prefix str = unlines $ map prefixLine $ lines str + where prefixLine :: String -> String + prefixLine s = case s of "" -> prefix + _ -> prefix ++ " " ++ s + +addColor :: Bool -> ANSI.Color -> String -> String +addColor False _ s = s +addColor True c s = + setSGRCode [SetConsoleIntensity BoldIntensity, SetColor Foreground Vivid c] + ++ s ++ setSGRCode [Reset] + + pathOption :: ReadM [LibPath] pathOption = splitPaths [] <$> str where diff --git a/src/lib/Builder.hs b/src/lib/Builder.hs index a12b5c8b..7415cec7 100644 --- a/src/lib/Builder.hs +++ b/src/lib/Builder.hs @@ -35,6 +35,7 @@ import Types.Core import Types.Imp import Types.Primitives import Types.Source +import Types.Top import Util (enumerate, transitiveClosureM, bindM2, toSnocList) -- === Ordinary (local) builder class === @@ -281,7 +282,7 @@ emitTopLet hint letAnn expr = do v <- emitBinding hint $ AtomNameBinding $ LetBound (DeclBinding letAnn expr) return $ AtomVar v ty -emitTopFunBinding :: (Mut n, TopBuilder m) => NameHint -> TopFunDef n -> STopLam n -> m n (TopFunName n) +emitTopFunBinding :: (Mut n, TopBuilder m) => NameHint -> TopFunDef n -> TopLam SimpIR n -> m n (TopFunName n) emitTopFunBinding hint def f = do emitBinding hint $ TopFunBinding $ DexTopFun def f Waiting diff --git a/src/lib/CheapReduction.hs b/src/lib/CheapReduction.hs index 6fdd7280..8df743f2 100644 --- a/src/lib/CheapReduction.hs +++ b/src/lib/CheapReduction.hs @@ -32,6 +32,7 @@ import Name import PPrint () import QueryTypePure import Types.Core +import Types.Top import Types.Imp import Types.Primitives import Util diff --git a/src/lib/CheckType.hs b/src/lib/CheckType.hs index f808e715..9bcbf029 100644 --- a/src/lib/CheckType.hs +++ b/src/lib/CheckType.hs @@ -27,6 +27,7 @@ import QueryType import Types.Core import Types.Primitives import Types.Source +import Types.Top -- === top-level API === diff --git a/src/lib/ConcreteSyntax.hs b/src/lib/ConcreteSyntax.hs index 838089e7..70bc67b9 100644 --- a/src/lib/ConcreteSyntax.hs +++ b/src/lib/ConcreteSyntax.hs @@ -17,12 +17,10 @@ import Data.Char import Data.Either import Data.Functor import Data.List.NonEmpty (NonEmpty (..)) -import Data.Map qualified as M import Data.String (fromString) import Data.Text (Text) import Data.Text qualified as T import Data.Text.Encoding qualified as T -import Data.Tuple import Data.Void import Text.Megaparsec hiding (Label, State) import Text.Megaparsec.Char hiding (space, eol) @@ -31,7 +29,6 @@ import Lexing import Types.Core import Types.Source import Types.Primitives -import qualified Types.OpNames as P import Util -- TODO: implement this more efficiently rather than just parsing the whole @@ -697,101 +694,6 @@ withSrcs p = do (sids, result) <- collectAtomicLexemeIds p return $ WithSrcs sid sids result --- === primitive constructors and operators === - -strToPrimName :: String -> Maybe PrimName -strToPrimName s = M.lookup s primNames - -primNameToStr :: PrimName -> String -primNameToStr prim = case lookup prim $ map swap $ M.toList primNames of - Just s -> s - Nothing -> show prim - -showPrimName :: PrimName -> String -showPrimName prim = primNameToStr prim -{-# NOINLINE showPrimName #-} - -primNames :: M.Map String PrimName -primNames = M.fromList - [ ("ask" , UMAsk), ("mextend", UMExtend) - , ("get" , UMGet), ("put" , UMPut) - , ("while" , UWhile) - , ("linearize", ULinearize), ("linearTranspose", UTranspose) - , ("runReader", URunReader), ("runWriter" , URunWriter), ("runState", URunState) - , ("runIO" , URunIO ), ("catchException" , UCatchException) - , ("iadd" , binary IAdd), ("isub" , binary ISub) - , ("imul" , binary IMul), ("fdiv" , binary FDiv) - , ("fadd" , binary FAdd), ("fsub" , binary FSub) - , ("fmul" , binary FMul), ("idiv" , binary IDiv) - , ("irem" , binary IRem) - , ("fpow" , binary FPow) - , ("and" , binary BAnd), ("or" , binary BOr ) - , ("not" , unary BNot), ("xor" , binary BXor) - , ("shl" , binary BShL), ("shr" , binary BShR) - , ("ieq" , binary (ICmp Equal)), ("feq", binary (FCmp Equal)) - , ("igt" , binary (ICmp Greater)), ("fgt", binary (FCmp Greater)) - , ("ilt" , binary (ICmp Less)), ("flt", binary (FCmp Less)) - , ("fneg" , unary FNeg) - , ("exp" , unary Exp), ("exp2" , unary Exp2) - , ("log" , unary Log), ("log2" , unary Log2), ("log10" , unary Log10) - , ("sin" , unary Sin), ("cos" , unary Cos) - , ("tan" , unary Tan), ("sqrt" , unary Sqrt) - , ("floor", unary Floor), ("ceil" , unary Ceil), ("round", unary Round) - , ("log1p", unary Log1p), ("lgamma", unary LGamma) - , ("erf" , unary Erf), ("erfc" , unary Erfc) - , ("TyKind" , UPrimTC $ P.TypeKind) - , ("Float64" , baseTy $ Scalar Float64Type) - , ("Float32" , baseTy $ Scalar Float32Type) - , ("Int64" , baseTy $ Scalar Int64Type) - , ("Int32" , baseTy $ Scalar Int32Type) - , ("Word8" , baseTy $ Scalar Word8Type) - , ("Word32" , baseTy $ Scalar Word32Type) - , ("Word64" , baseTy $ Scalar Word64Type) - , ("Int32Ptr" , baseTy $ ptrTy $ Scalar Int32Type) - , ("Word8Ptr" , baseTy $ ptrTy $ Scalar Word8Type) - , ("Word32Ptr" , baseTy $ ptrTy $ Scalar Word32Type) - , ("Word64Ptr" , baseTy $ ptrTy $ Scalar Word64Type) - , ("Float32Ptr", baseTy $ ptrTy $ Scalar Float32Type) - , ("PtrPtr" , baseTy $ ptrTy $ ptrTy $ Scalar Word8Type) - , ("Nat" , UNat) - , ("Fin" , UFin) - , ("EffKind" , UEffectRowKind) - , ("NatCon" , UNatCon) - , ("Ref" , UPrimTC $ P.RefType) - , ("HeapType" , UPrimTC $ P.HeapType) - , ("indexRef" , UIndexRef) - , ("alloc" , memOp $ P.IOAlloc) - , ("free" , memOp $ P.IOFree) - , ("ptrOffset", memOp $ P.PtrOffset) - , ("ptrLoad" , memOp $ P.PtrLoad) - , ("ptrStore" , memOp $ P.PtrStore) - , ("throwError" , miscOp $ P.ThrowError) - , ("throwException", miscOp $ P.ThrowException) - , ("dataConTag" , miscOp $ P.SumTag) - , ("toEnum" , miscOp $ P.ToEnum) - , ("outputStream" , miscOp $ P.OutputStream) - , ("cast" , miscOp $ P.CastOp) - , ("bitcast" , miscOp $ P.BitcastOp) - , ("unsafeCoerce" , miscOp $ P.UnsafeCoerce) - , ("garbageVal" , miscOp $ P.GarbageVal) - , ("select" , miscOp $ P.Select) - , ("showAny" , miscOp $ P.ShowAny) - , ("showScalar" , miscOp $ P.ShowScalar) - , ("projNewtype" , UProjNewtype) - , ("applyMethod0" , UApplyMethod 0) - , ("applyMethod1" , UApplyMethod 1) - , ("applyMethod2" , UApplyMethod 2) - , ("explicitApply", UExplicitApply) - , ("monoLit", UMonoLiteral) - ] - where - binary op = UBinOp op - baseTy b = UBaseType b - memOp op = UMemOp op - unary op = UUnOp op - ptrTy ty = PtrType (CPU, ty) - miscOp op = UMiscOp op - -- === notes === -- note [if-syntax] diff --git a/src/lib/Core.hs b/src/lib/Core.hs index 2c60f846..e420b50a 100644 --- a/src/lib/Core.hs +++ b/src/lib/Core.hs @@ -37,6 +37,7 @@ import Err import IRVariants import Types.Core +import Types.Top import Types.Imp import Types.Primitives import Types.Source diff --git a/src/lib/Err.hs b/src/lib/Err.hs index d0ad6c9d..51b34eb1 100644 --- a/src/lib/Err.hs +++ b/src/lib/Err.hs @@ -24,12 +24,10 @@ import Control.Monad.State.Strict import Control.Monad.Reader import Data.Coerce import Data.Foldable (fold) -import Data.Text qualified as T -import Data.Text.Prettyprint.Doc.Render.Text import Data.Text.Prettyprint.Doc import GHC.Stack -import System.Environment -import System.IO.Unsafe + +import PPrint -- === core API === @@ -285,20 +283,6 @@ instance Fallible Maybe where throwErr _ = Nothing {-# INLINE throwErr #-} --- === small pretty-printing utils === --- These are here instead of in PPrint.hs for import cycle reasons - -pprint :: Pretty a => a -> String -pprint x = docAsStr $ pretty x -{-# SCC pprint #-} - -docAsStr :: Doc ann -> String -docAsStr doc = T.unpack $ renderStrict $ layoutPretty layout $ doc - -layout :: LayoutOptions -layout = if unbounded then LayoutOptions Unbounded else defaultLayoutOptions - where unbounded = unsafePerformIO $ (Just "1"==) <$> lookupEnv "DEX_PPRINT_UNBOUNDED" - -- === instances === instance Fallible Except where diff --git a/src/lib/Export.hs b/src/lib/Export.hs index 7983f52c..1108507f 100644 --- a/src/lib/Export.hs +++ b/src/lib/Export.hs @@ -29,6 +29,7 @@ import Subst hiding (Rename) import TopLevel import Types.Core import Types.Imp +import Types.Top import Types.Primitives hiding (sizeOf) type ExportAtomNameC = AtomNameC CoreIR diff --git a/src/lib/Generalize.hs b/src/lib/Generalize.hs index 945552a7..7ace599c 100644 --- a/src/lib/Generalize.hs +++ b/src/lib/Generalize.hs @@ -18,6 +18,7 @@ import QueryType import Name import Subst import Types.Primitives +import Types.Top type RolePiBinder = WithAttrB RoleExpl CBinder type RolePiBinders = Nest RolePiBinder diff --git a/src/lib/Imp.hs b/src/lib/Imp.hs index 7ab6c865..07fada48 100644 --- a/src/lib/Imp.hs +++ b/src/lib/Imp.hs @@ -42,10 +42,10 @@ import QueryType import Types.Core import Types.Imp import Types.Primitives +import Types.Top import Util (forMFilter, Tree (..), zipTrees, enumerate) -toImpFunction :: EnvReader m - => CallingConvention -> STopLam n -> m n (ImpFunction n) +toImpFunction :: EnvReader m => CallingConvention -> STopLam n -> m n (ImpFunction n) toImpFunction cc (TopLam True destTy lam) = do LamExpr bsAndRefB body <- return lam PairB bs destB <- case popNest bsAndRefB of diff --git a/src/lib/ImpToLLVM.hs b/src/lib/ImpToLLVM.hs index 556f927b..f333c1f0 100644 --- a/src/lib/ImpToLLVM.hs +++ b/src/lib/ImpToLLVM.hs @@ -56,6 +56,7 @@ import Types.Core import Types.Imp import Types.Primitives import Types.Source +import Types.Top import Util (IsBool (..), bindM2, enumerate) -- === Compile monad === diff --git a/src/lib/Inference.hs b/src/lib/Inference.hs index f6fb4d92..bea45b65 100644 --- a/src/lib/Inference.hs +++ b/src/lib/Inference.hs @@ -43,6 +43,7 @@ import Types.Core import Types.Imp import Types.Primitives import Types.Source +import Types.Top import qualified Types.OpNames as P import Util hiding (group) diff --git a/src/lib/Inline.hs b/src/lib/Inline.hs index f3ada792..f72f24be 100644 --- a/src/lib/Inline.hs +++ b/src/lib/Inline.hs @@ -17,6 +17,7 @@ import Occurrence hiding (Var) import PeepholeOptimize import Types.Core import Types.Primitives +import Types.Top -- === External API === diff --git a/src/lib/JAX/ToSimp.hs b/src/lib/JAX/ToSimp.hs index cdf25d73..7466d237 100644 --- a/src/lib/JAX/ToSimp.hs +++ b/src/lib/JAX/ToSimp.hs @@ -17,6 +17,7 @@ import JAX.Concrete import Subst import QueryType import Types.Core +import Types.Top import Types.Primitives qualified as P newtype JaxSimpM (i::S) (o::S) a = JaxSimpM diff --git a/src/lib/Linearize.hs b/src/lib/Linearize.hs index 98bbb7d3..ee61d843 100644 --- a/src/lib/Linearize.hs +++ b/src/lib/Linearize.hs @@ -27,6 +27,7 @@ import PPrint import QueryType import Types.Core import Types.Primitives +import Types.Top import Util (enumerate) -- === linearization monad === diff --git a/src/lib/Lower.hs b/src/lib/Lower.hs index cf28b066..db7b83fa 100644 --- a/src/lib/Lower.hs +++ b/src/lib/Lower.hs @@ -26,6 +26,7 @@ import Name import Subst import QueryType import Types.Core +import Types.Top import Types.Primitives import Util (enumerate) diff --git a/src/lib/MTL1.hs b/src/lib/MTL1.hs index 2011fa64..47fe8b8c 100644 --- a/src/lib/MTL1.hs +++ b/src/lib/MTL1.hs @@ -17,7 +17,7 @@ import Data.Foldable (toList) import Name import Err -import Types.Core (Env) +import Types.Top (Env) import Core (EnvReader (..), EnvExtender (..)) import Util (SnocList (..), snoc, emptySnocList) diff --git a/src/lib/Name.hs b/src/lib/Name.hs index dc36f6c3..fd23def5 100644 --- a/src/lib/Name.hs +++ b/src/lib/Name.hs @@ -44,6 +44,7 @@ import RawName ( RawNameMap, RawName, NameHint, HasNameHint (..) , freshRawName, rawNameFromHint, rawNames, noHint) import qualified RawName as R import Util ( zipErr, onFst, onSnd, transitiveClosure, SnocList (..), unsnoc ) +import PPrint import Err import IRVariants @@ -445,6 +446,9 @@ type OrdE e = (forall (n::S) . Ord (e n )) :: Constraint type OrdV v = (forall (c::C) (n::S). Ord (v c n)) :: Constraint type OrdB b = (forall (n::S) (l::S). Ord (b n l)) :: Constraint +type PrettyPrecE e = (forall (n::S) . PrettyPrec (e n )) :: Constraint +type PrettyPrecB b = (forall (n::S) (l::S). PrettyPrec (b n l)) :: Constraint + type HashableE (e::E) = forall n. Hashable (e n) data UnitE (n::S) = UnitE @@ -2164,6 +2168,8 @@ instance PrettyE e => Pretty (ListE e n) where instance PrettyE e => Pretty (RListE e n) where pretty (RListE e) = pretty $ unsnoc e +deriving instance (forall c n. Pretty (v c n)) => Pretty (RecSubst v o) + instance ( Generic (b UnsafeS UnsafeS) , Generic (body UnsafeS) ) => Generic (Abs b body n) where @@ -2746,6 +2752,9 @@ canonicalizeForPrinting e cont = do ClosedWithScope scope e' -> cont $ renameE (scope, newSubst id) e' +pprintCanonicalized :: (HoistableE e, RenameE e, PrettyE e) => e n -> String +pprintCanonicalized e = canonicalizeForPrinting e \e' -> pprint e' + liftHoistExcept :: Fallible m => HoistExcept a -> m a liftHoistExcept (HoistSuccess x) = return x liftHoistExcept (HoistFailure vs) = throw EscapedNameErr (pprint vs) @@ -2887,6 +2896,10 @@ abstractFreeVarsNoAnn vs e = Abs bs e' -> Abs bs' e' where bs' = fmapNest (\(b:>UnitE) -> b) bs +unsafeFromNest :: Nest b n l -> [b UnsafeS UnsafeS] +unsafeFromNest Empty = [] +unsafeFromNest (Nest b rest) = unsafeCoerceB b : unsafeFromNest rest + instance Color c => HoistableB (NameBinder c) where freeVarsB _ = mempty @@ -3389,6 +3402,13 @@ hoistNameMap b = ignoreHoistFailure . hoistNameMapE b unsafeCoerceIRE :: forall (r'::IR) (r::IR) (e::IR->E) (n::S). e r n -> e r' n unsafeCoerceIRE = TrulyUnsafe.unsafeCoerce +-- === Pretty instances === + +instance PrettyPrec (Name s n) where prettyPrec = atPrec ArgPrec . pretty + +instance PrettyE ann => Pretty (BinderP c ann n l) + where pretty (b:>ty) = pretty b <> ":" <> pretty ty + -- === notes === {- diff --git a/src/lib/OccAnalysis.hs b/src/lib/OccAnalysis.hs index fcf04cdf..0e75165b 100644 --- a/src/lib/OccAnalysis.hs +++ b/src/lib/OccAnalysis.hs @@ -20,6 +20,7 @@ import Occurrence hiding (Var) import Occurrence qualified as Occ import Types.Core import Types.Primitives +import Types.Top import QueryType -- === External API === @@ -28,7 +29,7 @@ import QueryType -- annotation holding a summary of how that binding is used. It also eliminates -- unused pure bindings as it goes, since it has all the needed information. -analyzeOccurrences :: EnvReader m => STopLam n -> m n (STopLam n) +analyzeOccurrences :: EnvReader m => TopLam SimpIR n -> m n (TopLam SimpIR n) analyzeOccurrences lam = liftLamExpr lam \e -> liftOCCM $ occ accessOnce e {-# INLINE analyzeOccurrences #-} diff --git a/src/lib/Occurrence.hs b/src/lib/Occurrence.hs index 5e024e85..ea8248de 100644 --- a/src/lib/Occurrence.hs +++ b/src/lib/Occurrence.hs @@ -19,6 +19,7 @@ import Data.List (foldl') import Data.Store (Store (..)) import GHC.Generics (Generic (..)) +import PPrint import IRVariants import Name @@ -888,3 +889,15 @@ instance RenameE AccessInfo instance Hashable UsageInfo instance Store UsageInfo + +-- === instances === + +instance Pretty UsageInfo where + pretty (UsageInfo static (ixDepth, ct)) = + "occurs in" <+> pretty static <+> "places, read" + <+> pretty ct <+> "times, to depth" <+> pretty (show ixDepth) + +instance Pretty Count where + pretty = \case + Bounded ct -> "<=" <+> pretty ct + Unbounded -> "many" diff --git a/src/lib/Optimize.hs b/src/lib/Optimize.hs index 425291cd..1ed73ff2 100644 --- a/src/lib/Optimize.hs +++ b/src/lib/Optimize.hs @@ -15,6 +15,7 @@ import Control.Monad.State.Strict import Types.Core import Types.Primitives +import Types.Top import MTL1 import Name import Subst diff --git a/src/lib/PPrint.hs b/src/lib/PPrint.hs index 0344bd86..b16559fa 100644 --- a/src/lib/PPrint.hs +++ b/src/lib/PPrint.hs @@ -6,40 +6,34 @@ {-# LANGUAGE IncoherentInstances #-} -- due to `ConRef` {-# LANGUAGE UndecidableInstances #-} -{-# OPTIONS_GHC -Wno-orphans #-} module PPrint ( - pprint, pprintCanonicalized, pprintList, asStr , atPrec, - PrettyPrec(..), PrecedenceLevel (..), prettyBlock, - printOutput, prettyFromPrettyPrec) where + Pretty (..), Doc, DocPrec, (<+>), pprint, pprintList, asStr , atPrec, + pAppArg, pApp, pArg, hardline, PrettyPrec(..), PrecedenceLevel (..), + docAsStr, parensSep, prettyLines, sep, pLowest, prettyFromPrettyPrec, + indented, commaSep, spaced, spaceIfColinear, encloseSep) where -import GHC.Exts (Constraint) -import GHC.Float import Data.Foldable (toList, fold) -import qualified Data.Map.Strict as M import Data.Text.Prettyprint.Doc.Render.Text import Data.Text.Prettyprint.Doc -import Data.Text (Text, snoc, uncons, unsnoc, unpack) -import qualified Data.Set as S -import Data.String (fromString) -import qualified System.Console.ANSI as ANSI -import System.Console.ANSI hiding (Color) +import Data.Text (unpack) import System.IO.Unsafe import qualified System.Environment as E -import Numeric -import ConcreteSyntax -import Err -import IRVariants -import Name -import Occurrence (Count (Bounded), UsageInfo (..)) -import Occurrence qualified as Occ -import Types.Core -import Types.Imp -import Types.Primitives -import Types.Source -import QueryTypePure -import Util (Tree (..)) +-- === small pretty-printing utils === + +pprint :: Pretty a => a -> String +pprint x = docAsStr $ pretty x +{-# SCC pprint #-} + +docAsStr :: Doc ann -> String +docAsStr doc = unpack $ renderStrict $ layoutPretty layout $ doc + +layout :: LayoutOptions +layout = if unbounded then LayoutOptions Unbounded else defaultLayoutOptions + where unbounded = unsafePerformIO $ (Just "1"==) <$> E.lookupEnv "DEX_PPRINT_UNBOUNDED" + +-- === DocPrec === -- A DocPrec is a slightly context-aware Doc, specifically one that -- knows the precedence level of the immediately enclosing operation, @@ -93,31 +87,12 @@ prettyFromPrettyPrec = pArg pAppArg :: (PrettyPrec a, Foldable f) => Doc ann -> f a -> Doc ann pAppArg name as = align $ name <> group (nest 2 $ foldMap (\a -> line <> pArg a) as) -fromInfix :: Text -> Maybe Text -fromInfix t = do - ('(', t') <- uncons t - (t'', ')') <- unsnoc t' - return t'' - -type PrettyPrecE e = (forall (n::S) . PrettyPrec (e n )) :: Constraint -type PrettyPrecB b = (forall (n::S) (l::S). PrettyPrec (b n l)) :: Constraint - -pprintCanonicalized :: (HoistableE e, RenameE e, PrettyE e) => e n -> String -pprintCanonicalized e = canonicalizeForPrinting e \e' -> pprint e' - pprintList :: Pretty a => [a] -> String -pprintList xs = asStr $ vsep $ punctuate "," (map p xs) - -layout :: LayoutOptions -layout = if unbounded then LayoutOptions Unbounded else defaultLayoutOptions - where unbounded = unsafePerformIO $ (Just "1"==) <$> E.lookupEnv "DEX_PPRINT_UNBOUNDED" +pprintList xs = asStr $ vsep $ punctuate "," (map pretty xs) asStr :: Doc ann -> String asStr doc = unpack $ renderStrict $ layoutPretty layout $ doc -p :: Pretty a => a -> Doc ann -p = pretty - pLowest :: PrettyPrec a => a -> Doc ann pLowest a = prettyPrec a LowestPrec @@ -127,17 +102,8 @@ pApp a = prettyPrec a AppPrec pArg :: PrettyPrec a => a -> Doc ann pArg a = prettyPrec a ArgPrec -prettyBlock :: (IRRep r, PrettyPrec (e l)) => Nest (Decl r) n l -> e l -> Doc ann -prettyBlock Empty expr = group $ line <> pLowest expr -prettyBlock decls expr = prettyLines decls' <> hardline <> pLowest expr - where decls' = fromNest decls - -fromNest :: Nest b n l -> [b UnsafeS UnsafeS] -fromNest Empty = [] -fromNest (Nest b rest) = unsafeCoerceB b : fromNest rest - prettyLines :: (Foldable f, Pretty a) => f a -> Doc ann -prettyLines xs = foldMap (\d -> hardline <> p d) $ toList xs +prettyLines xs = foldMap (\d -> hardline <> pretty d) $ toList xs parensSep :: Doc ann -> [Doc ann] -> Doc ann parensSep separator items = encloseSep "(" ")" separator items @@ -148,907 +114,13 @@ spaceIfColinear = flatAlt "" space instance PrettyPrec a => PrettyPrec [a] where prettyPrec xs = atPrec ArgPrec $ hsep $ map pLowest xs -instance PrettyE ann => Pretty (BinderP c ann n l) - where pretty (b:>ty) = p b <> ":" <> p ty - -instance IRRep r => Pretty (Expr r n) where pretty = prettyFromPrettyPrec -instance IRRep r => PrettyPrec (Expr r n) where - prettyPrec = \case - Atom x -> prettyPrec x - Block _ (Abs decls body) -> atPrec AppPrec $ prettyBlock decls body - App _ f xs -> atPrec AppPrec $ pApp f <+> spaced (toList xs) - TopApp _ f xs -> atPrec AppPrec $ pApp f <+> spaced (toList xs) - TabApp _ f x -> atPrec AppPrec $ pApp f <> brackets (p x) - Case e alts (EffTy effs _) -> prettyPrecCase "case" e alts effs - TabCon _ _ es -> atPrec ArgPrec $ list $ pApp <$> es - PrimOp op -> prettyPrec op - ApplyMethod _ d i xs -> atPrec AppPrec $ "applyMethod" <+> p d <+> p i <+> p xs - Project _ i x -> atPrec AppPrec $ "Project" <+> p i <+> p x - Unwrap _ x -> atPrec AppPrec $ "Unwrap" <+> p x - -prettyPrecCase :: IRRep r => Doc ann -> Atom r n -> [Alt r n] -> EffectRow r n -> DocPrec ann -prettyPrecCase name e alts effs = atPrec LowestPrec $ - name <+> pApp e <+> "of" <> - nest 2 (foldMap (\alt -> hardline <> prettyAlt alt) alts - <> effectLine effs) - where - effectLine :: IRRep r => EffectRow r n -> Doc ann - effectLine Pure = "" - effectLine row = hardline <> "case annotated with effects" <+> p row - -prettyAlt :: IRRep r => Alt r n -> Doc ann -prettyAlt (Abs b body) = prettyBinderNoAnn b <+> "->" <> nest 2 (p body) - -prettyBinderNoAnn :: Binder r n l -> Doc ann -prettyBinderNoAnn (b:>_) = p b - -instance (IRRep r, PrettyPrecE e) => Pretty (Abs (Binder r) e n) where pretty = prettyFromPrettyPrec -instance (IRRep r, PrettyPrecE e) => PrettyPrec (Abs (Binder r) e n) where - prettyPrec (Abs binder body) = atPrec LowestPrec $ "\\" <> p binder <> "." <> pLowest body - -instance IRRep r => Pretty (DeclBinding r n) where - pretty (DeclBinding ann expr) = "Decl" <> p ann <+> p expr - -instance IRRep r => Pretty (Decl r n l) where - pretty (Let b (DeclBinding ann rhs)) = - align $ annDoc <> p (b:>getType rhs) <+> "=" <> (nest 2 $ group $ line <> pLowest rhs) - where annDoc = case ann of NoInlineLet -> pretty ann <> " "; _ -> pretty ann - -instance IRRep r => Pretty (PiType r n) where - pretty (PiType bs (EffTy effs resultTy)) = - (spaced $ fromNest $ bs) <+> "->" <+> "{" <> p effs <> "}" <+> p resultTy - -instance IRRep r => Pretty (LamExpr r n) where pretty = prettyFromPrettyPrec -instance IRRep r => PrettyPrec (LamExpr r n) where - prettyPrec (LamExpr bs body) = - atPrec LowestPrec $ prettyLam (p bs <> ".") body - -instance IRRep r => Pretty (IxType r n) where - pretty (IxType ty dict) = parens $ "IxType" <+> pretty ty <> prettyIxDict dict - -instance IRRep r => Pretty (Dict r n) where - pretty = \case - DictCon con -> pretty con - StuckDict _ stuck -> pretty stuck - -instance IRRep r => Pretty (DictCon r n) where - pretty = \case - InstanceDict _ name args -> "Instance" <+> p name <+> p args - IxFin n -> "Ix (Fin" <+> p n <> ")" - DataData a -> "Data " <+> p a - IxRawFin n -> "Ix (RawFin " <> p n <> ")" - IxSpecialized d xs -> p d <+> p xs - -instance Pretty (DictType n) where - pretty = \case - DictType classSourceName _ params -> p classSourceName <+> spaced params - IxDictType ty -> "Ix" <+> p ty - DataDictType ty -> "Data" <+> p ty - -instance IRRep r => Pretty (DepPairType r n) where pretty = prettyFromPrettyPrec -instance IRRep r => PrettyPrec (DepPairType r n) where - prettyPrec (DepPairType _ b rhs) = - atPrec ArgPrec $ align $ group $ parensSep (spaceIfColinear <> "&> ") [p b, p rhs] - -instance Pretty (CoreLamExpr n) where - pretty (CoreLamExpr _ lam) = p lam - -instance IRRep r => Pretty (Atom r n) where pretty = prettyFromPrettyPrec -instance IRRep r => PrettyPrec (Atom r n) where - prettyPrec atom = case atom of - Con e -> prettyPrec e - Stuck _ e -> prettyPrec e - -instance IRRep r => Pretty (Type r n) where pretty = prettyFromPrettyPrec -instance IRRep r => PrettyPrec (Type r n) where - prettyPrec = \case - TyCon e -> prettyPrec e - StuckTy _ e -> prettyPrec e - -instance IRRep r => Pretty (Stuck r n) where pretty = prettyFromPrettyPrec -instance IRRep r => PrettyPrec (Stuck r n) where - prettyPrec = \case - Var v -> atPrec ArgPrec $ p v - StuckProject i v -> atPrec LowestPrec $ "StuckProject" <+> p i <+> p v - StuckTabApp f xs -> atPrec AppPrec $ pArg f <> "." <> pArg xs - StuckUnwrap v -> atPrec LowestPrec $ "StuckUnwrap" <+> p v - InstantiatedGiven v args -> atPrec LowestPrec $ "Given" <+> p v <+> p (toList args) - SuperclassProj d' i -> atPrec LowestPrec $ "SuperclassProj" <+> p d' <+> p i - PtrVar _ v -> atPrec ArgPrec $ p v - RepValAtom x -> atPrec LowestPrec $ pretty x - ACase e alts _ -> atPrec AppPrec $ "acase" <+> p e <+> p alts - LiftSimp ty x -> atPrec ArgPrec $ "<embedded-simp-atom " <+> p x <+> " : " <+> p ty <+> ">" - LiftSimpFun ty x -> atPrec ArgPrec $ "<embedded-simp-function " <+> p x <+> " : " <+> p ty <+> ">" - TabLam lam -> atPrec AppPrec $ "tablam" <+> p lam - -instance Pretty (RepVal n) where - pretty (RepVal ty tree) = "<RepVal " <+> p tree <+> ":" <+> p ty <> ">" - -instance Pretty a => Pretty (Tree a) where - pretty = \case - Leaf x -> pretty x - Branch xs -> pretty xs - -instance Pretty Projection where - pretty = \case - UnwrapNewtype -> "u" - ProjectProduct i -> p i - -forStr :: ForAnn -> Doc ann -forStr Fwd = "for" -forStr Rev = "rof" - -instance Pretty (CorePiType n) where - pretty (CorePiType appExpl expls bs (EffTy eff resultTy)) = - prettyBindersWithExpl expls bs <+> p appExpl <> prettyEff <> p resultTy - where - prettyEff = case eff of - Pure -> space - _ -> space <> pretty eff <> space - -prettyBindersWithExpl :: forall b n l ann. PrettyB b - => [Explicitness] -> Nest b n l -> Doc ann -prettyBindersWithExpl expls bs = do - let groups = groupByExpl $ zip expls (fromNest bs) - let groups' = case groups of [] -> [(Explicit, [])] - _ -> groups - mconcat [withExplParens expl $ commaSep bsGroup | (expl, bsGroup) <- groups'] - -groupByExpl :: [(Explicitness, b UnsafeS UnsafeS)] -> [(Explicitness, [b UnsafeS UnsafeS])] -groupByExpl [] = [] -groupByExpl ((expl, b):bs) = do - let (matches, rest) = span (\(expl', _) -> expl == expl') bs - let matches' = map snd matches - (expl, b:matches') : groupByExpl rest - -withExplParens :: Explicitness -> Doc ann -> Doc ann -withExplParens Explicit x = parens x -withExplParens (Inferred _ Unify) x = braces $ x -withExplParens (Inferred _ (Synth _)) x = brackets x - -instance IRRep r => Pretty (TabPiType r n) where - pretty (TabPiType dict (b:>ty) body) = let - prettyBody = case body of - TyCon (Pi subpi) -> pretty subpi - _ -> pLowest body - prettyBinder = prettyBinderHelper (b:>ty) body - in prettyBinder <> prettyIxDict dict <> (group $ line <> "=>" <+> prettyBody) - --- A helper to let us turn dict printing on and off. We mostly want it off to --- reduce clutter in prints and error messages, but when debugging synthesis we --- want it on. -prettyIxDict :: IRRep r => IxDict r n -> Doc ann -prettyIxDict dict = if False then " " <> p dict else mempty - -prettyBinderHelper :: IRRep r => HoistableE e => Binder r n l -> e l -> Doc ann -prettyBinderHelper (b:>ty) body = - if binderName b `isFreeIn` body - then parens $ p (b:>ty) - else p ty - -prettyLam :: Pretty a => Doc ann -> a -> Doc ann -prettyLam binders body = - group $ group (nest 4 $ binders) <> group (nest 2 $ p body) - -instance IRRep r => Pretty (EffectRow r n) where - pretty (EffectRow effs t) = - braces $ hsep (punctuate "," (map p (eSetToList effs))) <> p t - -instance IRRep r => Pretty (EffectRowTail r n) where - pretty = \case - NoTail -> mempty - EffectRowTail v -> "|" <> p v - -instance IRRep r => Pretty (Effect r n) where - pretty eff = case eff of - RWSEffect rws h -> p rws <+> p h - ExceptionEffect -> "Except" - IOEffect -> "IO" - InitEffect -> "Init" - -instance Pretty (UEffect n) where - pretty eff = case eff of - URWSEffect rws h -> p rws <+> p h - UExceptionEffect -> "Except" - UIOEffect -> "IO" - -instance PrettyPrec (Name s n) where prettyPrec = atPrec ArgPrec . pretty - -instance PrettyPrec (AtomVar r n) where - prettyPrec (AtomVar v _) = prettyPrec v -instance Pretty (AtomVar r n) where pretty = prettyFromPrettyPrec - -instance IRRep r => Pretty (AtomBinding r n) where - pretty binding = case binding of - LetBound b -> p b - MiscBound t -> p t - SolverBound b -> p b - FFIFunBound s _ -> p s - NoinlineFun ty _ -> "Top function with type: " <+> p ty - TopDataBound (RepVal ty _) -> "Top data with type: " <+> p ty - -instance Pretty (SpecializationSpec n) where - pretty (AppSpecialization f (Abs bs (ListE args))) = - "Specialization" <+> p f <+> p bs <+> p args - -instance Pretty IxMethod where - pretty method = p $ show method - -instance Pretty (SolverBinding n) where - pretty (InfVarBound ty) = "Inference variable of type:" <+> p ty - pretty (SkolemBound ty) = "Skolem variable of type:" <+> p ty - pretty (DictBound ty) = "Dictionary variable of type:" <+> p ty - -instance Pretty (Binding c n) where - pretty b = case b of - -- using `unsafeCoerceIRE` here because otherwise we don't have `IRRep` - -- TODO: can we avoid printing needing IRRep? Presumably it's related to - -- manipulating sets or something, which relies on Eq/Ord, which relies on renaming. - AtomNameBinding info -> "Atom name:" <+> pretty (unsafeCoerceIRE @CoreIR info) - TyConBinding dataDef _ -> "Type constructor: " <+> pretty dataDef - DataConBinding tyConName idx -> "Data constructor:" <+> - pretty tyConName <+> "Constructor index:" <+> pretty idx - ClassBinding classDef -> pretty classDef - InstanceBinding instanceDef _ -> pretty instanceDef - MethodBinding className idx -> "Method" <+> pretty idx <+> "of" <+> pretty className - TopFunBinding f -> pretty f - FunObjCodeBinding _ -> "<object file>" - ModuleBinding _ -> "<module>" - PtrBinding _ _ -> "<ptr>" - SpecializedDictBinding _ -> "<specialized-dict-binding>" - ImpNameBinding ty -> "Imp name of type: " <+> p ty - -instance Pretty (Module n) where - pretty m = prettyRecord - [ ("moduleSourceName" , p $ moduleSourceName m) - , ("moduleDirectDeps" , p $ S.toList $ moduleDirectDeps m) - , ("moduleTransDeps" , p $ S.toList $ moduleTransDeps m) - , ("moduleExports" , p $ moduleExports m) - , ("moduleSynthCandidates", p $ moduleSynthCandidates m) ] - -instance Pretty (TyConParams n) where - pretty (TyConParams _ _) = undefined - -instance Pretty (TyConDef n) where - pretty (TyConDef name _ bs cons) = "data" <+> p name <+> p bs <> pretty cons - -instance Pretty (DataConDefs n) where - pretty = undefined - -instance Pretty (DataConDef n) where - pretty (DataConDef name _ repTy _) = - p name <+> ":" <+> p repTy - -instance Pretty (ClassDef n) where - pretty (ClassDef classSourceName _ methodNames _ _ params superclasses methodTys) = - "Class:" <+> pretty classSourceName <+> pretty methodNames - <> indented ( - line <> "parameter binders:" <+> pretty params <> - line <> "superclasses:" <+> pretty superclasses <> - line <> "methods:" <+> pretty methodTys) - -instance Pretty ParamRole where - pretty r = p (show r) - -instance Pretty (InstanceDef n) where - pretty (InstanceDef className _ bs params _) = - "Instance" <+> p className <+> pretty bs <+> p params - -deriving instance (forall c n. Pretty (v c n)) => Pretty (RecSubst v o) - -instance Pretty (TopEnv n) where - pretty (TopEnv defs rules cache _ _) = - prettyRecord [ ("Defs" , p defs) - , ("Rules" , p rules) - , ("Cache" , p cache) ] - -instance Pretty (CustomRules n) where - pretty _ = "TODO: Rule printing" - -instance Pretty (ImportStatus n) where - pretty imports = pretty $ S.toList $ directImports imports - -instance Pretty (ModuleEnv n) where - pretty (ModuleEnv imports sm sc) = - prettyRecord [ ("Imports" , p imports) - , ("Source map" , p sm) - , ("Synth candidates", p sc) ] - -instance Pretty (Env n) where - pretty (Env env1 env2) = - prettyRecord [ ("Top env" , p env1) - , ("Module env", p env2)] - -prettyRecord :: [(String, Doc ann)] -> Doc ann -prettyRecord xs = foldMap (\(name, val) -> pretty name <> indented val) xs - -instance Pretty SourceBlock where - pretty block = pretty $ ensureNewline (sbText block) where - -- Force the SourceBlock to end in a newline for echoing, even if - -- it was terminated with EOF in the original program. - ensureNewline t = case unsnoc t of - Nothing -> t - Just (_, '\n') -> t - _ -> t `snoc` '\n' - -instance Pretty Output where - pretty = \case - TextOut s -> pretty s - HtmlOut _ -> "<html output>" - SourceInfo _ -> "" - PassInfo _ s -> p s - MiscLog s -> p s - Error e -> p e - -instance Pretty PassName where - pretty x = p $ show x - -instance Pretty Result where - pretty (Result (Outputs outs) r) = vcat (map pretty outs) <> maybeErr - where maybeErr = case r of Failure err -> p err - Success () -> mempty - -instance Pretty (UBinder' c n l) where pretty = prettyFromPrettyPrec -instance PrettyPrec (UBinder' c n l) where - prettyPrec b = atPrec ArgPrec case b of - UBindSource v -> p v - UIgnore -> "_" - UBind v _ -> p v - -instance Pretty e => Pretty (WithSrcs e) where pretty (WithSrcs _ _ x) = p x -instance PrettyPrec e => PrettyPrec (WithSrcs e) where prettyPrec (WithSrcs _ _ x) = prettyPrec x - -instance Pretty e => Pretty (WithSrc e) where pretty (WithSrc _ x) = p x -instance PrettyPrec e => PrettyPrec (WithSrc e) where prettyPrec (WithSrc _ x) = prettyPrec x - -instance PrettyE e => Pretty (WithSrcE e n) where pretty (WithSrcE _ x) = p x -instance PrettyPrecE e => PrettyPrec (WithSrcE e n) where prettyPrec (WithSrcE _ x) = prettyPrec x - -instance PrettyB b => Pretty (WithSrcB b n l) where pretty (WithSrcB _ x) = p x -instance PrettyPrecB b => PrettyPrec (WithSrcB b n l) where prettyPrec (WithSrcB _ x) = prettyPrec x - -instance PrettyE e => Pretty (SourceNameOr e n) where - pretty (SourceName _ v) = p v - pretty (InternalName _ v _) = p v - -instance Pretty (SourceOrInternalName c n) where - pretty (SourceOrInternalName sn) = p sn - -instance Pretty (ULamExpr n) where pretty = prettyFromPrettyPrec -instance PrettyPrec (ULamExpr n) where - prettyPrec (ULamExpr bs _ _ _ body) = atPrec LowestPrec $ - "\\" <> p bs <+> "." <+> indented (p body) - -instance Pretty (UPiExpr n) where pretty = prettyFromPrettyPrec -instance PrettyPrec (UPiExpr n) where - prettyPrec (UPiExpr pats appExpl UPure ty) = atPrec LowestPrec $ align $ - p pats <+> p appExpl <+> pLowest ty - prettyPrec (UPiExpr pats appExpl eff ty) = atPrec LowestPrec $ align $ - p pats <+> p appExpl <+> p eff <+> pLowest ty - -instance Pretty Explicitness where - pretty expl = p (show expl) - -instance Pretty (UTabPiExpr n) where pretty = prettyFromPrettyPrec -instance PrettyPrec (UTabPiExpr n) where - prettyPrec (UTabPiExpr pat ty) = atPrec LowestPrec $ align $ - p pat <+> "=>" <+> pLowest ty - -instance Pretty (UDepPairType n) where pretty = prettyFromPrettyPrec -instance PrettyPrec (UDepPairType n) where - -- TODO: print explicitness info - prettyPrec (UDepPairType _ pat ty) = atPrec LowestPrec $ align $ - p pat <+> "&>" <+> pLowest ty - -instance Pretty (UBlock' n) where - pretty (UBlock decls result) = - prettyLines (fromNest decls) <> hardline <> pLowest result - -instance Pretty (UExpr' n) where pretty = prettyFromPrettyPrec -instance PrettyPrec (UExpr' n) where - prettyPrec expr = case expr of - ULit l -> prettyPrec l - UVar v -> atPrec ArgPrec $ p v - ULam lam -> prettyPrec lam - UApp f xs named -> atPrec AppPrec $ pAppArg (pApp f) xs <+> p named - UTabApp f x -> atPrec AppPrec $ pArg f <> "." <> pArg x - UFor dir (UForExpr binder body) -> - atPrec LowestPrec $ kw <+> p binder <> "." - <+> nest 2 (p body) - where kw = case dir of Fwd -> "for" - Rev -> "rof" - UPi piType -> prettyPrec piType - UTabPi piType -> prettyPrec piType - UDepPairTy depPairType -> prettyPrec depPairType - UDepPair lhs rhs -> atPrec ArgPrec $ parens $ - p lhs <+> ",>" <+> p rhs - UHole -> atPrec ArgPrec "_" - UTypeAnn v ty -> atPrec LowestPrec $ - group $ pApp v <> line <> ":" <+> pApp ty - UTabCon xs -> atPrec ArgPrec $ p xs - UPrim prim xs -> atPrec AppPrec $ p (show prim) <+> p xs - UCase e alts -> atPrec LowestPrec $ "case" <+> p e <> - nest 2 (prettyLines alts) - UFieldAccess x (WithSrc _ f) -> atPrec AppPrec $ p x <> "~" <> p f - UNatLit v -> atPrec ArgPrec $ p v - UIntLit v -> atPrec ArgPrec $ p v - UFloatLit v -> atPrec ArgPrec $ p v - UDo block -> atPrec LowestPrec $ p block - -instance Pretty FieldName' where - pretty = \case - FieldName s -> pretty s - FieldNum n -> pretty n - -instance Pretty (UAlt n) where - pretty (UAlt pat body) = p pat <+> "->" <+> p body - -instance Pretty (UTopDecl n l) where - pretty (UDataDefDecl (UDataDef nm bs dataCons) bTyCon bDataCons) = - "data" <+> p bTyCon <+> p nm <+> spaced (fromNest bs) <+> "where" <> nest 2 - (prettyLines (zip (toList $ fromNest bDataCons) dataCons)) - pretty (UStructDecl bTyCon (UStructDef nm bs fields defs)) = - "struct" <+> p bTyCon <+> p nm <+> spaced (fromNest bs) <+> "where" <> nest 2 - (prettyLines fields <> prettyLines defs) - pretty (UInterface params methodTys interfaceName methodNames) = - "interface" <+> p params <+> p interfaceName - <> hardline <> foldMap (<>hardline) methods - where - methods = [ p b <> ":" <> p (unsafeCoerceE ty) - | (b, ty) <- zip (toList $ fromNest methodNames) methodTys] - pretty (UInstance className bs params methods (RightB UnitB) _) = - "instance" <+> p bs <+> p className <+> spaced params <+> - prettyLines methods - pretty (UInstance className bs params methods (LeftB v) _) = - "named-instance" <+> p v <+> ":" <+> p bs <+> p className <+> p params - <> prettyLines methods - pretty (ULocalDecl decl) = p decl - -instance Pretty (UDecl' n l) where - pretty (ULet ann b _ rhs) = - align $ p ann <+> p b <+> "=" <> (nest 2 $ group $ line <> pLowest rhs) - pretty (UExprDecl expr) = p expr - pretty UPass = "pass" - -instance Pretty (UEffectRow n) where - pretty (UEffectRow x Nothing) = encloseSep "<" ">" "," $ (p <$> toList x) - pretty (UEffectRow x (Just y)) = "{" <> (hsep $ punctuate "," (p <$> toList x)) <+> "|" <+> p y <> "}" - -prettyBinderNest :: PrettyB b => Nest b n l -> Doc ann -prettyBinderNest bs = nest 6 $ line' <> (sep $ map p $ fromNest bs) - -instance Pretty (UDataDefTrail n) where - pretty (UDataDefTrail bs) = p $ fromNest bs - -instance Pretty (UAnnBinder n l) where - pretty (UAnnBinder _ b ty _) = p b <> ":" <> p ty - -instance Pretty (UAnn n) where - pretty (UAnn ty) = ":" <> p ty - pretty UNoAnn = mempty - -instance Pretty (UMethodDef' n) where - pretty (UMethodDef b rhs) = p b <+> "=" <+> p rhs - -instance Pretty (UPat' n l) where pretty = prettyFromPrettyPrec -instance PrettyPrec (UPat' n l) where - prettyPrec pat = case pat of - UPatBinder x -> atPrec ArgPrec $ p x - UPatProd xs -> atPrec ArgPrec $ parens $ commaSep (fromNest xs) - UPatDepPair (PairB x y) -> atPrec ArgPrec $ parens $ p x <> ",> " <> p y - UPatCon con pats -> atPrec AppPrec $ parens $ p con <+> spaced (fromNest pats) - UPatTable pats -> atPrec ArgPrec $ p pats +instance PrettyPrec () where prettyPrec = atPrec ArgPrec . pretty spaced :: (Foldable f, Pretty a) => f a -> Doc ann -spaced xs = hsep $ map p $ toList xs +spaced xs = hsep $ map pretty $ toList xs commaSep :: (Foldable f, Pretty a) => f a -> Doc ann -commaSep xs = fold $ punctuate "," $ map p $ toList xs - -instance Pretty (EnvFrag n l) where - pretty (EnvFrag bindings) = p bindings - -instance Pretty (Cache n) where - pretty (Cache _ _ _ _ _ _) = "<cache>" -- TODO - -instance Pretty (SynthCandidates n) where - pretty scs = "instance dicts:" <+> p (M.toList $ instanceDicts scs) - -instance Pretty (LoadedModules n) where - pretty _ = "<loaded modules>" +commaSep xs = fold $ punctuate "," $ map pretty $ toList xs indented :: Doc ann -> Doc ann indented doc = nest 2 (hardline <> doc) <> hardline - --- ==== Imp IR === - -instance Pretty (IExpr n) where - pretty (ILit v) = p v - pretty (IVar v _) = p v - pretty (IPtrVar v _) = p v - -instance PrettyPrec (IExpr n) where prettyPrec = atPrec ArgPrec . pretty - -instance Pretty (ImpDecl n l) where - pretty (ImpLet Empty instr) = p instr - pretty (ImpLet (Nest b Empty) instr) = p b <+> "=" <+> p instr - pretty (ImpLet bs instr) = p bs <+> "=" <+> p instr - -instance Pretty IFunType where - pretty (IFunType cc argTys retTys) = - "Fun" <+> p cc <+> p argTys <+> "->" <+> p retTys - -instance Pretty (TopFunDef n) where - pretty = \case - Specialization s -> p s - LinearizationPrimal _ -> "<linearization primal>" - LinearizationTangent _ -> "<linearization tangent>" - -instance Pretty (TopFun n) where - pretty = \case - DexTopFun def lam lowering -> - "Top-level Function" - <> hardline <+> "definition:" <+> pretty def - <> hardline <+> "lambda:" <+> pretty lam - <> hardline <+> "lowering:" <+> pretty lowering - FFITopFun f _ -> p f - -instance IRRep r => Pretty (TopLam r n) where - pretty (TopLam _ _ lam) = pretty lam - -instance Pretty a => Pretty (EvalStatus a) where - pretty = \case - Waiting -> "<waiting>" - Running -> "<running>" - Finished a -> pretty a - -instance Pretty (ImpFunction n) where - pretty (ImpFunction (IFunType cc _ _) (Abs bs body)) = - "impfun" <+> p cc <+> prettyBinderNest bs - <> nest 2 (hardline <> p body) <> hardline - -instance Pretty (ImpBlock n) where - pretty (ImpBlock Empty []) = mempty - pretty (ImpBlock Empty expr) = group $ line <> pLowest expr - pretty (ImpBlock decls []) = prettyLines $ fromNest decls - pretty (ImpBlock decls expr) = prettyLines decls' <> hardline <> pLowest expr - where decls' = fromNest decls - -instance Pretty (IBinder n l) where - pretty (IBinder b ty) = p b <+> ":" <+> p ty - -instance Pretty (ImpInstr n) where - pretty = \case - IFor a n (Abs i block) -> forStr a <+> p i <+> "<" <+> p n <> - nest 4 (p block) - IWhile body -> "while" <+> nest 2 (p body) - ICond predicate cons alt -> - "if" <+> p predicate <+> "then" <> nest 2 (p cons) <> - hardline <> "else" <> nest 2 (p alt) - IQueryParallelism f s -> "queryParallelism" <+> p f <+> p s - ILaunch f size args -> - "launch" <+> p f <+> p size <+> spaced args - ICastOp t x -> "cast" <+> p x <+> "to" <+> p t - IBitcastOp t x -> "bitcast" <+> p x <+> "to" <+> p t - Store dest val -> "store" <+> p dest <+> p val - Alloc _ t s -> "alloc" <+> p t <> "[" <> sizeStr s <> "]" - StackAlloc t s -> "alloca" <+> p t <> "[" <> sizeStr s <> "]" - MemCopy dest src numel -> "memcopy" <+> p dest <+> p src <+> p numel - InitializeZeros ptr numel -> "initializeZeros" <+> p ptr <+> p numel - GetAllocSize ptr -> "getAllocSize" <+> p ptr - Free ptr -> "free" <+> p ptr - ISyncWorkgroup -> "syncWorkgroup" - IThrowError -> "throwError" - ICall f args -> "call" <+> p f <+> p args - IVectorBroadcast v _ -> "vbroadcast" <+> p v - IVectorIota _ -> "viota" - DebugPrint s x -> "debug_print" <+> p (show s) <+> p x - IPtrLoad ptr -> "load" <+> p ptr - IPtrOffset ptr idx -> p ptr <+> "+>" <+> p idx - IBinOp op x y -> opDefault (UBinOp op) [x, y] - IUnOp op x -> opDefault (UUnOp op) [x] - ISelect x y z -> "select" <+> p x <+> p y <+> p z - IOutputStream -> "outputStream" - IShowScalar ptr x -> "show_scalar" <+> p ptr <+> p x - where opDefault name xs = prettyOpDefault name xs $ AppPrec - -sizeStr :: IExpr n -> Doc ann -sizeStr s = case s of - ILit (Word32Lit x) -> p x -- print in decimal because it's more readable - _ -> p s - -instance Pretty BaseType where pretty = prettyFromPrettyPrec -instance PrettyPrec BaseType where - prettyPrec b = case b of - Scalar sb -> prettyPrec sb - Vector shape sb -> atPrec ArgPrec $ encloseSep "<" ">" "x" $ (p <$> shape) ++ [p sb] - PtrType ty -> atPrec AppPrec $ "Ptr" <+> p ty - -instance Pretty AddressSpace where pretty d = p (show d) - -instance Pretty ScalarBaseType where pretty = prettyFromPrettyPrec -instance PrettyPrec ScalarBaseType where - prettyPrec sb = atPrec ArgPrec $ case sb of - Int64Type -> "Int64" - Int32Type -> "Int32" - Float64Type -> "Float64" - Float32Type -> "Float32" - Word8Type -> "Word8" - Word32Type -> "Word32" - Word64Type -> "Word64" - -instance IRRep r => Pretty (TyCon r n) where pretty = prettyFromPrettyPrec -instance IRRep r => PrettyPrec (TyCon r n) where - prettyPrec con = case con of - BaseType b -> prettyPrec b - ProdType [] -> atPrec ArgPrec $ "()" - ProdType as -> atPrec ArgPrec $ align $ group $ - encloseSep "(" ")" ", " $ fmap pApp as - SumType cs -> atPrec ArgPrec $ align $ group $ - encloseSep "(|" "|)" " | " $ fmap pApp cs - RefType h a -> atPrec AppPrec $ pAppArg "Ref" [h] <+> p a - TypeKind -> atPrec ArgPrec "Type" - HeapType -> atPrec ArgPrec "Heap" - Pi piType -> atPrec LowestPrec $ align $ p piType - TabPi piType -> atPrec LowestPrec $ align $ p piType - DepPairTy ty -> prettyPrec ty - DictTy t -> atPrec LowestPrec $ p t - NewtypeTyCon con' -> prettyPrec con' - -prettyPrecNewtype :: NewtypeCon n -> CAtom n -> DocPrec ann -prettyPrecNewtype con x = case (con, x) of - (NatCon, (IdxRepVal n)) -> atPrec ArgPrec $ pretty n - (_, x') -> prettyPrec x' - -instance Pretty (NewtypeTyCon n) where pretty = prettyFromPrettyPrec -instance PrettyPrec (NewtypeTyCon n) where - prettyPrec = \case - Nat -> atPrec ArgPrec $ "Nat" - Fin n -> atPrec AppPrec $ "Fin" <+> pArg n - EffectRowKind -> atPrec ArgPrec "EffKind" - UserADTType "RangeTo" _ (TyConParams _ [i]) -> atPrec LowestPrec $ ".." <> pApp i - UserADTType "RangeToExc" _ (TyConParams _ [i]) -> atPrec LowestPrec $ "..<" <> pApp i - UserADTType "RangeFrom" _ (TyConParams _ [i]) -> atPrec LowestPrec $ pApp i <> ".." - UserADTType "RangeFromExc" _ (TyConParams _ [i]) -> atPrec LowestPrec $ pApp i <> "<.." - UserADTType name _ (TyConParams infs params) -> case (infs, params) of - ([], []) -> atPrec ArgPrec $ p name - ([Explicit, Explicit], [l, r]) - | Just sym <- fromInfix (fromString $ pprint name) -> - atPrec ArgPrec $ align $ group $ - parens $ flatAlt " " "" <> pApp l <> line <> p sym <+> pApp r - _ -> atPrec LowestPrec $ pAppArg (p name) $ ignoreSynthParams (TyConParams infs params) - -instance IRRep r => Pretty (Con r n) where pretty = prettyFromPrettyPrec -instance IRRep r => PrettyPrec (Con r n) where - prettyPrec = \case - Lit l -> prettyPrec l - ProdCon [x] -> atPrec ArgPrec $ "(" <> pLowest x <> ",)" - ProdCon xs -> atPrec ArgPrec $ align $ group $ - encloseSep "(" ")" ", " $ fmap pLowest xs - SumCon _ tag payload -> atPrec ArgPrec $ - "(" <> p tag <> "|" <+> pApp payload <+> "|)" - HeapVal -> atPrec ArgPrec "HeapValue" - Lam lam -> atPrec LowestPrec $ p lam - DepPair x y _ -> atPrec ArgPrec $ align $ group $ - parens $ p x <+> ",>" <+> p y - Eff e -> atPrec ArgPrec $ p e - DictConAtom d -> atPrec LowestPrec $ p d - NewtypeCon con x -> prettyPrecNewtype con x - TyConAtom ty -> prettyPrec ty - -instance IRRep r => Pretty (PrimOp r n) where pretty = prettyFromPrettyPrec -instance IRRep r => PrettyPrec (PrimOp r n) where - prettyPrec = \case - MemOp op -> prettyPrec op - VectorOp op -> prettyPrec op - DAMOp op -> prettyPrec op - Hof (TypedHof _ hof) -> prettyPrec hof - RefOp ref eff -> atPrec LowestPrec case eff of - MAsk -> "ask" <+> pApp ref - MExtend _ x -> "extend" <+> pApp ref <+> pApp x - MGet -> "get" <+> pApp ref - MPut x -> pApp ref <+> ":=" <+> pApp x - IndexRef _ i -> pApp ref <+> "!" <+> pApp i - ProjRef _ i -> "proj_ref" <+> pApp ref <+> p i - UnOp op x -> prettyOpDefault (UUnOp op) [x] - BinOp op x y -> prettyOpDefault (UBinOp op) [x, y] - MiscOp op -> prettyOpGeneric op - -instance IRRep r => Pretty (MemOp r n) where pretty = prettyFromPrettyPrec -instance IRRep r => PrettyPrec (MemOp r n) where - prettyPrec = \case - PtrOffset ptr idx -> atPrec LowestPrec $ pApp ptr <+> "+>" <+> pApp idx - PtrLoad ptr -> atPrec AppPrec $ pAppArg "load" [ptr] - op -> prettyOpGeneric op - -instance IRRep r => Pretty (VectorOp r n) where pretty = prettyFromPrettyPrec -instance IRRep r => PrettyPrec (VectorOp r n) where - prettyPrec = \case - VectorBroadcast v vty -> atPrec LowestPrec $ "vbroadcast" <+> pApp v <+> pApp vty - VectorIota vty -> atPrec LowestPrec $ "viota" <+> pApp vty - VectorIdx tbl i vty -> atPrec LowestPrec $ "vslice" <+> pApp tbl <+> pApp i <+> pApp vty - VectorSubref ref i _ -> atPrec LowestPrec $ "vrefslice" <+> pApp ref <+> pApp i - -prettyOpDefault :: PrettyPrec a => PrimName -> [a] -> DocPrec ann -prettyOpDefault name args = - case length args of - 0 -> atPrec ArgPrec primName - _ -> atPrec AppPrec $ pAppArg primName args - where primName = p name - -prettyOpGeneric :: (IRRep r, GenericOp op, Show (OpConst op r)) => op r n -> DocPrec ann -prettyOpGeneric op = case fromEGenericOpRep op of - GenericOpRep op' [] [] [] -> atPrec ArgPrec (p $ show op') - GenericOpRep op' ts xs lams -> atPrec AppPrec $ pAppArg (p (show op')) xs <+> p ts <+> p lams - -instance Pretty PrimName where - pretty primName = p $ "%" ++ showPrimName primName - -instance IRRep r => Pretty (Hof r n) where pretty = prettyFromPrettyPrec -instance IRRep r => PrettyPrec (Hof r n) where - prettyPrec hof = atPrec LowestPrec case hof of - For _ _ lam -> "for" <+> pLowest lam - While body -> "while" <+> pArg body - RunReader x body -> "runReader" <+> pArg x <> nest 2 (line <> p body) - RunWriter _ bm body -> "runWriter" <+> pArg bm <> nest 2 (line <> p body) - RunState _ x body -> "runState" <+> pArg x <> nest 2 (line <> p body) - RunIO body -> "runIO" <+> pArg body - RunInit body -> "runInit" <+> pArg body - CatchException _ body -> "catchException" <+> pArg body - Linearize body x -> "linearize" <+> pArg body <+> pArg x - Transpose body x -> "transpose" <+> pArg body <+> pArg x - -instance IRRep r => Pretty (DAMOp r n) where pretty = prettyFromPrettyPrec -instance IRRep r => PrettyPrec (DAMOp r n) where - prettyPrec op = atPrec LowestPrec case op of - Seq _ ann _ c lamExpr -> case lamExpr of - UnaryLamExpr b body -> do - "seq" <+> pApp ann <+> pApp c <+> prettyLam (p b <> ".") body - _ -> p (show op) -- shouldn't happen, but crashing pretty printers make debugging hard - RememberDest _ x y -> "rememberDest" <+> pArg x <+> pArg y - Place r v -> pApp r <+> "r:=" <+> pApp v - Freeze r -> "freeze" <+> pApp r - AllocDest ty -> "alloc" <+> pApp ty - -instance IRRep r => Pretty (BaseMonoid r n) where pretty = prettyFromPrettyPrec -instance IRRep r => PrettyPrec (BaseMonoid r n) where - prettyPrec (BaseMonoid x f) = - atPrec LowestPrec $ "baseMonoid" <+> pArg x <> nest 2 (line <> pArg f) - -instance PrettyPrec Direction where - prettyPrec d = atPrec ArgPrec $ case d of - Fwd -> "fwd" - Rev -> "rev" - -printDouble :: Double -> Doc ann -printDouble x = p (double2Float x) - -printFloat :: Float -> Doc ann -printFloat x = p $ reverse $ dropWhile (=='0') $ reverse $ - showFFloat (Just 6) x "" - -instance Pretty LitVal where pretty = prettyFromPrettyPrec -instance PrettyPrec LitVal where - prettyPrec (Int64Lit x) = atPrec ArgPrec $ p x - prettyPrec (Int32Lit x) = atPrec ArgPrec $ p x - prettyPrec (Float64Lit x) = atPrec ArgPrec $ printDouble x - prettyPrec (Float32Lit x) = atPrec ArgPrec $ printFloat x - prettyPrec (Word8Lit x) = atPrec ArgPrec $ p $ show $ toEnum @Char $ fromIntegral x - prettyPrec (Word32Lit x) = atPrec ArgPrec $ p $ "0x" ++ showHex x "" - prettyPrec (Word64Lit x) = atPrec ArgPrec $ p $ "0x" ++ showHex x "" - prettyPrec (PtrLit ty (PtrLitVal x)) = - atPrec ArgPrec $ "Ptr" <+> p ty <+> p (show x) - prettyPrec (PtrLit _ NullPtr) = atPrec ArgPrec $ "NullPtr" - prettyPrec (PtrLit _ (PtrSnapshot _)) = atPrec ArgPrec "<ptr snapshot>" - -instance Pretty CallingConvention where - pretty = p . show - -instance Pretty LetAnn where - pretty ann = case ann of - PlainLet -> "" - InlineLet -> "%inline" - NoInlineLet -> "%noinline" - LinearLet -> "%linear" - OccInfoPure u -> p u <> line - OccInfoImpure u -> p u <> ", impure" <> line - -instance Pretty UsageInfo where - pretty (UsageInfo static (ixDepth, ct)) = - "occurs in" <+> p static <+> "places, read" - <+> p ct <+> "times, to depth" <+> p (show ixDepth) - -instance Pretty Count where - pretty (Bounded ct) = "<=" <+> pretty ct - pretty Occ.Unbounded = "many" - -instance PrettyPrec () where prettyPrec = atPrec ArgPrec . pretty - -instance Pretty RWS where - pretty eff = case eff of - Reader -> "Read" - Writer -> "Accum" - State -> "State" - -printOutput :: Bool -> Output -> String -printOutput isatty out = case out of - Error _ -> addColor isatty Red $ addPrefix ">" $ pprint out - _ -> addPrefix (addColor isatty Cyan ">") $ pprint $ out - -addPrefix :: String -> String -> String -addPrefix prefix str = unlines $ map prefixLine $ lines str - where prefixLine :: String -> String - prefixLine s = case s of "" -> prefix - _ -> prefix ++ " " ++ s - -addColor :: Bool -> ANSI.Color -> String -> String -addColor False _ s = s -addColor True c s = - setSGRCode [SetConsoleIntensity BoldIntensity, SetColor Foreground Vivid c] - ++ s ++ setSGRCode [Reset] - --- === Concrete syntax rendering === - -instance Pretty SourceBlock' where - pretty (TopDecl decl) = p decl - pretty d = fromString $ show d - -instance Pretty CTopDecl where - pretty (CSDecl ann decl) = annDoc <> p decl - where annDoc = case ann of - PlainLet -> mempty - _ -> p ann <> " " - pretty d = fromString $ show d - -instance Pretty CSDecl where - pretty = undefined - -- pretty (CLet pat blk) = pArg pat <+> "=" <+> p blk - -- pretty (CBind pat blk) = pArg pat <+> "<-" <+> p blk - -- pretty (CDefDecl (CDef name args maybeAnn blk)) = - -- "def " <> fromString name <> " " <> prettyParamGroups args <+> annDoc - -- <> nest 2 (hardline <> p blk) - -- where annDoc = case maybeAnn of Just (expl, ty) -> p expl <+> pArg ty - -- Nothing -> mempty - -- pretty (CInstance header givens methods name) = - -- name' <> p header <> p givens <> nest 2 (hardline <> p methods) where - -- name' = case name of - -- Nothing -> "instance " - -- (Just n) -> "named-instance " <> p n <> " " - -- pretty (CExpr e) = p e - -instance Pretty AppExplicitness where - pretty ExplicitApp = "->" - pretty ImplicitApp = "->>" - -instance Pretty CSBlock where - pretty (IndentedBlock _ decls) = nest 2 $ prettyLines decls - pretty (ExprBlock g) = pArg g - -instance Pretty Group where pretty = prettyFromPrettyPrec -instance PrettyPrec Group where - prettyPrec = undefined - -- prettyPrec (CIdentifier n) = atPrec ArgPrec $ fromString n - -- prettyPrec (CPrim prim args) = prettyOpDefault prim args - -- prettyPrec (CParens blk) = - -- atPrec ArgPrec $ "(" <> p blk <> ")" - -- prettyPrec (CBrackets g) = atPrec ArgPrec $ pretty g - -- prettyPrec (CBin op lhs rhs) = - -- atPrec LowestPrec $ pArg lhs <+> p op <+> pArg rhs - -- prettyPrec (CLambda args body) = - -- atPrec LowestPrec $ "\\" <> spaced args <> "." <> p body - -- prettyPrec (CCase scrut alts) = - -- atPrec LowestPrec $ "case " <> p scrut <> " of " <> prettyLines alts - -- prettyPrec g = atPrec ArgPrec $ fromString $ show g - -instance Pretty Bin where - pretty (EvalBinOp name) = pretty name - pretty DepAmpersand = "&>" - pretty Dot = "." - pretty DepComma = ",>" - pretty Colon = ":" - pretty DoubleColon = "::" - pretty Dollar = "$" - pretty ImplicitArrow = "->>" - pretty FatArrow = "=>" - pretty Pipe = "|" - pretty CSEqual = "=" diff --git a/src/lib/QueryType.hs b/src/lib/QueryType.hs index 79214c57..a6e3b4b5 100644 --- a/src/lib/QueryType.hs +++ b/src/lib/QueryType.hs @@ -15,6 +15,7 @@ import Data.Functor ((<&>)) import Types.Primitives import Types.Core import Types.Source +import Types.Top import Types.Imp import IRVariants import Core diff --git a/src/lib/QueryTypePure.hs b/src/lib/QueryTypePure.hs index 45a080f7..153ed544 100644 --- a/src/lib/QueryTypePure.hs +++ b/src/lib/QueryTypePure.hs @@ -8,6 +8,7 @@ module QueryTypePure where import Types.Primitives import Types.Core +import Types.Top import IRVariants import Name diff --git a/src/lib/Runtime.hs b/src/lib/Runtime.hs index 1bac0c11..20730e33 100644 --- a/src/lib/Runtime.hs +++ b/src/lib/Runtime.hs @@ -29,7 +29,7 @@ import Err import MonadUtil import PPrint () -import Types.Core hiding (DexDestructor) +import Types.Top hiding (DexDestructor) import Types.Source hiding (CInt) import Types.Primitives diff --git a/src/lib/Simplify.hs b/src/lib/Simplify.hs index 129039bd..789ccb59 100644 --- a/src/lib/Simplify.hs +++ b/src/lib/Simplify.hs @@ -31,6 +31,7 @@ import RuntimePrint import Transpose import Types.Core import Types.Source +import Types.Top import Types.Primitives import Util (enumerate) diff --git a/src/lib/Simplify.hs-boot b/src/lib/Simplify.hs-boot index c14ae648..8e1499c3 100644 --- a/src/lib/Simplify.hs-boot +++ b/src/lib/Simplify.hs-boot @@ -9,5 +9,6 @@ module Simplify (linearizeTopFun) where import Name import Builder import Types.Core +import Types.Top linearizeTopFun :: (Mut n, Fallible1 m, TopBuilder m) => LinearizationSpec n -> m n (TopFunName n, TopFunName n) diff --git a/src/lib/SourceRename.hs b/src/lib/SourceRename.hs index ee2b6f9f..c6b68d82 100644 --- a/src/lib/SourceRename.hs +++ b/src/lib/SourceRename.hs @@ -23,7 +23,7 @@ import PPrint () import IRVariants import Types.Source import Types.Primitives -import Types.Core (Env (..), ModuleEnv (..)) +import Types.Top (Env (..), ModuleEnv (..)) renameSourceNamesTopUDecl :: (Fallible1 m, EnvReader m) diff --git a/src/lib/Subst.hs b/src/lib/Subst.hs index 06265b78..b8124d36 100644 --- a/src/lib/Subst.hs +++ b/src/lib/Subst.hs @@ -19,6 +19,7 @@ import Name import MTL1 import IRVariants import Types.Core +import Types.Top import Core import qualified RawName as R import QueryTypePure diff --git a/src/lib/TopLevel.hs b/src/lib/TopLevel.hs index 78dafb62..69dc417d 100644 --- a/src/lib/TopLevel.hs +++ b/src/lib/TopLevel.hs @@ -63,7 +63,6 @@ import Subst import Name import OccAnalysis import Optimize -import PPrint (pprintCanonicalized) import Paths_dex (getDataFileName) import QueryType import Runtime @@ -75,6 +74,7 @@ import Types.Core import Types.Imp import Types.Primitives import Types.Source +import Types.Top import Util ( Tree (..), File (..), readFileWithHash) import Vectorize diff --git a/src/lib/Transpose.hs b/src/lib/Transpose.hs index 10c87d37..e35305bc 100644 --- a/src/lib/Transpose.hs +++ b/src/lib/Transpose.hs @@ -20,6 +20,7 @@ import Name import Subst import QueryType import Types.Core +import Types.Top import Types.Primitives import Util (enumerate) @@ -36,9 +37,7 @@ transpose lam ct = liftEmitBuilder $ runTransposeM do runTransposeM :: TransposeM n n a -> BuilderM SimpIR n a runTransposeM cont = runSubstReaderT idSubst $ cont -transposeTopFun - :: (MonadFail1 m, EnvReader m) - => STopLam n -> m n (STopLam n) +transposeTopFun :: (MonadFail1 m, EnvReader m) => STopLam n -> m n (STopLam n) transposeTopFun (TopLam False _ lam) = liftBuilder $ runTransposeM do (Abs bsNonlin (Abs bLin body), Abs bsNonlin'' outTy) <- unpackLinearLamExpr lam refreshBinders bsNonlin \bsNonlin' substFrag -> extendRenamer substFrag do diff --git a/src/lib/Types/Core.hs b/src/lib/Types/Core.hs index 0475f6ac..daee7511 100644 --- a/src/lib/Types/Core.hs +++ b/src/lib/Types/Core.hs @@ -4,20 +4,8 @@ -- license that can be found in the LICENSE file or at -- https://developers.google.com/open-source/licenses/bsd -{-# LANGUAGE DeriveGeneric #-} -{-# LANGUAGE FlexibleContexts #-} -{-# LANGUAGE FlexibleInstances #-} -{-# LANGUAGE StrictData #-} -{-# LANGUAGE DeriveFunctor #-} -{-# LANGUAGE ViewPatterns #-} -{-# LANGUAGE TypeFamilies #-} -{-# LANGUAGE InstanceSigs #-} {-# LANGUAGE UndecidableInstances #-} -{-# LANGUAGE StandaloneDeriving #-} -{-# LANGUAGE DerivingStrategies #-} -{-# LANGUAGE DerivingVia #-} -{-# LANGUAGE DefaultSignatures #-} -{-# LANGUAGE ConstraintKinds #-} +{-# LANGUAGE StrictData #-} -- Core data types for CoreIR and its variations. @@ -25,19 +13,20 @@ module Types.Core (module Types.Core, SymbolicZeros (..)) where import Data.Word import Data.Maybe (fromJust) -import Data.Functor +import Data.Foldable (toList) import Data.Hashable -import Data.Text.Prettyprint.Doc hiding (nest) +import Data.String (fromString) +import Data.Text.Prettyprint.Doc +import Data.Text (Text, unsnoc, uncons) import qualified Data.Map.Strict as M -import qualified Data.Set as S import GHC.Generics (Generic (..)) import Data.Store (Store (..)) -import Foreign.Ptr import Name -import Util (FileHash, SnocList (..), Tree (..)) +import Util (Tree (..)) import IRVariants +import PPrint import qualified Types.OpNames as P import Types.Primitives @@ -141,6 +130,9 @@ data BaseMonoid r n = , baseCombine :: LamExpr r n } deriving (Show, Generic) +data RepVal (n::S) = RepVal (SType n) (Tree (IExpr n)) + deriving (Show, Generic) + data DeclBinding r n = DeclBinding LetAnn (Expr r n) deriving (Show, Generic) data Decl (r::IR) (n::S) (l::S) = Let (AtomNameBinder r n l) (DeclBinding r n) @@ -204,11 +196,6 @@ data TyConParams n = TyConParams [Explicitness] [Atom CoreIR n] type WithDecls (r::IR) = Abs (Decls r) :: E -> E type Block (r::IR) = WithDecls r (Expr r) :: E -type TopBlock = TopLam -- used for nullary lambda -type IsDestLam = Bool -data TopLam (r::IR) (n::S) = TopLam IsDestLam (PiType r n) (LamExpr r n) - deriving (Show, Generic) - data LamExpr (r::IR) (n::S) where LamExpr :: Nest (Binder r) n l -> Expr r l -> LamExpr r n @@ -282,9 +269,6 @@ instance ToBindersAbs TyConDef DataConDefs CoreIR where instance ToBindersAbs ClassDef (Abs (Nest CBinder) (ListE CorePiType)) CoreIR where toAbs (ClassDef _ _ _ _ _ bs scBs tys) = Abs bs (Abs scBs (ListE tys)) -instance ToBindersAbs (TopLam r) (Expr r) r where - toAbs (TopLam _ _ lam) = toAbs lam - -- === GenericOp class === class GenericOp (e::IR->E) where @@ -434,7 +418,6 @@ type CDecl = Decl CoreIR type CDecls = Decls CoreIR type CAtomName = AtomName CoreIR type CAtomVar = AtomVar CoreIR -type CTopLam = TopLam CoreIR type SAtom = Atom SimpIR type SType = Type SimpIR @@ -449,7 +432,6 @@ type SAtomName = AtomName SimpIR type SAtomVar = AtomVar SimpIR type SBinder = Binder SimpIR type SLam = LamExpr SimpIR -type STopLam = TopLam SimpIR -- === newtypes === @@ -522,174 +504,6 @@ data DictCon (r::IR) (n::S) where IxRawFin :: Atom r n -> DictCon r n IxSpecialized :: SpecDictName n -> [SAtom n] -> DictCon SimpIR n --- TODO: Use an IntMap -newtype CustomRules (n::S) = - CustomRules { customRulesMap :: M.Map (AtomName CoreIR n) (AtomRules n) } - deriving (Semigroup, Monoid, Store) -data AtomRules (n::S) = - -- number of implicit args, number of explicit args, linearization function - CustomLinearize Int Int SymbolicZeros (CAtom n) - deriving (Generic) - --- === Runtime representations === - -data RepVal (n::S) = RepVal (SType n) (Tree (IExpr n)) - deriving (Show, Generic) - --- === envs and modules === - --- `ModuleEnv` contains data that only makes sense in the context of evaluating --- a particular module. `TopEnv` contains everything that makes sense "between" --- evaluating modules. -data Env n = Env - { topEnv :: {-# UNPACK #-} TopEnv n - , moduleEnv :: {-# UNPACK #-} ModuleEnv n } - deriving (Generic) - -data TopEnv (n::S) = TopEnv - { envDefs :: RecSubst Binding n - , envCustomRules :: CustomRules n - , envCache :: Cache n - , envLoadedModules :: LoadedModules n - , envLoadedObjects :: LoadedObjects n } - deriving (Generic) - -data SerializedEnv n = SerializedEnv - { serializedEnvDefs :: RecSubst Binding n - , serializedEnvCustomRules :: CustomRules n - , serializedEnvCache :: Cache n } - deriving (Generic) - --- TODO: consider splitting this further into `ModuleEnv` (the env that's --- relevant between top-level decls) and `LocalEnv` (the additional parts of the --- env that's relevant under a lambda binder). Unlike the Top/Module --- distinction, there's some overlap. For example, instances can be defined at --- both the module-level and local level. Similarly, if we start allowing --- top-level effects in `Main` then we'll have module-level effects and local --- effects. -data ModuleEnv (n::S) = ModuleEnv - { envImportStatus :: ImportStatus n - , envSourceMap :: SourceMap n - , envSynthCandidates :: SynthCandidates n } - deriving (Generic) - -data Module (n::S) = Module - { moduleSourceName :: ModuleSourceName - , moduleDirectDeps :: S.Set (ModuleName n) - , moduleTransDeps :: S.Set (ModuleName n) -- XXX: doesn't include the module itself - , moduleExports :: SourceMap n - -- these are just the synth candidates required by this - -- module by itself. We'll usually also need those required by the module's - -- (transitive) dependencies, which must be looked up separately. - , moduleSynthCandidates :: SynthCandidates n } - deriving (Show, Generic) - -data LoadedModules (n::S) = LoadedModules - { fromLoadedModules :: M.Map ModuleSourceName (ModuleName n)} - deriving (Show, Generic) - -emptyModuleEnv :: ModuleEnv n -emptyModuleEnv = ModuleEnv emptyImportStatus (SourceMap mempty) mempty - -emptyLoadedModules :: LoadedModules n -emptyLoadedModules = LoadedModules mempty - -data LoadedObjects (n::S) = LoadedObjects - -- the pointer points to the actual runtime function - { fromLoadedObjects :: M.Map (FunObjCodeName n) NativeFunction} - deriving (Show, Generic) - -emptyLoadedObjects :: LoadedObjects n -emptyLoadedObjects = LoadedObjects mempty - -data ImportStatus (n::S) = ImportStatus - { directImports :: S.Set (ModuleName n) - -- XXX: This are cached for efficiency. It's derivable from `directImports`. - , transImports :: S.Set (ModuleName n) } - deriving (Show, Generic) - -data TopEnvFrag n l = TopEnvFrag (EnvFrag n l) (ModuleEnv l) (SnocList (TopEnvUpdate l)) - -data TopEnvUpdate n = - ExtendCache (Cache n) - | AddCustomRule (CAtomName n) (AtomRules n) - | UpdateLoadedModules ModuleSourceName (ModuleName n) - | UpdateLoadedObjects (FunObjCodeName n) NativeFunction - | FinishDictSpecialization (SpecDictName n) [TopLam SimpIR n] - | LowerDictSpecialization (SpecDictName n) [TopLam SimpIR n] - | UpdateTopFunEvalStatus (TopFunName n) (TopFunEvalStatus n) - | UpdateInstanceDef (InstanceName n) (InstanceDef n) - | UpdateTyConDef (TyConName n) (TyConDef n) - | UpdateFieldDef (TyConName n) SourceName (CAtomName n) - --- TODO: we could add a lot more structure for querying by dict type, caching, etc. -data SynthCandidates n = SynthCandidates - { instanceDicts :: M.Map (ClassName n) [InstanceName n] - , ixInstances :: [InstanceName n] } - deriving (Show, Generic) - -emptyImportStatus :: ImportStatus n -emptyImportStatus = ImportStatus mempty mempty - --- TODO: figure out the additional top-level context we need -- backend, other --- compiler flags etc. We can have a map from those to this. - -data Cache (n::S) = Cache - { specializationCache :: EMap SpecializationSpec TopFunName n - , ixDictCache :: EMap AbsDict SpecDictName n - , linearizationCache :: EMap LinearizationSpec (PairE TopFunName TopFunName) n - , transpositionCache :: EMap TopFunName TopFunName n - -- This is memoizing `parseAndGetDeps :: Text -> [ModuleSourceName]`. But we - -- only want to store one entry per module name as a simple cache eviction - -- policy, so we store it keyed on the module name, with the text hash for - -- the validity check. - , parsedDeps :: M.Map ModuleSourceName (FileHash, [ModuleSourceName]) - , moduleEvaluations :: M.Map ModuleSourceName ((FileHash, [ModuleName n]), ModuleName n) - } deriving (Show, Generic) - --- === runtime function and variable representations === - -type RuntimeEnv = DynamicVarKeyPtrs - -type DexDestructor = FunPtr (IO ()) - -data NativeFunction = NativeFunction - { nativeFunPtr :: FunPtr () - , nativeFunTeardown :: IO () } - -instance Show NativeFunction where - show _ = "<native function>" - --- Holds pointers to thread-local storage used to simulate dynamically scoped --- variables, such as the output stream file descriptor. -type DynamicVarKeyPtrs = [(DynamicVar, Ptr ())] - -data DynamicVar = OutStreamDyvar -- TODO: add others as needed - deriving (Enum, Bounded) - -dynamicVarCName :: DynamicVar -> String -dynamicVarCName OutStreamDyvar = "dex_out_stream_dyvar" - -dynamicVarLinkMap :: DynamicVarKeyPtrs -> [(String, Ptr ())] -dynamicVarLinkMap dyvars = dyvars <&> \(v, ptr) -> (dynamicVarCName v, ptr) - --- === bindings - static information we carry about a lexical scope === - --- TODO: consider making this an open union via a typeable-like class -data Binding (c::C) (n::S) where - AtomNameBinding :: AtomBinding r n -> Binding (AtomNameC r) n - TyConBinding :: Maybe (TyConDef n) -> DotMethods n -> Binding TyConNameC n - DataConBinding :: TyConName n -> Int -> Binding DataConNameC n - ClassBinding :: ClassDef n -> Binding ClassNameC n - InstanceBinding :: InstanceDef n -> CorePiType n -> Binding InstanceNameC n - MethodBinding :: ClassName n -> Int -> Binding MethodNameC n - TopFunBinding :: TopFun n -> Binding TopFunNameC n - FunObjCodeBinding :: CFunction n -> Binding FunObjCodeNameC n - ModuleBinding :: Module n -> Binding ModuleNameC n - -- TODO: add a case for abstracted pointers, as used in `ClosedImpFunction` - PtrBinding :: PtrType -> PtrLitVal -> Binding PtrNameC n - SpecializedDictBinding :: SpecializedDictDef n -> Binding SpecializedDictNameC n - ImpNameBinding :: BaseType -> Binding ImpNameC n data EffectOpDef (n::S) where EffectOpDef :: EffectName n -- name of associated effect @@ -748,108 +562,6 @@ instance RenameE EffectOpType deriving instance Show (EffectOpType n) deriving via WrapE EffectOpType n instance Generic (EffectOpType n) -instance GenericE SpecializedDictDef where - type RepE SpecializedDictDef = AbsDict `PairE` MaybeE (ListE (TopLam SimpIR)) - fromE (SpecializedDict ab methods) = ab `PairE` methods' - where methods' = case methods of Just xs -> LeftE (ListE xs) - Nothing -> RightE UnitE - {-# INLINE fromE #-} - toE (ab `PairE` methods) = SpecializedDict ab methods' - where methods' = case methods of LeftE (ListE xs) -> Just xs - RightE UnitE -> Nothing - {-# INLINE toE #-} - -instance SinkableE SpecializedDictDef -instance HoistableE SpecializedDictDef -instance AlphaEqE SpecializedDictDef -instance AlphaHashableE SpecializedDictDef -instance RenameE SpecializedDictDef - -data EvalStatus a = Waiting | Running | Finished a - deriving (Show, Eq, Ord, Generic, Functor, Foldable, Traversable) -type TopFunEvalStatus n = EvalStatus (TopFunLowerings n) - -data TopFun (n::S) = - DexTopFun (TopFunDef n) (TopLam SimpIR n) (TopFunEvalStatus n) - | FFITopFun String IFunType - deriving (Show, Generic) - -data TopFunDef (n::S) = - Specialization (SpecializationSpec n) - | LinearizationPrimal (LinearizationSpec n) - -- Tangent functions all take some number of nonlinear args, then a *single* - -- linear arg. This is so that transposition can be an involution - you apply - -- it twice and you get back to the original function. - | LinearizationTangent (LinearizationSpec n) - deriving (Show, Generic) - -newtype TopFunLowerings (n::S) = TopFunLowerings - { topFunObjCode :: FunObjCodeName n } -- TODO: add optimized, imp etc. as needed - deriving (Show, Generic, SinkableE, HoistableE, RenameE, AlphaEqE, AlphaHashableE, Pretty) - -data AtomBinding (r::IR) (n::S) where - LetBound :: DeclBinding r n -> AtomBinding r n - MiscBound :: Type r n -> AtomBinding r n - TopDataBound :: RepVal n -> AtomBinding SimpIR n - SolverBound :: SolverBinding n -> AtomBinding CoreIR n - NoinlineFun :: CType n -> CAtom n -> AtomBinding CoreIR n - FFIFunBound :: CorePiType n -> TopFunName n -> AtomBinding CoreIR n - -deriving instance IRRep r => Show (AtomBinding r n) -deriving via WrapE (AtomBinding r) n instance IRRep r => Generic (AtomBinding r n) - --- name of function, name of arg -type InferenceArgDesc = (String, String) -data InfVarDesc = - ImplicitArgInfVar InferenceArgDesc - | AnnotationInfVar String -- name of binder - | TypeInstantiationInfVar String -- name of type - | MiscInfVar - deriving (Show, Generic, Eq, Ord) - -data SolverBinding (n::S) = - InfVarBound (CType n) - | SkolemBound (CType n) - | DictBound (CType n) - deriving (Show, Generic) - -newtype EnvFrag (n::S) (l::S) = EnvFrag (RecSubstFrag Binding n l) - deriving (OutFrag) - -instance HasScope Env where - toScope = toScope . envDefs . topEnv - -instance OutMap Env where - emptyOutMap = - Env (TopEnv (RecSubst emptyInFrag) mempty mempty emptyLoadedModules emptyLoadedObjects) - emptyModuleEnv - {-# INLINE emptyOutMap #-} - -instance ExtOutMap Env (RecSubstFrag Binding) where - -- TODO: We might want to reorganize this struct to make this - -- do less explicit sinking etc. It's a hot operation! - extendOutMap (Env (TopEnv defs rules cache loadedM loadedO) moduleEnv) frag = - withExtEvidence frag $ Env - (TopEnv - (defs `extendRecSubst` frag) - (sink rules) - (sink cache) - (sink loadedM) - (sink loadedO)) - (sink moduleEnv) - {-# INLINE extendOutMap #-} - -instance ExtOutMap Env EnvFrag where - extendOutMap = extendEnv - {-# INLINE extendOutMap #-} - -extendEnv :: Distinct l => Env n -> EnvFrag n l -> Env l -extendEnv env (EnvFrag newEnv) = do - case extendOutMap env newEnv of - Env envTop (ModuleEnv imports sm scs) -> do - Env envTop (ModuleEnv imports sm scs) -{-# NOINLINE [1] extendEnv #-} - -- === effects === data Effect (r::IR) (n::S) = @@ -906,31 +618,6 @@ instance IRRep r => Store (EffectRowTail r n) instance IRRep r => Store (EffectRow r n) instance IRRep r => Store (Effect r n) --- === Specialization and generalization === - -type Generalized (r::IR) (e::E) (n::S) = (Abstracted r e n, [Atom r n]) -type Abstracted (r::IR) (e::E) = Abs (Nest (Binder r)) e -type AbsDict = Abstracted CoreIR (Dict CoreIR) - -data SpecializedDictDef n = - SpecializedDict - (AbsDict n) - -- Methods (thunked if nullary), if they're available. - -- We create specialized dict names during simplification, but we don't - -- actually simplify/lower them until we return to TopLevel - (Maybe [TopLam SimpIR n]) - deriving (Show, Generic) - --- TODO: extend with AD-oriented specializations, backend-specific specializations etc. -data SpecializationSpec (n::S) = - AppSpecialization (AtomVar CoreIR n) (Abstracted CoreIR (ListE CAtom) n) - deriving (Show, Generic) - -type Active = Bool -data LinearizationSpec (n::S) = - LinearizationSpec (TopFunName n) [Active] - deriving (Show, Generic) - -- === Binder utils === binderType :: Binder r n l -> Type r n @@ -946,39 +633,6 @@ bindersVars = \case Nest b bs -> withExtEvidence b $ withSubscopeDistinct bs $ sink (binderVar b) : bindersVars bs --- === ToBinding === - -atomBindingToBinding :: AtomBinding r n -> Binding (AtomNameC r) n -atomBindingToBinding b = AtomNameBinding b - -bindingToAtomBinding :: Binding (AtomNameC r) n -> AtomBinding r n -bindingToAtomBinding (AtomNameBinding b) = b - -class (RenameE e, SinkableE e) => ToBinding (e::E) (c::C) | e -> c where - toBinding :: e n -> Binding c n - -instance Color c => ToBinding (Binding c) c where - toBinding = id - -instance IRRep r => ToBinding (AtomBinding r) (AtomNameC r) where - toBinding = atomBindingToBinding - -instance IRRep r => ToBinding (DeclBinding r) (AtomNameC r) where - toBinding = toBinding . LetBound - -instance IRRep r => ToBinding (Type r) (AtomNameC r) where - toBinding = toBinding . MiscBound - -instance ToBinding SolverBinding (AtomNameC CoreIR) where - toBinding = toBinding . SolverBound - -instance IRRep r => ToBinding (IxType r) (AtomNameC r) where - toBinding (IxType t _) = toBinding t - -instance (ToBinding e1 c, ToBinding e2 c) => ToBinding (EitherE e1 e2) c where - toBinding (LeftE e) = toBinding e - toBinding (RightE e) = toBinding e - -- === ToAtom === class ToAtom (e::E) (r::IR) | e -> r where @@ -1168,15 +822,6 @@ pattern TrueAtom = Con (Lit (Word8Lit 1)) -- === Typeclass instances for Name and other Haskell libraries === -instance GenericE AtomRules where - type RepE AtomRules = (LiftE (Int, Int, SymbolicZeros)) `PairE` CAtom - fromE (CustomLinearize ni ne sz a) = LiftE (ni, ne, sz) `PairE` a - toE (LiftE (ni, ne, sz) `PairE` a) = CustomLinearize ni ne sz a -instance SinkableE AtomRules -instance HoistableE AtomRules -instance AlphaEqE AtomRules -instance RenameE AtomRules - instance GenericE RepVal where type RepE RepVal= PairE SType (ComposeE Tree IExpr) fromE (RepVal ty tree) = ty `PairE` ComposeE tree @@ -1188,15 +833,6 @@ instance HoistableE RepVal instance AlphaHashableE RepVal instance AlphaEqE RepVal -instance GenericE CustomRules where - type RepE CustomRules = ListE (PairE (AtomName CoreIR) AtomRules) - fromE (CustomRules m) = ListE $ toPairE <$> M.toList m - toE (ListE l) = CustomRules $ M.fromList $ fromPairE <$> l -instance SinkableE CustomRules -instance HoistableE CustomRules -instance AlphaEqE CustomRules -instance RenameE CustomRules - instance GenericE TyConParams where type RepE TyConParams = PairE (LiftE [Explicitness]) (ListE CAtom) fromE (TyConParams infs xs) = PairE (LiftE infs) (ListE xs) @@ -2014,45 +1650,6 @@ instance IRRep r => AlphaEqE (DictCon r) instance IRRep r => AlphaHashableE (DictCon r) instance IRRep r => RenameE (DictCon r) -instance GenericE Cache where - type RepE Cache = - EMap SpecializationSpec TopFunName - `PairE` EMap AbsDict SpecDictName - `PairE` EMap LinearizationSpec (PairE TopFunName TopFunName) - `PairE` EMap TopFunName TopFunName - `PairE` LiftE (M.Map ModuleSourceName (FileHash, [ModuleSourceName])) - `PairE` ListE ( LiftE ModuleSourceName - `PairE` LiftE FileHash - `PairE` ListE ModuleName - `PairE` ModuleName) - fromE (Cache x y z w parseCache evalCache) = - x `PairE` y `PairE` z `PairE` w `PairE` LiftE parseCache `PairE` - ListE [LiftE sourceName `PairE` LiftE hashVal `PairE` ListE deps `PairE` result - | (sourceName, ((hashVal, deps), result)) <- M.toList evalCache ] - {-# INLINE fromE #-} - toE (x `PairE` y `PairE` z `PairE` w `PairE` LiftE parseCache `PairE` ListE evalCache) = - Cache x y z w parseCache - (M.fromList - [(sourceName, ((hashVal, deps), result)) - | LiftE sourceName `PairE` LiftE hashVal `PairE` ListE deps `PairE` result - <- evalCache]) - {-# INLINE toE #-} - -instance SinkableE Cache -instance HoistableE Cache -instance AlphaEqE Cache -instance RenameE Cache -instance Store (Cache n) - -instance Monoid (Cache n) where - mempty = Cache mempty mempty mempty mempty mempty mempty - mappend = (<>) - -instance Semigroup (Cache n) where - -- right-biased instead of left-biased - Cache x1 x2 x3 x4 x5 x6 <> Cache y1 y2 y3 y4 y5 y6 = - Cache (y1<>x1) (y2<>x2) (y3<>x3) (y4<>x4) (x5<>y5) (x6<>y6) - instance GenericE (LamExpr r) where type RepE (LamExpr r) = Abs (Nest (Binder r)) (Expr r) fromE (LamExpr b block) = Abs b block @@ -2167,228 +1764,6 @@ instance IRRep r => RenameE (DepPairType r) deriving instance IRRep r => Show (DepPairType r n) deriving via WrapE (DepPairType r) n instance IRRep r => Generic (DepPairType r n) -instance GenericE SynthCandidates where - type RepE SynthCandidates = ListE (PairE ClassName (ListE InstanceName)) - `PairE` ListE InstanceName - fromE (SynthCandidates xs ys) = ListE xs' `PairE` ListE ys - where xs' = map (\(k,vs) -> PairE k (ListE vs)) (M.toList xs) - {-# INLINE fromE #-} - toE (ListE xs `PairE` ListE ys) = SynthCandidates xs' ys - where xs' = M.fromList $ map (\(PairE k (ListE vs)) -> (k,vs)) xs - {-# INLINE toE #-} - -instance SinkableE SynthCandidates -instance HoistableE SynthCandidates -instance AlphaEqE SynthCandidates -instance AlphaHashableE SynthCandidates -instance RenameE SynthCandidates - -instance IRRep r => GenericE (AtomBinding r) where - type RepE (AtomBinding r) = - EitherE2 (EitherE3 - (DeclBinding r) -- LetBound - (Type r) -- MiscBound - (WhenCore r SolverBinding) -- SolverBound - ) (EitherE3 - (WhenCore r (PairE CType CAtom)) -- NoinlineFun - (WhenSimp r RepVal) -- TopDataBound - (WhenCore r (CorePiType `PairE` TopFunName)) -- FFIFunBound - ) - - fromE = \case - LetBound x -> Case0 $ Case0 x - MiscBound x -> Case0 $ Case1 x - SolverBound x -> Case0 $ Case2 $ WhenIRE x - NoinlineFun t x -> Case1 $ Case0 $ WhenIRE $ PairE t x - TopDataBound repVal -> Case1 $ Case1 $ WhenIRE repVal - FFIFunBound ty v -> Case1 $ Case2 $ WhenIRE $ ty `PairE` v - {-# INLINE fromE #-} - - toE = \case - Case0 x' -> case x' of - Case0 x -> LetBound x - Case1 x -> MiscBound x - Case2 (WhenIRE x) -> SolverBound x - _ -> error "impossible" - Case1 x' -> case x' of - Case0 (WhenIRE (PairE t x)) -> NoinlineFun t x - Case1 (WhenIRE repVal) -> TopDataBound repVal - Case2 (WhenIRE (ty `PairE` v)) -> FFIFunBound ty v - _ -> error "impossible" - _ -> error "impossible" - {-# INLINE toE #-} - - -instance IRRep r => SinkableE (AtomBinding r) -instance IRRep r => HoistableE (AtomBinding r) -instance IRRep r => RenameE (AtomBinding r) -instance IRRep r => AlphaEqE (AtomBinding r) -instance IRRep r => AlphaHashableE (AtomBinding r) - -instance GenericE TopFunDef where - type RepE TopFunDef = EitherE3 SpecializationSpec LinearizationSpec LinearizationSpec - fromE = \case - Specialization s -> Case0 s - LinearizationPrimal s -> Case1 s - LinearizationTangent s -> Case2 s - {-# INLINE fromE #-} - toE = \case - Case0 s -> Specialization s - Case1 s -> LinearizationPrimal s - Case2 s -> LinearizationTangent s - _ -> error "impossible" - {-# INLINE toE #-} - -instance SinkableE TopFunDef -instance HoistableE TopFunDef -instance RenameE TopFunDef -instance AlphaEqE TopFunDef -instance AlphaHashableE TopFunDef - -instance IRRep r => GenericE (TopLam r) where - type RepE (TopLam r) = LiftE Bool `PairE` PiType r `PairE` LamExpr r - fromE (TopLam d x y) = LiftE d `PairE` x `PairE` y - {-# INLINE fromE #-} - toE (LiftE d `PairE` x `PairE` y) = TopLam d x y - {-# INLINE toE #-} - -instance IRRep r => SinkableE (TopLam r) -instance IRRep r => HoistableE (TopLam r) -instance IRRep r => RenameE (TopLam r) -instance IRRep r => AlphaEqE (TopLam r) -instance IRRep r => AlphaHashableE (TopLam r) - -instance GenericE TopFun where - type RepE TopFun = EitherE - (TopFunDef `PairE` TopLam SimpIR `PairE` ComposeE EvalStatus TopFunLowerings) - (LiftE (String, IFunType)) - fromE = \case - DexTopFun def lam status -> LeftE (def `PairE` lam `PairE` ComposeE status) - FFITopFun name ty -> RightE (LiftE (name, ty)) - {-# INLINE fromE #-} - toE = \case - LeftE (def `PairE` lam `PairE` ComposeE status) -> DexTopFun def lam status - RightE (LiftE (name, ty)) -> FFITopFun name ty - {-# INLINE toE #-} - -instance SinkableE TopFun -instance HoistableE TopFun -instance RenameE TopFun -instance AlphaEqE TopFun -instance AlphaHashableE TopFun - -instance GenericE SpecializationSpec where - type RepE SpecializationSpec = - PairE (AtomVar CoreIR) (Abs (Nest (Binder CoreIR)) (ListE CAtom)) - fromE (AppSpecialization fname (Abs bs args)) = PairE fname (Abs bs args) - {-# INLINE fromE #-} - toE (PairE fname (Abs bs args)) = AppSpecialization fname (Abs bs args) - {-# INLINE toE #-} - -instance HasNameHint (SpecializationSpec n) where - getNameHint (AppSpecialization f _) = getNameHint f - -instance SinkableE SpecializationSpec -instance HoistableE SpecializationSpec -instance RenameE SpecializationSpec -instance AlphaEqE SpecializationSpec -instance AlphaHashableE SpecializationSpec - -instance GenericE LinearizationSpec where - type RepE LinearizationSpec = PairE TopFunName (LiftE [Active]) - fromE (LinearizationSpec fname actives) = PairE fname (LiftE actives) - {-# INLINE fromE #-} - toE (PairE fname (LiftE actives)) = LinearizationSpec fname actives - {-# INLINE toE #-} - -instance SinkableE LinearizationSpec -instance HoistableE LinearizationSpec -instance RenameE LinearizationSpec -instance AlphaEqE LinearizationSpec -instance AlphaHashableE LinearizationSpec - -instance GenericE SolverBinding where - type RepE SolverBinding = EitherE3 - CType - CType - CType - fromE = \case - InfVarBound ty -> Case0 ty - SkolemBound ty -> Case1 ty - DictBound ty -> Case2 ty - {-# INLINE fromE #-} - - toE = \case - Case0 ty -> InfVarBound ty - Case1 ty -> SkolemBound ty - Case2 ty -> DictBound ty - _ -> error "impossible" - {-# INLINE toE #-} - -instance SinkableE SolverBinding -instance HoistableE SolverBinding -instance RenameE SolverBinding -instance AlphaEqE SolverBinding -instance AlphaHashableE SolverBinding - -instance GenericE (Binding c) where - type RepE (Binding c) = - EitherE3 - (EitherE6 - (WhenAtomName c AtomBinding) - (WhenC TyConNameC c (MaybeE TyConDef `PairE` DotMethods)) - (WhenC DataConNameC c (TyConName `PairE` LiftE Int)) - (WhenC ClassNameC c (ClassDef)) - (WhenC InstanceNameC c (InstanceDef `PairE` CorePiType)) - (WhenC MethodNameC c (ClassName `PairE` LiftE Int))) - (EitherE4 - (WhenC TopFunNameC c (TopFun)) - (WhenC FunObjCodeNameC c (CFunction)) - (WhenC ModuleNameC c (Module)) - (WhenC PtrNameC c (LiftE (PtrType, PtrLitVal)))) - (EitherE2 - (WhenC SpecializedDictNameC c (SpecializedDictDef)) - (WhenC ImpNameC c (LiftE BaseType))) - - fromE = \case - AtomNameBinding binding -> Case0 $ Case0 $ WhenAtomName binding - TyConBinding dataDef methods -> Case0 $ Case1 $ WhenC $ toMaybeE dataDef `PairE` methods - DataConBinding dataDefName idx -> Case0 $ Case2 $ WhenC $ dataDefName `PairE` LiftE idx - ClassBinding classDef -> Case0 $ Case3 $ WhenC $ classDef - InstanceBinding instanceDef ty -> Case0 $ Case4 $ WhenC $ instanceDef `PairE` ty - MethodBinding className idx -> Case0 $ Case5 $ WhenC $ className `PairE` LiftE idx - TopFunBinding fun -> Case1 $ Case0 $ WhenC $ fun - FunObjCodeBinding cFun -> Case1 $ Case1 $ WhenC $ cFun - ModuleBinding m -> Case1 $ Case2 $ WhenC $ m - PtrBinding ty p -> Case1 $ Case3 $ WhenC $ LiftE (ty,p) - SpecializedDictBinding def -> Case2 $ Case0 $ WhenC $ def - ImpNameBinding ty -> Case2 $ Case1 $ WhenC $ LiftE ty - {-# INLINE fromE #-} - - toE = \case - Case0 (Case0 (WhenAtomName binding)) -> AtomNameBinding binding - Case0 (Case1 (WhenC (def `PairE` methods))) -> TyConBinding (fromMaybeE def) methods - Case0 (Case2 (WhenC (n `PairE` LiftE idx))) -> DataConBinding n idx - Case0 (Case3 (WhenC (classDef))) -> ClassBinding classDef - Case0 (Case4 (WhenC (instanceDef `PairE` ty))) -> InstanceBinding instanceDef ty - Case0 (Case5 (WhenC ((n `PairE` LiftE i)))) -> MethodBinding n i - Case1 (Case0 (WhenC (fun))) -> TopFunBinding fun - Case1 (Case1 (WhenC (f))) -> FunObjCodeBinding f - Case1 (Case2 (WhenC (m))) -> ModuleBinding m - Case1 (Case3 (WhenC ((LiftE (ty,p))))) -> PtrBinding ty p - Case2 (Case0 (WhenC (def))) -> SpecializedDictBinding def - Case2 (Case1 (WhenC ((LiftE ty)))) -> ImpNameBinding ty - _ -> error "impossible" - {-# INLINE toE #-} - -deriving via WrapE (Binding c) n instance Generic (Binding c n) -instance SinkableV Binding -instance HoistableV Binding -instance RenameV Binding -instance Color c => SinkableE (Binding c) -instance Color c => HoistableE (Binding c) -instance Color c => RenameE (Binding c) - instance GenericE DotMethods where type RepE DotMethods = ListE (LiftE SourceName `PairE` CAtomName) fromE (DotMethods xys) = ListE $ [LiftE x `PairE` y | (x, y) <- M.toList xys] @@ -2506,277 +1881,9 @@ instance IRRep r => BindsOneName (Decl r) (AtomNameC r) where binderName (Let b _) = binderName b {-# INLINE binderName #-} -instance Semigroup (SynthCandidates n) where - SynthCandidates xs ys <> SynthCandidates xs' ys' = - SynthCandidates (M.unionWith (<>) xs xs') (ys <> ys') - -instance Monoid (SynthCandidates n) where - mempty = SynthCandidates mempty mempty - -instance GenericB EnvFrag where - type RepB EnvFrag = RecSubstFrag Binding - fromB (EnvFrag frag) = frag - toB frag = EnvFrag frag - -instance SinkableB EnvFrag -instance HoistableB EnvFrag -instance ProvesExt EnvFrag -instance BindsNames EnvFrag -instance RenameB EnvFrag - -instance GenericE TopEnvUpdate where - type RepE TopEnvUpdate = EitherE2 ( - EitherE4 - {- ExtendCache -} Cache - {- AddCustomRule -} (CAtomName `PairE` AtomRules) - {- UpdateLoadedModules -} (LiftE ModuleSourceName `PairE` ModuleName) - {- UpdateLoadedObjects -} (FunObjCodeName `PairE` LiftE NativeFunction) - ) ( EitherE6 - {- FinishDictSpecialization -} (SpecDictName `PairE` ListE (TopLam SimpIR)) - {- LowerDictSpecialization -} (SpecDictName `PairE` ListE (TopLam SimpIR)) - {- UpdateTopFunEvalStatus -} (TopFunName `PairE` ComposeE EvalStatus TopFunLowerings) - {- UpdateInstanceDef -} (InstanceName `PairE` InstanceDef) - {- UpdateTyConDef -} (TyConName `PairE` TyConDef) - {- UpdateFieldDef -} (TyConName `PairE` LiftE SourceName `PairE` CAtomName) - ) - fromE = \case - ExtendCache x -> Case0 $ Case0 x - AddCustomRule x y -> Case0 $ Case1 (x `PairE` y) - UpdateLoadedModules x y -> Case0 $ Case2 (LiftE x `PairE` y) - UpdateLoadedObjects x y -> Case0 $ Case3 (x `PairE` LiftE y) - FinishDictSpecialization x y -> Case1 $ Case0 (x `PairE` ListE y) - LowerDictSpecialization x y -> Case1 $ Case1 (x `PairE` ListE y) - UpdateTopFunEvalStatus x y -> Case1 $ Case2 (x `PairE` ComposeE y) - UpdateInstanceDef x y -> Case1 $ Case3 (x `PairE` y) - UpdateTyConDef x y -> Case1 $ Case4 (x `PairE` y) - UpdateFieldDef x y z -> Case1 $ Case5 (x `PairE` LiftE y `PairE` z) - - toE = \case - Case0 e -> case e of - Case0 x -> ExtendCache x - Case1 (x `PairE` y) -> AddCustomRule x y - Case2 (LiftE x `PairE` y) -> UpdateLoadedModules x y - Case3 (x `PairE` LiftE y) -> UpdateLoadedObjects x y - _ -> error "impossible" - Case1 e -> case e of - Case0 (x `PairE` ListE y) -> FinishDictSpecialization x y - Case1 (x `PairE` ListE y) -> LowerDictSpecialization x y - Case2 (x `PairE` ComposeE y) -> UpdateTopFunEvalStatus x y - Case3 (x `PairE` y) -> UpdateInstanceDef x y - Case4 (x `PairE` y) -> UpdateTyConDef x y - Case5 (x `PairE` LiftE y `PairE` z) -> UpdateFieldDef x y z - _ -> error "impossible" - _ -> error "impossible" - -instance SinkableE TopEnvUpdate -instance HoistableE TopEnvUpdate -instance RenameE TopEnvUpdate - -instance GenericB TopEnvFrag where - type RepB TopEnvFrag = PairB EnvFrag (LiftB (ModuleEnv `PairE` ListE TopEnvUpdate)) - fromB (TopEnvFrag x y (ReversedList z)) = PairB x (LiftB (y `PairE` ListE z)) - toB (PairB x (LiftB (y `PairE` ListE z))) = TopEnvFrag x y (ReversedList z) - -instance RenameB TopEnvFrag -instance HoistableB TopEnvFrag -instance SinkableB TopEnvFrag -instance ProvesExt TopEnvFrag -instance BindsNames TopEnvFrag - -instance OutFrag TopEnvFrag where - emptyOutFrag = TopEnvFrag emptyOutFrag mempty mempty - {-# INLINE emptyOutFrag #-} - catOutFrags (TopEnvFrag frag1 env1 partial1) - (TopEnvFrag frag2 env2 partial2) = - withExtEvidence frag2 $ - TopEnvFrag - (catOutFrags frag1 frag2) - (sink env1 <> env2) - (sinkSnocList partial1 <> partial2) - {-# INLINE catOutFrags #-} - --- XXX: unlike `ExtOutMap Env EnvFrag` instance, this once doesn't --- extend the synthesis candidates based on the annotated let-bound names. It --- only extends synth candidates when they're supplied explicitly. -instance ExtOutMap Env TopEnvFrag where - extendOutMap env (TopEnvFrag (EnvFrag frag) mEnv' otherUpdates) = do - let newerTopEnv = foldl applyUpdate newTopEnv otherUpdates - Env newerTopEnv newModuleEnv - where - Env (TopEnv defs rules cache loadedM loadedO) mEnv = env - - newTopEnv = withExtEvidence frag $ TopEnv - (defs `extendRecSubst` frag) - (sink rules) (sink cache) (sink loadedM) (sink loadedO) - - newModuleEnv = - ModuleEnv - (imports <> imports') - (sm <> sm' <> newImportedSM) - (scs <> scs' <> newImportedSC) - where - ModuleEnv imports sm scs = withExtEvidence frag $ sink mEnv - ModuleEnv imports' sm' scs' = mEnv' - newDirectImports = S.difference (directImports imports') (directImports imports) - newTransImports = S.difference (transImports imports') (transImports imports) - newImportedSM = flip foldMap newDirectImports $ moduleExports . lookupModulePure - newImportedSC = flip foldMap newTransImports $ moduleSynthCandidates . lookupModulePure - - lookupModulePure v = case lookupEnvPure newTopEnv v of ModuleBinding m -> m - -applyUpdate :: TopEnv n -> TopEnvUpdate n -> TopEnv n -applyUpdate e = \case - ExtendCache cache -> e { envCache = envCache e <> cache} - AddCustomRule x y -> e { envCustomRules = envCustomRules e <> CustomRules (M.singleton x y)} - UpdateLoadedModules x y -> e { envLoadedModules = envLoadedModules e <> LoadedModules (M.singleton x y)} - UpdateLoadedObjects x y -> e { envLoadedObjects = envLoadedObjects e <> LoadedObjects (M.singleton x y)} - FinishDictSpecialization dName methods -> do - let SpecializedDictBinding (SpecializedDict dAbs oldMethods) = lookupEnvPure e dName - case oldMethods of - Nothing -> do - let newBinding = SpecializedDictBinding $ SpecializedDict dAbs (Just methods) - updateEnv dName newBinding e - Just _ -> error "shouldn't be adding methods if we already have them" - LowerDictSpecialization dName methods -> do - let SpecializedDictBinding (SpecializedDict dAbs _) = lookupEnvPure e dName - let newBinding = SpecializedDictBinding $ SpecializedDict dAbs (Just methods) - updateEnv dName newBinding e - UpdateTopFunEvalStatus f s -> do - case lookupEnvPure e f of - TopFunBinding (DexTopFun def lam _) -> - updateEnv f (TopFunBinding $ DexTopFun def lam s) e - _ -> error "can't update ffi function impl" - UpdateInstanceDef name def -> do - case lookupEnvPure e name of - InstanceBinding _ ty -> updateEnv name (InstanceBinding def ty) e - UpdateTyConDef name def -> do - let TyConBinding _ methods = lookupEnvPure e name - updateEnv name (TyConBinding (Just def) methods) e - UpdateFieldDef name sn x -> do - let TyConBinding def methods = lookupEnvPure e name - updateEnv name (TyConBinding def (methods <> DotMethods (M.singleton sn x))) e - -updateEnv :: Color c => Name c n -> Binding c n -> TopEnv n -> TopEnv n -updateEnv v rhs env = - env { envDefs = RecSubst $ updateSubstFrag v rhs bs } - where (RecSubst bs) = envDefs env - -lookupEnvPure :: Color c => TopEnv n -> Name c n -> Binding c n -lookupEnvPure env v = lookupTerminalSubstFrag (fromRecSubst $ envDefs $ env) v - -instance GenericE Module where - type RepE Module = LiftE ModuleSourceName - `PairE` ListE ModuleName - `PairE` ListE ModuleName - `PairE` SourceMap - `PairE` SynthCandidates - - fromE (Module name deps transDeps sm sc) = - LiftE name `PairE` ListE (S.toList deps) `PairE` ListE (S.toList transDeps) - `PairE` sm `PairE` sc - {-# INLINE fromE #-} - - toE (LiftE name `PairE` ListE deps `PairE` ListE transDeps - `PairE` sm `PairE` sc) = - Module name (S.fromList deps) (S.fromList transDeps) sm sc - {-# INLINE toE #-} - -instance SinkableE Module -instance HoistableE Module -instance AlphaEqE Module -instance AlphaHashableE Module -instance RenameE Module - -instance GenericE ImportStatus where - type RepE ImportStatus = ListE ModuleName `PairE` ListE ModuleName - fromE (ImportStatus direct trans) = ListE (S.toList direct) - `PairE` ListE (S.toList trans) - {-# INLINE fromE #-} - toE (ListE direct `PairE` ListE trans) = - ImportStatus (S.fromList direct) (S.fromList trans) - {-# INLINE toE #-} - -instance SinkableE ImportStatus -instance HoistableE ImportStatus -instance AlphaEqE ImportStatus -instance AlphaHashableE ImportStatus -instance RenameE ImportStatus - -instance Semigroup (ImportStatus n) where - ImportStatus direct trans <> ImportStatus direct' trans' = - ImportStatus (direct <> direct') (trans <> trans') - -instance Monoid (ImportStatus n) where - mappend = (<>) - mempty = ImportStatus mempty mempty - -instance GenericE LoadedModules where - type RepE LoadedModules = ListE (PairE (LiftE ModuleSourceName) ModuleName) - fromE (LoadedModules m) = - ListE $ M.toList m <&> \(v,md) -> PairE (LiftE v) md - {-# INLINE fromE #-} - toE (ListE pairs) = - LoadedModules $ M.fromList $ pairs <&> \(PairE (LiftE v) md) -> (v, md) - {-# INLINE toE #-} - -instance SinkableE LoadedModules -instance HoistableE LoadedModules -instance AlphaEqE LoadedModules -instance AlphaHashableE LoadedModules -instance RenameE LoadedModules - -instance GenericE LoadedObjects where - type RepE LoadedObjects = ListE (PairE FunObjCodeName (LiftE NativeFunction)) - fromE (LoadedObjects m) = - ListE $ M.toList m <&> \(v,p) -> PairE v (LiftE p) - {-# INLINE fromE #-} - toE (ListE pairs) = - LoadedObjects $ M.fromList $ pairs <&> \(PairE v (LiftE p)) -> (v, p) - {-# INLINE toE #-} - -instance SinkableE LoadedObjects -instance HoistableE LoadedObjects -instance RenameE LoadedObjects - -instance GenericE ModuleEnv where - type RepE ModuleEnv = ImportStatus - `PairE` SourceMap - `PairE` SynthCandidates - fromE (ModuleEnv imports sm sc) = imports `PairE` sm `PairE` sc - {-# INLINE fromE #-} - toE (imports `PairE` sm `PairE` sc) = ModuleEnv imports sm sc - {-# INLINE toE #-} - -instance SinkableE ModuleEnv -instance HoistableE ModuleEnv -instance AlphaEqE ModuleEnv -instance AlphaHashableE ModuleEnv -instance RenameE ModuleEnv - -instance Semigroup (ModuleEnv n) where - ModuleEnv x1 x2 x3 <> ModuleEnv y1 y2 y3 = - ModuleEnv (x1<>y1) (x2<>y2) (x3<>y3) - -instance Monoid (ModuleEnv n) where - mempty = ModuleEnv mempty mempty mempty - -instance Semigroup (LoadedModules n) where - LoadedModules m1 <> LoadedModules m2 = LoadedModules (m2 <> m1) - -instance Monoid (LoadedModules n) where - mempty = LoadedModules mempty - -instance Semigroup (LoadedObjects n) where - LoadedObjects m1 <> LoadedObjects m2 = LoadedObjects (m2 <> m1) - -instance Monoid (LoadedObjects n) where - mempty = LoadedObjects mempty - -instance Hashable InfVarDesc instance Hashable IxMethod instance Hashable ParamRole instance Hashable BuiltinClassName -instance Hashable a => Hashable (EvalStatus a) instance IRRep r => Store (MiscOp r n) instance IRRep r => Store (VectorOp r n) @@ -2791,24 +1898,18 @@ instance IRRep r => Store (Stuck r n) instance IRRep r => Store (Atom r n) instance IRRep r => Store (AtomVar r n) instance IRRep r => Store (Expr r n) -instance Store (SolverBinding n) -instance IRRep r => Store (AtomBinding r n) -instance Store (SpecializationSpec n) -instance Store (LinearizationSpec n) instance IRRep r => Store (DeclBinding r n) instance IRRep r => Store (Decl r n l) instance Store (TyConParams n) instance Store (DataConDefs n) instance Store (TyConDef n) instance Store (DataConDef n) -instance IRRep r => Store (TopLam r n) instance IRRep r => Store (LamExpr r n) instance IRRep r => Store (IxType r n) instance Store (CorePiType n) instance Store (CoreLamExpr n) instance IRRep r => Store (TabPiType r n) instance IRRep r => Store (DepPairType r n) -instance Store (AtomRules n) instance Store BuiltinClassName instance Store (ClassDef n) instance Store (InstanceDef n) @@ -2819,21 +1920,9 @@ instance Store (EffectDef n) instance Store (EffectOpDef n) instance Store (EffectOpType n) instance Store (EffectOpIdx) -instance Store (SynthCandidates n) -instance Store (Module n) -instance Store (ImportStatus n) -instance Store (TopFunLowerings n) -instance Store a => Store (EvalStatus a) -instance Store (TopFun n) -instance Store (TopFunDef n) -instance Color c => Store (Binding c n) -instance Store (ModuleEnv n) -instance Store (SerializedEnv n) instance Store (ann n) => Store (NonDepNest r ann n l) -instance Store InfVarDesc instance Store IxMethod instance Store ParamRole -instance Store (SpecializedDictDef n) instance IRRep r => Store (Dict r n) instance IRRep r => Store (TypedHof r n) instance IRRep r => Store (Hof r n) @@ -2843,3 +1932,366 @@ instance IRRep r => Store (DAMOp r n) instance Store (NewtypeCon n) instance Store (NewtypeTyCon n) instance Store (DotMethods n) + +-- === Pretty instances === + +instance IRRep r => Pretty (Hof r n) where pretty = prettyFromPrettyPrec +instance IRRep r => PrettyPrec (Hof r n) where + prettyPrec hof = atPrec LowestPrec case hof of + For _ _ lam -> "for" <+> pLowest lam + While body -> "while" <+> pArg body + RunReader x body -> "runReader" <+> pArg x <> nest 2 (line <> p body) + RunWriter _ bm body -> "runWriter" <+> pArg bm <> nest 2 (line <> p body) + RunState _ x body -> "runState" <+> pArg x <> nest 2 (line <> p body) + RunIO body -> "runIO" <+> pArg body + RunInit body -> "runInit" <+> pArg body + CatchException _ body -> "catchException" <+> pArg body + Linearize body x -> "linearize" <+> pArg body <+> pArg x + Transpose body x -> "transpose" <+> pArg body <+> pArg x + where + p :: Pretty a => a -> Doc ann + p = pretty + +instance IRRep r => Pretty (DAMOp r n) where pretty = prettyFromPrettyPrec +instance IRRep r => PrettyPrec (DAMOp r n) where + prettyPrec op = atPrec LowestPrec case op of + Seq _ ann _ c lamExpr -> case lamExpr of + UnaryLamExpr b body -> do + "seq" <+> pApp ann <+> pApp c <+> prettyLam (pretty b <> ".") body + _ -> pretty (show op) -- shouldn't happen, but crashing pretty printers make debugging hard + RememberDest _ x y -> "rememberDest" <+> pArg x <+> pArg y + Place r v -> pApp r <+> "r:=" <+> pApp v + Freeze r -> "freeze" <+> pApp r + AllocDest ty -> "alloc" <+> pApp ty + +instance IRRep r => Pretty (TyCon r n) where pretty = prettyFromPrettyPrec +instance IRRep r => PrettyPrec (TyCon r n) where + prettyPrec con = case con of + BaseType b -> prettyPrec b + ProdType [] -> atPrec ArgPrec $ "()" + ProdType as -> atPrec ArgPrec $ align $ group $ + encloseSep "(" ")" ", " $ fmap pApp as + SumType cs -> atPrec ArgPrec $ align $ group $ + encloseSep "(|" "|)" " | " $ fmap pApp cs + RefType h a -> atPrec AppPrec $ pAppArg "Ref" [h] <+> p a + TypeKind -> atPrec ArgPrec "Type" + HeapType -> atPrec ArgPrec "Heap" + Pi piType -> atPrec LowestPrec $ align $ p piType + TabPi piType -> atPrec LowestPrec $ align $ p piType + DepPairTy ty -> prettyPrec ty + DictTy t -> atPrec LowestPrec $ p t + NewtypeTyCon con' -> prettyPrec con' + where + p :: Pretty a => a -> Doc ann + p = pretty + +prettyPrecNewtype :: NewtypeCon n -> CAtom n -> DocPrec ann +prettyPrecNewtype con x = case (con, x) of + (NatCon, (IdxRepVal n)) -> atPrec ArgPrec $ pretty n + (_, x') -> prettyPrec x' + +instance Pretty (NewtypeTyCon n) where pretty = prettyFromPrettyPrec +instance PrettyPrec (NewtypeTyCon n) where + prettyPrec = \case + Nat -> atPrec ArgPrec $ "Nat" + Fin n -> atPrec AppPrec $ "Fin" <+> pArg n + EffectRowKind -> atPrec ArgPrec "EffKind" + UserADTType name _ (TyConParams infs params) -> case (infs, params) of + ([], []) -> atPrec ArgPrec $ pretty name + ([Explicit, Explicit], [l, r]) + | Just sym <- fromInfix (fromString $ pprint name) -> + atPrec ArgPrec $ align $ group $ + parens $ flatAlt " " "" <> pApp l <> line <> pretty sym <+> pApp r + _ -> atPrec LowestPrec $ pAppArg (pretty name) $ ignoreSynthParams (TyConParams infs params) + where + fromInfix :: Text -> Maybe Text + fromInfix t = do + ('(', t') <- uncons t + (t'', ')') <- unsnoc t' + return t'' + +instance IRRep r => Pretty (Con r n) where pretty = prettyFromPrettyPrec +instance IRRep r => PrettyPrec (Con r n) where + prettyPrec = \case + Lit l -> prettyPrec l + ProdCon [x] -> atPrec ArgPrec $ "(" <> pLowest x <> ",)" + ProdCon xs -> atPrec ArgPrec $ align $ group $ + encloseSep "(" ")" ", " $ fmap pLowest xs + SumCon _ tag payload -> atPrec ArgPrec $ + "(" <> p tag <> "|" <+> pApp payload <+> "|)" + HeapVal -> atPrec ArgPrec "HeapValue" + Lam lam -> atPrec LowestPrec $ p lam + DepPair x y _ -> atPrec ArgPrec $ align $ group $ + parens $ p x <+> ",>" <+> p y + Eff e -> atPrec ArgPrec $ p e + DictConAtom d -> atPrec LowestPrec $ p d + NewtypeCon con x -> prettyPrecNewtype con x + TyConAtom ty -> prettyPrec ty + where + p :: Pretty a => a -> Doc ann + p = pretty + +instance IRRep r => Pretty (PrimOp r n) where pretty = prettyFromPrettyPrec +instance IRRep r => PrettyPrec (PrimOp r n) where + prettyPrec = \case + MemOp op -> prettyPrec op + VectorOp op -> prettyPrec op + DAMOp op -> prettyPrec op + Hof (TypedHof _ hof) -> prettyPrec hof + RefOp ref eff -> atPrec LowestPrec case eff of + MAsk -> "ask" <+> pApp ref + MExtend _ x -> "extend" <+> pApp ref <+> pApp x + MGet -> "get" <+> pApp ref + MPut x -> pApp ref <+> ":=" <+> pApp x + IndexRef _ i -> pApp ref <+> "!" <+> pApp i + ProjRef _ i -> "proj_ref" <+> pApp ref <+> p i + UnOp op x -> prettyOpDefault (UUnOp op) [x] + BinOp op x y -> prettyOpDefault (UBinOp op) [x, y] + MiscOp op -> prettyOpGeneric op + where + p :: Pretty a => a -> Doc ann + p = pretty + +instance IRRep r => Pretty (MemOp r n) where pretty = prettyFromPrettyPrec +instance IRRep r => PrettyPrec (MemOp r n) where + prettyPrec = \case + PtrOffset ptr idx -> atPrec LowestPrec $ pApp ptr <+> "+>" <+> pApp idx + PtrLoad ptr -> atPrec AppPrec $ pAppArg "load" [ptr] + op -> prettyOpGeneric op + +instance IRRep r => Pretty (VectorOp r n) where pretty = prettyFromPrettyPrec +instance IRRep r => PrettyPrec (VectorOp r n) where + prettyPrec = \case + VectorBroadcast v vty -> atPrec LowestPrec $ "vbroadcast" <+> pApp v <+> pApp vty + VectorIota vty -> atPrec LowestPrec $ "viota" <+> pApp vty + VectorIdx tbl i vty -> atPrec LowestPrec $ "vslice" <+> pApp tbl <+> pApp i <+> pApp vty + VectorSubref ref i _ -> atPrec LowestPrec $ "vrefslice" <+> pApp ref <+> pApp i + +prettyOpGeneric :: (IRRep r, GenericOp op, Show (OpConst op r)) => op r n -> DocPrec ann +prettyOpGeneric op = case fromEGenericOpRep op of + GenericOpRep op' [] [] [] -> atPrec ArgPrec (pretty $ show op') + GenericOpRep op' ts xs lams -> atPrec AppPrec $ pAppArg (pretty (show op')) xs <+> pretty ts <+> pretty lams + +instance Pretty IxMethod where + pretty method = pretty $ show method + +instance Pretty (TyConParams n) where + pretty (TyConParams _ _) = undefined + +instance Pretty (TyConDef n) where + pretty (TyConDef name _ bs cons) = "data" <+> pretty name <+> pretty bs <> pretty cons + +instance Pretty (DataConDefs n) where + pretty = undefined + +instance Pretty (DataConDef n) where + pretty (DataConDef name _ repTy _) = pretty name <+> ":" <+> pretty repTy + +instance Pretty (ClassDef n) where + pretty (ClassDef classSourceName _ methodNames _ _ params superclasses methodTys) = + "Class:" <+> pretty classSourceName <+> pretty methodNames + <> indented ( + line <> "parameter binders:" <+> pretty params <> + line <> "superclasses:" <+> pretty superclasses <> + line <> "methods:" <+> pretty methodTys) + +instance Pretty ParamRole where + pretty r = pretty (show r) + +instance Pretty (InstanceDef n) where + pretty (InstanceDef className _ bs params _) = + "Instance" <+> pretty className <+> pretty bs <+> pretty params + +instance IRRep r => Pretty (Expr r n) where pretty = prettyFromPrettyPrec +instance IRRep r => PrettyPrec (Expr r n) where + prettyPrec = \case + Atom x -> prettyPrec x + Block _ (Abs decls body) -> atPrec AppPrec $ prettyBlock decls body + App _ f xs -> atPrec AppPrec $ pApp f <+> spaced (toList xs) + TopApp _ f xs -> atPrec AppPrec $ pApp f <+> spaced (toList xs) + TabApp _ f x -> atPrec AppPrec $ pApp f <> brackets (p x) + Case e alts (EffTy effs _) -> prettyPrecCase "case" e alts effs + TabCon _ _ es -> atPrec ArgPrec $ list $ pApp <$> es + PrimOp op -> prettyPrec op + ApplyMethod _ d i xs -> atPrec AppPrec $ "applyMethod" <+> p d <+> p i <+> p xs + Project _ i x -> atPrec AppPrec $ "Project" <+> p i <+> p x + Unwrap _ x -> atPrec AppPrec $ "Unwrap" <+> p x + where + p :: Pretty a => a -> Doc ann + p = pretty + +prettyPrecCase :: IRRep r => Doc ann -> Atom r n -> [Alt r n] -> EffectRow r n -> DocPrec ann +prettyPrecCase name e alts effs = atPrec LowestPrec $ + name <+> pApp e <+> "of" <> + nest 2 (foldMap (\alt -> hardline <> prettyAlt alt) alts + <> effectLine effs) + where + effectLine :: IRRep r => EffectRow r n -> Doc ann + effectLine Pure = "" + effectLine row = hardline <> "case annotated with effects" <+> pretty row + +prettyAlt :: IRRep r => Alt r n -> Doc ann +prettyAlt (Abs b body) = prettyBinderNoAnn b <+> "->" <> nest 2 (pretty body) + +prettyBinderNoAnn :: Binder r n l -> Doc ann +prettyBinderNoAnn (b:>_) = pretty b + +instance IRRep r => Pretty (DeclBinding r n) where + pretty (DeclBinding ann expr) = "Decl" <> pretty ann <+> pretty expr + +instance IRRep r => Pretty (Decl r n l) where + pretty (Let b (DeclBinding ann rhs)) = + align $ annDoc <> pretty b <+> "=" <> (nest 2 $ group $ line <> pLowest rhs) + where annDoc = case ann of NoInlineLet -> pretty ann <> " "; _ -> pretty ann + +instance IRRep r => Pretty (PiType r n) where + pretty (PiType bs (EffTy effs resultTy)) = + (spaced $ unsafeFromNest $ bs) <+> "->" <+> "{" <> pretty effs <> "}" <+> pretty resultTy + +instance IRRep r => Pretty (LamExpr r n) where pretty = prettyFromPrettyPrec +instance IRRep r => PrettyPrec (LamExpr r n) where + prettyPrec (LamExpr bs body) = atPrec LowestPrec $ prettyLam (pretty bs <> ".") body + +instance IRRep r => Pretty (IxType r n) where + pretty (IxType ty dict) = parens $ "IxType" <+> pretty ty <> prettyIxDict dict + +instance IRRep r => Pretty (Dict r n) where + pretty = \case + DictCon con -> pretty con + StuckDict _ stuck -> pretty stuck + +instance IRRep r => Pretty (DictCon r n) where + pretty = \case + InstanceDict _ name args -> "Instance" <+> pretty name <+> pretty args + IxFin n -> "Ix (Fin" <+> pretty n <> ")" + DataData a -> "Data " <+> pretty a + IxRawFin n -> "Ix (RawFin " <> pretty n <> ")" + IxSpecialized d xs -> pretty d <+> pretty xs + +instance Pretty (DictType n) where + pretty = \case + DictType classSourceName _ params -> pretty classSourceName <+> spaced params + IxDictType ty -> "Ix" <+> pretty ty + DataDictType ty -> "Data" <+> pretty ty + +instance IRRep r => Pretty (DepPairType r n) where pretty = prettyFromPrettyPrec +instance IRRep r => PrettyPrec (DepPairType r n) where + prettyPrec (DepPairType _ b rhs) = + atPrec ArgPrec $ align $ group $ parensSep (spaceIfColinear <> "&> ") [pretty b, pretty rhs] + +instance Pretty (CoreLamExpr n) where + pretty (CoreLamExpr _ lam) = pretty lam + +instance IRRep r => Pretty (Atom r n) where pretty = prettyFromPrettyPrec +instance IRRep r => PrettyPrec (Atom r n) where + prettyPrec atom = case atom of + Con e -> prettyPrec e + Stuck _ e -> prettyPrec e + +instance IRRep r => Pretty (Type r n) where pretty = prettyFromPrettyPrec +instance IRRep r => PrettyPrec (Type r n) where + prettyPrec = \case + TyCon e -> prettyPrec e + StuckTy _ e -> prettyPrec e + +instance IRRep r => Pretty (Stuck r n) where pretty = prettyFromPrettyPrec +instance IRRep r => PrettyPrec (Stuck r n) where + prettyPrec = \case + Var v -> atPrec ArgPrec $ p v + StuckProject i v -> atPrec LowestPrec $ "StuckProject" <+> p i <+> p v + StuckTabApp f xs -> atPrec AppPrec $ pArg f <> "." <> pArg xs + StuckUnwrap v -> atPrec LowestPrec $ "StuckUnwrap" <+> p v + InstantiatedGiven v args -> atPrec LowestPrec $ "Given" <+> p v <+> p (toList args) + SuperclassProj d' i -> atPrec LowestPrec $ "SuperclassProj" <+> p d' <+> p i + PtrVar _ v -> atPrec ArgPrec $ p v + RepValAtom x -> atPrec LowestPrec $ pretty x + ACase e alts _ -> atPrec AppPrec $ "acase" <+> p e <+> p alts + LiftSimp ty x -> atPrec ArgPrec $ "<embedded-simp-atom " <+> p x <+> " : " <+> p ty <+> ">" + LiftSimpFun ty x -> atPrec ArgPrec $ "<embedded-simp-function " <+> p x <+> " : " <+> p ty <+> ">" + TabLam lam -> atPrec AppPrec $ "tablam" <+> p lam + where + p :: Pretty a => a -> Doc ann + p = pretty + +instance PrettyPrec (AtomVar r n) where + prettyPrec (AtomVar v _) = prettyPrec v +instance Pretty (AtomVar r n) where pretty = prettyFromPrettyPrec + +instance IRRep r => Pretty (EffectRow r n) where + pretty (EffectRow effs t) = braces $ hsep (punctuate "," (map pretty (eSetToList effs))) <> pretty t + +instance IRRep r => Pretty (EffectRowTail r n) where + pretty = \case + NoTail -> mempty + EffectRowTail v -> "|" <> pretty v + +instance IRRep r => Pretty (Effect r n) where + pretty eff = case eff of + RWSEffect rws h -> pretty rws <+> pretty h + ExceptionEffect -> "Except" + IOEffect -> "IO" + InitEffect -> "Init" + +prettyLam :: Pretty a => Doc ann -> a -> Doc ann +prettyLam binders body = group $ group (nest 4 $ binders) <> group (nest 2 $ pretty body) + +instance IRRep r => Pretty (TabPiType r n) where + pretty (TabPiType dict (b:>ty) body) = let + prettyBody = case body of + TyCon (Pi subpi) -> pretty subpi + _ -> pLowest body + prettyBinder = prettyBinderHelper (b:>ty) body + in prettyBinder <> prettyIxDict dict <> (group $ line <> "=>" <+> prettyBody) + +-- A helper to let us turn dict printing on and off. We mostly want it off to +-- reduce clutter in prints and error messages, but when debugging synthesis we +-- want it on. +prettyIxDict :: IRRep r => IxDict r n -> Doc ann +prettyIxDict dict = if False then " " <> pretty dict else mempty + +prettyBinderHelper :: IRRep r => HoistableE e => Binder r n l -> e l -> Doc ann +prettyBinderHelper (b:>ty) body = + if binderName b `isFreeIn` body + then parens $ pretty (b:>ty) + else pretty ty + +instance Pretty (CorePiType n) where + pretty (CorePiType appExpl expls bs (EffTy eff resultTy)) = + prettyBindersWithExpl expls bs <+> pretty appExpl <> prettyEff <> pretty resultTy + where + prettyEff = case eff of + Pure -> space + _ -> space <> pretty eff <> space + +prettyBindersWithExpl :: forall b n l ann. PrettyB b + => [Explicitness] -> Nest b n l -> Doc ann +prettyBindersWithExpl expls bs = do + let groups = groupByExpl $ zip expls (unsafeFromNest bs) + let groups' = case groups of [] -> [(Explicit, [])] + _ -> groups + mconcat [withExplParens expl $ commaSep bsGroup | (expl, bsGroup) <- groups'] + +groupByExpl :: [(Explicitness, b UnsafeS UnsafeS)] -> [(Explicitness, [b UnsafeS UnsafeS])] +groupByExpl [] = [] +groupByExpl ((expl, b):bs) = do + let (matches, rest) = span (\(expl', _) -> expl == expl') bs + let matches' = map snd matches + (expl, b:matches') : groupByExpl rest + +withExplParens :: Explicitness -> Doc ann -> Doc ann +withExplParens Explicit x = parens x +withExplParens (Inferred _ Unify) x = braces $ x +withExplParens (Inferred _ (Synth _)) x = brackets x + +instance Pretty (RepVal n) where + pretty (RepVal ty tree) = "<RepVal " <+> pretty tree <+> ":" <+> pretty ty <> ">" + +prettyBlock :: (IRRep r, PrettyPrec (e l)) => Nest (Decl r) n l -> e l -> Doc ann +prettyBlock Empty expr = group $ line <> pLowest expr +prettyBlock decls expr = prettyLines decls' <> hardline <> pLowest expr + where decls' = unsafeFromNest decls + +instance IRRep r => Pretty (BaseMonoid r n) where pretty = prettyFromPrettyPrec +instance IRRep r => PrettyPrec (BaseMonoid r n) where + prettyPrec (BaseMonoid x f) = + atPrec LowestPrec $ "baseMonoid" <+> pArg x <> nest 2 (line <> pArg f) diff --git a/src/lib/Types/Imp.hs b/src/lib/Types/Imp.hs index d99d66c4..9006745c 100644 --- a/src/lib/Types/Imp.hs +++ b/src/lib/Types/Imp.hs @@ -27,11 +27,16 @@ import qualified Data.ByteString as BS import GHC.Generics (Generic (..)) import Data.Store (Store (..)) +import Data.String (fromString) +import Data.Text.Prettyprint.Doc (line', nest, group) import Name +import PPrint import Util (IsBool (..)) - import Types.Primitives +import Types.Source + +-- === data types === type ImpName = Name ImpNameC @@ -480,3 +485,91 @@ instance Store LinktimeVals instance Hashable IsCUDARequired instance Hashable CallingConvention instance Hashable IFunType + +instance Pretty CallingConvention where pretty = fromString . show + +instance Pretty (ImpFunction n) where + pretty (ImpFunction (IFunType cc _ _) (Abs bs body)) = + "impfun" <+> pretty cc <+> prettyBinderNest bs + <> nest 2 (hardline <> pretty body) <> hardline + +instance Pretty (ImpBlock n) where + pretty = \case + ImpBlock Empty [] -> mempty + ImpBlock Empty expr -> group $ hardline <> pLowest expr + ImpBlock decls [] -> prettyLines $ fromNest decls + ImpBlock decls expr -> prettyLines decls' <> hardline <> pLowest expr + where decls' = fromNest decls + +instance Pretty (IBinder n l) where + pretty (IBinder b ty) = pretty b <+> ":" <+> pretty ty + +instance Pretty (ImpInstr n) where + pretty = \case + IFor a n (Abs i block) -> forStr a <+> p i <+> "<" <+> p n <> + nest 4 (p block) + IWhile body -> "while" <+> nest 2 (p body) + ICond predicate cons alt -> + "if" <+> p predicate <+> "then" <> nest 2 (p cons) <> + hardline <> "else" <> nest 2 (p alt) + IQueryParallelism f s -> "queryParallelism" <+> p f <+> p s + ILaunch f s args -> "launch" <+> p f <+> p s <+> spaced args + ICastOp t x -> "cast" <+> p x <+> "to" <+> p t + IBitcastOp t x -> "bitcast" <+> p x <+> "to" <+> p t + Store dest val -> "store" <+> p dest <+> p val + Alloc _ t s -> "alloc" <+> p t <> "[" <> sizeStr s <> "]" + StackAlloc t s -> "alloca" <+> p t <> "[" <> sizeStr s <> "]" + MemCopy dest src numel -> "memcopy" <+> p dest <+> p src <+> p numel + InitializeZeros ptr numel -> "initializeZeros" <+> p ptr <+> p numel + GetAllocSize ptr -> "getAllocSize" <+> p ptr + Free ptr -> "free" <+> p ptr + ISyncWorkgroup -> "syncWorkgroup" + IThrowError -> "throwError" + ICall f args -> "call" <+> p f <+> p args + IVectorBroadcast v _ -> "vbroadcast" <+> p v + IVectorIota _ -> "viota" + DebugPrint s x -> "debug_print" <+> p (show s) <+> p x + IPtrLoad ptr -> "load" <+> p ptr + IPtrOffset ptr idx -> p ptr <+> "+>" <+> p idx + IBinOp op x y -> opDefault (UBinOp op) [x, y] + IUnOp op x -> opDefault (UUnOp op) [x] + ISelect x y z -> "select" <+> p x <+> p y <+> p z + IOutputStream -> "outputStream" + IShowScalar ptr x -> "show_scalar" <+> p ptr <+> p x + where opDefault name xs = prettyOpDefault name xs $ AppPrec + p :: Pretty a => a -> Doc ann + p = pretty + forStr :: ForAnn -> Doc ann + forStr = \case + Fwd -> "for" + Rev -> "rof" + +sizeStr :: IExpr n -> Doc ann +sizeStr s = case s of + ILit (Word32Lit x) -> pretty x -- print in decimal because it's more readable + _ -> pretty s + +instance Pretty (IExpr n) where + pretty = \case + ILit v -> pretty v + IVar v _ -> pretty v + IPtrVar v _ -> pretty v + +instance PrettyPrec (IExpr n) where prettyPrec = atPrec ArgPrec . pretty + +instance Pretty (ImpDecl n l) where + pretty = \case + ImpLet Empty instr -> pretty instr + ImpLet (Nest b Empty) instr -> pretty b <+> "=" <+> pretty instr + ImpLet bs instr -> pretty bs <+> "=" <+> pretty instr + +instance Pretty IFunType where + pretty (IFunType cc argTys retTys) = + "Fun" <+> pretty cc <+> pretty argTys <+> "->" <+> pretty retTys + +prettyBinderNest :: PrettyB b => Nest b n l -> Doc ann +prettyBinderNest bs = nest 6 $ line' <> (sep $ map pretty $ fromNest bs) + +fromNest :: Nest b n l -> [b UnsafeS UnsafeS] +fromNest Empty = [] +fromNest (Nest b rest) = unsafeCoerceB b : fromNest rest diff --git a/src/lib/Types/OpNames.hs b/src/lib/Types/OpNames.hs index 178936ec..344329ac 100644 --- a/src/lib/Types/OpNames.hs +++ b/src/lib/Types/OpNames.hs @@ -14,6 +14,8 @@ import Data.Hashable import GHC.Generics (Generic (..)) import Data.Store (Store (..)) +import PPrint + data TC = ProdType | SumType | RefType | TypeKind | HeapType data Con = ProdCon | SumCon Int | HeapVal @@ -117,3 +119,8 @@ deriving instance Eq (Hof r) deriving instance Eq DAMOp deriving instance Eq RefOp deriving instance Eq UserEffectOp + +instance Pretty Projection where + pretty = \case + UnwrapNewtype -> "u" + ProjectProduct i -> pretty i diff --git a/src/lib/Types/Primitives.hs b/src/lib/Types/Primitives.hs index 8096f7e6..f449acba 100644 --- a/src/lib/Types/Primitives.hs +++ b/src/lib/Types/Primitives.hs @@ -28,12 +28,14 @@ import Data.String (IsString (..)) import Data.Word import Data.Hashable import Data.Store (Store (..)) -import Data.Text.Prettyprint.Doc (Pretty (..)) import qualified Data.Store.Internal as SI import Foreign.Ptr +import Numeric +import GHC.Float import GHC.Generics (Generic (..)) +import PPrint import Occurrence import Types.OpNames (UnOp (..), BinOp (..), CmpOp (..), Projection (..)) import Name @@ -222,3 +224,75 @@ instance Hashable AppExplicitness instance Hashable DepPairExplicitness instance Hashable InferenceMechanism instance Hashable RequiredMethodAccess + +-- === Pretty instances === + +instance Pretty AppExplicitness where + pretty ExplicitApp = "->" + pretty ImplicitApp = "->>" + +instance Pretty RWS where + pretty eff = case eff of + Reader -> "Read" + Writer -> "Accum" + State -> "State" + +instance Pretty LetAnn where + pretty ann = case ann of + PlainLet -> "" + InlineLet -> "%inline" + NoInlineLet -> "%noinline" + LinearLet -> "%linear" + OccInfoPure u -> pretty u <> hardline + OccInfoImpure u -> pretty u <> ", impure" <> hardline + +instance PrettyPrec Direction where + prettyPrec d = atPrec ArgPrec $ case d of + Fwd -> "fwd" + Rev -> "rev" + +printDouble :: Double -> Doc ann +printDouble x = pretty (double2Float x) + +printFloat :: Float -> Doc ann +printFloat x = pretty $ reverse $ dropWhile (=='0') $ reverse $ + showFFloat (Just 6) x "" + +instance Pretty LitVal where pretty = prettyFromPrettyPrec +instance PrettyPrec LitVal where + prettyPrec = \case + Int64Lit x -> atPrec ArgPrec $ p x + Int32Lit x -> atPrec ArgPrec $ p x + Float64Lit x -> atPrec ArgPrec $ printDouble x + Float32Lit x -> atPrec ArgPrec $ printFloat x + Word8Lit x -> atPrec ArgPrec $ p $ show $ toEnum @Char $ fromIntegral x + Word32Lit x -> atPrec ArgPrec $ p $ "0x" ++ showHex x "" + Word64Lit x -> atPrec ArgPrec $ p $ "0x" ++ showHex x "" + PtrLit ty (PtrLitVal x) -> atPrec ArgPrec $ "Ptr" <+> p ty <+> p (show x) + PtrLit _ NullPtr -> atPrec ArgPrec $ "NullPtr" + PtrLit _ (PtrSnapshot _) -> atPrec ArgPrec "<ptr snapshot>" + where p :: Pretty a => a -> Doc ann + p = pretty + +instance Pretty Device where pretty = fromString . show + +instance Pretty BaseType where pretty = prettyFromPrettyPrec +instance PrettyPrec BaseType where + prettyPrec b = case b of + Scalar sb -> prettyPrec sb + Vector shape sb -> atPrec ArgPrec $ encloseSep "<" ">" "x" $ (pretty <$> shape) ++ [pretty sb] + PtrType ty -> atPrec AppPrec $ "Ptr" <+> pretty ty + +instance Pretty ScalarBaseType where pretty = prettyFromPrettyPrec +instance PrettyPrec ScalarBaseType where + prettyPrec sb = atPrec ArgPrec $ case sb of + Int64Type -> "Int64" + Int32Type -> "Int32" + Float64Type -> "Float64" + Float32Type -> "Float32" + Word8Type -> "Word8" + Word32Type -> "Word32" + Word64Type -> "Word64" + +instance Pretty Explicitness where + pretty expl = pretty (show expl) diff --git a/src/lib/Types/Source.hs b/src/lib/Types/Source.hs index 4249b9a9..b43b81a4 100644 --- a/src/lib/Types/Source.hs +++ b/src/lib/Types/Source.hs @@ -26,13 +26,17 @@ import Data.Foldable import qualified Data.Map.Strict as M import qualified Data.Set as S import Data.Text (Text) -import Data.Text.Prettyprint.Doc (Pretty (..), hardline, (<+>)) import Data.Word +import Data.Text.Prettyprint.Doc (vcat, line, group, parens, nest, align, punctuate, hsep) +import Data.Text (snoc, unsnoc) +import Data.Tuple (swap) import GHC.Generics (Generic (..)) import Data.Store (Store (..)) +import Data.String (fromString) import Err +import PPrint import Name import qualified Types.OpNames as P import IRVariants @@ -650,6 +654,102 @@ data PrimName = | UTuple -- overloaded for type constructor and data constructor, resolved in inference deriving (Show, Eq, Generic) +-- === primitive constructors and operators === + +strToPrimName :: String -> Maybe PrimName +strToPrimName s = M.lookup s primNames + +primNameToStr :: PrimName -> String +primNameToStr prim = case lookup prim $ map swap $ M.toList primNames of + Just s -> s + Nothing -> show prim + +showPrimName :: PrimName -> String +showPrimName prim = primNameToStr prim +{-# NOINLINE showPrimName #-} + +primNames :: M.Map String PrimName +primNames = M.fromList + [ ("ask" , UMAsk), ("mextend", UMExtend) + , ("get" , UMGet), ("put" , UMPut) + , ("while" , UWhile) + , ("linearize", ULinearize), ("linearTranspose", UTranspose) + , ("runReader", URunReader), ("runWriter" , URunWriter), ("runState", URunState) + , ("runIO" , URunIO ), ("catchException" , UCatchException) + , ("iadd" , binary IAdd), ("isub" , binary ISub) + , ("imul" , binary IMul), ("fdiv" , binary FDiv) + , ("fadd" , binary FAdd), ("fsub" , binary FSub) + , ("fmul" , binary FMul), ("idiv" , binary IDiv) + , ("irem" , binary IRem) + , ("fpow" , binary FPow) + , ("and" , binary BAnd), ("or" , binary BOr ) + , ("not" , unary BNot), ("xor" , binary BXor) + , ("shl" , binary BShL), ("shr" , binary BShR) + , ("ieq" , binary (ICmp Equal)), ("feq", binary (FCmp Equal)) + , ("igt" , binary (ICmp Greater)), ("fgt", binary (FCmp Greater)) + , ("ilt" , binary (ICmp Less)), ("flt", binary (FCmp Less)) + , ("fneg" , unary FNeg) + , ("exp" , unary Exp), ("exp2" , unary Exp2) + , ("log" , unary Log), ("log2" , unary Log2), ("log10" , unary Log10) + , ("sin" , unary Sin), ("cos" , unary Cos) + , ("tan" , unary Tan), ("sqrt" , unary Sqrt) + , ("floor", unary Floor), ("ceil" , unary Ceil), ("round", unary Round) + , ("log1p", unary Log1p), ("lgamma", unary LGamma) + , ("erf" , unary Erf), ("erfc" , unary Erfc) + , ("TyKind" , UPrimTC $ P.TypeKind) + , ("Float64" , baseTy $ Scalar Float64Type) + , ("Float32" , baseTy $ Scalar Float32Type) + , ("Int64" , baseTy $ Scalar Int64Type) + , ("Int32" , baseTy $ Scalar Int32Type) + , ("Word8" , baseTy $ Scalar Word8Type) + , ("Word32" , baseTy $ Scalar Word32Type) + , ("Word64" , baseTy $ Scalar Word64Type) + , ("Int32Ptr" , baseTy $ ptrTy $ Scalar Int32Type) + , ("Word8Ptr" , baseTy $ ptrTy $ Scalar Word8Type) + , ("Word32Ptr" , baseTy $ ptrTy $ Scalar Word32Type) + , ("Word64Ptr" , baseTy $ ptrTy $ Scalar Word64Type) + , ("Float32Ptr", baseTy $ ptrTy $ Scalar Float32Type) + , ("PtrPtr" , baseTy $ ptrTy $ ptrTy $ Scalar Word8Type) + , ("Nat" , UNat) + , ("Fin" , UFin) + , ("EffKind" , UEffectRowKind) + , ("NatCon" , UNatCon) + , ("Ref" , UPrimTC $ P.RefType) + , ("HeapType" , UPrimTC $ P.HeapType) + , ("indexRef" , UIndexRef) + , ("alloc" , memOp $ P.IOAlloc) + , ("free" , memOp $ P.IOFree) + , ("ptrOffset", memOp $ P.PtrOffset) + , ("ptrLoad" , memOp $ P.PtrLoad) + , ("ptrStore" , memOp $ P.PtrStore) + , ("throwError" , miscOp $ P.ThrowError) + , ("throwException", miscOp $ P.ThrowException) + , ("dataConTag" , miscOp $ P.SumTag) + , ("toEnum" , miscOp $ P.ToEnum) + , ("outputStream" , miscOp $ P.OutputStream) + , ("cast" , miscOp $ P.CastOp) + , ("bitcast" , miscOp $ P.BitcastOp) + , ("unsafeCoerce" , miscOp $ P.UnsafeCoerce) + , ("garbageVal" , miscOp $ P.GarbageVal) + , ("select" , miscOp $ P.Select) + , ("showAny" , miscOp $ P.ShowAny) + , ("showScalar" , miscOp $ P.ShowScalar) + , ("projNewtype" , UProjNewtype) + , ("applyMethod0" , UApplyMethod 0) + , ("applyMethod1" , UApplyMethod 1) + , ("applyMethod2" , UApplyMethod 2) + , ("explicitApply", UExplicitApply) + , ("monoLit", UMonoLiteral) + ] + where + binary op = UBinOp op + baseTy b = UBaseType b + memOp op = UMemOp op + unary op = UUnOp op + ptrTy ty = PtrType (CPU, ty) + miscOp op = UMiscOp op + + -- === instances === instance Semigroup (SourceMap n) where @@ -862,3 +962,265 @@ deriving instance Ord (UEffectRow n) instance ToJSON SrcId deriving instance ToJSONKey SrcId instance ToJSON LexemeType + +-- === Pretty instances === + + + +instance Pretty CSBlock where + pretty (IndentedBlock _ decls) = nest 2 $ prettyLines decls + pretty (ExprBlock g) = pArg g + +instance Pretty Group where pretty = prettyFromPrettyPrec +instance PrettyPrec Group where + prettyPrec = undefined + -- prettyPrec (CIdentifier n) = atPrec ArgPrec $ fromString n + -- prettyPrec (CPrim prim args) = prettyOpDefault prim args + -- prettyPrec (CParens blk) = + -- atPrec ArgPrec $ "(" <> p blk <> ")" + -- prettyPrec (CBrackets g) = atPrec ArgPrec $ pretty g + -- prettyPrec (CBin op lhs rhs) = + -- atPrec LowestPrec $ pArg lhs <+> p op <+> pArg rhs + -- prettyPrec (CLambda args body) = + -- atPrec LowestPrec $ "\\" <> spaced args <> "." <> p body + -- prettyPrec (CCase scrut alts) = + -- atPrec LowestPrec $ "case " <> p scrut <> " of " <> prettyLines alts + -- prettyPrec g = atPrec ArgPrec $ fromString $ show g + +instance Pretty Bin where + pretty = \case + EvalBinOp name -> pretty name + DepAmpersand -> "&>" + Dot -> "." + DepComma -> ",>" + Colon -> ":" + DoubleColon -> "::" + Dollar -> "$" + ImplicitArrow -> "->>" + FatArrow -> "->>" + Pipe -> "|" + CSEqual -> "=" + +instance Pretty SourceBlock' where + pretty (TopDecl decl) = pretty decl + pretty d = fromString $ show d + +instance Pretty CTopDecl where + pretty (CSDecl ann decl) = annDoc <> pretty decl + where annDoc = case ann of + PlainLet -> mempty + _ -> pretty ann <> " " + pretty d = fromString $ show d + +instance Pretty CSDecl where + pretty = undefined + -- pretty (CLet pat blk) = pArg pat <+> "=" <+> p blk + -- pretty (CBind pat blk) = pArg pat <+> "<-" <+> p blk + -- pretty (CDefDecl (CDef name args maybeAnn blk)) = + -- "def " <> fromString name <> " " <> prettyParamGroups args <+> annDoc + -- <> nest 2 (hardline <> p blk) + -- where annDoc = case maybeAnn of Just (expl, ty) -> p expl <+> pArg ty + -- Nothing -> mempty + -- pretty (CInstance header givens methods name) = + -- name' <> p header <> p givens <> nest 2 (hardline <> p methods) where + -- name' = case name of + -- Nothing -> "instance " + -- (Just n) -> "named-instance " <> p n <> " " + -- pretty (CExpr e) = p e + +instance Pretty PrimName where + pretty primName = pretty $ "%" ++ showPrimName primName + +instance Pretty (UDataDefTrail n) where + pretty (UDataDefTrail bs) = pretty $ unsafeFromNest bs + +instance Pretty (UAnnBinder n l) where + pretty (UAnnBinder _ b ty _) = pretty b <> ":" <> pretty ty + +instance Pretty (UAnn n) where + pretty (UAnn ty) = ":" <> pretty ty + pretty UNoAnn = mempty + +instance Pretty (UMethodDef' n) where + pretty (UMethodDef b rhs) = pretty b <+> "=" <+> pretty rhs + +instance Pretty (UPat' n l) where pretty = prettyFromPrettyPrec +instance PrettyPrec (UPat' n l) where + prettyPrec pat = case pat of + UPatBinder x -> atPrec ArgPrec $ p x + UPatProd xs -> atPrec ArgPrec $ parens $ commaSep (unsafeFromNest xs) + UPatDepPair (PairB x y) -> atPrec ArgPrec $ parens $ p x <> ",> " <> p y + UPatCon con pats -> atPrec AppPrec $ parens $ p con <+> spaced (unsafeFromNest pats) + UPatTable pats -> atPrec ArgPrec $ p pats + where + p :: Pretty a => a -> Doc ann + p = pretty + +instance Pretty (UAlt n) where + pretty (UAlt pat body) = pretty pat <+> "->" <+> pretty body + +instance Pretty (UTopDecl n l) where + pretty = \case + UDataDefDecl (UDataDef nm bs dataCons) bTyCon bDataCons -> + "data" <+> p bTyCon <+> p nm <+> spaced (unsafeFromNest bs) <+> "where" <> nest 2 + (prettyLines (zip (toList $ unsafeFromNest bDataCons) dataCons)) + UStructDecl bTyCon (UStructDef nm bs fields defs) -> + "struct" <+> p bTyCon <+> p nm <+> spaced (unsafeFromNest bs) <+> "where" <> nest 2 + (prettyLines fields <> prettyLines defs) + UInterface params methodTys interfaceName methodNames -> + "interface" <+> p params <+> p interfaceName + <> hardline <> foldMap (<>hardline) methods + where + methods = [ p b <> ":" <> p (unsafeCoerceE ty) + | (b, ty) <- zip (toList $ unsafeFromNest methodNames) methodTys] + UInstance className bs params methods (RightB UnitB) _ -> + "instance" <+> p bs <+> p className <+> spaced params <+> + prettyLines methods + UInstance className bs params methods (LeftB v) _ -> + "named-instance" <+> p v <+> ":" <+> p bs <+> p className <+> p params + <> prettyLines methods + ULocalDecl decl -> p decl + where + p :: Pretty a => a -> Doc ann + p = pretty + +instance Pretty (UDecl' n l) where + pretty = \case + ULet ann b _ rhs -> align $ pretty ann <+> pretty b <+> "=" <> (nest 2 $ group $ line <> pLowest rhs) + UExprDecl expr -> pretty expr + UPass -> "pass" + +instance Pretty (UEffectRow n) where + pretty (UEffectRow x Nothing) = encloseSep "<" ">" "," $ (pretty <$> toList x) + pretty (UEffectRow x (Just y)) = "{" <> (hsep $ punctuate "," (pretty <$> toList x)) <+> "|" <+> pretty y <> "}" + +instance Pretty e => Pretty (WithSrcs e) where pretty (WithSrcs _ _ x) = pretty x +instance PrettyPrec e => PrettyPrec (WithSrcs e) where prettyPrec (WithSrcs _ _ x) = prettyPrec x + +instance Pretty e => Pretty (WithSrc e) where pretty (WithSrc _ x) = pretty x +instance PrettyPrec e => PrettyPrec (WithSrc e) where prettyPrec (WithSrc _ x) = prettyPrec x + +instance PrettyE e => Pretty (WithSrcE e n) where pretty (WithSrcE _ x) = pretty x +instance PrettyPrecE e => PrettyPrec (WithSrcE e n) where prettyPrec (WithSrcE _ x) = prettyPrec x + +instance PrettyB b => Pretty (WithSrcB b n l) where pretty (WithSrcB _ x) = pretty x +instance PrettyPrecB b => PrettyPrec (WithSrcB b n l) where prettyPrec (WithSrcB _ x) = prettyPrec x + +instance PrettyE e => Pretty (SourceNameOr e n) where + pretty (SourceName _ v) = pretty v + pretty (InternalName _ v _) = pretty v + +instance Pretty (SourceOrInternalName c n) where + pretty (SourceOrInternalName sn) = pretty sn + +instance Pretty (ULamExpr n) where pretty = prettyFromPrettyPrec +instance PrettyPrec (ULamExpr n) where + prettyPrec (ULamExpr bs _ _ _ body) = atPrec LowestPrec $ + "\\" <> pretty bs <+> "." <+> indented (pretty body) + +instance Pretty (UPiExpr n) where pretty = prettyFromPrettyPrec +instance PrettyPrec (UPiExpr n) where + prettyPrec (UPiExpr pats appExpl UPure ty) = atPrec LowestPrec $ align $ + pretty pats <+> pretty appExpl <+> pLowest ty + prettyPrec (UPiExpr pats appExpl eff ty) = atPrec LowestPrec $ align $ + pretty pats <+> pretty appExpl <+> pretty eff <+> pLowest ty + +instance Pretty (UTabPiExpr n) where pretty = prettyFromPrettyPrec +instance PrettyPrec (UTabPiExpr n) where + prettyPrec (UTabPiExpr pat ty) = atPrec LowestPrec $ align $ + pretty pat <+> "=>" <+> pLowest ty + +instance Pretty (UDepPairType n) where pretty = prettyFromPrettyPrec +instance PrettyPrec (UDepPairType n) where + -- TODO: print explicitness info + prettyPrec (UDepPairType _ pat ty) = atPrec LowestPrec $ align $ + pretty pat <+> "&>" <+> pLowest ty + +instance Pretty (UBlock' n) where + pretty (UBlock decls result) = + prettyLines (unsafeFromNest decls) <> hardline <> pLowest result + +instance Pretty (UExpr' n) where pretty = prettyFromPrettyPrec +instance PrettyPrec (UExpr' n) where + prettyPrec expr = case expr of + ULit l -> prettyPrec l + UVar v -> atPrec ArgPrec $ p v + ULam lam -> prettyPrec lam + UApp f xs named -> atPrec AppPrec $ pAppArg (pApp f) xs <+> p named + UTabApp f x -> atPrec AppPrec $ pArg f <> "." <> pArg x + UFor dir (UForExpr binder body) -> + atPrec LowestPrec $ kw <+> p binder <> "." + <+> nest 2 (p body) + where kw = case dir of Fwd -> "for" + Rev -> "rof" + UPi piType -> prettyPrec piType + UTabPi piType -> prettyPrec piType + UDepPairTy depPairType -> prettyPrec depPairType + UDepPair lhs rhs -> atPrec ArgPrec $ parens $ + p lhs <+> ",>" <+> p rhs + UHole -> atPrec ArgPrec "_" + UTypeAnn v ty -> atPrec LowestPrec $ + group $ pApp v <> line <> ":" <+> pApp ty + UTabCon xs -> atPrec ArgPrec $ p xs + UPrim prim xs -> atPrec AppPrec $ p (show prim) <+> p xs + UCase e alts -> atPrec LowestPrec $ "case" <+> p e <> + nest 2 (prettyLines alts) + UFieldAccess x (WithSrc _ f) -> atPrec AppPrec $ p x <> "~" <> p f + UNatLit v -> atPrec ArgPrec $ p v + UIntLit v -> atPrec ArgPrec $ p v + UFloatLit v -> atPrec ArgPrec $ p v + UDo block -> atPrec LowestPrec $ p block + where + p :: Pretty a => a -> Doc ann + p = pretty + +instance Pretty SourceBlock where + pretty block = pretty $ ensureNewline (sbText block) where + -- Force the SourceBlock to end in a newline for echoing, even if + -- it was terminated with EOF in the original program. + ensureNewline t = case unsnoc t of + Nothing -> t + Just (_, '\n') -> t + _ -> t `snoc` '\n' + +instance Pretty Output where + pretty = \case + TextOut s -> pretty s + HtmlOut _ -> "<html output>" + SourceInfo _ -> "" + PassInfo _ s -> pretty s + MiscLog s -> pretty s + Error e -> pretty e + +instance Pretty PassName where + pretty x = pretty $ show x + +instance Pretty Result where + pretty (Result (Outputs outs) r) = vcat (map pretty outs) <> maybeErr + where maybeErr = case r of Failure err -> pretty err + Success () -> mempty + +instance Pretty (UBinder' c n l) where pretty = prettyFromPrettyPrec +instance PrettyPrec (UBinder' c n l) where + prettyPrec b = atPrec ArgPrec case b of + UBindSource v -> pretty v + UIgnore -> "_" + UBind v _ -> pretty v + +instance Pretty FieldName' where + pretty = \case + FieldName s -> pretty s + FieldNum n -> pretty n + +instance Pretty (UEffect n) where + pretty eff = case eff of + URWSEffect rws h -> pretty rws <+> pretty h + UExceptionEffect -> "Except" + UIOEffect -> "IO" + +prettyOpDefault :: PrettyPrec a => PrimName -> [a] -> DocPrec ann +prettyOpDefault name args = + case length args of + 0 -> atPrec ArgPrec primName + _ -> atPrec AppPrec $ pAppArg primName args + where primName = pretty name diff --git a/src/lib/Types/Top.hs b/src/lib/Types/Top.hs new file mode 100644 index 00000000..fba64b0e --- /dev/null +++ b/src/lib/Types/Top.hs @@ -0,0 +1,1046 @@ +-- Copyright 2022 Google LLC +-- +-- Use of this source code is governed by a BSD-style +-- license that can be found in the LICENSE file or at +-- https://developers.google.com/open-source/licenses/bsd + +{-# LANGUAGE UndecidableInstances #-} +{-# LANGUAGE StrictData #-} + +-- Top-level data types + +module Types.Top where + +import Data.Functor ((<&>)) +import Data.Hashable +import Data.Text.Prettyprint.Doc +import qualified Data.Map.Strict as M +import qualified Data.Set as S + +import GHC.Generics (Generic (..)) +import Data.Store (Store (..)) +import Foreign.Ptr + +import Name +import Util (FileHash, SnocList (..)) +import IRVariants +import PPrint + +import Types.Primitives +import Types.Core +import Types.Source +import Types.Imp + +type TopBlock = TopLam -- used for nullary lambda +type IsDestLam = Bool +data TopLam (r::IR) (n::S) = TopLam IsDestLam (PiType r n) (LamExpr r n) + deriving (Show, Generic) +type STopLam = TopLam SimpIR +type CTopLam = TopLam CoreIR + +data EvalStatus a = Waiting | Running | Finished a + deriving (Show, Eq, Ord, Generic, Functor, Foldable, Traversable) +type TopFunEvalStatus n = EvalStatus (TopFunLowerings n) + +data TopFun (n::S) = + DexTopFun (TopFunDef n) (TopLam SimpIR n) (TopFunEvalStatus n) + | FFITopFun String IFunType + deriving (Show, Generic) + +data TopFunDef (n::S) = + Specialization (SpecializationSpec n) + | LinearizationPrimal (LinearizationSpec n) + -- Tangent functions all take some number of nonlinear args, then a *single* + -- linear arg. This is so that transposition can be an involution - you apply + -- it twice and you get back to the original function. + | LinearizationTangent (LinearizationSpec n) + deriving (Show, Generic) + +newtype TopFunLowerings (n::S) = TopFunLowerings + { topFunObjCode :: FunObjCodeName n } -- TODO: add optimized, imp etc. as needed + deriving (Show, Generic, SinkableE, HoistableE, RenameE, AlphaEqE, AlphaHashableE, Pretty) + +data AtomBinding (r::IR) (n::S) where + LetBound :: DeclBinding r n -> AtomBinding r n + MiscBound :: Type r n -> AtomBinding r n + TopDataBound :: RepVal n -> AtomBinding SimpIR n + SolverBound :: SolverBinding n -> AtomBinding CoreIR n + NoinlineFun :: CType n -> CAtom n -> AtomBinding CoreIR n + FFIFunBound :: CorePiType n -> TopFunName n -> AtomBinding CoreIR n + +deriving instance IRRep r => Show (AtomBinding r n) +deriving via WrapE (AtomBinding r) n instance IRRep r => Generic (AtomBinding r n) + +-- name of function, name of arg +type InferenceArgDesc = (String, String) +data InfVarDesc = + ImplicitArgInfVar InferenceArgDesc + | AnnotationInfVar String -- name of binder + | TypeInstantiationInfVar String -- name of type + | MiscInfVar + deriving (Show, Generic, Eq, Ord) + +data SolverBinding (n::S) = + InfVarBound (CType n) + | SkolemBound (CType n) + | DictBound (CType n) + deriving (Show, Generic) + +-- TODO: Use an IntMap +newtype CustomRules (n::S) = + CustomRules { customRulesMap :: M.Map (AtomName CoreIR n) (AtomRules n) } + deriving (Semigroup, Monoid, Store) +data AtomRules (n::S) = + -- number of implicit args, number of explicit args, linearization function + CustomLinearize Int Int SymbolicZeros (CAtom n) + deriving (Generic) + +-- === envs and modules === + +-- `ModuleEnv` contains data that only makes sense in the context of evaluating +-- a particular module. `TopEnv` contains everything that makes sense "between" +-- evaluating modules. +data Env n = Env + { topEnv :: {-# UNPACK #-} TopEnv n + , moduleEnv :: {-# UNPACK #-} ModuleEnv n } + deriving (Generic) + +newtype EnvFrag (n::S) (l::S) = EnvFrag (RecSubstFrag Binding n l) + deriving (OutFrag) + +data TopEnv (n::S) = TopEnv + { envDefs :: RecSubst Binding n + , envCustomRules :: CustomRules n + , envCache :: Cache n + , envLoadedModules :: LoadedModules n + , envLoadedObjects :: LoadedObjects n } + deriving (Generic) + +data SerializedEnv n = SerializedEnv + { serializedEnvDefs :: RecSubst Binding n + , serializedEnvCustomRules :: CustomRules n + , serializedEnvCache :: Cache n } + deriving (Generic) + +-- TODO: consider splitting this further into `ModuleEnv` (the env that's +-- relevant between top-level decls) and `LocalEnv` (the additional parts of the +-- env that's relevant under a lambda binder). Unlike the Top/Module +-- distinction, there's some overlap. For example, instances can be defined at +-- both the module-level and local level. Similarly, if we start allowing +-- top-level effects in `Main` then we'll have module-level effects and local +-- effects. +data ModuleEnv (n::S) = ModuleEnv + { envImportStatus :: ImportStatus n + , envSourceMap :: SourceMap n + , envSynthCandidates :: SynthCandidates n } + deriving (Generic) + +data Module (n::S) = Module + { moduleSourceName :: ModuleSourceName + , moduleDirectDeps :: S.Set (ModuleName n) + , moduleTransDeps :: S.Set (ModuleName n) -- XXX: doesn't include the module itself + , moduleExports :: SourceMap n + -- these are just the synth candidates required by this + -- module by itself. We'll usually also need those required by the module's + -- (transitive) dependencies, which must be looked up separately. + , moduleSynthCandidates :: SynthCandidates n } + deriving (Show, Generic) + +data LoadedModules (n::S) = LoadedModules + { fromLoadedModules :: M.Map ModuleSourceName (ModuleName n)} + deriving (Show, Generic) + +emptyModuleEnv :: ModuleEnv n +emptyModuleEnv = ModuleEnv emptyImportStatus (SourceMap mempty) mempty + +emptyLoadedModules :: LoadedModules n +emptyLoadedModules = LoadedModules mempty + +data LoadedObjects (n::S) = LoadedObjects + -- the pointer points to the actual runtime function + { fromLoadedObjects :: M.Map (FunObjCodeName n) NativeFunction} + deriving (Show, Generic) + +emptyLoadedObjects :: LoadedObjects n +emptyLoadedObjects = LoadedObjects mempty + +data ImportStatus (n::S) = ImportStatus + { directImports :: S.Set (ModuleName n) + -- XXX: This are cached for efficiency. It's derivable from `directImports`. + , transImports :: S.Set (ModuleName n) } + deriving (Show, Generic) + +data TopEnvFrag n l = TopEnvFrag (EnvFrag n l) (ModuleEnv l) (SnocList (TopEnvUpdate l)) + +data TopEnvUpdate n = + ExtendCache (Cache n) + | AddCustomRule (CAtomName n) (AtomRules n) + | UpdateLoadedModules ModuleSourceName (ModuleName n) + | UpdateLoadedObjects (FunObjCodeName n) NativeFunction + | FinishDictSpecialization (SpecDictName n) [TopLam SimpIR n] + | LowerDictSpecialization (SpecDictName n) [TopLam SimpIR n] + | UpdateTopFunEvalStatus (TopFunName n) (TopFunEvalStatus n) + | UpdateInstanceDef (InstanceName n) (InstanceDef n) + | UpdateTyConDef (TyConName n) (TyConDef n) + | UpdateFieldDef (TyConName n) SourceName (CAtomName n) + +-- TODO: we could add a lot more structure for querying by dict type, caching, etc. +data SynthCandidates n = SynthCandidates + { instanceDicts :: M.Map (ClassName n) [InstanceName n] + , ixInstances :: [InstanceName n] } + deriving (Show, Generic) + +emptyImportStatus :: ImportStatus n +emptyImportStatus = ImportStatus mempty mempty + +-- TODO: figure out the additional top-level context we need -- backend, other +-- compiler flags etc. We can have a map from those to this. + +data Cache (n::S) = Cache + { specializationCache :: EMap SpecializationSpec TopFunName n + , ixDictCache :: EMap AbsDict SpecDictName n + , linearizationCache :: EMap LinearizationSpec (PairE TopFunName TopFunName) n + , transpositionCache :: EMap TopFunName TopFunName n + -- This is memoizing `parseAndGetDeps :: Text -> [ModuleSourceName]`. But we + -- only want to store one entry per module name as a simple cache eviction + -- policy, so we store it keyed on the module name, with the text hash for + -- the validity check. + , parsedDeps :: M.Map ModuleSourceName (FileHash, [ModuleSourceName]) + , moduleEvaluations :: M.Map ModuleSourceName ((FileHash, [ModuleName n]), ModuleName n) + } deriving (Show, Generic) + +-- === runtime function and variable representations === + +type RuntimeEnv = DynamicVarKeyPtrs + +type DexDestructor = FunPtr (IO ()) + +data NativeFunction = NativeFunction + { nativeFunPtr :: FunPtr () + , nativeFunTeardown :: IO () } + +instance Show NativeFunction where + show _ = "<native function>" + +-- Holds pointers to thread-local storage used to simulate dynamically scoped +-- variables, such as the output stream file descriptor. +type DynamicVarKeyPtrs = [(DynamicVar, Ptr ())] + +data DynamicVar = OutStreamDyvar -- TODO: add others as needed + deriving (Enum, Bounded) + +dynamicVarCName :: DynamicVar -> String +dynamicVarCName OutStreamDyvar = "dex_out_stream_dyvar" + +dynamicVarLinkMap :: DynamicVarKeyPtrs -> [(String, Ptr ())] +dynamicVarLinkMap dyvars = dyvars <&> \(v, ptr) -> (dynamicVarCName v, ptr) + +-- === Specialization and generalization === + +type Generalized (r::IR) (e::E) (n::S) = (Abstracted r e n, [Atom r n]) +type Abstracted (r::IR) (e::E) = Abs (Nest (Binder r)) e +type AbsDict = Abstracted CoreIR (Dict CoreIR) + +data SpecializedDictDef n = + SpecializedDict + (AbsDict n) + -- Methods (thunked if nullary), if they're available. + -- We create specialized dict names during simplification, but we don't + -- actually simplify/lower them until we return to TopLevel + (Maybe [TopLam SimpIR n]) + deriving (Show, Generic) + +-- TODO: extend with AD-oriented specializations, backend-specific specializations etc. +data SpecializationSpec (n::S) = + AppSpecialization (AtomVar CoreIR n) (Abstracted CoreIR (ListE CAtom) n) + deriving (Show, Generic) + +type Active = Bool +data LinearizationSpec (n::S) = LinearizationSpec (TopFunName n) [Active] + deriving (Show, Generic) + +-- === bindings - static information we carry about a lexical scope === + +-- TODO: consider making this an open union via a typeable-like class +data Binding (c::C) (n::S) where + AtomNameBinding :: AtomBinding r n -> Binding (AtomNameC r) n + TyConBinding :: Maybe (TyConDef n) -> DotMethods n -> Binding TyConNameC n + DataConBinding :: TyConName n -> Int -> Binding DataConNameC n + ClassBinding :: ClassDef n -> Binding ClassNameC n + InstanceBinding :: InstanceDef n -> CorePiType n -> Binding InstanceNameC n + MethodBinding :: ClassName n -> Int -> Binding MethodNameC n + TopFunBinding :: TopFun n -> Binding TopFunNameC n + FunObjCodeBinding :: CFunction n -> Binding FunObjCodeNameC n + ModuleBinding :: Module n -> Binding ModuleNameC n + -- TODO: add a case for abstracted pointers, as used in `ClosedImpFunction` + PtrBinding :: PtrType -> PtrLitVal -> Binding PtrNameC n + SpecializedDictBinding :: SpecializedDictDef n -> Binding SpecializedDictNameC n + ImpNameBinding :: BaseType -> Binding ImpNameC n + +-- === ToBinding === + +atomBindingToBinding :: AtomBinding r n -> Binding (AtomNameC r) n +atomBindingToBinding b = AtomNameBinding b + +bindingToAtomBinding :: Binding (AtomNameC r) n -> AtomBinding r n +bindingToAtomBinding (AtomNameBinding b) = b + +class (RenameE e, SinkableE e) => ToBinding (e::E) (c::C) | e -> c where + toBinding :: e n -> Binding c n + +instance Color c => ToBinding (Binding c) c where + toBinding = id + +instance IRRep r => ToBinding (AtomBinding r) (AtomNameC r) where + toBinding = atomBindingToBinding + +instance IRRep r => ToBinding (DeclBinding r) (AtomNameC r) where + toBinding = toBinding . LetBound + +instance IRRep r => ToBinding (Type r) (AtomNameC r) where + toBinding = toBinding . MiscBound + +instance ToBinding SolverBinding (AtomNameC CoreIR) where + toBinding = toBinding . SolverBound + +instance IRRep r => ToBinding (IxType r) (AtomNameC r) where + toBinding (IxType t _) = toBinding t + +instance (ToBinding e1 c, ToBinding e2 c) => ToBinding (EitherE e1 e2) c where + toBinding (LeftE e) = toBinding e + toBinding (RightE e) = toBinding e + +instance ToBindersAbs (TopLam r) (Expr r) r where + toAbs (TopLam _ _ lam) = toAbs lam + +-- === GenericE, GenericB === + +instance GenericE SpecializedDictDef where + type RepE SpecializedDictDef = AbsDict `PairE` MaybeE (ListE (TopLam SimpIR)) + fromE (SpecializedDict ab methods) = ab `PairE` methods' + where methods' = case methods of Just xs -> LeftE (ListE xs) + Nothing -> RightE UnitE + {-# INLINE fromE #-} + toE (ab `PairE` methods) = SpecializedDict ab methods' + where methods' = case methods of LeftE (ListE xs) -> Just xs + RightE UnitE -> Nothing + {-# INLINE toE #-} + +instance SinkableE SpecializedDictDef +instance HoistableE SpecializedDictDef +instance AlphaEqE SpecializedDictDef +instance AlphaHashableE SpecializedDictDef +instance RenameE SpecializedDictDef + +instance HasScope Env where + toScope = toScope . envDefs . topEnv + +instance OutMap Env where + emptyOutMap = + Env (TopEnv (RecSubst emptyInFrag) mempty mempty emptyLoadedModules emptyLoadedObjects) + emptyModuleEnv + {-# INLINE emptyOutMap #-} + +instance ExtOutMap Env (RecSubstFrag Binding) where + -- TODO: We might want to reorganize this struct to make this + -- do less explicit sinking etc. It's a hot operation! + extendOutMap (Env (TopEnv defs rules cache loadedM loadedO) moduleEnv) frag = + withExtEvidence frag $ Env + (TopEnv + (defs `extendRecSubst` frag) + (sink rules) + (sink cache) + (sink loadedM) + (sink loadedO)) + (sink moduleEnv) + {-# INLINE extendOutMap #-} + +instance ExtOutMap Env EnvFrag where + extendOutMap = extendEnv + {-# INLINE extendOutMap #-} + +extendEnv :: Distinct l => Env n -> EnvFrag n l -> Env l +extendEnv env (EnvFrag newEnv) = do + case extendOutMap env newEnv of + Env envTop (ModuleEnv imports sm scs) -> do + Env envTop (ModuleEnv imports sm scs) +{-# NOINLINE [1] extendEnv #-} + + +instance GenericE AtomRules where + type RepE AtomRules = (LiftE (Int, Int, SymbolicZeros)) `PairE` CAtom + fromE (CustomLinearize ni ne sz a) = LiftE (ni, ne, sz) `PairE` a + toE (LiftE (ni, ne, sz) `PairE` a) = CustomLinearize ni ne sz a +instance SinkableE AtomRules +instance HoistableE AtomRules +instance AlphaEqE AtomRules +instance RenameE AtomRules + +instance GenericE CustomRules where + type RepE CustomRules = ListE (PairE (AtomName CoreIR) AtomRules) + fromE (CustomRules m) = ListE $ toPairE <$> M.toList m + toE (ListE l) = CustomRules $ M.fromList $ fromPairE <$> l +instance SinkableE CustomRules +instance HoistableE CustomRules +instance AlphaEqE CustomRules +instance RenameE CustomRules + +instance GenericE Cache where + type RepE Cache = + EMap SpecializationSpec TopFunName + `PairE` EMap AbsDict SpecDictName + `PairE` EMap LinearizationSpec (PairE TopFunName TopFunName) + `PairE` EMap TopFunName TopFunName + `PairE` LiftE (M.Map ModuleSourceName (FileHash, [ModuleSourceName])) + `PairE` ListE ( LiftE ModuleSourceName + `PairE` LiftE FileHash + `PairE` ListE ModuleName + `PairE` ModuleName) + fromE (Cache x y z w parseCache evalCache) = + x `PairE` y `PairE` z `PairE` w `PairE` LiftE parseCache `PairE` + ListE [LiftE sourceName `PairE` LiftE hashVal `PairE` ListE deps `PairE` result + | (sourceName, ((hashVal, deps), result)) <- M.toList evalCache ] + {-# INLINE fromE #-} + toE (x `PairE` y `PairE` z `PairE` w `PairE` LiftE parseCache `PairE` ListE evalCache) = + Cache x y z w parseCache + (M.fromList + [(sourceName, ((hashVal, deps), result)) + | LiftE sourceName `PairE` LiftE hashVal `PairE` ListE deps `PairE` result + <- evalCache]) + {-# INLINE toE #-} + +instance SinkableE Cache +instance HoistableE Cache +instance AlphaEqE Cache +instance RenameE Cache +instance Store (Cache n) + +instance Monoid (Cache n) where + mempty = Cache mempty mempty mempty mempty mempty mempty + mappend = (<>) + +instance Semigroup (Cache n) where + -- right-biased instead of left-biased + Cache x1 x2 x3 x4 x5 x6 <> Cache y1 y2 y3 y4 y5 y6 = + Cache (y1<>x1) (y2<>x2) (y3<>x3) (y4<>x4) (x5<>y5) (x6<>y6) + + +instance GenericE SynthCandidates where + type RepE SynthCandidates = ListE (PairE ClassName (ListE InstanceName)) + `PairE` ListE InstanceName + fromE (SynthCandidates xs ys) = ListE xs' `PairE` ListE ys + where xs' = map (\(k,vs) -> PairE k (ListE vs)) (M.toList xs) + {-# INLINE fromE #-} + toE (ListE xs `PairE` ListE ys) = SynthCandidates xs' ys + where xs' = M.fromList $ map (\(PairE k (ListE vs)) -> (k,vs)) xs + {-# INLINE toE #-} + +instance SinkableE SynthCandidates +instance HoistableE SynthCandidates +instance AlphaEqE SynthCandidates +instance AlphaHashableE SynthCandidates +instance RenameE SynthCandidates + +instance IRRep r => GenericE (AtomBinding r) where + type RepE (AtomBinding r) = + EitherE2 (EitherE3 + (DeclBinding r) -- LetBound + (Type r) -- MiscBound + (WhenCore r SolverBinding) -- SolverBound + ) (EitherE3 + (WhenCore r (PairE CType CAtom)) -- NoinlineFun + (WhenSimp r RepVal) -- TopDataBound + (WhenCore r (CorePiType `PairE` TopFunName)) -- FFIFunBound + ) + + fromE = \case + LetBound x -> Case0 $ Case0 x + MiscBound x -> Case0 $ Case1 x + SolverBound x -> Case0 $ Case2 $ WhenIRE x + NoinlineFun t x -> Case1 $ Case0 $ WhenIRE $ PairE t x + TopDataBound repVal -> Case1 $ Case1 $ WhenIRE repVal + FFIFunBound ty v -> Case1 $ Case2 $ WhenIRE $ ty `PairE` v + {-# INLINE fromE #-} + + toE = \case + Case0 x' -> case x' of + Case0 x -> LetBound x + Case1 x -> MiscBound x + Case2 (WhenIRE x) -> SolverBound x + _ -> error "impossible" + Case1 x' -> case x' of + Case0 (WhenIRE (PairE t x)) -> NoinlineFun t x + Case1 (WhenIRE repVal) -> TopDataBound repVal + Case2 (WhenIRE (ty `PairE` v)) -> FFIFunBound ty v + _ -> error "impossible" + _ -> error "impossible" + {-# INLINE toE #-} + + +instance IRRep r => SinkableE (AtomBinding r) +instance IRRep r => HoistableE (AtomBinding r) +instance IRRep r => RenameE (AtomBinding r) +instance IRRep r => AlphaEqE (AtomBinding r) +instance IRRep r => AlphaHashableE (AtomBinding r) + +instance GenericE TopFunDef where + type RepE TopFunDef = EitherE3 SpecializationSpec LinearizationSpec LinearizationSpec + fromE = \case + Specialization s -> Case0 s + LinearizationPrimal s -> Case1 s + LinearizationTangent s -> Case2 s + {-# INLINE fromE #-} + toE = \case + Case0 s -> Specialization s + Case1 s -> LinearizationPrimal s + Case2 s -> LinearizationTangent s + _ -> error "impossible" + {-# INLINE toE #-} + +instance SinkableE TopFunDef +instance HoistableE TopFunDef +instance RenameE TopFunDef +instance AlphaEqE TopFunDef +instance AlphaHashableE TopFunDef + +instance IRRep r => GenericE (TopLam r) where + type RepE (TopLam r) = LiftE Bool `PairE` PiType r `PairE` LamExpr r + fromE (TopLam d x y) = LiftE d `PairE` x `PairE` y + {-# INLINE fromE #-} + toE (LiftE d `PairE` x `PairE` y) = TopLam d x y + {-# INLINE toE #-} + +instance IRRep r => SinkableE (TopLam r) +instance IRRep r => HoistableE (TopLam r) +instance IRRep r => RenameE (TopLam r) +instance IRRep r => AlphaEqE (TopLam r) +instance IRRep r => AlphaHashableE (TopLam r) + +instance GenericE TopFun where + type RepE TopFun = EitherE + (TopFunDef `PairE` TopLam SimpIR `PairE` ComposeE EvalStatus TopFunLowerings) + (LiftE (String, IFunType)) + fromE = \case + DexTopFun def lam status -> LeftE (def `PairE` lam `PairE` ComposeE status) + FFITopFun name ty -> RightE (LiftE (name, ty)) + {-# INLINE fromE #-} + toE = \case + LeftE (def `PairE` lam `PairE` ComposeE status) -> DexTopFun def lam status + RightE (LiftE (name, ty)) -> FFITopFun name ty + {-# INLINE toE #-} + +instance SinkableE TopFun +instance HoistableE TopFun +instance RenameE TopFun +instance AlphaEqE TopFun +instance AlphaHashableE TopFun + +instance GenericE SpecializationSpec where + type RepE SpecializationSpec = + PairE (AtomVar CoreIR) (Abs (Nest (Binder CoreIR)) (ListE CAtom)) + fromE (AppSpecialization fname (Abs bs args)) = PairE fname (Abs bs args) + {-# INLINE fromE #-} + toE (PairE fname (Abs bs args)) = AppSpecialization fname (Abs bs args) + {-# INLINE toE #-} + +instance HasNameHint (SpecializationSpec n) where + getNameHint (AppSpecialization f _) = getNameHint f + +instance SinkableE SpecializationSpec +instance HoistableE SpecializationSpec +instance RenameE SpecializationSpec +instance AlphaEqE SpecializationSpec +instance AlphaHashableE SpecializationSpec + +instance GenericE LinearizationSpec where + type RepE LinearizationSpec = PairE TopFunName (LiftE [Active]) + fromE (LinearizationSpec fname actives) = PairE fname (LiftE actives) + {-# INLINE fromE #-} + toE (PairE fname (LiftE actives)) = LinearizationSpec fname actives + {-# INLINE toE #-} + +instance SinkableE LinearizationSpec +instance HoistableE LinearizationSpec +instance RenameE LinearizationSpec +instance AlphaEqE LinearizationSpec +instance AlphaHashableE LinearizationSpec + +instance GenericE SolverBinding where + type RepE SolverBinding = EitherE3 + CType + CType + CType + fromE = \case + InfVarBound ty -> Case0 ty + SkolemBound ty -> Case1 ty + DictBound ty -> Case2 ty + {-# INLINE fromE #-} + + toE = \case + Case0 ty -> InfVarBound ty + Case1 ty -> SkolemBound ty + Case2 ty -> DictBound ty + _ -> error "impossible" + {-# INLINE toE #-} + +instance SinkableE SolverBinding +instance HoistableE SolverBinding +instance RenameE SolverBinding +instance AlphaEqE SolverBinding +instance AlphaHashableE SolverBinding + +instance GenericE (Binding c) where + type RepE (Binding c) = + EitherE3 + (EitherE6 + (WhenAtomName c AtomBinding) + (WhenC TyConNameC c (MaybeE TyConDef `PairE` DotMethods)) + (WhenC DataConNameC c (TyConName `PairE` LiftE Int)) + (WhenC ClassNameC c (ClassDef)) + (WhenC InstanceNameC c (InstanceDef `PairE` CorePiType)) + (WhenC MethodNameC c (ClassName `PairE` LiftE Int))) + (EitherE4 + (WhenC TopFunNameC c (TopFun)) + (WhenC FunObjCodeNameC c (CFunction)) + (WhenC ModuleNameC c (Module)) + (WhenC PtrNameC c (LiftE (PtrType, PtrLitVal)))) + (EitherE2 + (WhenC SpecializedDictNameC c (SpecializedDictDef)) + (WhenC ImpNameC c (LiftE BaseType))) + + fromE = \case + AtomNameBinding binding -> Case0 $ Case0 $ WhenAtomName binding + TyConBinding dataDef methods -> Case0 $ Case1 $ WhenC $ toMaybeE dataDef `PairE` methods + DataConBinding dataDefName idx -> Case0 $ Case2 $ WhenC $ dataDefName `PairE` LiftE idx + ClassBinding classDef -> Case0 $ Case3 $ WhenC $ classDef + InstanceBinding instanceDef ty -> Case0 $ Case4 $ WhenC $ instanceDef `PairE` ty + MethodBinding className idx -> Case0 $ Case5 $ WhenC $ className `PairE` LiftE idx + TopFunBinding fun -> Case1 $ Case0 $ WhenC $ fun + FunObjCodeBinding cFun -> Case1 $ Case1 $ WhenC $ cFun + ModuleBinding m -> Case1 $ Case2 $ WhenC $ m + PtrBinding ty p -> Case1 $ Case3 $ WhenC $ LiftE (ty,p) + SpecializedDictBinding def -> Case2 $ Case0 $ WhenC $ def + ImpNameBinding ty -> Case2 $ Case1 $ WhenC $ LiftE ty + {-# INLINE fromE #-} + + toE = \case + Case0 (Case0 (WhenAtomName binding)) -> AtomNameBinding binding + Case0 (Case1 (WhenC (def `PairE` methods))) -> TyConBinding (fromMaybeE def) methods + Case0 (Case2 (WhenC (n `PairE` LiftE idx))) -> DataConBinding n idx + Case0 (Case3 (WhenC (classDef))) -> ClassBinding classDef + Case0 (Case4 (WhenC (instanceDef `PairE` ty))) -> InstanceBinding instanceDef ty + Case0 (Case5 (WhenC ((n `PairE` LiftE i)))) -> MethodBinding n i + Case1 (Case0 (WhenC (fun))) -> TopFunBinding fun + Case1 (Case1 (WhenC (f))) -> FunObjCodeBinding f + Case1 (Case2 (WhenC (m))) -> ModuleBinding m + Case1 (Case3 (WhenC ((LiftE (ty,p))))) -> PtrBinding ty p + Case2 (Case0 (WhenC (def))) -> SpecializedDictBinding def + Case2 (Case1 (WhenC ((LiftE ty)))) -> ImpNameBinding ty + _ -> error "impossible" + {-# INLINE toE #-} + +deriving via WrapE (Binding c) n instance Generic (Binding c n) +instance SinkableV Binding +instance HoistableV Binding +instance RenameV Binding +instance Color c => SinkableE (Binding c) +instance Color c => HoistableE (Binding c) +instance Color c => RenameE (Binding c) + +instance Semigroup (SynthCandidates n) where + SynthCandidates xs ys <> SynthCandidates xs' ys' = + SynthCandidates (M.unionWith (<>) xs xs') (ys <> ys') + +instance Monoid (SynthCandidates n) where + mempty = SynthCandidates mempty mempty + +instance GenericB EnvFrag where + type RepB EnvFrag = RecSubstFrag Binding + fromB (EnvFrag frag) = frag + toB frag = EnvFrag frag + +instance SinkableB EnvFrag +instance HoistableB EnvFrag +instance ProvesExt EnvFrag +instance BindsNames EnvFrag +instance RenameB EnvFrag + +instance GenericE TopEnvUpdate where + type RepE TopEnvUpdate = EitherE2 ( + EitherE4 + {- ExtendCache -} Cache + {- AddCustomRule -} (CAtomName `PairE` AtomRules) + {- UpdateLoadedModules -} (LiftE ModuleSourceName `PairE` ModuleName) + {- UpdateLoadedObjects -} (FunObjCodeName `PairE` LiftE NativeFunction) + ) ( EitherE6 + {- FinishDictSpecialization -} (SpecDictName `PairE` ListE (TopLam SimpIR)) + {- LowerDictSpecialization -} (SpecDictName `PairE` ListE (TopLam SimpIR)) + {- UpdateTopFunEvalStatus -} (TopFunName `PairE` ComposeE EvalStatus TopFunLowerings) + {- UpdateInstanceDef -} (InstanceName `PairE` InstanceDef) + {- UpdateTyConDef -} (TyConName `PairE` TyConDef) + {- UpdateFieldDef -} (TyConName `PairE` LiftE SourceName `PairE` CAtomName) + ) + fromE = \case + ExtendCache x -> Case0 $ Case0 x + AddCustomRule x y -> Case0 $ Case1 (x `PairE` y) + UpdateLoadedModules x y -> Case0 $ Case2 (LiftE x `PairE` y) + UpdateLoadedObjects x y -> Case0 $ Case3 (x `PairE` LiftE y) + FinishDictSpecialization x y -> Case1 $ Case0 (x `PairE` ListE y) + LowerDictSpecialization x y -> Case1 $ Case1 (x `PairE` ListE y) + UpdateTopFunEvalStatus x y -> Case1 $ Case2 (x `PairE` ComposeE y) + UpdateInstanceDef x y -> Case1 $ Case3 (x `PairE` y) + UpdateTyConDef x y -> Case1 $ Case4 (x `PairE` y) + UpdateFieldDef x y z -> Case1 $ Case5 (x `PairE` LiftE y `PairE` z) + + toE = \case + Case0 e -> case e of + Case0 x -> ExtendCache x + Case1 (x `PairE` y) -> AddCustomRule x y + Case2 (LiftE x `PairE` y) -> UpdateLoadedModules x y + Case3 (x `PairE` LiftE y) -> UpdateLoadedObjects x y + _ -> error "impossible" + Case1 e -> case e of + Case0 (x `PairE` ListE y) -> FinishDictSpecialization x y + Case1 (x `PairE` ListE y) -> LowerDictSpecialization x y + Case2 (x `PairE` ComposeE y) -> UpdateTopFunEvalStatus x y + Case3 (x `PairE` y) -> UpdateInstanceDef x y + Case4 (x `PairE` y) -> UpdateTyConDef x y + Case5 (x `PairE` LiftE y `PairE` z) -> UpdateFieldDef x y z + _ -> error "impossible" + _ -> error "impossible" + +instance SinkableE TopEnvUpdate +instance HoistableE TopEnvUpdate +instance RenameE TopEnvUpdate + +instance GenericB TopEnvFrag where + type RepB TopEnvFrag = PairB EnvFrag (LiftB (ModuleEnv `PairE` ListE TopEnvUpdate)) + fromB (TopEnvFrag x y (ReversedList z)) = PairB x (LiftB (y `PairE` ListE z)) + toB (PairB x (LiftB (y `PairE` ListE z))) = TopEnvFrag x y (ReversedList z) + +instance RenameB TopEnvFrag +instance HoistableB TopEnvFrag +instance SinkableB TopEnvFrag +instance ProvesExt TopEnvFrag +instance BindsNames TopEnvFrag + +instance OutFrag TopEnvFrag where + emptyOutFrag = TopEnvFrag emptyOutFrag mempty mempty + {-# INLINE emptyOutFrag #-} + catOutFrags (TopEnvFrag frag1 env1 partial1) + (TopEnvFrag frag2 env2 partial2) = + withExtEvidence frag2 $ + TopEnvFrag + (catOutFrags frag1 frag2) + (sink env1 <> env2) + (sinkSnocList partial1 <> partial2) + {-# INLINE catOutFrags #-} + +-- XXX: unlike `ExtOutMap Env EnvFrag` instance, this once doesn't +-- extend the synthesis candidates based on the annotated let-bound names. It +-- only extends synth candidates when they're supplied explicitly. +instance ExtOutMap Env TopEnvFrag where + extendOutMap env (TopEnvFrag (EnvFrag frag) mEnv' otherUpdates) = do + let newerTopEnv = foldl applyUpdate newTopEnv otherUpdates + Env newerTopEnv newModuleEnv + where + Env (TopEnv defs rules cache loadedM loadedO) mEnv = env + + newTopEnv = withExtEvidence frag $ TopEnv + (defs `extendRecSubst` frag) + (sink rules) (sink cache) (sink loadedM) (sink loadedO) + + newModuleEnv = + ModuleEnv + (imports <> imports') + (sm <> sm' <> newImportedSM) + (scs <> scs' <> newImportedSC) + where + ModuleEnv imports sm scs = withExtEvidence frag $ sink mEnv + ModuleEnv imports' sm' scs' = mEnv' + newDirectImports = S.difference (directImports imports') (directImports imports) + newTransImports = S.difference (transImports imports') (transImports imports) + newImportedSM = flip foldMap newDirectImports $ moduleExports . lookupModulePure + newImportedSC = flip foldMap newTransImports $ moduleSynthCandidates . lookupModulePure + + lookupModulePure v = case lookupEnvPure newTopEnv v of ModuleBinding m -> m + +applyUpdate :: TopEnv n -> TopEnvUpdate n -> TopEnv n +applyUpdate e = \case + ExtendCache cache -> e { envCache = envCache e <> cache} + AddCustomRule x y -> e { envCustomRules = envCustomRules e <> CustomRules (M.singleton x y)} + UpdateLoadedModules x y -> e { envLoadedModules = envLoadedModules e <> LoadedModules (M.singleton x y)} + UpdateLoadedObjects x y -> e { envLoadedObjects = envLoadedObjects e <> LoadedObjects (M.singleton x y)} + FinishDictSpecialization dName methods -> do + let SpecializedDictBinding (SpecializedDict dAbs oldMethods) = lookupEnvPure e dName + case oldMethods of + Nothing -> do + let newBinding = SpecializedDictBinding $ SpecializedDict dAbs (Just methods) + updateEnv dName newBinding e + Just _ -> error "shouldn't be adding methods if we already have them" + LowerDictSpecialization dName methods -> do + let SpecializedDictBinding (SpecializedDict dAbs _) = lookupEnvPure e dName + let newBinding = SpecializedDictBinding $ SpecializedDict dAbs (Just methods) + updateEnv dName newBinding e + UpdateTopFunEvalStatus f s -> do + case lookupEnvPure e f of + TopFunBinding (DexTopFun def lam _) -> + updateEnv f (TopFunBinding $ DexTopFun def lam s) e + _ -> error "can't update ffi function impl" + UpdateInstanceDef name def -> do + case lookupEnvPure e name of + InstanceBinding _ ty -> updateEnv name (InstanceBinding def ty) e + UpdateTyConDef name def -> do + let TyConBinding _ methods = lookupEnvPure e name + updateEnv name (TyConBinding (Just def) methods) e + UpdateFieldDef name sn x -> do + let TyConBinding def methods = lookupEnvPure e name + updateEnv name (TyConBinding def (methods <> DotMethods (M.singleton sn x))) e + +updateEnv :: Color c => Name c n -> Binding c n -> TopEnv n -> TopEnv n +updateEnv v rhs env = + env { envDefs = RecSubst $ updateSubstFrag v rhs bs } + where (RecSubst bs) = envDefs env + +lookupEnvPure :: Color c => TopEnv n -> Name c n -> Binding c n +lookupEnvPure env v = lookupTerminalSubstFrag (fromRecSubst $ envDefs $ env) v + +instance GenericE Module where + type RepE Module = LiftE ModuleSourceName + `PairE` ListE ModuleName + `PairE` ListE ModuleName + `PairE` SourceMap + `PairE` SynthCandidates + + fromE (Module name deps transDeps sm sc) = + LiftE name `PairE` ListE (S.toList deps) `PairE` ListE (S.toList transDeps) + `PairE` sm `PairE` sc + {-# INLINE fromE #-} + + toE (LiftE name `PairE` ListE deps `PairE` ListE transDeps + `PairE` sm `PairE` sc) = + Module name (S.fromList deps) (S.fromList transDeps) sm sc + {-# INLINE toE #-} + +instance SinkableE Module +instance HoistableE Module +instance AlphaEqE Module +instance AlphaHashableE Module +instance RenameE Module + +instance GenericE ImportStatus where + type RepE ImportStatus = ListE ModuleName `PairE` ListE ModuleName + fromE (ImportStatus direct trans) = ListE (S.toList direct) + `PairE` ListE (S.toList trans) + {-# INLINE fromE #-} + toE (ListE direct `PairE` ListE trans) = + ImportStatus (S.fromList direct) (S.fromList trans) + {-# INLINE toE #-} + +instance SinkableE ImportStatus +instance HoistableE ImportStatus +instance AlphaEqE ImportStatus +instance AlphaHashableE ImportStatus +instance RenameE ImportStatus + +instance Semigroup (ImportStatus n) where + ImportStatus direct trans <> ImportStatus direct' trans' = + ImportStatus (direct <> direct') (trans <> trans') + +instance Monoid (ImportStatus n) where + mappend = (<>) + mempty = ImportStatus mempty mempty + +instance GenericE LoadedModules where + type RepE LoadedModules = ListE (PairE (LiftE ModuleSourceName) ModuleName) + fromE (LoadedModules m) = + ListE $ M.toList m <&> \(v,md) -> PairE (LiftE v) md + {-# INLINE fromE #-} + toE (ListE pairs) = + LoadedModules $ M.fromList $ pairs <&> \(PairE (LiftE v) md) -> (v, md) + {-# INLINE toE #-} + +instance SinkableE LoadedModules +instance HoistableE LoadedModules +instance AlphaEqE LoadedModules +instance AlphaHashableE LoadedModules +instance RenameE LoadedModules + +instance GenericE LoadedObjects where + type RepE LoadedObjects = ListE (PairE FunObjCodeName (LiftE NativeFunction)) + fromE (LoadedObjects m) = + ListE $ M.toList m <&> \(v,p) -> PairE v (LiftE p) + {-# INLINE fromE #-} + toE (ListE pairs) = + LoadedObjects $ M.fromList $ pairs <&> \(PairE v (LiftE p)) -> (v, p) + {-# INLINE toE #-} + +instance SinkableE LoadedObjects +instance HoistableE LoadedObjects +instance RenameE LoadedObjects + +instance GenericE ModuleEnv where + type RepE ModuleEnv = ImportStatus + `PairE` SourceMap + `PairE` SynthCandidates + fromE (ModuleEnv imports sm sc) = imports `PairE` sm `PairE` sc + {-# INLINE fromE #-} + toE (imports `PairE` sm `PairE` sc) = ModuleEnv imports sm sc + {-# INLINE toE #-} + +instance SinkableE ModuleEnv +instance HoistableE ModuleEnv +instance AlphaEqE ModuleEnv +instance AlphaHashableE ModuleEnv +instance RenameE ModuleEnv + +instance Semigroup (ModuleEnv n) where + ModuleEnv x1 x2 x3 <> ModuleEnv y1 y2 y3 = + ModuleEnv (x1<>y1) (x2<>y2) (x3<>y3) + +instance Monoid (ModuleEnv n) where + mempty = ModuleEnv mempty mempty mempty + +instance Semigroup (LoadedModules n) where + LoadedModules m1 <> LoadedModules m2 = LoadedModules (m2 <> m1) + +instance Monoid (LoadedModules n) where + mempty = LoadedModules mempty + +instance Semigroup (LoadedObjects n) where + LoadedObjects m1 <> LoadedObjects m2 = LoadedObjects (m2 <> m1) + +instance Monoid (LoadedObjects n) where + mempty = LoadedObjects mempty + + +-- === instance === + +prettyRecord :: [(String, Doc ann)] -> Doc ann +prettyRecord xs = foldMap (\(name, val) -> pretty name <> indented val) xs + +instance Pretty (TopEnv n) where + pretty (TopEnv defs rules cache _ _) = + prettyRecord [ ("Defs" , pretty defs) + , ("Rules" , pretty rules) + , ("Cache" , pretty cache) ] + +instance Pretty (CustomRules n) where + pretty _ = "TODO: Rule printing" + +instance Pretty (ImportStatus n) where + pretty imports = pretty $ S.toList $ directImports imports + +instance Pretty (ModuleEnv n) where + pretty (ModuleEnv imports sm sc) = + prettyRecord [ ("Imports" , pretty imports) + , ("Source map" , pretty sm) + , ("Synth candidates", pretty sc) ] + +instance Pretty (Env n) where + pretty (Env env1 env2) = + prettyRecord [ ("Top env" , pretty env1) + , ("Module env", pretty env2)] + +instance Pretty (SolverBinding n) where + pretty (InfVarBound ty) = "Inference variable of type:" <+> pretty ty + pretty (SkolemBound ty) = "Skolem variable of type:" <+> pretty ty + pretty (DictBound ty) = "Dictionary variable of type:" <+> pretty ty + +instance Pretty (Binding c n) where + pretty b = case b of + -- using `unsafeCoerceIRE` here because otherwise we don't have `IRRep` + -- TODO: can we avoid printing needing IRRep? Presumably it's related to + -- manipulating sets or something, which relies on Eq/Ord, which relies on renaming. + AtomNameBinding info -> "Atom name:" <+> pretty (unsafeCoerceIRE @CoreIR info) + TyConBinding dataDef _ -> "Type constructor: " <+> pretty dataDef + DataConBinding tyConName idx -> "Data constructor:" <+> + pretty tyConName <+> "Constructor index:" <+> pretty idx + ClassBinding classDef -> pretty classDef + InstanceBinding instanceDef _ -> pretty instanceDef + MethodBinding className idx -> "Method" <+> pretty idx <+> "of" <+> pretty className + TopFunBinding f -> pretty f + FunObjCodeBinding _ -> "<object file>" + ModuleBinding _ -> "<module>" + PtrBinding _ _ -> "<ptr>" + SpecializedDictBinding _ -> "<specialized-dict-binding>" + ImpNameBinding ty -> "Imp name of type: " <+> pretty ty + +instance Pretty (Module n) where + pretty m = prettyRecord + [ ("moduleSourceName" , pretty $ moduleSourceName m) + , ("moduleDirectDeps" , pretty $ S.toList $ moduleDirectDeps m) + , ("moduleTransDeps" , pretty $ S.toList $ moduleTransDeps m) + , ("moduleExports" , pretty $ moduleExports m) + , ("moduleSynthCandidates", pretty $ moduleSynthCandidates m) ] + +instance Pretty a => Pretty (EvalStatus a) where + pretty = \case + Waiting -> "<waiting>" + Running -> "<running>" + Finished a -> pretty a + +instance Pretty (EnvFrag n l) where + pretty (EnvFrag bindings) = pretty bindings + +instance Pretty (Cache n) where + pretty (Cache _ _ _ _ _ _) = "<cache>" -- TODO + +instance Pretty (SynthCandidates n) where + pretty scs = "instance dicts:" <+> pretty (M.toList $ instanceDicts scs) + +instance Pretty (LoadedModules n) where + pretty _ = "<loaded modules>" + +instance Pretty (TopFunDef n) where + pretty = \case + Specialization s -> pretty s + LinearizationPrimal _ -> "<linearization primal>" + LinearizationTangent _ -> "<linearization tangent>" + +instance Pretty (TopFun n) where + pretty = \case + DexTopFun def lam lowering -> + "Top-level Function" + <> hardline <+> "definition:" <+> pretty def + <> hardline <+> "lambda:" <+> pretty lam + <> hardline <+> "lowering:" <+> pretty lowering + FFITopFun f _ -> pretty f + +instance IRRep r => Pretty (TopLam r n) where + pretty (TopLam _ _ lam) = pretty lam + +instance IRRep r => Pretty (AtomBinding r n) where + pretty binding = case binding of + LetBound b -> pretty b + MiscBound t -> pretty t + SolverBound b -> pretty b + FFIFunBound s _ -> pretty s + NoinlineFun ty _ -> "Top function with type: " <+> pretty ty + TopDataBound (RepVal ty _) -> "Top data with type: " <+> pretty ty + +instance Pretty (SpecializationSpec n) where + pretty (AppSpecialization f (Abs bs (ListE args))) = + "Specialization" <+> pretty f <+> pretty bs <+> pretty args + +instance Hashable InfVarDesc +instance Hashable a => Hashable (EvalStatus a) + +instance Store (SolverBinding n) +instance IRRep r => Store (AtomBinding r n) +instance IRRep r => Store (TopLam r n) +instance Store (SynthCandidates n) +instance Store (Module n) +instance Store (ImportStatus n) +instance Store (TopFunLowerings n) +instance Store a => Store (EvalStatus a) +instance Store (TopFun n) +instance Store (TopFunDef n) +instance Color c => Store (Binding c n) +instance Store (ModuleEnv n) +instance Store (SerializedEnv n) +instance Store InfVarDesc +instance Store (AtomRules n) +instance Store (LinearizationSpec n) +instance Store (SpecializedDictDef n) +instance Store (SpecializationSpec n) diff --git a/src/lib/Util.hs b/src/lib/Util.hs index 8a44e723..4dbc43ed 100644 --- a/src/lib/Util.hs +++ b/src/lib/Util.hs @@ -26,6 +26,7 @@ import Data.Store (Store) import qualified Data.List.NonEmpty as NE import qualified Data.ByteString as BS import Data.Foldable +import Data.Text.Prettyprint.Doc (Pretty (..), pretty) import Data.List.NonEmpty (NonEmpty (..)) import GHC.Generics (Generic) @@ -354,6 +355,11 @@ zipTrees (Leaf x) (Leaf y) = Leaf (x, y) zipTrees (Branch xs) (Branch ys) | length xs == length ys = Branch $ zipWith zipTrees xs ys zipTrees _ _ = error "zip error" +instance Pretty a => Pretty (Tree a) where + pretty = \case + Leaf x -> pretty x + Branch xs -> pretty xs + -- === bytestrings paired with their hash digest === -- TODO: use something other than a string to store the digest diff --git a/src/lib/Vectorize.hs b/src/lib/Vectorize.hs index d6fec397..90e289df 100644 --- a/src/lib/Vectorize.hs +++ b/src/lib/Vectorize.hs @@ -9,7 +9,7 @@ module Vectorize (vectorizeLoops) where import Prelude hiding ((.)) import Data.Word import Data.Functor -import Data.Text.Prettyprint.Doc (Pretty, pretty, viaShow, (<+>)) +import Data.Text.Prettyprint.Doc (viaShow) import Control.Category import Control.Monad.Reader import Control.Monad.State.Strict @@ -26,6 +26,7 @@ import Subst import PPrint import QueryType import Types.Core +import Types.Top import Types.OpNames qualified as P import Types.Primitives import Util (allM, zipWithZ) |