summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorDougal <d.maclaurin@gmail.com>2023-11-09 16:50:55 -0500
committerDougal <d.maclaurin@gmail.com>2023-11-09 16:50:55 -0500
commit9672aa3afe288ba110a6c37616fc047515e09f0b (patch)
tree86f4917674e03c3d98df15c9a593f705b71d811c
parentd10cfc591dfe5d04d97d56e72dcde6c0969fe1de (diff)
Fix autodiff using explicit linearity annotations and handle projections efficiently.
-rw-r--r--src/lib/Algebra.hs2
-rw-r--r--src/lib/Builder.hs44
-rw-r--r--src/lib/Inline.hs3
-rw-r--r--src/lib/Linearize.hs321
-rw-r--r--src/lib/PPrint.hs1
-rw-r--r--src/lib/Transpose.hs194
-rw-r--r--src/lib/Types/Primitives.hs1
7 files changed, 257 insertions, 309 deletions
diff --git a/src/lib/Algebra.hs b/src/lib/Algebra.hs
index b3d6d250..5ecc05f7 100644
--- a/src/lib/Algebra.hs
+++ b/src/lib/Algebra.hs
@@ -18,7 +18,7 @@ import Data.Text.Prettyprint.Doc
import Data.List (intersperse)
import Data.Tuple (swap)
-import Builder hiding (sub, add, mul)
+import Builder
import Core
import CheapReduction
import Err
diff --git a/src/lib/Builder.hs b/src/lib/Builder.hs
index b65b1414..0bb8e8fe 100644
--- a/src/lib/Builder.hs
+++ b/src/lib/Builder.hs
@@ -847,6 +847,12 @@ buildRememberDest hint dest cont = do
-- === vector space (ish) type class ===
+emitLin :: (Builder r m, ToExpr e r, Emits n) => e n -> m n (Atom r n)
+emitLin e = case toExpr e of
+ Atom x -> return x
+ expr -> liftM toAtom $ emitDecl noHint LinearLet $ peepholeExpr expr
+{-# INLINE emitLin #-}
+
zeroAt :: (Emits n, SBuilder m) => SType n -> m n (SAtom n)
zeroAt ty = liftEmitBuilder $ go ty where
go :: Emits n => SType n -> BuilderM SimpIR n (SAtom n)
@@ -930,14 +936,17 @@ symbolicTangentNonZero val = do
-- === builder versions of common local ops ===
-neg :: (Builder r m, Emits n) => Atom r n -> m n (Atom r n)
-neg x = emit $ UnOp FNeg x
+fadd :: (Builder r m, Emits n) => Atom r n -> Atom r n -> m n (Atom r n)
+fadd x y = emit $ BinOp FAdd x y
-add :: (Builder r m, Emits n) => Atom r n -> Atom r n -> m n (Atom r n)
-add x y = emit $ BinOp FAdd x y
+fsub :: (Builder r m, Emits n) => Atom r n -> Atom r n -> m n (Atom r n)
+fsub x y = emit $ BinOp FSub x y
-mul :: (Builder r m, Emits n) => Atom r n -> Atom r n -> m n (Atom r n)
-mul x y = emit $ BinOp FMul x y
+fmul :: (Builder r m, Emits n) => Atom r n -> Atom r n -> m n (Atom r n)
+fmul x y = emit $ BinOp FMul x y
+
+fdiv :: (Builder r m, Emits n) => Atom r n -> Atom r n -> m n (Atom r n)
+fdiv x y = emit $ BinOp FDiv x y
iadd :: (Builder r m, Emits n) => Atom r n -> Atom r n -> m n (Atom r n)
iadd x y = emit $ BinOp IAdd x y
@@ -945,22 +954,10 @@ iadd x y = emit $ BinOp IAdd x y
imul :: (Builder r m, Emits n) => Atom r n -> Atom r n -> m n (Atom r n)
imul x y = emit $ BinOp IMul x y
-div' :: (Builder r m, Emits n) => Atom r n -> Atom r n -> m n (Atom r n)
-div' x y = emit $ BinOp FDiv x y
-
-fpow :: (Builder r m, Emits n) => Atom r n -> Atom r n -> m n (Atom r n)
-fpow x y = emit $ BinOp FPow x y
-
-sub :: (Builder r m, Emits n) => Atom r n -> Atom r n -> m n (Atom r n)
-sub x y = emit $ BinOp FSub x y
-
-flog :: (Builder r m, Emits n) => Atom r n -> m n (Atom r n)
-flog x = emit $ UnOp Log x
-
-fLitLike :: (SBuilder m, Emits n) => Double -> SAtom n -> m n (SAtom n)
+fLitLike :: Double -> SAtom n -> SAtom n
fLitLike x t = case getTyCon t of
- BaseType (Scalar Float64Type) -> return $ toAtom $ Lit $ Float64Lit x
- BaseType (Scalar Float32Type) -> return $ toAtom $ Lit $ Float32Lit $ realToFrac x
+ BaseType (Scalar Float64Type) -> toAtom $ Lit $ Float64Lit x
+ BaseType (Scalar Float32Type) -> toAtom $ Lit $ Float32Lit $ realToFrac x
_ -> error "Expected a floating point scalar"
fromPair :: (Builder r m, Emits n) => Atom r n -> m n (Atom r n, Atom r n)
@@ -1085,6 +1082,11 @@ mkTabApp xs ixs = do
ty <- typeOfTabApp (getType xs) ixs
return $ TabApp ty xs ixs
+mkProject :: (EnvReader m, IRRep r) => Int -> Atom r n -> m n (Expr r n)
+mkProject i x = do
+ ty <- projType i x
+ return $ Project ty i x
+
mkTopApp :: EnvReader m => TopFunName n -> [SAtom n] -> m n (SExpr n)
mkTopApp f xs = do
resultTy <- typeOfTopApp f xs
diff --git a/src/lib/Inline.hs b/src/lib/Inline.hs
index 4104948c..f3ada792 100644
--- a/src/lib/Inline.hs
+++ b/src/lib/Inline.hs
@@ -98,12 +98,14 @@ inlineDeclsSubst = \case
inlineDeclsSubst rest
where
dropOccInfo PlainLet = PlainLet
+ dropOccInfo LinearLet = LinearLet
dropOccInfo InlineLet = InlineLet
dropOccInfo NoInlineLet = NoInlineLet
dropOccInfo (OccInfoPure _) = PlainLet
dropOccInfo (OccInfoImpure _) = PlainLet
resolveWorkConservation PlainLet _ =
NoInline -- No occurrence info, assume the worst
+ resolveWorkConservation LinearLet _ = NoInline
resolveWorkConservation InlineLet _ = NoInline
resolveWorkConservation NoInlineLet _ = NoInline
-- Quick hack to always unconditionally inline renames, until we get
@@ -176,6 +178,7 @@ preInlineUnconditionally = \case
PlainLet -> False -- "Missing occurrence annotation"
InlineLet -> True
NoInlineLet -> False
+ LinearLet -> False
OccInfoPure (UsageInfo s (0, d)) | s <= One && d <= One -> True
OccInfoPure _ -> False
OccInfoImpure _ -> False
diff --git a/src/lib/Linearize.hs b/src/lib/Linearize.hs
index d4ca0417..87054836 100644
--- a/src/lib/Linearize.hs
+++ b/src/lib/Linearize.hs
@@ -16,6 +16,7 @@ import GHC.Stack
import Builder
import Core
+import CheapReduction
import Imp
import IRVariants
import MTL1
@@ -26,7 +27,7 @@ import PPrint
import QueryType
import Types.Core
import Types.Primitives
-import Util (bindM2, enumerate)
+import Util (enumerate)
-- === linearization monad ===
@@ -93,48 +94,39 @@ extendTangentArgss vs' m = local (\(TangentArgs vs) -> TangentArgs $ vs ++ vs')
getTangentArgs :: TangentM o (TangentArgs o)
getTangentArgs = ask
-bindLin
- :: Emits o
- => LinM i o e e
- -> (forall o' m. (Emits o', DExt o o', Builder SimpIR m) => e o' -> m o' (e' o'))
- -> LinM i o e' e'
-bindLin m f = do
- result <- m
- withBoth result f
-
-withBoth
- :: Emits o
+emitBoth
+ :: (Emits o, ToExpr e' SimpIR)
=> WithTangent o e e
- -> (forall o' m. (Emits o', DExt o o', Builder SimpIR m) => e o' -> m o' (e' o'))
- -> PrimalM i o (WithTangent o e' e')
-withBoth (WithTangent x tx) f = do
+ -> (forall o' m. (DExt o o', Builder SimpIR m) => e o' -> m o' (e' o'))
+ -> LinM i o SAtom SAtom
+emitBoth (WithTangent x tx) f = do
Distinct <- getDistinct
- y <- f x
- return $ WithTangent y do
- tx >>= f
+ x' <- emit =<< f x
+ return $ WithTangent x' do
+ tx' <- tx
+ emitLin =<< f tx'
-_withTangentComputation
- :: Emits o
- => WithTangent o e1 e2
- -> (forall o' m. (Emits o', DExt o o', Builder SimpIR m) => e2 o' -> m o' (e2' o'))
- -> PrimalM i o (WithTangent o e1 e2')
-_withTangentComputation (WithTangent x tx) f = do
- Distinct <- getDistinct
- return $ WithTangent x do
- tx >>= f
+emitZeroT :: (Emits o, HasNamesE e', ToExpr e' SimpIR) => e' i -> LinM i o SAtom SAtom
+emitZeroT e = do
+ x <- emit =<< renameM e
+ return $ WithTangent x (zeroLikeT x)
+
+zeroLikeT :: (DExt o o', Emits o', HasType SimpIR e) => e o -> TangentM o' (SAtom o')
+zeroLikeT x = do
+ ty <- sinkM $ getType x
+ zeroAt =<< tangentType ty
fmapLin
:: Emits o
=> (forall o'. e o' -> e' o')
-> LinM i o e e
-> LinM i o e' e'
-fmapLin f m = m `bindLin` (pure . f)
+fmapLin f m = do
+ WithTangent ans tx <- m
+ return $ WithTangent (f ans) (f <$> tx)
-zipLin :: LinM i o e1 e1 -> LinM i o e2 e2 -> LinM i o (PairE e1 e2) (PairE e1 e2)
-zipLin m1 m2 = do
- WithTangent x1 t1 <- m1
- WithTangent x2 t2 <- m2
- return $ WithTangent (PairE x1 x2) do PairE <$> t1 <*> t2
+zipLin :: WithTangent o e1 e1 -> WithTangent o e2 e2 -> WithTangent o (PairE e1 e2) (PairE e1 e2)
+zipLin (WithTangent x1 t1) (WithTangent x2 t2) = WithTangent (PairE x1 x2) do PairE <$> t1 <*> t2
seqLin
:: Traversable f
@@ -325,19 +317,28 @@ linearizeLambdaApp _ _ = error "not implemented"
linearizeAtom :: Emits o => Atom SimpIR i -> LinM i o SAtom SAtom
linearizeAtom (Con con) = linearizePrimCon con
-linearizeAtom atom@(Stuck _ stuck) = case stuck of
- PtrVar _ _ -> emitZeroT
+linearizeAtom (Stuck _ stuck) = linearizeStuck stuck
+
+linearizeStuck :: Emits o => Stuck SimpIR i -> LinM i o SAtom SAtom
+linearizeStuck stuck = case stuck of
Var v -> do
v' <- renameM v
activePrimalIdx v' >>= \case
- Nothing -> withZeroT $ return (toAtom v')
+ Nothing -> zero
Just idx -> return $ WithTangent (toAtom v') $ getTangentArg idx
- -- TODO: buildScoped and reduce the results so we keep expression in non-ANF for type checking purposes
- StuckProject _ _ -> undefined
- StuckTabApp _ _ -> undefined
- RepValAtom _ -> emitZeroT
- where emitZeroT = withZeroT $ renameM atom
-
+ PtrVar _ _ -> zero
+ RepValAtom _ -> zero
+ -- TODO: de-dup with the Expr versions of these
+ StuckProject i x -> do
+ x' <- linearizeStuck x
+ emitBoth x' \x'' -> mkProject i x''
+ StuckTabApp x i -> do
+ pt <- zipLin <$> linearizeStuck x <*> pureLin i
+ emitBoth pt \(PairE x' i') -> mkTabApp x' i'
+ where
+ zero = do
+ atom <- mkStuck =<< renameM stuck
+ return $ WithTangent atom (zeroLikeT atom)
linearizeDecls :: Emits o => Nest SDecl i i' -> LinM i' o e1 e2 -> LinM i o e1 e2
linearizeDecls Empty cont = cont
@@ -388,7 +389,6 @@ linearizeExpr expr = case expr of
where
unitLike :: e n -> UnitE n
unitLike _ = UnitE
- TabApp _ x i -> zipLin (linearizeAtom x) (pureLin i) `bindLin` \(PairE x' i') -> tabApp x' i'
PrimOp op -> linearizeOp op
Case e alts (EffTy effs resultTy) -> do
e' <- renameM e
@@ -417,49 +417,54 @@ linearizeExpr expr = case expr of
applyLinLam linLam
TabCon _ ty xs -> do
ty' <- renameM ty
- seqLin (map linearizeAtom xs) `bindLin` \(ComposeE xs') ->
- emit $ TabCon Nothing (sink ty') xs'
+ pt <- seqLin (map linearizeAtom xs)
+ emitBoth pt \(ComposeE xs') -> return $ TabCon Nothing (sink ty') xs'
+ TabApp _ x i -> do
+ pt <- zipLin <$> linearizeAtom x <*> pureLin i
+ emitBoth pt \(PairE x' i') -> mkTabApp x' i'
Project _ i x -> do
- WithTangent x' tx <- linearizeAtom x
- xi <- proj i x'
- return $ WithTangent xi do
- t <- tx
- proj i t
+ x' <- linearizeAtom x
+ emitBoth x' \x'' -> mkProject i x''
linearizeOp :: Emits o => PrimOp SimpIR i -> LinM i o SAtom SAtom
linearizeOp op = case op of
Hof (TypedHof _ e) -> linearizeHof e
DAMOp _ -> error "shouldn't occur here"
- RefOp ref m -> case m of
- MAsk -> linearizeAtom ref `bindLin` \ref' -> emit $ RefOp ref' MAsk
- MExtend monoid x -> do
- -- TODO: check that we're dealing with a +/0 monoid
- monoid' <- renameM monoid
- zipLin (linearizeAtom ref) (linearizeAtom x) `bindLin` \(PairE ref' x') ->
- emit $ RefOp ref' $ MExtend (sink monoid') x'
- MGet -> linearizeAtom ref `bindLin` \ref' -> emit $ RefOp ref' MGet
- MPut x -> zipLin (linearizeAtom ref) (linearizeAtom x) `bindLin` \(PairE ref' x') ->
- emit $ RefOp ref' $ MPut x'
- IndexRef _ i -> do
- zipLin (la ref) (pureLin i) `bindLin` \(PairE ref' i') ->
- emit =<< mkIndexRef ref' i'
- ProjRef _ i -> la ref `bindLin` \ref' -> emit =<< mkProjRef ref' i
+ RefOp ref m -> do
+ ref' <- linearizeAtom ref
+ case m of
+ MAsk -> emitBoth ref' \ref'' -> return $ RefOp ref'' MAsk
+ MExtend monoid x -> do
+ -- TODO: check that we're dealing with a +/0 monoid
+ monoid' <- renameM monoid
+ x' <- linearizeAtom x
+ emitBoth (zipLin ref' x') \(PairE ref'' x'') ->
+ return $ RefOp ref'' $ MExtend (sink monoid') x''
+ MGet -> emitBoth ref' \ref'' -> return $ RefOp ref'' MGet
+ MPut x -> do
+ x' <- linearizeAtom x
+ emitBoth (zipLin ref' x') \(PairE ref'' x'') -> return $ RefOp ref'' $ MPut x''
+ IndexRef _ i -> do
+ i' <- pureLin i
+ emitBoth (zipLin ref' i') \(PairE ref'' i'') -> mkIndexRef ref'' i''
+ ProjRef _ i -> emitBoth ref' \ref'' -> mkProjRef ref'' i
UnOp uop x -> linearizeUnOp uop x
BinOp bop x y -> linearizeBinOp bop x y
-- XXX: This assumes that pointers are always constants
- MemOp _ -> emitZeroT
+ MemOp _ -> emitZeroT op
MiscOp miscOp -> linearizeMiscOp miscOp
VectorOp _ -> error "not implemented"
- where
- emitZeroT = withZeroT $ emit =<< renameM (PrimOp op)
- la = linearizeAtom
linearizeMiscOp :: Emits o => MiscOp SimpIR i -> LinM i o SAtom SAtom
linearizeMiscOp op = case op of
- SumTag _ -> emitZeroT
- ToEnum _ _ -> emitZeroT
- Select p t f -> (pureLin p `zipLin` la t `zipLin` la f) `bindLin`
- \(p' `PairE` t' `PairE` f') -> emit $ MiscOp $ Select p' t' f'
+ SumTag _ -> zero
+ ToEnum _ _ -> zero
+ Select p t f -> do
+ p' <- pureLin p
+ t' <- linearizeAtom t
+ f' <- linearizeAtom f
+ emitBoth (p' `zipLin` t' `zipLin` f')
+ \(p'' `PairE` t'' `PairE` f'') -> return $ Select p'' t'' f''
CastOp t v -> do
vt <- getType <$> renameM v
t' <- renameM t
@@ -468,92 +473,105 @@ linearizeMiscOp op = case op of
((&&) <$> (vtTangentType `alphaEq` vt)
<*> (tTangentType `alphaEq` t')) >>= \case
True -> do
- linearizeAtom v `bindLin` \v' -> emit $ MiscOp $ CastOp (sink t') v'
+ v' <- linearizeAtom v
+ emitBoth v' \v'' -> return $ CastOp (sink t') v''
False -> do
WithTangent x xt <- linearizeAtom v
yt <- case (vtTangentType, tTangentType) of
(_ , UnitTy) -> return $ UnitVal
(UnitTy, tt ) -> zeroAt tt
_ -> error "Expected at least one side of the CastOp to have a trivial tangent type"
- y <- emit $ MiscOp $ CastOp t' x
+ y <- emit $ CastOp t' x
return $ WithTangent y do xt >> return (sink yt)
BitcastOp _ _ -> notImplemented
UnsafeCoerce _ _ -> notImplemented
GarbageVal _ -> notImplemented
ThrowException _ -> notImplemented
- ThrowError _ -> emitZeroT
- OutputStream -> emitZeroT
+ ThrowError _ -> zero
+ OutputStream -> zero
ShowAny _ -> error "Shouldn't have ShowAny in simplified IR"
ShowScalar _ -> error "Shouldn't have ShowScalar in simplified IR"
- where
- emitZeroT = withZeroT $ emit =<< renameM (PrimOp $ MiscOp op)
- la = linearizeAtom
+ where zero = emitZeroT op
linearizeUnOp :: Emits o => UnOp -> Atom SimpIR i -> LinM i o SAtom SAtom
-linearizeUnOp op x' = do
- WithTangent x tx <- linearizeAtom x'
- let emitZeroT = withZeroT $ emit $ UnOp op x
- case op of
- Exp -> do
- y <- emit $ UnOp Exp x
- return $ WithTangent y (bindM2 mul tx (sinkM y))
- Exp2 -> notImplemented
- Log -> withT (emit $ UnOp Log x) $ (tx >>= (`div'` sink x))
- Log2 -> notImplemented
- Log10 -> notImplemented
- Log1p -> notImplemented
- Sin -> withT (emit $ UnOp Sin x) $ bindM2 mul tx (emit $ UnOp Cos (sink x))
- Cos -> withT (emit $ UnOp Cos x) $ bindM2 mul tx (neg =<< emit (UnOp Sin (sink x)))
- Tan -> notImplemented
- Sqrt -> do
- y <- emit $ UnOp Sqrt x
- return $ WithTangent y do
- denominator <- bindM2 mul (2 `fLitLike` sink y) (sinkM y)
- bindM2 div' tx (pure denominator)
- Floor -> emitZeroT
- Ceil -> emitZeroT
- Round -> emitZeroT
- LGamma -> notImplemented
- Erf -> notImplemented
- Erfc -> notImplemented
- FNeg -> withT (neg x) (neg =<< tx)
- BNot -> emitZeroT
+linearizeUnOp op x'' = do
+ WithTangent x' tx' <- linearizeAtom x''
+ ans' <- emit $ UnOp op x'
+ return $ WithTangent ans' do
+ ans <- sinkM ans'
+ x <- sinkM x'
+ tx <- tx'
+ let zero = zeroLikeT ans
+ case op of
+ Exp -> emitLin $ BinOp FMul tx ans
+ Exp2 -> notImplemented
+ Log -> emitLin $ BinOp FDiv tx x
+ Log2 -> notImplemented
+ Log10 -> notImplemented
+ Log1p -> notImplemented
+ Sin -> do
+ c <- emit $ UnOp Cos x
+ emitLin $ BinOp FMul tx c
+ Cos -> do
+ c <- emit =<< (UnOp FNeg <$> emit (UnOp Sin x))
+ emitLin $ BinOp FMul tx c
+ Tan -> notImplemented
+ Sqrt -> do
+ denominator <- fmul (2 `fLitLike` ans) ans
+ emitLin $ BinOp FDiv tx denominator
+ Floor -> zero
+ Ceil -> zero
+ Round -> zero
+ LGamma -> notImplemented
+ Erf -> notImplemented
+ Erfc -> notImplemented
+ FNeg -> emitLin $ UnOp FNeg tx
+ BNot -> zero
linearizeBinOp :: Emits o => BinOp -> SAtom i -> SAtom i -> LinM i o SAtom SAtom
-linearizeBinOp op x' y' = do
- WithTangent x tx <- linearizeAtom x'
- WithTangent y ty <- linearizeAtom y'
- let emitZeroT = withZeroT $ emit $ BinOp op x y
- case op of
- IAdd -> emitZeroT
- ISub -> emitZeroT
- IMul -> emitZeroT
- IDiv -> emitZeroT
- IRem -> emitZeroT
- ICmp _ -> emitZeroT
- FAdd -> withT (add x y) (bindM2 add tx ty)
- FSub -> withT (sub x y) (bindM2 sub tx ty)
- FMul -> withT (mul x y)
- (bindM2 add (bindM2 mul (referToPrimal x) ty)
- (bindM2 mul tx (referToPrimal y)))
- FDiv -> withT (div' x y) do
- tx' <- bindM2 div' tx (referToPrimal y)
- ty' <- bindM2 div' (bindM2 mul (referToPrimal x) ty)
- (bindM2 mul (referToPrimal y) (referToPrimal y))
- sub tx' ty'
- FPow -> withT (emit $ BinOp FPow x y) do
- px <- referToPrimal x
- py <- referToPrimal y
- c <- (1.0 `fLitLike` py) >>= (sub py) >>= fpow px
- tx' <- bindM2 mul tx (return py)
- ty' <- bindM2 mul (bindM2 mul (return px) ty) (flog px)
- mul c =<< add tx' ty'
- FCmp _ -> emitZeroT
- BAnd -> emitZeroT
- BOr -> emitZeroT
- BXor -> emitZeroT
- BShL -> emitZeroT
- BShR -> emitZeroT
+linearizeBinOp op x'' y'' = do
+ WithTangent x' tx' <- linearizeAtom x''
+ WithTangent y' ty' <- linearizeAtom y''
+ ans' <- emit $ BinOp op x' y'
+ return $ WithTangent ans' do
+ ans <- sinkM ans'
+ x <- referToPrimal x'
+ y <- referToPrimal y'
+ tx <- tx'
+ ty <- ty'
+ let zero = zeroLikeT ans
+ case op of
+ IAdd -> zero
+ ISub -> zero
+ IMul -> zero
+ IDiv -> zero
+ IRem -> zero
+ ICmp _ -> zero
+ FAdd -> emitLin $ BinOp FAdd tx ty
+ FSub -> emitLin $ BinOp FSub tx ty
+ FMul -> do
+ t1 <- emitLin $ BinOp FMul ty x
+ t2 <- emitLin $ BinOp FMul tx y
+ emitLin $ BinOp FAdd t1 t2
+ FDiv -> do
+ t1 <- emitLin $ BinOp FDiv tx y
+ xyy <- fdiv x =<< fmul y y
+ t2 <- emitLin $ BinOp FMul ty xyy
+ emitLin $ BinOp FSub t1 t2
+ FPow -> do
+ ym1 <- fsub y (1.0 `fLitLike` y)
+ xpowym1 <- emit $ BinOp FPow x ym1
+ xlogx <- fmul x =<< emit (UnOp Log x)
+ t1 <- emitLin $ BinOp FMul tx y
+ t2 <- emitLin $ BinOp FMul ty xlogx
+ t12 <- emitLin $ BinOp FAdd t1 t2
+ emitLin $ BinOp FMul xpowym1 t12
+ FCmp _ -> zero
+ BAnd -> zero
+ BOr -> zero
+ BXor -> zero
+ BShL -> zero
+ BShR -> zero
-- This has the same type as `sinkM` and falls back thereto, but recomputes
-- indexing a primal array in the tangent to avoid materializing intermediate
@@ -575,12 +593,12 @@ referToPrimal x = do
linearizePrimCon :: Emits o => Con SimpIR i -> LinM i o SAtom SAtom
linearizePrimCon con = case con of
- Lit _ -> emitZeroT
+ Lit _ -> zero
ProdCon xs -> fmapLin (Con . ProdCon . fromComposeE) $ seqLin (fmap linearizeAtom xs)
SumCon _ _ _ -> notImplemented
- HeapVal -> emitZeroT
+ HeapVal -> zero
DepPair _ _ _ -> notImplemented
- where emitZeroT = withZeroT $ renameM $ Con con
+ where zero = emitZeroT con
linearizeHof :: Emits o => Hof SimpIR i -> LinM i o SAtom SAtom
linearizeHof hof = case hof of
@@ -672,21 +690,6 @@ linearizeEffectFun rws (BinaryLamExpr hB refB body) = do
return (BinaryLamExpr h b body', linLam')
linearizeEffectFun _ _ = error "expect effect function to be a binary lambda"
-withT :: PrimalM i o (e1 o)
- -> (forall o'. (Emits o', DExt o o') => TangentM o' (e2 o'))
- -> PrimalM i o (WithTangent o e1 e2)
-withT p t = do
- p' <- p
- return $ WithTangent p' t
-
-withZeroT :: PrimalM i o (Atom SimpIR o)
- -> PrimalM i o (WithTangent o SAtom SAtom)
-withZeroT p = do
- p' <- p
- return $ WithTangent p' do
- pTy <- return $ getType $ sink p'
- zeroAt =<< tangentType pTy
-
notImplemented :: HasCallStack => a
notImplemented = error "Not implemented"
diff --git a/src/lib/PPrint.hs b/src/lib/PPrint.hs
index 24189448..af1c48f1 100644
--- a/src/lib/PPrint.hs
+++ b/src/lib/PPrint.hs
@@ -970,6 +970,7 @@ instance Pretty LetAnn where
PlainLet -> ""
InlineLet -> "%inline"
NoInlineLet -> "%noinline"
+ LinearLet -> "%linear"
OccInfoPure u -> p u <> line
OccInfoImpure u -> p u <> ", impure" <> line
diff --git a/src/lib/Transpose.hs b/src/lib/Transpose.hs
index e312de43..302ca9e4 100644
--- a/src/lib/Transpose.hs
+++ b/src/lib/Transpose.hs
@@ -9,8 +9,6 @@ module Transpose (transpose, transposeTopFun) where
import Data.Foldable
import Data.Functor
import Control.Category ((>>>))
-import Control.Monad.Reader
-import qualified Data.Set as S
import GHC.Stack
import Builder
@@ -18,7 +16,6 @@ import Core
import Err
import Imp
import IRVariants
-import MTL1
import Name
import Subst
import QueryType
@@ -37,7 +34,7 @@ transpose lam ct = liftEmitBuilder $ runTransposeM do
{-# SCC transpose #-}
runTransposeM :: TransposeM n n a -> BuilderM SimpIR n a
-runTransposeM cont = runReaderT1 (ListE []) $ runSubstReaderT idSubst $ cont
+runTransposeM cont = runSubstReaderT idSubst $ cont
transposeTopFun
:: (MonadFail1 m, EnvReader m)
@@ -73,20 +70,15 @@ unpackLinearLamExpr lam@(LamExpr bs body) = do
-- === transposition monad ===
+type AtomTransposeSubstVal = TransposeSubstVal (AtomNameC SimpIR)
data TransposeSubstVal c n where
RenameNonlin :: Name c n -> TransposeSubstVal c n
-- accumulator references corresponding to non-ref linear variables
- LinRef :: SAtom n -> TransposeSubstVal (AtomNameC SimpIR) n
+ LinRef :: SAtom n -> AtomTransposeSubstVal n
-- as an optimization, we don't make references for trivial vector spaces
- LinTrivial :: TransposeSubstVal (AtomNameC SimpIR) n
+ LinTrivial :: AtomTransposeSubstVal n
-type LinRegions = ListE SAtomVar
-
-type TransposeM a = SubstReaderT TransposeSubstVal
- (ReaderT1 LinRegions (BuilderM SimpIR)) a
-
-type TransposeM' a = SubstReaderT AtomSubstVal
- (ReaderT1 LinRegions (BuilderM SimpIR)) a
+type TransposeM a = SubstReaderT TransposeSubstVal (BuilderM SimpIR) a
-- TODO: it might make sense to replace substNonlin/isLin
-- with a single `trySubtNonlin :: e i -> Maybe (e o)`.
@@ -99,30 +91,6 @@ substNonlin e = do
RenameNonlin v' -> v'
_ -> error "not a nonlinear expression") e
--- TODO: Can we generalize onNonLin to accept SubstReaderT Name instead of
--- SubstReaderT AtomSubstVal? For that to work, we need another combinator,
--- that lifts a SubstReader AtomSubstVal into a SubstReader Name, because
--- effectsSubstE is currently typed as SubstReader AtomSubstVal.
--- Then we can presumably recode substNonlin as `onNonLin substM`. We may
--- be able to do that anyway, except we will then need to restrict the type
--- of substNonlin to require `SubstE AtomSubstVal e`; but that may be fine.
-onNonLin :: HasCallStack
- => TransposeM' i o a -> TransposeM i o a
-onNonLin cont = do
- subst <- getSubst
- let subst' = newSubst (\v -> case subst ! v of
- RenameNonlin v' -> Rename v'
- _ -> error "not a nonlinear expression")
- liftSubstReaderT $ runSubstReaderT subst' cont
-
-isLin :: HoistableE e => e i -> TransposeM i o Bool
-isLin e = do
- substVals <- mapM lookupSubstM $ freeAtomVarsList @SimpIR e
- return $ flip any substVals \case
- LinTrivial -> True
- LinRef _ -> True
- RenameNonlin _ -> False
-
withAccumulator
:: Emits o
=> SType o
@@ -147,43 +115,42 @@ emitCTToRef ref ct = do
baseMonoid <- tangentBaseMonoidFor (getType ct)
void $ emit $ RefOp ref $ MExtend baseMonoid ct
-getLinRegions :: TransposeM i o [SAtomVar o]
-getLinRegions = asks fromListE
-
-extendLinRegions :: SAtomVar o -> TransposeM i o a -> TransposeM i o a
-extendLinRegions v cont = local (\(ListE vs) -> ListE (v:vs)) cont
-
-- === actual pass ===
-transposeWithDecls :: Emits o => Nest SDecl i i' -> SExpr i' -> SAtom o -> TransposeM i o ()
+transposeWithDecls :: forall i i' o. Emits o => Nest SDecl i i' -> SExpr i' -> SAtom o -> TransposeM i o ()
transposeWithDecls Empty atom ct = transposeExpr atom ct
-transposeWithDecls (Nest (Let b (DeclBinding _ expr)) rest) result ct =
- substExprIfNonlin expr >>= \case
- Nothing -> do
- ty' <- substNonlin $ getType expr
- ctExpr <- withAccumulator ty' \refSubstVal ->
- extendSubst (b @> refSubstVal) $
- transposeWithDecls rest result (sink ct)
- transposeExpr expr ctExpr
- Just nonlinExpr -> do
- v <- emitToVar nonlinExpr
- extendSubst (b @> RenameNonlin (atomVarName v)) $
- transposeWithDecls rest result ct
-
-substExprIfNonlin :: SExpr i -> TransposeM i o (Maybe (SExpr o))
-substExprIfNonlin expr =
- isLin expr >>= \case
- True -> return Nothing
- False -> do
- onNonLin (substM $ getEffects expr) >>= isLinEff >>= \case
- True -> return Nothing
- False -> Just <$> substNonlin expr
+transposeWithDecls (Nest (Let b (DeclBinding ann expr)) rest) result ct = case ann of
+ LinearLet -> do
+ ty' <- substNonlin $ getType expr
+ case expr of
+ Project _ i x -> do
+ continue =<< projectLinearRef x \ref -> emitLin =<< mkProjRef ref (ProjectProduct i)
+ TabApp _ x i -> do
+ continue =<< projectLinearRef x \ref -> do
+ i' <- substNonlin i
+ emitLin =<< mkIndexRef ref i'
+ _ -> do
+ ctExpr <- withAccumulator ty' \refSubstVal -> continue refSubstVal
+ transposeExpr expr ctExpr
+ _ -> do
+ v <- substNonlin expr >>= emitToVar
+ continue $ RenameNonlin (atomVarName v)
+ where
+ continue :: forall o'. (Emits o', Ext o o') => AtomTransposeSubstVal o' -> TransposeM i o' ()
+ continue substVal = do
+ ct' <- sinkM ct
+ extendSubst (b @> substVal) $ transposeWithDecls rest result ct'
-isLinEff :: EffectRow SimpIR o -> TransposeM i o Bool
-isLinEff effs@(EffectRow _ NoTail) = do
- regions <- fmap atomVarName <$> getLinRegions
- let effRegions = freeAtomVarsList effs
- return $ not $ null $ S.fromList effRegions `S.intersection` S.fromList regions
+projectLinearRef
+ :: Emits o
+ => SAtom i -> (SAtom o -> TransposeM i o (SAtom o))
+ -> TransposeM i o (AtomTransposeSubstVal o)
+projectLinearRef x f = do
+ Stuck _ (Var v) <- return x
+ lookupSubstM (atomVarName v) >>= \case
+ RenameNonlin _ -> error "nonlinear"
+ LinRef ref -> LinRef <$> f ref
+ LinTrivial -> return LinTrivial
getTransposedTopFun :: EnvReader m => TopFunName n -> m n (Maybe (TopFunName n))
getTransposedTopFun f = do
@@ -200,44 +167,23 @@ transposeExpr expr ct = case expr of
xsNonlin' <- mapM substNonlin xsNonlin
ct' <- naryTopApp fT (xsNonlin' ++ [ct])
transposeAtom xLin ct'
- -- TODO: Instead, should we handle table application like nonlinear
- -- expressions, where we just project the reference?
- TabApp _ x i -> do
- i' <- substNonlin i
- case x of
- Stuck _ stuck -> case stuck of
- Var v -> do
- lookupSubstM (atomVarName v) >>= \case
- RenameNonlin _ -> error "shouldn't happen"
- LinRef ref -> do
- refProj <- indexRef ref i'
- emitCTToRef refProj ct
- LinTrivial -> return ()
- StuckProject _ _ -> undefined
- StuckTabApp _ _ -> undefined
- PtrVar _ _ -> error "not tangent"
- RepValAtom _ -> error "not tangent"
- _ -> error $ "shouldn't occur: " ++ pprint x
PrimOp op -> transposeOp op ct
Case e alts _ -> do
- linearScrutinee <- isLin e
- case linearScrutinee of
- True -> notImplemented
- False -> do
- e' <- substNonlin e
- void $ buildCase e' UnitTy \i v -> do
- v' <- emitToVar v
- Abs b body <- return $ alts !! i
- extendSubst (b @> RenameNonlin (atomVarName v')) do
- transposeExpr body (sink ct)
- return UnitVal
+ e' <- substNonlin e
+ void $ buildCase e' UnitTy \i v -> do
+ v' <- emitToVar v
+ Abs b body <- return $ alts !! i
+ extendSubst (b @> RenameNonlin (atomVarName v')) do
+ transposeExpr body (sink ct)
+ return UnitVal
TabCon _ ty es -> do
TabTy d b _ <- return ty
idxTy <- substNonlin $ IxType (binderType b) d
forM_ (enumerate es) \(ordinalIdx, e) -> do
i <- unsafeFromOrdinal idxTy (IdxRepVal $ fromIntegral ordinalIdx)
tabApp ct i >>= transposeAtom e
- Project _ _ _ -> undefined
+ TabApp _ _ _ -> error "should have been handled by reference projection"
+ Project _ _ _ -> error "should have been handled by reference projection"
transposeOp :: Emits o => PrimOp SimpIR i -> SAtom o -> TransposeM i o ()
transposeOp op ct = case op of
@@ -262,18 +208,21 @@ transposeOp op ct = case op of
ProjRef _ _ -> notImplemented
Hof (TypedHof _ hof) -> transposeHof hof ct
MiscOp miscOp -> transposeMiscOp miscOp ct
- UnOp FNeg x -> transposeAtom x =<< neg ct
+ UnOp FNeg x -> transposeAtom x =<< (emitLin $ UnOp FNeg ct)
UnOp _ _ -> notLinear
BinOp FAdd x y -> transposeAtom x ct >> transposeAtom y ct
- BinOp FSub x y -> transposeAtom x ct >> (transposeAtom y =<< neg ct)
+ BinOp FSub x y -> transposeAtom x ct >> (transposeAtom y =<< (emitLin $ UnOp FNeg ct))
+ -- XXX: linear argument to FMul is always first
BinOp FMul x y -> do
- xLin <- isLin x
- if xLin
- then transposeAtom x =<< mul ct =<< substNonlin y
- else transposeAtom y =<< mul ct =<< substNonlin x
- BinOp FDiv x y -> transposeAtom x =<< div' ct =<< substNonlin y
+ y' <- substNonlin y
+ tx <- emitLin $ BinOp FMul ct y'
+ transposeAtom x tx
+ BinOp FDiv x y -> do
+ y' <- substNonlin y
+ tx <- emitLin $ BinOp FDiv ct y'
+ transposeAtom x tx
BinOp _ _ _ -> notLinear
- MemOp _ -> notLinear
+ MemOp _ -> notLinear
VectorOp _ -> unreachable
where
notLinear = error $ "Can't transpose a non-linear operation: " ++ pprint op
@@ -291,10 +240,9 @@ transposeMiscOp op _ = case op of
BitcastOp _ _ -> notImplemented
UnsafeCoerce _ _ -> notImplemented
GarbageVal _ -> notImplemented
- ShowAny _ -> error "Shouldn't have ShowAny in simplified IR"
- ShowScalar _ -> error "Shouldn't have ShowScalar in simplified IR"
- where
- notLinear = error $ "Can't transpose a non-linear operation: " ++ show op
+ ShowAny _ -> notLinear
+ ShowScalar _ -> notLinear
+ where notLinear = error $ "Can't transpose a non-linear operation: " ++ show op
transposeAtom :: HasCallStack => Emits o => SAtom i -> SAtom o -> TransposeM i o ()
transposeAtom atom ct = case atom of
@@ -308,16 +256,9 @@ transposeAtom atom ct = case atom of
return ()
LinRef ref -> emitCTToRef ref ct
LinTrivial -> return ()
- StuckProject _ _ -> error "not implemented"
- StuckTabApp _ _ -> error "not implemented"
- -- let (idxs, v) = asNaryProj i' x'
- -- lookupSubstM (atomVarName v) >>= \case
- -- RenameNonlin _ -> error "an error, probably"
- -- LinRef ref -> do
- -- ref' <- applyProjectionsRef (toList idxs) ref
- -- emitCTToRef ref' ct
- -- LinTrivial -> return ()
- RepValAtom _ -> error "not implemented"
+ StuckProject _ _ -> error "not linear"
+ StuckTabApp _ _ -> error "not linear"
+ RepValAtom _ -> error "not linear"
where notTangent = error $ "Not a tangent atom: " ++ pprint atom
transposeHof :: Emits o => Hof SimpIR i -> SAtom o -> TransposeM i o ()
@@ -333,8 +274,7 @@ transposeHof hof ct = case hof of
(ctBody, ctState) <- fromPair ct
(_, cts) <- (fromPair =<<) $ emitRunState noHint ctState \h ref -> do
extendSubst (hB@>RenameNonlin (atomVarName h)) $ extendSubst (refB@>RenameNonlin (atomVarName ref)) $
- extendLinRegions h $
- transposeExpr body (sink ctBody)
+ transposeExpr body (sink ctBody)
return UnitVal
transposeAtom s cts
RunReader r (BinaryLamExpr hB refB body) -> do
@@ -342,8 +282,7 @@ transposeHof hof ct = case hof of
baseMonoid <- tangentBaseMonoidFor accumTy
(_, ct') <- (fromPair =<<) $ emitRunWriter noHint accumTy baseMonoid \h ref -> do
extendSubst (hB@>RenameNonlin (atomVarName h)) $ extendSubst (refB@>RenameNonlin (atomVarName ref)) $
- extendLinRegions h $
- transposeExpr body (sink ct)
+ transposeExpr body (sink ct)
return UnitVal
transposeAtom r ct'
RunWriter Nothing _ (BinaryLamExpr hB refB body)-> do
@@ -351,8 +290,7 @@ transposeHof hof ct = case hof of
(ctBody, ctEff) <- fromPair ct
void $ emitRunReader noHint ctEff \h ref -> do
extendSubst (hB@>RenameNonlin (atomVarName h)) $ extendSubst (refB@>RenameNonlin (atomVarName ref)) $
- extendLinRegions h $
- transposeExpr body (sink ctBody)
+ transposeExpr body (sink ctBody)
return UnitVal
_ -> notImplemented
diff --git a/src/lib/Types/Primitives.hs b/src/lib/Types/Primitives.hs
index d85d8824..83ba3ffb 100644
--- a/src/lib/Types/Primitives.hs
+++ b/src/lib/Types/Primitives.hs
@@ -64,6 +64,7 @@ data LetAnn =
| InlineLet
-- Binding explicitly tagged "do not inline"
| NoInlineLet
+ | LinearLet
-- Bound expression is pure, and the binding's occurrences are summarized by
-- the UsageInfo
| OccInfoPure UsageInfo