diff options
author | Dougal <d.maclaurin@gmail.com> | 2023-11-09 21:41:14 -0500 |
---|---|---|
committer | Dougal <d.maclaurin@gmail.com> | 2023-11-09 21:41:14 -0500 |
commit | 1b2d252b45592476d32d1e91d69b5def8150d834 (patch) | |
tree | f6c51a467487df7d736ed073a748c3d7c5661fa9 | |
parent | 9672aa3afe288ba110a6c37616fc047515e09f0b (diff) |
Add some missing linearity annotations.
We really need to build the linearity checker.
-rw-r--r-- | src/lib/Builder.hs | 19 | ||||
-rw-r--r-- | src/lib/Inference.hs | 2 | ||||
-rw-r--r-- | src/lib/Linearize.hs | 18 | ||||
-rw-r--r-- | src/lib/Simplify.hs | 2 | ||||
-rw-r--r-- | src/lib/Transpose.hs | 18 |
5 files changed, 28 insertions, 31 deletions
diff --git a/src/lib/Builder.hs b/src/lib/Builder.hs index 0bb8e8fe..caebe262 100644 --- a/src/lib/Builder.hs +++ b/src/lib/Builder.hs @@ -765,23 +765,23 @@ mkTypedHof hof = do effTy <- effTyOfHof hof return $ TypedHof effTy hof -buildForAnn - :: (Emits n, ScopableBuilder r m) +mkFor + :: (ScopableBuilder r m) => NameHint -> ForAnn -> IxType r n -> (forall l. (Emits l, DExt n l) => AtomVar r l -> m l (Atom r l)) - -> m n (Atom r n) -buildForAnn hint ann (IxType iTy ixDict) body = do + -> m n (Expr r n) +mkFor hint ann (IxType iTy ixDict) body = do lam <- withFreshBinder hint iTy \b -> do let v = binderVar b body' <- buildBlock $ body $ sink v return $ LamExpr (UnaryNest b) body' - emitHof $ For ann (IxType iTy ixDict) lam + liftM toExpr $ mkTypedHof $ For ann (IxType iTy ixDict) lam buildFor :: (Emits n, ScopableBuilder r m) => NameHint -> Direction -> IxType r n -> (forall l. (Emits l, DExt n l) => AtomVar r l -> m l (Atom r l)) -> m n (Atom r n) -buildFor hint dir ty body = buildForAnn hint dir ty body +buildFor hint ann ty body = mkFor hint ann ty body >>= emit buildMap :: (Emits n, ScopableBuilder SimpIR m) => SAtom n @@ -853,6 +853,10 @@ emitLin e = case toExpr e of expr -> liftM toAtom $ emitDecl noHint LinearLet $ peepholeExpr expr {-# INLINE emitLin #-} +emitHofLin :: (Builder r m, Emits n) => Hof r n -> m n (Atom r n) +emitHofLin hof = mkTypedHof hof >>= emitLin +{-# INLINE emitHofLin #-} + zeroAt :: (Emits n, SBuilder m) => SType n -> m n (SAtom n) zeroAt ty = liftEmitBuilder $ go ty where go :: Emits n => SType n -> BuilderM SimpIR n (SAtom n) @@ -1100,9 +1104,8 @@ mkApplyMethod d i xs = do mkInstanceDict :: EnvReader m => InstanceName n -> [CAtom n] -> m n (CDict n) mkInstanceDict instanceName args = do instanceDef@(InstanceDef className _ _ _ _) <- lookupInstanceDef instanceName - sourceName <- getSourceName <$> lookupClassDef className PairE (ListE params) _ <- instantiate instanceDef args - let ty = toType $ DictType sourceName className params + ty <- toType <$> dictType className params return $ toDict $ InstanceDict ty instanceName args mkCase :: (EnvReader m, IRRep r) => Atom r n -> Type r n -> [Alt r n] -> m n (Expr r n) diff --git a/src/lib/Inference.hs b/src/lib/Inference.hs index b9253952..48a672c8 100644 --- a/src/lib/Inference.hs +++ b/src/lib/Inference.hs @@ -2013,7 +2013,7 @@ generalizeDict ty dict = do result <- liftEnvReaderT $ liftInfererM $ generalizeDictRec ty dict case result of Failure e -> error $ "Failed to generalize " ++ pprint dict - ++ " to " ++ pprint ty ++ " because " ++ pprint e + ++ " to " ++ show ty ++ " because " ++ pprint e Success ans -> return ans generalizeDictRec :: CType n -> CDict n -> InfererM i n (CDict n) diff --git a/src/lib/Linearize.hs b/src/lib/Linearize.hs index 87054836..98bbb7d3 100644 --- a/src/lib/Linearize.hs +++ b/src/lib/Linearize.hs @@ -270,7 +270,7 @@ applyLinLam :: Emits o => SLam i -> SubstReaderT AtomSubstVal TangentM i o (Atom applyLinLam (LamExpr bs body) = do TangentArgs args <- liftSubstReaderT $ getTangentArgs extendSubst (bs @@> ((Rename . atomVarName) <$> args)) do - substM body >>= emit + substM body >>= emitLin -- === actual linearization passs === @@ -299,7 +299,7 @@ linearizeTopLam (TopLam False _ (LamExpr bs body)) actives = do ts <- getUnpacked $ toAtom $ sink $ binderVar bTangent let substFrag = bsRecon @@> map (SubstVal . sink) xs <.> bsTangent @@> map (SubstVal . sink) ts - emit =<< applySubst substFrag tangentBody + emitLin =<< applySubst substFrag tangentBody return $ LamExpr (bs' >>> BinaryNest bResidual bTangent) tangentBody' return (primalFun, tangentFun) (,) <$> asTopLam primalFun <*> asTopLam tangentFun @@ -358,7 +358,7 @@ linearizeDecls (Nest (Let b (DeclBinding ann expr)) rest) cont = do WithTangent pRest tfRest <- linearizeDecls rest cont return $ WithTangent pRest do t <- tf - vt <- emitDecl (getNameHint b) ann (Atom t) + vt <- emitDecl (getNameHint b) LinearLet (Atom t) extendTangentArgs vt $ tfRest @@ -410,7 +410,7 @@ linearizeExpr expr = case expr of (primal, residualss) <- fromPair result resultTangentType <- tangentType resultTy' return $ WithTangent primal do - buildCase (sink residualss) (sink resultTangentType) \i residuals -> do + emitLin =<< buildCase' (sink residualss) (sink resultTangentType) \i residuals -> do ObligateRecon _ (Abs bs linLam) <- return $ sinkList recons !! i residuals' <- unpackTelescope bs residuals withSubstReaderT $ extendSubst (bs @@> (SubstVal <$> residuals')) do @@ -613,13 +613,13 @@ linearizeHof hof = case hof of TrivialRecon linLam' -> return $ WithTangent primalsAux do Abs ib'' linLam'' <- sinkM (Abs ib' linLam') - withSubstReaderT $ buildFor noHint d (sink ixTy) \i' -> do + withSubstReaderT $ emitLin =<< mkFor noHint d (sink ixTy) \i' -> do extendSubst (ib''@>Rename (atomVarName i')) $ applyLinLam linLam'' ReconWithData reconAbs -> do primals <- buildMap primalsAux getFst return $ WithTangent primals do Abs ib'' (Abs bs linLam') <- sinkM (Abs ib' reconAbs) - withSubstReaderT $ buildFor noHint d (sink ixTy) \i' -> do + withSubstReaderT $ emitLin =<< mkFor noHint d (sink ixTy) \i' -> do extendSubst (ib''@> Rename (atomVarName i')) do residuals' <- tabApp (sink primalsAux) (toAtom i') >>= getSnd >>= unpackTelescope bs extendSubst (bs @@> (SubstVal <$> residuals')) $ @@ -636,7 +636,7 @@ linearizeHof hof = case hof of tanEffLam <- buildEffLam noHint tt \h ref -> extendTangentArgss [h, ref] do withSubstReaderT $ applyLinLam $ sink linLam - emitHof $ RunReader rLin' tanEffLam + emitHofLin $ RunReader rLin' tanEffLam RunState Nothing sInit lam -> do WithTangent sInit' sLin <- linearizeAtom sInit (lam', recon) <- linearizeEffectFun State lam @@ -649,7 +649,7 @@ linearizeHof hof = case hof of tanEffLam <- buildEffLam noHint tt \h ref -> extendTangentArgss [h, ref] do withSubstReaderT $ applyLinLam $ sink linLam - emitHof $ RunState Nothing sLin' tanEffLam + emitHofLin $ RunState Nothing sLin' tanEffLam RunWriter Nothing bm lam -> do -- TODO: check it's actually the 0/+ monoid (or should we just build that in?) bm' <- renameM bm @@ -663,7 +663,7 @@ linearizeHof hof = case hof of tanEffLam <- buildEffLam noHint tt \h ref -> extendTangentArgss [h, ref] do withSubstReaderT $ applyLinLam $ sink linLam - emitHof $ RunWriter Nothing bm'' tanEffLam + emitHofLin $ RunWriter Nothing bm'' tanEffLam RunIO body -> do (body', recon) <- linearizeExprDefunc body primalAux <- emitHof $ RunIO body' diff --git a/src/lib/Simplify.hs b/src/lib/Simplify.hs index 9997df23..129039bd 100644 --- a/src/lib/Simplify.hs +++ b/src/lib/Simplify.hs @@ -1056,7 +1056,7 @@ exceptToMaybeExpr expr = case expr of return $ JustAtom ty x' PrimOp (Hof (TypedHof _ (For ann ixTy' (UnaryLamExpr b body)))) -> do ixTy <- substM ixTy' - maybes <- buildForAnn (getNameHint b) ann ixTy \i -> do + maybes <- buildFor (getNameHint b) ann ixTy \i -> do extendSubst (b@>Rename (atomVarName i)) $ exceptToMaybeExpr body catMaybesE maybes PrimOp (MiscOp (ThrowException _)) -> do diff --git a/src/lib/Transpose.hs b/src/lib/Transpose.hs index 302ca9e4..10c87d37 100644 --- a/src/lib/Transpose.hs +++ b/src/lib/Transpose.hs @@ -80,16 +80,12 @@ data TransposeSubstVal c n where type TransposeM a = SubstReaderT TransposeSubstVal (BuilderM SimpIR) a --- TODO: it might make sense to replace substNonlin/isLin --- with a single `trySubtNonlin :: e i -> Maybe (e o)`. --- But for that we need a way to traverse names, like a monadic --- version of `substE`. -substNonlin :: (SinkableE e, RenameE e, HasCallStack) => e i -> TransposeM i o (e o) +substNonlin :: (PrettyE e, SinkableE e, RenameE e, HasCallStack) => e i -> TransposeM i o (e o) substNonlin e = do subst <- getSubst fmapRenamingM (\v -> case subst ! v of RenameNonlin v' -> v' - _ -> error "not a nonlinear expression") e + _ -> error $ "not a nonlinear expression: " ++ pprint e) e withAccumulator :: Emits o @@ -113,7 +109,7 @@ withAccumulator ty cont = do emitCTToRef :: (Emits n, Builder SimpIR m) => SAtom n -> SAtom n -> m n () emitCTToRef ref ct = do baseMonoid <- tangentBaseMonoidFor (getType ct) - void $ emit $ RefOp ref $ MExtend baseMonoid ct + void $ emitLin $ RefOp ref $ MExtend baseMonoid ct -- === actual pass === @@ -190,7 +186,7 @@ transposeOp op ct = case op of DAMOp _ -> error "unreachable" -- TODO: rule out statically RefOp refArg m -> do refArg' <- substNonlin refArg - let emitEff = emit . RefOp refArg' + let emitEff = emitLin . RefOp refArg' case m of MAsk -> do baseMonoid <- tangentBaseMonoidFor (getType ct) @@ -251,9 +247,7 @@ transposeAtom atom ct = case atom of PtrVar _ _ -> notTangent Var v -> do lookupSubstM (atomVarName v) >>= \case - RenameNonlin _ -> - -- XXX: we seem to need this case, but it feels like it should be an error! - return () + RenameNonlin _ -> error "nonlinear" LinRef ref -> emitCTToRef ref ct LinTrivial -> return () StuckProject _ _ -> error "not linear" @@ -266,7 +260,7 @@ transposeHof hof ct = case hof of For ann ixTy' lam -> do UnaryLamExpr b body <- return lam ixTy <- substNonlin ixTy' - void $ buildForAnn (getNameHint b) (flipDir ann) ixTy \i -> do + void $ emitLin =<< mkFor (getNameHint b) (flipDir ann) ixTy \i -> do ctElt <- tabApp (sink ct) (toAtom i) extendSubst (b@>RenameNonlin (atomVarName i)) $ transposeExpr body ctElt return UnitVal |