summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorDougal <d.maclaurin@gmail.com>2023-11-09 21:41:14 -0500
committerDougal <d.maclaurin@gmail.com>2023-11-09 21:41:14 -0500
commit1b2d252b45592476d32d1e91d69b5def8150d834 (patch)
treef6c51a467487df7d736ed073a748c3d7c5661fa9
parent9672aa3afe288ba110a6c37616fc047515e09f0b (diff)
Add some missing linearity annotations.
We really need to build the linearity checker.
-rw-r--r--src/lib/Builder.hs19
-rw-r--r--src/lib/Inference.hs2
-rw-r--r--src/lib/Linearize.hs18
-rw-r--r--src/lib/Simplify.hs2
-rw-r--r--src/lib/Transpose.hs18
5 files changed, 28 insertions, 31 deletions
diff --git a/src/lib/Builder.hs b/src/lib/Builder.hs
index 0bb8e8fe..caebe262 100644
--- a/src/lib/Builder.hs
+++ b/src/lib/Builder.hs
@@ -765,23 +765,23 @@ mkTypedHof hof = do
effTy <- effTyOfHof hof
return $ TypedHof effTy hof
-buildForAnn
- :: (Emits n, ScopableBuilder r m)
+mkFor
+ :: (ScopableBuilder r m)
=> NameHint -> ForAnn -> IxType r n
-> (forall l. (Emits l, DExt n l) => AtomVar r l -> m l (Atom r l))
- -> m n (Atom r n)
-buildForAnn hint ann (IxType iTy ixDict) body = do
+ -> m n (Expr r n)
+mkFor hint ann (IxType iTy ixDict) body = do
lam <- withFreshBinder hint iTy \b -> do
let v = binderVar b
body' <- buildBlock $ body $ sink v
return $ LamExpr (UnaryNest b) body'
- emitHof $ For ann (IxType iTy ixDict) lam
+ liftM toExpr $ mkTypedHof $ For ann (IxType iTy ixDict) lam
buildFor :: (Emits n, ScopableBuilder r m)
=> NameHint -> Direction -> IxType r n
-> (forall l. (Emits l, DExt n l) => AtomVar r l -> m l (Atom r l))
-> m n (Atom r n)
-buildFor hint dir ty body = buildForAnn hint dir ty body
+buildFor hint ann ty body = mkFor hint ann ty body >>= emit
buildMap :: (Emits n, ScopableBuilder SimpIR m)
=> SAtom n
@@ -853,6 +853,10 @@ emitLin e = case toExpr e of
expr -> liftM toAtom $ emitDecl noHint LinearLet $ peepholeExpr expr
{-# INLINE emitLin #-}
+emitHofLin :: (Builder r m, Emits n) => Hof r n -> m n (Atom r n)
+emitHofLin hof = mkTypedHof hof >>= emitLin
+{-# INLINE emitHofLin #-}
+
zeroAt :: (Emits n, SBuilder m) => SType n -> m n (SAtom n)
zeroAt ty = liftEmitBuilder $ go ty where
go :: Emits n => SType n -> BuilderM SimpIR n (SAtom n)
@@ -1100,9 +1104,8 @@ mkApplyMethod d i xs = do
mkInstanceDict :: EnvReader m => InstanceName n -> [CAtom n] -> m n (CDict n)
mkInstanceDict instanceName args = do
instanceDef@(InstanceDef className _ _ _ _) <- lookupInstanceDef instanceName
- sourceName <- getSourceName <$> lookupClassDef className
PairE (ListE params) _ <- instantiate instanceDef args
- let ty = toType $ DictType sourceName className params
+ ty <- toType <$> dictType className params
return $ toDict $ InstanceDict ty instanceName args
mkCase :: (EnvReader m, IRRep r) => Atom r n -> Type r n -> [Alt r n] -> m n (Expr r n)
diff --git a/src/lib/Inference.hs b/src/lib/Inference.hs
index b9253952..48a672c8 100644
--- a/src/lib/Inference.hs
+++ b/src/lib/Inference.hs
@@ -2013,7 +2013,7 @@ generalizeDict ty dict = do
result <- liftEnvReaderT $ liftInfererM $ generalizeDictRec ty dict
case result of
Failure e -> error $ "Failed to generalize " ++ pprint dict
- ++ " to " ++ pprint ty ++ " because " ++ pprint e
+ ++ " to " ++ show ty ++ " because " ++ pprint e
Success ans -> return ans
generalizeDictRec :: CType n -> CDict n -> InfererM i n (CDict n)
diff --git a/src/lib/Linearize.hs b/src/lib/Linearize.hs
index 87054836..98bbb7d3 100644
--- a/src/lib/Linearize.hs
+++ b/src/lib/Linearize.hs
@@ -270,7 +270,7 @@ applyLinLam :: Emits o => SLam i -> SubstReaderT AtomSubstVal TangentM i o (Atom
applyLinLam (LamExpr bs body) = do
TangentArgs args <- liftSubstReaderT $ getTangentArgs
extendSubst (bs @@> ((Rename . atomVarName) <$> args)) do
- substM body >>= emit
+ substM body >>= emitLin
-- === actual linearization passs ===
@@ -299,7 +299,7 @@ linearizeTopLam (TopLam False _ (LamExpr bs body)) actives = do
ts <- getUnpacked $ toAtom $ sink $ binderVar bTangent
let substFrag = bsRecon @@> map (SubstVal . sink) xs
<.> bsTangent @@> map (SubstVal . sink) ts
- emit =<< applySubst substFrag tangentBody
+ emitLin =<< applySubst substFrag tangentBody
return $ LamExpr (bs' >>> BinaryNest bResidual bTangent) tangentBody'
return (primalFun, tangentFun)
(,) <$> asTopLam primalFun <*> asTopLam tangentFun
@@ -358,7 +358,7 @@ linearizeDecls (Nest (Let b (DeclBinding ann expr)) rest) cont = do
WithTangent pRest tfRest <- linearizeDecls rest cont
return $ WithTangent pRest do
t <- tf
- vt <- emitDecl (getNameHint b) ann (Atom t)
+ vt <- emitDecl (getNameHint b) LinearLet (Atom t)
extendTangentArgs vt $
tfRest
@@ -410,7 +410,7 @@ linearizeExpr expr = case expr of
(primal, residualss) <- fromPair result
resultTangentType <- tangentType resultTy'
return $ WithTangent primal do
- buildCase (sink residualss) (sink resultTangentType) \i residuals -> do
+ emitLin =<< buildCase' (sink residualss) (sink resultTangentType) \i residuals -> do
ObligateRecon _ (Abs bs linLam) <- return $ sinkList recons !! i
residuals' <- unpackTelescope bs residuals
withSubstReaderT $ extendSubst (bs @@> (SubstVal <$> residuals')) do
@@ -613,13 +613,13 @@ linearizeHof hof = case hof of
TrivialRecon linLam' ->
return $ WithTangent primalsAux do
Abs ib'' linLam'' <- sinkM (Abs ib' linLam')
- withSubstReaderT $ buildFor noHint d (sink ixTy) \i' -> do
+ withSubstReaderT $ emitLin =<< mkFor noHint d (sink ixTy) \i' -> do
extendSubst (ib''@>Rename (atomVarName i')) $ applyLinLam linLam''
ReconWithData reconAbs -> do
primals <- buildMap primalsAux getFst
return $ WithTangent primals do
Abs ib'' (Abs bs linLam') <- sinkM (Abs ib' reconAbs)
- withSubstReaderT $ buildFor noHint d (sink ixTy) \i' -> do
+ withSubstReaderT $ emitLin =<< mkFor noHint d (sink ixTy) \i' -> do
extendSubst (ib''@> Rename (atomVarName i')) do
residuals' <- tabApp (sink primalsAux) (toAtom i') >>= getSnd >>= unpackTelescope bs
extendSubst (bs @@> (SubstVal <$> residuals')) $
@@ -636,7 +636,7 @@ linearizeHof hof = case hof of
tanEffLam <- buildEffLam noHint tt \h ref ->
extendTangentArgss [h, ref] do
withSubstReaderT $ applyLinLam $ sink linLam
- emitHof $ RunReader rLin' tanEffLam
+ emitHofLin $ RunReader rLin' tanEffLam
RunState Nothing sInit lam -> do
WithTangent sInit' sLin <- linearizeAtom sInit
(lam', recon) <- linearizeEffectFun State lam
@@ -649,7 +649,7 @@ linearizeHof hof = case hof of
tanEffLam <- buildEffLam noHint tt \h ref ->
extendTangentArgss [h, ref] do
withSubstReaderT $ applyLinLam $ sink linLam
- emitHof $ RunState Nothing sLin' tanEffLam
+ emitHofLin $ RunState Nothing sLin' tanEffLam
RunWriter Nothing bm lam -> do
-- TODO: check it's actually the 0/+ monoid (or should we just build that in?)
bm' <- renameM bm
@@ -663,7 +663,7 @@ linearizeHof hof = case hof of
tanEffLam <- buildEffLam noHint tt \h ref ->
extendTangentArgss [h, ref] do
withSubstReaderT $ applyLinLam $ sink linLam
- emitHof $ RunWriter Nothing bm'' tanEffLam
+ emitHofLin $ RunWriter Nothing bm'' tanEffLam
RunIO body -> do
(body', recon) <- linearizeExprDefunc body
primalAux <- emitHof $ RunIO body'
diff --git a/src/lib/Simplify.hs b/src/lib/Simplify.hs
index 9997df23..129039bd 100644
--- a/src/lib/Simplify.hs
+++ b/src/lib/Simplify.hs
@@ -1056,7 +1056,7 @@ exceptToMaybeExpr expr = case expr of
return $ JustAtom ty x'
PrimOp (Hof (TypedHof _ (For ann ixTy' (UnaryLamExpr b body)))) -> do
ixTy <- substM ixTy'
- maybes <- buildForAnn (getNameHint b) ann ixTy \i -> do
+ maybes <- buildFor (getNameHint b) ann ixTy \i -> do
extendSubst (b@>Rename (atomVarName i)) $ exceptToMaybeExpr body
catMaybesE maybes
PrimOp (MiscOp (ThrowException _)) -> do
diff --git a/src/lib/Transpose.hs b/src/lib/Transpose.hs
index 302ca9e4..10c87d37 100644
--- a/src/lib/Transpose.hs
+++ b/src/lib/Transpose.hs
@@ -80,16 +80,12 @@ data TransposeSubstVal c n where
type TransposeM a = SubstReaderT TransposeSubstVal (BuilderM SimpIR) a
--- TODO: it might make sense to replace substNonlin/isLin
--- with a single `trySubtNonlin :: e i -> Maybe (e o)`.
--- But for that we need a way to traverse names, like a monadic
--- version of `substE`.
-substNonlin :: (SinkableE e, RenameE e, HasCallStack) => e i -> TransposeM i o (e o)
+substNonlin :: (PrettyE e, SinkableE e, RenameE e, HasCallStack) => e i -> TransposeM i o (e o)
substNonlin e = do
subst <- getSubst
fmapRenamingM (\v -> case subst ! v of
RenameNonlin v' -> v'
- _ -> error "not a nonlinear expression") e
+ _ -> error $ "not a nonlinear expression: " ++ pprint e) e
withAccumulator
:: Emits o
@@ -113,7 +109,7 @@ withAccumulator ty cont = do
emitCTToRef :: (Emits n, Builder SimpIR m) => SAtom n -> SAtom n -> m n ()
emitCTToRef ref ct = do
baseMonoid <- tangentBaseMonoidFor (getType ct)
- void $ emit $ RefOp ref $ MExtend baseMonoid ct
+ void $ emitLin $ RefOp ref $ MExtend baseMonoid ct
-- === actual pass ===
@@ -190,7 +186,7 @@ transposeOp op ct = case op of
DAMOp _ -> error "unreachable" -- TODO: rule out statically
RefOp refArg m -> do
refArg' <- substNonlin refArg
- let emitEff = emit . RefOp refArg'
+ let emitEff = emitLin . RefOp refArg'
case m of
MAsk -> do
baseMonoid <- tangentBaseMonoidFor (getType ct)
@@ -251,9 +247,7 @@ transposeAtom atom ct = case atom of
PtrVar _ _ -> notTangent
Var v -> do
lookupSubstM (atomVarName v) >>= \case
- RenameNonlin _ ->
- -- XXX: we seem to need this case, but it feels like it should be an error!
- return ()
+ RenameNonlin _ -> error "nonlinear"
LinRef ref -> emitCTToRef ref ct
LinTrivial -> return ()
StuckProject _ _ -> error "not linear"
@@ -266,7 +260,7 @@ transposeHof hof ct = case hof of
For ann ixTy' lam -> do
UnaryLamExpr b body <- return lam
ixTy <- substNonlin ixTy'
- void $ buildForAnn (getNameHint b) (flipDir ann) ixTy \i -> do
+ void $ emitLin =<< mkFor (getNameHint b) (flipDir ann) ixTy \i -> do
ctElt <- tabApp (sink ct) (toAtom i)
extendSubst (b@>RenameNonlin (atomVarName i)) $ transposeExpr body ctElt
return UnitVal