diff options
author | Dougal <d.maclaurin@gmail.com> | 2023-10-23 11:02:47 -0400 |
---|---|---|
committer | Dougal <d.maclaurin@gmail.com> | 2023-10-23 11:02:47 -0400 |
commit | d80f318b9ea90f32420c9aa49bc935ae2aed6324 (patch) | |
tree | 6bca935ec93e9464b4dfca0fb9c2cd3bb8c44cc3 | |
parent | de88bf8bfcf164fc90d603c61539274455498f96 (diff) |
Add a `StuckTabApp` case to `Stuck`
-rw-r--r-- | src/lib/Builder.hs | 12 | ||||
-rw-r--r-- | src/lib/CheapReduction.hs | 24 | ||||
-rw-r--r-- | src/lib/CheckType.hs | 7 | ||||
-rw-r--r-- | src/lib/Imp.hs | 6 | ||||
-rw-r--r-- | src/lib/Inference.hs | 11 | ||||
-rw-r--r-- | src/lib/Linearize.hs | 2 | ||||
-rw-r--r-- | src/lib/OccAnalysis.hs | 16 | ||||
-rw-r--r-- | src/lib/PPrint.hs | 1 | ||||
-rw-r--r-- | src/lib/QueryTypePure.hs | 1 | ||||
-rw-r--r-- | src/lib/Simplify.hs | 8 | ||||
-rw-r--r-- | src/lib/Transpose.hs | 4 | ||||
-rw-r--r-- | src/lib/Types/Core.hs | 18 | ||||
-rw-r--r-- | src/lib/Vectorize.hs | 2 |
13 files changed, 86 insertions, 26 deletions
diff --git a/src/lib/Builder.hs b/src/lib/Builder.hs index 52425566..5574aa44 100644 --- a/src/lib/Builder.hs +++ b/src/lib/Builder.hs @@ -129,6 +129,18 @@ buildScopedAssumeNoDecls cont = do _ -> error "Expected no decl emissions" {-# INLINE buildScopedAssumeNoDecls #-} +withReducibleEmissions + :: (ScopableBuilder r m, Builder r m, HasNamesE e, SubstE AtomSubstVal e) + => String + -> (forall o' . (Emits o', DExt o o') => m o' (e o')) + -> m o (e o) +withReducibleEmissions msg cont = do + withDecls <- buildScoped cont + reduceWithDecls withDecls >>= \case + Just t -> return t + _ -> throw TypeErr msg +{-# INLINE withReducibleEmissions #-} + -- === "Hoisting" top-level builder class === -- `emitHoistedEnv` lets you emit top env fragments, like cache entries or diff --git a/src/lib/CheapReduction.hs b/src/lib/CheapReduction.hs index c7fd9658..03d60730 100644 --- a/src/lib/CheapReduction.hs +++ b/src/lib/CheapReduction.hs @@ -109,8 +109,14 @@ reduceExprM = \case case (ty, val) of (BaseTy (Scalar Word32Type), Con (Lit (Word64Lit v))) -> return $ Con $ Lit $ Word32Lit $ fromIntegral v _ -> empty + TabApp ty tab xs -> do + ty' <- substM ty + xs' <- mapM substM xs + tab' <- substM tab + case tab' of + Stuck tab'' -> return $ Stuck $ StuckTabApp ty' tab'' xs' + _ -> error "not a table" -- what about RepVal? TopApp _ _ _ -> empty - TabApp _ _ _ -> empty Case _ _ _ -> empty TabCon _ _ _ -> empty PrimOp _ -> empty @@ -188,6 +194,11 @@ typeOfApp (Pi piTy) xs = withSubstReaderT $ withInstantiated piTy xs \(EffTy _ ty) -> substM ty typeOfApp _ _ = error "expected a pi type" +typeOfTabApp :: (IRRep r, EnvReader m) => Type r n -> [Atom r n] -> m n (Type r n) +typeOfTabApp (TabPi piTy) xs = withSubstReaderT $ + withInstantiated piTy xs \ty -> substM ty +typeOfTabApp _ _ = error "expected a TabPi type" + repValAtom :: EnvReader m => SRepVal n -> m n (SAtom n) repValAtom (RepVal ty tree) = case ty of ProdTy ts -> case tree of @@ -220,6 +231,13 @@ reduceUnwrapM = \case _ -> error "expected a newtype" _ -> empty +reduceTabAppM :: IRRep r => Atom r o -> [Atom r o] -> ReducerM i o (Atom r o) +reduceTabAppM tab xs = case tab of + Stuck tab' -> do + ty <- typeOfTabApp (getType tab') xs + return $ Stuck $ StuckTabApp ty tab' xs + _ -> error $ "not a table" ++ pprint tab + unwrapNewtypeType :: EnvReader m => NewtypeTyCon n -> m n (NewtypeCon n, Type CoreIR n) unwrapNewtypeType = \case Nat -> return (NatCon, IdxRepTy) @@ -616,6 +634,10 @@ reduceStuck = \case StuckUnwrap _ x -> do x' <- reduceStuck x dropSubst $ reduceUnwrapM x' + StuckTabApp _ f xs -> do + f' <- reduceStuck f + xs' <- mapM substM xs + dropSubst $ reduceTabAppM f' xs' InstantiatedGiven _ f xs -> do xs' <- mapM substM xs f' <- reduceStuck f diff --git a/src/lib/CheckType.hs b/src/lib/CheckType.hs index 6f3b2093..ad6698a9 100644 --- a/src/lib/CheckType.hs +++ b/src/lib/CheckType.hs @@ -320,6 +320,13 @@ instance IRRep r => CheckableE r (Stuck r) where StuckProject resultTy i x -> do Project resultTy' i' (Stuck x') <- checkWithEffects Pure $ Project resultTy i (Stuck x) return $ StuckProject resultTy' i' x' + StuckTabApp reqTy f xs -> do + reqTy' <- reqTy |: TyKind + (f', tabTy) <- checkAndGetType f + xs' <- mapM checkE xs + ty' <- checkTabApp tabTy xs' + checkTypesEq reqTy' ty' + return $ StuckTabApp reqTy' f' xs' InstantiatedGiven resultTy given args -> do resultTy' <- resultTy |: TyKind (given', Pi piTy) <- checkAndGetType given diff --git a/src/lib/Imp.hs b/src/lib/Imp.hs index 63047979..9ab013f2 100644 --- a/src/lib/Imp.hs +++ b/src/lib/Imp.hs @@ -865,9 +865,15 @@ atomToRepVal x = RepVal (getType x) <$> go x where Stuck (StuckVar v) -> lookupAtomName (atomVarName v) >>= \case TopDataBound (RepVal _ tree) -> return tree _ -> error "should only have pointer and data atom names left" + -- TODO: I think we want to be able to rule this one out by insisting that + -- RepValAtom is itself part of Stuck and it can't represent a product. Stuck (StuckProject _ i val) -> do Branch ts <- go $ Stuck val return $ ts !! i + Stuck (StuckTabApp _ f xs) -> do + f' <- atomToRepVal $ Stuck f + RepVal _ t <- naryIndexRepVal f' (toList xs) + return t -- XXX: We used to have a function called `destToAtom` which loaded the value -- from the dest. This version is not that. It just lifts a dest into an atom of diff --git a/src/lib/Inference.hs b/src/lib/Inference.hs index b89c77bd..276bd923 100644 --- a/src/lib/Inference.hs +++ b/src/lib/Inference.hs @@ -1089,17 +1089,6 @@ checkSigmaDependent e@(WithSrcE ctx _) ty = addSrcContext ctx $ "Dependent functions can only be applied to fully evaluated expressions. " ++ "Bind the argument to a name before you apply the function." -withReducibleEmissions - :: Zonkable e - => String - -> (forall o' . (Emits o', DExt o o') => InfererM i o' (e o')) - -> InfererM i o (e o) -withReducibleEmissions msg cont = do - withDecls <- buildScoped cont - reduceWithDecls withDecls >>= \case - Just t -> return t - _ -> throw TypeErr msg - -- === sorting case alternatives === data IndexedAlt n = IndexedAlt CaseAltIndex (Alt CoreIR n) diff --git a/src/lib/Linearize.hs b/src/lib/Linearize.hs index 4fbae982..42661564 100644 --- a/src/lib/Linearize.hs +++ b/src/lib/Linearize.hs @@ -334,7 +334,9 @@ linearizeAtom atom = case atom of activePrimalIdx v' >>= \case Nothing -> withZeroT $ return (Var v') Just idx -> return $ WithTangent (Var v') $ getTangentArg idx + -- TODO: buildScoped and reduce the results so we keep expression in non-ANF for type checking purposes Stuck (StuckProject ty i x) -> linearizeExpr $ Project ty i (Stuck x) + Stuck (StuckTabApp t f xs) -> linearizeExpr $ TabApp t (Stuck f) xs RepValAtom _ -> emitZeroT where emitZeroT = withZeroT $ renameM atom diff --git a/src/lib/OccAnalysis.hs b/src/lib/OccAnalysis.hs index 59b7a438..711374df 100644 --- a/src/lib/OccAnalysis.hs +++ b/src/lib/OccAnalysis.hs @@ -255,6 +255,11 @@ instance HasOCC SStuck where ty' <- occTy ty return $ StuckVar (AtomVar n ty') StuckProject t i x -> StuckProject <$> occ a t <*> pure i <*> occ a x + StuckTabApp t array ixs -> do + t' <- occTy t + (a', ixs') <- occIdxs a ixs + array' <- occ a' array + return $ StuckTabApp t' array' ixs' instance HasOCC SType where occ a ty = runOCCMVisitor a $ visitTypePartial ty @@ -360,7 +365,7 @@ instance HasOCC SExpr where return $ Block effTy' (Abs decls' ans') TabApp t array ixs -> do t' <- occTy t - (a', ixs') <- go a ixs + (a', ixs') <- occIdxs a ixs array' <- occ a' array return $ TabApp t' array' ixs' Case scrut alts (EffTy effs ty) -> do @@ -376,10 +381,11 @@ instance HasOCC SExpr where ref' <- occ a ref PrimOp . RefOp ref' <$> occ a op expr -> occGeneric a expr - where - go acc [] = return (acc, []) - go acc (ix:ixs) = do - (acc', ixs') <- go acc ixs + +occIdxs :: Access n -> [SAtom n] -> OCCM n (Access n, [SAtom n]) +occIdxs acc [] = return (acc, []) +occIdxs acc (ix:ixs) = do + (acc', ixs') <- occIdxs acc ixs (summ, ix') <- occurrenceAndSummary ix return (location summ acc', ix':ixs') diff --git a/src/lib/PPrint.hs b/src/lib/PPrint.hs index 7e870889..968193a0 100644 --- a/src/lib/PPrint.hs +++ b/src/lib/PPrint.hs @@ -265,6 +265,7 @@ instance IRRep r => PrettyPrec (Stuck r n) where prettyPrec = \case StuckVar v -> atPrec ArgPrec $ p v StuckProject _ i v -> atPrec LowestPrec $ "StuckProject" <+> p i <+> p v + StuckTabApp _ f xs -> atPrec AppPrec $ pArg f <> "." <> pArg xs StuckUnwrap _ v -> atPrec LowestPrec $ "StuckUnwrap" <+> p v InstantiatedGiven _ v args -> atPrec LowestPrec $ "Given" <+> p v <+> p (toList args) SuperclassProj _ d' i -> atPrec LowestPrec $ "SuperclassProj" <+> p d' <+> p i diff --git a/src/lib/QueryTypePure.hs b/src/lib/QueryTypePure.hs index 99b4687b..258bbb9b 100644 --- a/src/lib/QueryTypePure.hs +++ b/src/lib/QueryTypePure.hs @@ -102,6 +102,7 @@ instance IRRep r => HasType r (Stuck r) where getType = \case StuckVar (AtomVar _ t) -> t StuckProject t _ _ -> t + StuckTabApp t _ _ -> t StuckUnwrap t _ -> t InstantiatedGiven t _ _ -> t SuperclassProj t _ _ -> t diff --git a/src/lib/Simplify.hs b/src/lib/Simplify.hs index b597ffa9..0f71998f 100644 --- a/src/lib/Simplify.hs +++ b/src/lib/Simplify.hs @@ -165,6 +165,14 @@ forceStuck stuck cont = withDistinct case stuck of DepPair l r _ -> forceConstructor ([l, r]!!i) cont _ -> error "Can't project stuck term" _ -> error "Can't project stuck term" + StuckTabApp _ f xs -> do + ty <- substM $ getType stuck + xs' <- forM xs \x -> toDataAtomIgnoreRecon =<< substM x + forceStuck f \case + CCSimpInCore (LiftSimp _ f') -> do + result <- naryTabApp f' (sink<$>xs') + cont $ CCSimpInCore $ LiftSimp (sink ty) result + _ -> error "not a table" -- what about table lambda? StuckUnwrap _ x -> forceStuck x \case CCCon (WithSubst s con) -> withSubst s case con of NewtypeCon _ x' -> forceConstructor x' cont diff --git a/src/lib/Transpose.hs b/src/lib/Transpose.hs index 0cde8255..75e14ec7 100644 --- a/src/lib/Transpose.hs +++ b/src/lib/Transpose.hs @@ -313,8 +313,8 @@ transposeAtom atom ct = case atom of return () LinRef ref -> emitCTToRef ref ct LinTrivial -> return () - Stuck (StuckProject _ _ _) -> undefined - -- Stuck (StuckProject _ i' x') -> do + Stuck (StuckProject _ _ _) -> error "not implemented" + Stuck (StuckTabApp _ _ _) -> error "not implemented" -- let (idxs, v) = asNaryProj i' x' -- lookupSubstM (atomVarName v) >>= \case -- RenameNonlin _ -> error "an error, probably" diff --git a/src/lib/Types/Core.hs b/src/lib/Types/Core.hs index 2ef17f6d..95380bbc 100644 --- a/src/lib/Types/Core.hs +++ b/src/lib/Types/Core.hs @@ -74,6 +74,7 @@ data Type (r::IR) (n::S) where data Stuck (r::IR) (n::S) where StuckVar :: AtomVar r n -> Stuck r n StuckProject :: Type r n -> Int -> Stuck r n -> Stuck r n + StuckTabApp :: Type r n -> Stuck r n -> [Atom r n] -> Stuck r n StuckUnwrap :: CType n -> CStuck n -> Stuck CoreIR n InstantiatedGiven :: CType n -> CStuck n -> [CAtom n] -> Stuck CoreIR n SuperclassProj :: CType n -> Int -> CStuck n -> Stuck CoreIR n @@ -1552,26 +1553,29 @@ instance IRRep r => AlphaHashableE (Atom r) instance IRRep r => RenameE (Atom r) instance IRRep r => GenericE (Stuck r) where - type RepE (Stuck r) = EitherE5 + type RepE (Stuck r) = EitherE6 {- StuckVar -} (AtomVar r) {- StuckProject -} (Type r `PairE` LiftE Int `PairE` Stuck r) + {- StuckTabApp -} (Type r `PairE` Stuck r `PairE` ListE (Atom r)) {- StuckUnwrap -} (WhenCore r (CType `PairE` CStuck)) {- InstantiatedGiven -} (WhenCore r (CType `PairE` CStuck `PairE` ListE CAtom)) {- SuperclassProj -} (WhenCore r (CType `PairE` LiftE Int `PairE` CStuck)) fromE = \case StuckVar v -> Case0 v StuckProject t i e -> Case1 $ t `PairE` LiftE i `PairE` e - StuckUnwrap t e -> Case2 $ WhenIRE $ t `PairE` e - InstantiatedGiven t e xs -> Case3 $ WhenIRE $ t `PairE` e `PairE` ListE xs - SuperclassProj t i e -> Case4 $ WhenIRE $ t `PairE` LiftE i `PairE` e + StuckTabApp t f x -> Case2 $ t `PairE` f `PairE` ListE x + StuckUnwrap t e -> Case3 $ WhenIRE $ t `PairE` e + InstantiatedGiven t e xs -> Case4 $ WhenIRE $ t `PairE` e `PairE` ListE xs + SuperclassProj t i e -> Case5 $ WhenIRE $ t `PairE` LiftE i `PairE` e {-# INLINE fromE #-} toE = \case Case0 v -> StuckVar v Case1 (t `PairE` LiftE i `PairE` e) -> StuckProject t i e - Case2 (WhenIRE (t `PairE` e)) -> StuckUnwrap t e - Case3 (WhenIRE (t `PairE` e `PairE` ListE xs)) -> InstantiatedGiven t e xs - Case4 (WhenIRE (t `PairE` LiftE i `PairE` e)) -> SuperclassProj t i e + Case2 (t `PairE` f `PairE` ListE x) -> StuckTabApp t f x + Case3 (WhenIRE (t `PairE` e)) -> StuckUnwrap t e + Case4 (WhenIRE (t `PairE` e `PairE` ListE xs)) -> InstantiatedGiven t e xs + Case5 (WhenIRE (t `PairE` LiftE i `PairE` e)) -> SuperclassProj t i e _ -> error "impossible" {-# INLINE toE #-} diff --git a/src/lib/Vectorize.hs b/src/lib/Vectorize.hs index f18fd4fa..3be5058b 100644 --- a/src/lib/Vectorize.hs +++ b/src/lib/Vectorize.hs @@ -536,6 +536,8 @@ vectorizeAtom atom = addVectErrCtx "vectorizeAtom" ("Atom:\n" ++ pprint atom) do _ -> throwVectErr "Invalid projection" x'' <- reduceProj i x' return $ VVal ov x'' + -- TODO: think about this case + StuckTabApp _ _ _ -> throwVectErr $ "Cannot vectorize atom: " ++ pprint atom Con (Lit l) -> return $ VVal Uniform $ Con $ Lit l _ -> do subst <- getSubst |