diff options
author | Dougal <d.maclaurin@gmail.com> | 2023-10-22 21:06:15 -0400 |
---|---|---|
committer | Dougal <d.maclaurin@gmail.com> | 2023-10-22 21:06:15 -0400 |
commit | de88bf8bfcf164fc90d603c61539274455498f96 (patch) | |
tree | f7ec9424568ee4573ac1a26056f6efe55a3252a1 | |
parent | 0ff323354d426aead464b8a597d98a283da370be (diff) |
Factor out the way Simplify handles ACase.
-rw-r--r-- | src/lib/Simplify.hs | 209 |
1 files changed, 98 insertions, 111 deletions
diff --git a/src/lib/Simplify.hs b/src/lib/Simplify.hs index 09c96654..b597ffa9 100644 --- a/src/lib/Simplify.hs +++ b/src/lib/Simplify.hs @@ -104,6 +104,75 @@ tryAsDataAtom atom = do where notData = error $ "Not runtime-representable data: " ++ pprint atom +data WithSubst (e::E) (o::S) where + WithSubst :: Subst AtomSubstVal i o -> e i -> WithSubst e o + +data ConcreteCAtom (n::S) = + CCCon (WithSubst CAtom n) -- can't be Stuck or SimpInCore + | CCSimpInCore (SimpInCore n) -- can't be ACase + | CCNoInlineFun (CAtomVar n) (CType n) (CAtom n) + | CCFFIFun (CorePiType n) (TopFunName n) + +-- Yields to the continuation a term with a concrete CoreIR constructor, +-- or LiftSimpFun, liftSimp, or TabLam. +forceConstructor + :: Emits o + => CAtom i + -> (forall o' i'. (DExt o o', Emits o') => ConcreteCAtom o'-> SimplifyM i' o' (CAtom o')) + -> SimplifyM i o (CAtom o) +forceConstructor atom cont = withDistinct case atom of + Stuck stuck -> forceStuck stuck cont + SimpInCore lifted -> case lifted of + ACase e alts resultTy -> do + e' <- substM e + resultTy' <- substM resultTy + defuncCase e' resultTy' \i x -> do + Abs b body <- return $ alts !! i + extendSubst (b@>SubstVal x) do + forceConstructor body cont + _ -> do + lifted' <- substM lifted + cont $ CCSimpInCore lifted' + _ -> do + Distinct <- getDistinct + subst <- getSubst + cont $ CCCon $ WithSubst subst atom + +forceStuck + :: Emits o + => CStuck i + -> (forall o' i'. (DExt o o', Emits o') => ConcreteCAtom o'-> SimplifyM i' o' (CAtom o')) + -> SimplifyM i o (CAtom o) +forceStuck stuck cont = withDistinct case stuck of + StuckVar v -> lookupSubstM (atomVarName v) >>= \case + SubstVal x -> dropSubst $ forceConstructor x cont + Rename v' -> lookupAtomName v' >>= \case + LetBound (DeclBinding _ (Atom x)) -> dropSubst $ forceConstructor x cont + NoinlineFun t f -> do + v'' <- toAtomVar v' + cont $ CCNoInlineFun v'' t f + FFIFunBound t f -> cont $ CCFFIFun t f + _ -> error "shouldn't have other CVars left" + -- TODO: figure out how to de-dup these cases with their Expr counterpart + StuckProject _ i x -> do + ty <- substM $ getType stuck + forceStuck x \case + CCSimpInCore (LiftSimp _ x') -> do + x'' <- proj i x' + cont $ CCSimpInCore $ LiftSimp (sink ty) x'' + CCCon (WithSubst s con) -> withSubst s case con of + ProdVal xs -> forceConstructor (xs!!i) cont + DepPair l r _ -> forceConstructor ([l, r]!!i) cont + _ -> error "Can't project stuck term" + _ -> error "Can't project stuck term" + StuckUnwrap _ x -> forceStuck x \case + CCCon (WithSubst s con) -> withSubst s case con of + NewtypeCon _ x' -> forceConstructor x' cont + _ -> error "can't unwrap stuck term" + _ -> error "can't unwrap stuck term" + InstantiatedGiven _ _ _ -> error "shouldn't have this left" + SuperclassProj _ _ _ -> error "shouldn't have this left" + forceTabLam :: Emits n => TabLamExpr n -> SimplifyM i n (SAtom n) forceTabLam (PairE ixTy (Abs b ab)) = buildFor (getNameHint b) Fwd ixTy \v -> do @@ -315,8 +384,7 @@ simplifyExpr expr = confuseGHC >>= \_ -> case expr of simplifyApp ty' f xs' TabApp _ f xs -> do xs' <- mapM simplifyAtom xs - f' <- simplifyAtom f - simplifyTabApp f' xs' + simplifyTabApp f xs' Atom x -> simplifyAtom x PrimOp op -> simplifyOp op ApplyMethod (EffTy _ ty) dict i xs -> do @@ -379,6 +447,7 @@ defuncCaseCore scrut resultTy cont = do let xCoreTy = altBinderTys !! i x' <- liftSimpAtom (sink xCoreTy) x cont i x' + -- TODO: we should use forceConstructor here Nothing -> case trySelectBranch scrut of Just (i, arg) -> getDistinct >>= \Distinct -> cont i arg Nothing -> go scrut where @@ -449,61 +518,21 @@ simplifyAlt split ty cont = do simplifyApp :: forall i o. Emits o => CType o -> CAtom i -> [CAtom o] -> SimplifyM i o (CAtom o) -simplifyApp resultTy f xs = case f of - Lam (CoreLamExpr _ lam) -> fast lam - _ -> slow =<< simplifyAtomAndInline f - where - fast :: LamExpr CoreIR i' -> SimplifyM i' o (CAtom o) - fast lam = withInstantiated lam xs \body -> simplifyExpr body - - slow :: CAtom o -> SimplifyM i o (CAtom o) - slow = \case - Lam (CoreLamExpr _ lam) -> dropSubst $ fast lam - SimpInCore (ACase e alts _) -> dropSubst do - defuncCase e resultTy \i x -> do - Abs b body <- return $ alts !! i - extendSubst (b@>SubstVal x) do - xs' <- mapM sinkM xs - simplifyApp (sink resultTy) body xs' - SimpInCore (LiftSimpFun _ lam) -> do - xs' <- mapM toDataAtomIgnoreRecon xs - result <- instantiate lam xs' >>= emitExpr - liftSimpAtom resultTy result - Var v -> do - lookupAtomName (atomVarName v) >>= \case - NoinlineFun _ _ -> simplifyTopFunApp v xs - FFIFunBound _ f' -> do - xs' <- mapM toDataAtomIgnoreRecon xs - liftSimpAtom resultTy =<< naryTopApp f' xs' - b -> error $ "Should only have noinline functions left " ++ pprint b - atom -> error $ "Unexpected function: " ++ pprint atom - --- | Like `simplifyAtom`, but will try to inline function definitions found --- in the environment. The only exception is when we're going to differentiate --- and the function has a custom derivative rule defined. --- TODO(dougalm): do we still need this? -simplifyAtomAndInline :: CAtom i -> SimplifyM i o (CAtom o) -simplifyAtomAndInline atom = confuseGHC >>= \_ -> case atom of - Var v -> do - env <- getSubst - case env ! atomVarName v of - Rename v' -> doInline =<< toAtomVar v' - SubstVal (Var v') -> doInline v' - SubstVal x -> return x - -- This is a hack because we weren't normalize the unwrapping of - -- `unit_type_scale` in `plot.dx`. We need a better system for deciding how to - -- normalize and inline. - Stuck (StuckProject _ i x) -> do - x' <- simplifyStuck x >>= reduceProj i - dropSubst $ simplifyAtomAndInline x' - _ -> simplifyAtom atom >>= \case - Var v -> doInline v - ans -> return ans - where - doInline v = do - lookupAtomName (atomVarName v) >>= \case - LetBound (DeclBinding _ (Atom x)) -> dropSubst $ simplifyAtomAndInline x - _ -> return $ Var v +simplifyApp resultTy f xs = forceConstructor f \f' -> do + xs' <- mapM sinkM xs + case f' of + CCCon (WithSubst s (Lam (CoreLamExpr _ lam))) -> + withSubst s $ withInstantiated lam xs' \body -> + simplifyExpr body + CCSimpInCore (LiftSimpFun _ lam) -> do + xs'' <- mapM toDataAtomIgnoreRecon xs' + result <- instantiate lam xs'' >>= emitExpr + liftSimpAtom (sink resultTy) result + CCNoInlineFun v _ _ -> simplifyTopFunApp v xs' + CCFFIFun _ f'' -> do + xs'' <- mapM toDataAtomIgnoreRecon xs' + liftSimpAtom (sink resultTy) =<< naryTopApp f'' xs'' + _ -> error "not a function" simplifyTopFunApp :: Emits n => CAtomVar n -> [CAtom n] -> SimplifyM i n (CAtom n) simplifyTopFunApp fName xs = do @@ -547,33 +576,23 @@ specializedFunCoreDefinition (AppSpecialization f (Abs bs staticArgs)) = do naryApp f' staticArgs' simplifyTabApp :: forall i o. Emits o - => CAtom o -> [CAtom o] -> SimplifyM i o (CAtom o) -simplifyTabApp f [] = return f -simplifyTabApp f@(SimpInCore sic) xs = case sic of - TabLam _ _ -> do - case fromNaryTabLam (length xs) f of + => CAtom i -> [CAtom o] -> SimplifyM i o (CAtom o) +simplifyTabApp f [] = simplifyAtom f +simplifyTabApp f xs = forceConstructor f \case + CCSimpInCore sic@(TabLam _ _) -> do + case fromNaryTabLam (length xs) (SimpInCore sic) of Just (bsCount, ab) -> do - let (xsPref, xsRest) = splitAt bsCount xs + (xsPref, xsRest) <- splitAt bsCount <$> mapM sinkM xs xsPref' <- mapM toDataAtomIgnoreRecon xsPref block' <- instantiate ab xsPref' atom <- emitDecls block' - simplifyTabApp atom xsRest + dropSubst $ simplifyTabApp atom xsRest Nothing -> error "should never happen" - ACase e alts ty -> dropSubst do - resultTy <- typeOfTabApp ty xs - defuncCase e resultTy \i x -> do - Abs b body <- return $ alts !! i - extendSubst (b@>SubstVal x) do - xs' <- mapM sinkM xs - body' <- substM body - simplifyTabApp body' xs' - LiftSimp _ f' -> do - fTy <- return $ getType f - resultTy <- typeOfTabApp fTy xs - xs' <- mapM toDataAtomIgnoreRecon xs + CCSimpInCore (LiftSimp fTy f') -> do + resultTy <- typeOfTabApp fTy (sink<$>xs) + xs' <- mapM (toDataAtomIgnoreRecon . sink) xs liftSimpAtom resultTy =<< naryTabApp f' xs' - LiftSimpFun _ _ -> error "not implemented" -simplifyTabApp f _ = error $ "Unexpected table: " ++ pprint f + _ -> error "not a table" simplifyIxType :: IxType CoreIR o -> SimplifyM i o (IxType SimpIR o) simplifyIxType (IxType t ixDict) = do @@ -625,40 +644,8 @@ ixMethodType method absDict = do let allBs = extraArgBs >>> methodArgs return $ PiType allBs (EffTy Pure resultTy) --- TODO: do we even need this, or is it just a glorified `SubstM`? simplifyAtom :: CAtom i -> SimplifyM i o (CAtom o) -simplifyAtom atom = confuseGHC >>= \_ -> case atom of - Stuck e -> simplifyStuck e - Lam _ -> substM atom - DepPair x y ty -> DepPair <$> simplifyAtom x <*> simplifyAtom y <*> substM ty - Con con -> Con <$> traverseOp con substM simplifyAtom (error "unexpected lambda") - Eff eff -> Eff <$> substM eff - PtrVar t v -> PtrVar t <$> substM v - DictCon _ -> substM atom - NewtypeCon _ _ -> substM atom - SimpInCore _ -> substM atom - TypeAsAtom _ -> substM atom - -simplifyStuck :: CStuck i -> SimplifyM i o (CAtom o) -simplifyStuck = \case - StuckVar v -> simplifyVar v - StuckProject _ i x -> reduceProj i =<< simplifyStuck x - stuck -> substM (Stuck stuck) - -simplifyVar :: AtomVar CoreIR i -> SimplifyM i o (CAtom o) -simplifyVar v = do - env <- getSubst - case env ! atomVarName v of - SubstVal x -> return x - Rename v' -> do - AtomNameBinding bindingInfo <- lookupEnv v' - let ty = getType bindingInfo - case bindingInfo of - -- Functions get inlined only at application sites - LetBound (DeclBinding _ _) | isFun -> return $ Var $ AtomVar v' ty - where isFun = case ty of Pi _ -> True; _ -> False - LetBound (DeclBinding _ (Atom x)) -> dropSubst $ simplifyAtom x - _ -> return $ Var $ AtomVar v' ty +simplifyAtom = substM -- Assumes first order (args/results are "data", allowing newtypes), monormophic simplifyLam |