summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorDougal <d.maclaurin@gmail.com>2023-10-23 11:02:47 -0400
committerDougal <d.maclaurin@gmail.com>2023-10-23 11:02:47 -0400
commitd80f318b9ea90f32420c9aa49bc935ae2aed6324 (patch)
tree6bca935ec93e9464b4dfca0fb9c2cd3bb8c44cc3
parentde88bf8bfcf164fc90d603c61539274455498f96 (diff)
Add a `StuckTabApp` case to `Stuck`
-rw-r--r--src/lib/Builder.hs12
-rw-r--r--src/lib/CheapReduction.hs24
-rw-r--r--src/lib/CheckType.hs7
-rw-r--r--src/lib/Imp.hs6
-rw-r--r--src/lib/Inference.hs11
-rw-r--r--src/lib/Linearize.hs2
-rw-r--r--src/lib/OccAnalysis.hs16
-rw-r--r--src/lib/PPrint.hs1
-rw-r--r--src/lib/QueryTypePure.hs1
-rw-r--r--src/lib/Simplify.hs8
-rw-r--r--src/lib/Transpose.hs4
-rw-r--r--src/lib/Types/Core.hs18
-rw-r--r--src/lib/Vectorize.hs2
13 files changed, 86 insertions, 26 deletions
diff --git a/src/lib/Builder.hs b/src/lib/Builder.hs
index 52425566..5574aa44 100644
--- a/src/lib/Builder.hs
+++ b/src/lib/Builder.hs
@@ -129,6 +129,18 @@ buildScopedAssumeNoDecls cont = do
_ -> error "Expected no decl emissions"
{-# INLINE buildScopedAssumeNoDecls #-}
+withReducibleEmissions
+ :: (ScopableBuilder r m, Builder r m, HasNamesE e, SubstE AtomSubstVal e)
+ => String
+ -> (forall o' . (Emits o', DExt o o') => m o' (e o'))
+ -> m o (e o)
+withReducibleEmissions msg cont = do
+ withDecls <- buildScoped cont
+ reduceWithDecls withDecls >>= \case
+ Just t -> return t
+ _ -> throw TypeErr msg
+{-# INLINE withReducibleEmissions #-}
+
-- === "Hoisting" top-level builder class ===
-- `emitHoistedEnv` lets you emit top env fragments, like cache entries or
diff --git a/src/lib/CheapReduction.hs b/src/lib/CheapReduction.hs
index c7fd9658..03d60730 100644
--- a/src/lib/CheapReduction.hs
+++ b/src/lib/CheapReduction.hs
@@ -109,8 +109,14 @@ reduceExprM = \case
case (ty, val) of
(BaseTy (Scalar Word32Type), Con (Lit (Word64Lit v))) -> return $ Con $ Lit $ Word32Lit $ fromIntegral v
_ -> empty
+ TabApp ty tab xs -> do
+ ty' <- substM ty
+ xs' <- mapM substM xs
+ tab' <- substM tab
+ case tab' of
+ Stuck tab'' -> return $ Stuck $ StuckTabApp ty' tab'' xs'
+ _ -> error "not a table" -- what about RepVal?
TopApp _ _ _ -> empty
- TabApp _ _ _ -> empty
Case _ _ _ -> empty
TabCon _ _ _ -> empty
PrimOp _ -> empty
@@ -188,6 +194,11 @@ typeOfApp (Pi piTy) xs = withSubstReaderT $
withInstantiated piTy xs \(EffTy _ ty) -> substM ty
typeOfApp _ _ = error "expected a pi type"
+typeOfTabApp :: (IRRep r, EnvReader m) => Type r n -> [Atom r n] -> m n (Type r n)
+typeOfTabApp (TabPi piTy) xs = withSubstReaderT $
+ withInstantiated piTy xs \ty -> substM ty
+typeOfTabApp _ _ = error "expected a TabPi type"
+
repValAtom :: EnvReader m => SRepVal n -> m n (SAtom n)
repValAtom (RepVal ty tree) = case ty of
ProdTy ts -> case tree of
@@ -220,6 +231,13 @@ reduceUnwrapM = \case
_ -> error "expected a newtype"
_ -> empty
+reduceTabAppM :: IRRep r => Atom r o -> [Atom r o] -> ReducerM i o (Atom r o)
+reduceTabAppM tab xs = case tab of
+ Stuck tab' -> do
+ ty <- typeOfTabApp (getType tab') xs
+ return $ Stuck $ StuckTabApp ty tab' xs
+ _ -> error $ "not a table" ++ pprint tab
+
unwrapNewtypeType :: EnvReader m => NewtypeTyCon n -> m n (NewtypeCon n, Type CoreIR n)
unwrapNewtypeType = \case
Nat -> return (NatCon, IdxRepTy)
@@ -616,6 +634,10 @@ reduceStuck = \case
StuckUnwrap _ x -> do
x' <- reduceStuck x
dropSubst $ reduceUnwrapM x'
+ StuckTabApp _ f xs -> do
+ f' <- reduceStuck f
+ xs' <- mapM substM xs
+ dropSubst $ reduceTabAppM f' xs'
InstantiatedGiven _ f xs -> do
xs' <- mapM substM xs
f' <- reduceStuck f
diff --git a/src/lib/CheckType.hs b/src/lib/CheckType.hs
index 6f3b2093..ad6698a9 100644
--- a/src/lib/CheckType.hs
+++ b/src/lib/CheckType.hs
@@ -320,6 +320,13 @@ instance IRRep r => CheckableE r (Stuck r) where
StuckProject resultTy i x -> do
Project resultTy' i' (Stuck x') <- checkWithEffects Pure $ Project resultTy i (Stuck x)
return $ StuckProject resultTy' i' x'
+ StuckTabApp reqTy f xs -> do
+ reqTy' <- reqTy |: TyKind
+ (f', tabTy) <- checkAndGetType f
+ xs' <- mapM checkE xs
+ ty' <- checkTabApp tabTy xs'
+ checkTypesEq reqTy' ty'
+ return $ StuckTabApp reqTy' f' xs'
InstantiatedGiven resultTy given args -> do
resultTy' <- resultTy |: TyKind
(given', Pi piTy) <- checkAndGetType given
diff --git a/src/lib/Imp.hs b/src/lib/Imp.hs
index 63047979..9ab013f2 100644
--- a/src/lib/Imp.hs
+++ b/src/lib/Imp.hs
@@ -865,9 +865,15 @@ atomToRepVal x = RepVal (getType x) <$> go x where
Stuck (StuckVar v) -> lookupAtomName (atomVarName v) >>= \case
TopDataBound (RepVal _ tree) -> return tree
_ -> error "should only have pointer and data atom names left"
+ -- TODO: I think we want to be able to rule this one out by insisting that
+ -- RepValAtom is itself part of Stuck and it can't represent a product.
Stuck (StuckProject _ i val) -> do
Branch ts <- go $ Stuck val
return $ ts !! i
+ Stuck (StuckTabApp _ f xs) -> do
+ f' <- atomToRepVal $ Stuck f
+ RepVal _ t <- naryIndexRepVal f' (toList xs)
+ return t
-- XXX: We used to have a function called `destToAtom` which loaded the value
-- from the dest. This version is not that. It just lifts a dest into an atom of
diff --git a/src/lib/Inference.hs b/src/lib/Inference.hs
index b89c77bd..276bd923 100644
--- a/src/lib/Inference.hs
+++ b/src/lib/Inference.hs
@@ -1089,17 +1089,6 @@ checkSigmaDependent e@(WithSrcE ctx _) ty = addSrcContext ctx $
"Dependent functions can only be applied to fully evaluated expressions. " ++
"Bind the argument to a name before you apply the function."
-withReducibleEmissions
- :: Zonkable e
- => String
- -> (forall o' . (Emits o', DExt o o') => InfererM i o' (e o'))
- -> InfererM i o (e o)
-withReducibleEmissions msg cont = do
- withDecls <- buildScoped cont
- reduceWithDecls withDecls >>= \case
- Just t -> return t
- _ -> throw TypeErr msg
-
-- === sorting case alternatives ===
data IndexedAlt n = IndexedAlt CaseAltIndex (Alt CoreIR n)
diff --git a/src/lib/Linearize.hs b/src/lib/Linearize.hs
index 4fbae982..42661564 100644
--- a/src/lib/Linearize.hs
+++ b/src/lib/Linearize.hs
@@ -334,7 +334,9 @@ linearizeAtom atom = case atom of
activePrimalIdx v' >>= \case
Nothing -> withZeroT $ return (Var v')
Just idx -> return $ WithTangent (Var v') $ getTangentArg idx
+ -- TODO: buildScoped and reduce the results so we keep expression in non-ANF for type checking purposes
Stuck (StuckProject ty i x) -> linearizeExpr $ Project ty i (Stuck x)
+ Stuck (StuckTabApp t f xs) -> linearizeExpr $ TabApp t (Stuck f) xs
RepValAtom _ -> emitZeroT
where emitZeroT = withZeroT $ renameM atom
diff --git a/src/lib/OccAnalysis.hs b/src/lib/OccAnalysis.hs
index 59b7a438..711374df 100644
--- a/src/lib/OccAnalysis.hs
+++ b/src/lib/OccAnalysis.hs
@@ -255,6 +255,11 @@ instance HasOCC SStuck where
ty' <- occTy ty
return $ StuckVar (AtomVar n ty')
StuckProject t i x -> StuckProject <$> occ a t <*> pure i <*> occ a x
+ StuckTabApp t array ixs -> do
+ t' <- occTy t
+ (a', ixs') <- occIdxs a ixs
+ array' <- occ a' array
+ return $ StuckTabApp t' array' ixs'
instance HasOCC SType where
occ a ty = runOCCMVisitor a $ visitTypePartial ty
@@ -360,7 +365,7 @@ instance HasOCC SExpr where
return $ Block effTy' (Abs decls' ans')
TabApp t array ixs -> do
t' <- occTy t
- (a', ixs') <- go a ixs
+ (a', ixs') <- occIdxs a ixs
array' <- occ a' array
return $ TabApp t' array' ixs'
Case scrut alts (EffTy effs ty) -> do
@@ -376,10 +381,11 @@ instance HasOCC SExpr where
ref' <- occ a ref
PrimOp . RefOp ref' <$> occ a op
expr -> occGeneric a expr
- where
- go acc [] = return (acc, [])
- go acc (ix:ixs) = do
- (acc', ixs') <- go acc ixs
+
+occIdxs :: Access n -> [SAtom n] -> OCCM n (Access n, [SAtom n])
+occIdxs acc [] = return (acc, [])
+occIdxs acc (ix:ixs) = do
+ (acc', ixs') <- occIdxs acc ixs
(summ, ix') <- occurrenceAndSummary ix
return (location summ acc', ix':ixs')
diff --git a/src/lib/PPrint.hs b/src/lib/PPrint.hs
index 7e870889..968193a0 100644
--- a/src/lib/PPrint.hs
+++ b/src/lib/PPrint.hs
@@ -265,6 +265,7 @@ instance IRRep r => PrettyPrec (Stuck r n) where
prettyPrec = \case
StuckVar v -> atPrec ArgPrec $ p v
StuckProject _ i v -> atPrec LowestPrec $ "StuckProject" <+> p i <+> p v
+ StuckTabApp _ f xs -> atPrec AppPrec $ pArg f <> "." <> pArg xs
StuckUnwrap _ v -> atPrec LowestPrec $ "StuckUnwrap" <+> p v
InstantiatedGiven _ v args -> atPrec LowestPrec $ "Given" <+> p v <+> p (toList args)
SuperclassProj _ d' i -> atPrec LowestPrec $ "SuperclassProj" <+> p d' <+> p i
diff --git a/src/lib/QueryTypePure.hs b/src/lib/QueryTypePure.hs
index 99b4687b..258bbb9b 100644
--- a/src/lib/QueryTypePure.hs
+++ b/src/lib/QueryTypePure.hs
@@ -102,6 +102,7 @@ instance IRRep r => HasType r (Stuck r) where
getType = \case
StuckVar (AtomVar _ t) -> t
StuckProject t _ _ -> t
+ StuckTabApp t _ _ -> t
StuckUnwrap t _ -> t
InstantiatedGiven t _ _ -> t
SuperclassProj t _ _ -> t
diff --git a/src/lib/Simplify.hs b/src/lib/Simplify.hs
index b597ffa9..0f71998f 100644
--- a/src/lib/Simplify.hs
+++ b/src/lib/Simplify.hs
@@ -165,6 +165,14 @@ forceStuck stuck cont = withDistinct case stuck of
DepPair l r _ -> forceConstructor ([l, r]!!i) cont
_ -> error "Can't project stuck term"
_ -> error "Can't project stuck term"
+ StuckTabApp _ f xs -> do
+ ty <- substM $ getType stuck
+ xs' <- forM xs \x -> toDataAtomIgnoreRecon =<< substM x
+ forceStuck f \case
+ CCSimpInCore (LiftSimp _ f') -> do
+ result <- naryTabApp f' (sink<$>xs')
+ cont $ CCSimpInCore $ LiftSimp (sink ty) result
+ _ -> error "not a table" -- what about table lambda?
StuckUnwrap _ x -> forceStuck x \case
CCCon (WithSubst s con) -> withSubst s case con of
NewtypeCon _ x' -> forceConstructor x' cont
diff --git a/src/lib/Transpose.hs b/src/lib/Transpose.hs
index 0cde8255..75e14ec7 100644
--- a/src/lib/Transpose.hs
+++ b/src/lib/Transpose.hs
@@ -313,8 +313,8 @@ transposeAtom atom ct = case atom of
return ()
LinRef ref -> emitCTToRef ref ct
LinTrivial -> return ()
- Stuck (StuckProject _ _ _) -> undefined
- -- Stuck (StuckProject _ i' x') -> do
+ Stuck (StuckProject _ _ _) -> error "not implemented"
+ Stuck (StuckTabApp _ _ _) -> error "not implemented"
-- let (idxs, v) = asNaryProj i' x'
-- lookupSubstM (atomVarName v) >>= \case
-- RenameNonlin _ -> error "an error, probably"
diff --git a/src/lib/Types/Core.hs b/src/lib/Types/Core.hs
index 2ef17f6d..95380bbc 100644
--- a/src/lib/Types/Core.hs
+++ b/src/lib/Types/Core.hs
@@ -74,6 +74,7 @@ data Type (r::IR) (n::S) where
data Stuck (r::IR) (n::S) where
StuckVar :: AtomVar r n -> Stuck r n
StuckProject :: Type r n -> Int -> Stuck r n -> Stuck r n
+ StuckTabApp :: Type r n -> Stuck r n -> [Atom r n] -> Stuck r n
StuckUnwrap :: CType n -> CStuck n -> Stuck CoreIR n
InstantiatedGiven :: CType n -> CStuck n -> [CAtom n] -> Stuck CoreIR n
SuperclassProj :: CType n -> Int -> CStuck n -> Stuck CoreIR n
@@ -1552,26 +1553,29 @@ instance IRRep r => AlphaHashableE (Atom r)
instance IRRep r => RenameE (Atom r)
instance IRRep r => GenericE (Stuck r) where
- type RepE (Stuck r) = EitherE5
+ type RepE (Stuck r) = EitherE6
{- StuckVar -} (AtomVar r)
{- StuckProject -} (Type r `PairE` LiftE Int `PairE` Stuck r)
+ {- StuckTabApp -} (Type r `PairE` Stuck r `PairE` ListE (Atom r))
{- StuckUnwrap -} (WhenCore r (CType `PairE` CStuck))
{- InstantiatedGiven -} (WhenCore r (CType `PairE` CStuck `PairE` ListE CAtom))
{- SuperclassProj -} (WhenCore r (CType `PairE` LiftE Int `PairE` CStuck))
fromE = \case
StuckVar v -> Case0 v
StuckProject t i e -> Case1 $ t `PairE` LiftE i `PairE` e
- StuckUnwrap t e -> Case2 $ WhenIRE $ t `PairE` e
- InstantiatedGiven t e xs -> Case3 $ WhenIRE $ t `PairE` e `PairE` ListE xs
- SuperclassProj t i e -> Case4 $ WhenIRE $ t `PairE` LiftE i `PairE` e
+ StuckTabApp t f x -> Case2 $ t `PairE` f `PairE` ListE x
+ StuckUnwrap t e -> Case3 $ WhenIRE $ t `PairE` e
+ InstantiatedGiven t e xs -> Case4 $ WhenIRE $ t `PairE` e `PairE` ListE xs
+ SuperclassProj t i e -> Case5 $ WhenIRE $ t `PairE` LiftE i `PairE` e
{-# INLINE fromE #-}
toE = \case
Case0 v -> StuckVar v
Case1 (t `PairE` LiftE i `PairE` e) -> StuckProject t i e
- Case2 (WhenIRE (t `PairE` e)) -> StuckUnwrap t e
- Case3 (WhenIRE (t `PairE` e `PairE` ListE xs)) -> InstantiatedGiven t e xs
- Case4 (WhenIRE (t `PairE` LiftE i `PairE` e)) -> SuperclassProj t i e
+ Case2 (t `PairE` f `PairE` ListE x) -> StuckTabApp t f x
+ Case3 (WhenIRE (t `PairE` e)) -> StuckUnwrap t e
+ Case4 (WhenIRE (t `PairE` e `PairE` ListE xs)) -> InstantiatedGiven t e xs
+ Case5 (WhenIRE (t `PairE` LiftE i `PairE` e)) -> SuperclassProj t i e
_ -> error "impossible"
{-# INLINE toE #-}
diff --git a/src/lib/Vectorize.hs b/src/lib/Vectorize.hs
index f18fd4fa..3be5058b 100644
--- a/src/lib/Vectorize.hs
+++ b/src/lib/Vectorize.hs
@@ -536,6 +536,8 @@ vectorizeAtom atom = addVectErrCtx "vectorizeAtom" ("Atom:\n" ++ pprint atom) do
_ -> throwVectErr "Invalid projection"
x'' <- reduceProj i x'
return $ VVal ov x''
+ -- TODO: think about this case
+ StuckTabApp _ _ _ -> throwVectErr $ "Cannot vectorize atom: " ++ pprint atom
Con (Lit l) -> return $ VVal Uniform $ Con $ Lit l
_ -> do
subst <- getSubst