diff options
author | Dougal <d.maclaurin@gmail.com> | 2023-11-09 16:50:55 -0500 |
---|---|---|
committer | Dougal <d.maclaurin@gmail.com> | 2023-11-09 16:50:55 -0500 |
commit | 9672aa3afe288ba110a6c37616fc047515e09f0b (patch) | |
tree | 86f4917674e03c3d98df15c9a593f705b71d811c | |
parent | d10cfc591dfe5d04d97d56e72dcde6c0969fe1de (diff) |
Fix autodiff using explicit linearity annotations and handle projections efficiently.
-rw-r--r-- | src/lib/Algebra.hs | 2 | ||||
-rw-r--r-- | src/lib/Builder.hs | 44 | ||||
-rw-r--r-- | src/lib/Inline.hs | 3 | ||||
-rw-r--r-- | src/lib/Linearize.hs | 321 | ||||
-rw-r--r-- | src/lib/PPrint.hs | 1 | ||||
-rw-r--r-- | src/lib/Transpose.hs | 194 | ||||
-rw-r--r-- | src/lib/Types/Primitives.hs | 1 |
7 files changed, 257 insertions, 309 deletions
diff --git a/src/lib/Algebra.hs b/src/lib/Algebra.hs index b3d6d250..5ecc05f7 100644 --- a/src/lib/Algebra.hs +++ b/src/lib/Algebra.hs @@ -18,7 +18,7 @@ import Data.Text.Prettyprint.Doc import Data.List (intersperse) import Data.Tuple (swap) -import Builder hiding (sub, add, mul) +import Builder import Core import CheapReduction import Err diff --git a/src/lib/Builder.hs b/src/lib/Builder.hs index b65b1414..0bb8e8fe 100644 --- a/src/lib/Builder.hs +++ b/src/lib/Builder.hs @@ -847,6 +847,12 @@ buildRememberDest hint dest cont = do -- === vector space (ish) type class === +emitLin :: (Builder r m, ToExpr e r, Emits n) => e n -> m n (Atom r n) +emitLin e = case toExpr e of + Atom x -> return x + expr -> liftM toAtom $ emitDecl noHint LinearLet $ peepholeExpr expr +{-# INLINE emitLin #-} + zeroAt :: (Emits n, SBuilder m) => SType n -> m n (SAtom n) zeroAt ty = liftEmitBuilder $ go ty where go :: Emits n => SType n -> BuilderM SimpIR n (SAtom n) @@ -930,14 +936,17 @@ symbolicTangentNonZero val = do -- === builder versions of common local ops === -neg :: (Builder r m, Emits n) => Atom r n -> m n (Atom r n) -neg x = emit $ UnOp FNeg x +fadd :: (Builder r m, Emits n) => Atom r n -> Atom r n -> m n (Atom r n) +fadd x y = emit $ BinOp FAdd x y -add :: (Builder r m, Emits n) => Atom r n -> Atom r n -> m n (Atom r n) -add x y = emit $ BinOp FAdd x y +fsub :: (Builder r m, Emits n) => Atom r n -> Atom r n -> m n (Atom r n) +fsub x y = emit $ BinOp FSub x y -mul :: (Builder r m, Emits n) => Atom r n -> Atom r n -> m n (Atom r n) -mul x y = emit $ BinOp FMul x y +fmul :: (Builder r m, Emits n) => Atom r n -> Atom r n -> m n (Atom r n) +fmul x y = emit $ BinOp FMul x y + +fdiv :: (Builder r m, Emits n) => Atom r n -> Atom r n -> m n (Atom r n) +fdiv x y = emit $ BinOp FDiv x y iadd :: (Builder r m, Emits n) => Atom r n -> Atom r n -> m n (Atom r n) iadd x y = emit $ BinOp IAdd x y @@ -945,22 +954,10 @@ iadd x y = emit $ BinOp IAdd x y imul :: (Builder r m, Emits n) => Atom r n -> Atom r n -> m n (Atom r n) imul x y = emit $ BinOp IMul x y -div' :: (Builder r m, Emits n) => Atom r n -> Atom r n -> m n (Atom r n) -div' x y = emit $ BinOp FDiv x y - -fpow :: (Builder r m, Emits n) => Atom r n -> Atom r n -> m n (Atom r n) -fpow x y = emit $ BinOp FPow x y - -sub :: (Builder r m, Emits n) => Atom r n -> Atom r n -> m n (Atom r n) -sub x y = emit $ BinOp FSub x y - -flog :: (Builder r m, Emits n) => Atom r n -> m n (Atom r n) -flog x = emit $ UnOp Log x - -fLitLike :: (SBuilder m, Emits n) => Double -> SAtom n -> m n (SAtom n) +fLitLike :: Double -> SAtom n -> SAtom n fLitLike x t = case getTyCon t of - BaseType (Scalar Float64Type) -> return $ toAtom $ Lit $ Float64Lit x - BaseType (Scalar Float32Type) -> return $ toAtom $ Lit $ Float32Lit $ realToFrac x + BaseType (Scalar Float64Type) -> toAtom $ Lit $ Float64Lit x + BaseType (Scalar Float32Type) -> toAtom $ Lit $ Float32Lit $ realToFrac x _ -> error "Expected a floating point scalar" fromPair :: (Builder r m, Emits n) => Atom r n -> m n (Atom r n, Atom r n) @@ -1085,6 +1082,11 @@ mkTabApp xs ixs = do ty <- typeOfTabApp (getType xs) ixs return $ TabApp ty xs ixs +mkProject :: (EnvReader m, IRRep r) => Int -> Atom r n -> m n (Expr r n) +mkProject i x = do + ty <- projType i x + return $ Project ty i x + mkTopApp :: EnvReader m => TopFunName n -> [SAtom n] -> m n (SExpr n) mkTopApp f xs = do resultTy <- typeOfTopApp f xs diff --git a/src/lib/Inline.hs b/src/lib/Inline.hs index 4104948c..f3ada792 100644 --- a/src/lib/Inline.hs +++ b/src/lib/Inline.hs @@ -98,12 +98,14 @@ inlineDeclsSubst = \case inlineDeclsSubst rest where dropOccInfo PlainLet = PlainLet + dropOccInfo LinearLet = LinearLet dropOccInfo InlineLet = InlineLet dropOccInfo NoInlineLet = NoInlineLet dropOccInfo (OccInfoPure _) = PlainLet dropOccInfo (OccInfoImpure _) = PlainLet resolveWorkConservation PlainLet _ = NoInline -- No occurrence info, assume the worst + resolveWorkConservation LinearLet _ = NoInline resolveWorkConservation InlineLet _ = NoInline resolveWorkConservation NoInlineLet _ = NoInline -- Quick hack to always unconditionally inline renames, until we get @@ -176,6 +178,7 @@ preInlineUnconditionally = \case PlainLet -> False -- "Missing occurrence annotation" InlineLet -> True NoInlineLet -> False + LinearLet -> False OccInfoPure (UsageInfo s (0, d)) | s <= One && d <= One -> True OccInfoPure _ -> False OccInfoImpure _ -> False diff --git a/src/lib/Linearize.hs b/src/lib/Linearize.hs index d4ca0417..87054836 100644 --- a/src/lib/Linearize.hs +++ b/src/lib/Linearize.hs @@ -16,6 +16,7 @@ import GHC.Stack import Builder import Core +import CheapReduction import Imp import IRVariants import MTL1 @@ -26,7 +27,7 @@ import PPrint import QueryType import Types.Core import Types.Primitives -import Util (bindM2, enumerate) +import Util (enumerate) -- === linearization monad === @@ -93,48 +94,39 @@ extendTangentArgss vs' m = local (\(TangentArgs vs) -> TangentArgs $ vs ++ vs') getTangentArgs :: TangentM o (TangentArgs o) getTangentArgs = ask -bindLin - :: Emits o - => LinM i o e e - -> (forall o' m. (Emits o', DExt o o', Builder SimpIR m) => e o' -> m o' (e' o')) - -> LinM i o e' e' -bindLin m f = do - result <- m - withBoth result f - -withBoth - :: Emits o +emitBoth + :: (Emits o, ToExpr e' SimpIR) => WithTangent o e e - -> (forall o' m. (Emits o', DExt o o', Builder SimpIR m) => e o' -> m o' (e' o')) - -> PrimalM i o (WithTangent o e' e') -withBoth (WithTangent x tx) f = do + -> (forall o' m. (DExt o o', Builder SimpIR m) => e o' -> m o' (e' o')) + -> LinM i o SAtom SAtom +emitBoth (WithTangent x tx) f = do Distinct <- getDistinct - y <- f x - return $ WithTangent y do - tx >>= f + x' <- emit =<< f x + return $ WithTangent x' do + tx' <- tx + emitLin =<< f tx' -_withTangentComputation - :: Emits o - => WithTangent o e1 e2 - -> (forall o' m. (Emits o', DExt o o', Builder SimpIR m) => e2 o' -> m o' (e2' o')) - -> PrimalM i o (WithTangent o e1 e2') -_withTangentComputation (WithTangent x tx) f = do - Distinct <- getDistinct - return $ WithTangent x do - tx >>= f +emitZeroT :: (Emits o, HasNamesE e', ToExpr e' SimpIR) => e' i -> LinM i o SAtom SAtom +emitZeroT e = do + x <- emit =<< renameM e + return $ WithTangent x (zeroLikeT x) + +zeroLikeT :: (DExt o o', Emits o', HasType SimpIR e) => e o -> TangentM o' (SAtom o') +zeroLikeT x = do + ty <- sinkM $ getType x + zeroAt =<< tangentType ty fmapLin :: Emits o => (forall o'. e o' -> e' o') -> LinM i o e e -> LinM i o e' e' -fmapLin f m = m `bindLin` (pure . f) +fmapLin f m = do + WithTangent ans tx <- m + return $ WithTangent (f ans) (f <$> tx) -zipLin :: LinM i o e1 e1 -> LinM i o e2 e2 -> LinM i o (PairE e1 e2) (PairE e1 e2) -zipLin m1 m2 = do - WithTangent x1 t1 <- m1 - WithTangent x2 t2 <- m2 - return $ WithTangent (PairE x1 x2) do PairE <$> t1 <*> t2 +zipLin :: WithTangent o e1 e1 -> WithTangent o e2 e2 -> WithTangent o (PairE e1 e2) (PairE e1 e2) +zipLin (WithTangent x1 t1) (WithTangent x2 t2) = WithTangent (PairE x1 x2) do PairE <$> t1 <*> t2 seqLin :: Traversable f @@ -325,19 +317,28 @@ linearizeLambdaApp _ _ = error "not implemented" linearizeAtom :: Emits o => Atom SimpIR i -> LinM i o SAtom SAtom linearizeAtom (Con con) = linearizePrimCon con -linearizeAtom atom@(Stuck _ stuck) = case stuck of - PtrVar _ _ -> emitZeroT +linearizeAtom (Stuck _ stuck) = linearizeStuck stuck + +linearizeStuck :: Emits o => Stuck SimpIR i -> LinM i o SAtom SAtom +linearizeStuck stuck = case stuck of Var v -> do v' <- renameM v activePrimalIdx v' >>= \case - Nothing -> withZeroT $ return (toAtom v') + Nothing -> zero Just idx -> return $ WithTangent (toAtom v') $ getTangentArg idx - -- TODO: buildScoped and reduce the results so we keep expression in non-ANF for type checking purposes - StuckProject _ _ -> undefined - StuckTabApp _ _ -> undefined - RepValAtom _ -> emitZeroT - where emitZeroT = withZeroT $ renameM atom - + PtrVar _ _ -> zero + RepValAtom _ -> zero + -- TODO: de-dup with the Expr versions of these + StuckProject i x -> do + x' <- linearizeStuck x + emitBoth x' \x'' -> mkProject i x'' + StuckTabApp x i -> do + pt <- zipLin <$> linearizeStuck x <*> pureLin i + emitBoth pt \(PairE x' i') -> mkTabApp x' i' + where + zero = do + atom <- mkStuck =<< renameM stuck + return $ WithTangent atom (zeroLikeT atom) linearizeDecls :: Emits o => Nest SDecl i i' -> LinM i' o e1 e2 -> LinM i o e1 e2 linearizeDecls Empty cont = cont @@ -388,7 +389,6 @@ linearizeExpr expr = case expr of where unitLike :: e n -> UnitE n unitLike _ = UnitE - TabApp _ x i -> zipLin (linearizeAtom x) (pureLin i) `bindLin` \(PairE x' i') -> tabApp x' i' PrimOp op -> linearizeOp op Case e alts (EffTy effs resultTy) -> do e' <- renameM e @@ -417,49 +417,54 @@ linearizeExpr expr = case expr of applyLinLam linLam TabCon _ ty xs -> do ty' <- renameM ty - seqLin (map linearizeAtom xs) `bindLin` \(ComposeE xs') -> - emit $ TabCon Nothing (sink ty') xs' + pt <- seqLin (map linearizeAtom xs) + emitBoth pt \(ComposeE xs') -> return $ TabCon Nothing (sink ty') xs' + TabApp _ x i -> do + pt <- zipLin <$> linearizeAtom x <*> pureLin i + emitBoth pt \(PairE x' i') -> mkTabApp x' i' Project _ i x -> do - WithTangent x' tx <- linearizeAtom x - xi <- proj i x' - return $ WithTangent xi do - t <- tx - proj i t + x' <- linearizeAtom x + emitBoth x' \x'' -> mkProject i x'' linearizeOp :: Emits o => PrimOp SimpIR i -> LinM i o SAtom SAtom linearizeOp op = case op of Hof (TypedHof _ e) -> linearizeHof e DAMOp _ -> error "shouldn't occur here" - RefOp ref m -> case m of - MAsk -> linearizeAtom ref `bindLin` \ref' -> emit $ RefOp ref' MAsk - MExtend monoid x -> do - -- TODO: check that we're dealing with a +/0 monoid - monoid' <- renameM monoid - zipLin (linearizeAtom ref) (linearizeAtom x) `bindLin` \(PairE ref' x') -> - emit $ RefOp ref' $ MExtend (sink monoid') x' - MGet -> linearizeAtom ref `bindLin` \ref' -> emit $ RefOp ref' MGet - MPut x -> zipLin (linearizeAtom ref) (linearizeAtom x) `bindLin` \(PairE ref' x') -> - emit $ RefOp ref' $ MPut x' - IndexRef _ i -> do - zipLin (la ref) (pureLin i) `bindLin` \(PairE ref' i') -> - emit =<< mkIndexRef ref' i' - ProjRef _ i -> la ref `bindLin` \ref' -> emit =<< mkProjRef ref' i + RefOp ref m -> do + ref' <- linearizeAtom ref + case m of + MAsk -> emitBoth ref' \ref'' -> return $ RefOp ref'' MAsk + MExtend monoid x -> do + -- TODO: check that we're dealing with a +/0 monoid + monoid' <- renameM monoid + x' <- linearizeAtom x + emitBoth (zipLin ref' x') \(PairE ref'' x'') -> + return $ RefOp ref'' $ MExtend (sink monoid') x'' + MGet -> emitBoth ref' \ref'' -> return $ RefOp ref'' MGet + MPut x -> do + x' <- linearizeAtom x + emitBoth (zipLin ref' x') \(PairE ref'' x'') -> return $ RefOp ref'' $ MPut x'' + IndexRef _ i -> do + i' <- pureLin i + emitBoth (zipLin ref' i') \(PairE ref'' i'') -> mkIndexRef ref'' i'' + ProjRef _ i -> emitBoth ref' \ref'' -> mkProjRef ref'' i UnOp uop x -> linearizeUnOp uop x BinOp bop x y -> linearizeBinOp bop x y -- XXX: This assumes that pointers are always constants - MemOp _ -> emitZeroT + MemOp _ -> emitZeroT op MiscOp miscOp -> linearizeMiscOp miscOp VectorOp _ -> error "not implemented" - where - emitZeroT = withZeroT $ emit =<< renameM (PrimOp op) - la = linearizeAtom linearizeMiscOp :: Emits o => MiscOp SimpIR i -> LinM i o SAtom SAtom linearizeMiscOp op = case op of - SumTag _ -> emitZeroT - ToEnum _ _ -> emitZeroT - Select p t f -> (pureLin p `zipLin` la t `zipLin` la f) `bindLin` - \(p' `PairE` t' `PairE` f') -> emit $ MiscOp $ Select p' t' f' + SumTag _ -> zero + ToEnum _ _ -> zero + Select p t f -> do + p' <- pureLin p + t' <- linearizeAtom t + f' <- linearizeAtom f + emitBoth (p' `zipLin` t' `zipLin` f') + \(p'' `PairE` t'' `PairE` f'') -> return $ Select p'' t'' f'' CastOp t v -> do vt <- getType <$> renameM v t' <- renameM t @@ -468,92 +473,105 @@ linearizeMiscOp op = case op of ((&&) <$> (vtTangentType `alphaEq` vt) <*> (tTangentType `alphaEq` t')) >>= \case True -> do - linearizeAtom v `bindLin` \v' -> emit $ MiscOp $ CastOp (sink t') v' + v' <- linearizeAtom v + emitBoth v' \v'' -> return $ CastOp (sink t') v'' False -> do WithTangent x xt <- linearizeAtom v yt <- case (vtTangentType, tTangentType) of (_ , UnitTy) -> return $ UnitVal (UnitTy, tt ) -> zeroAt tt _ -> error "Expected at least one side of the CastOp to have a trivial tangent type" - y <- emit $ MiscOp $ CastOp t' x + y <- emit $ CastOp t' x return $ WithTangent y do xt >> return (sink yt) BitcastOp _ _ -> notImplemented UnsafeCoerce _ _ -> notImplemented GarbageVal _ -> notImplemented ThrowException _ -> notImplemented - ThrowError _ -> emitZeroT - OutputStream -> emitZeroT + ThrowError _ -> zero + OutputStream -> zero ShowAny _ -> error "Shouldn't have ShowAny in simplified IR" ShowScalar _ -> error "Shouldn't have ShowScalar in simplified IR" - where - emitZeroT = withZeroT $ emit =<< renameM (PrimOp $ MiscOp op) - la = linearizeAtom + where zero = emitZeroT op linearizeUnOp :: Emits o => UnOp -> Atom SimpIR i -> LinM i o SAtom SAtom -linearizeUnOp op x' = do - WithTangent x tx <- linearizeAtom x' - let emitZeroT = withZeroT $ emit $ UnOp op x - case op of - Exp -> do - y <- emit $ UnOp Exp x - return $ WithTangent y (bindM2 mul tx (sinkM y)) - Exp2 -> notImplemented - Log -> withT (emit $ UnOp Log x) $ (tx >>= (`div'` sink x)) - Log2 -> notImplemented - Log10 -> notImplemented - Log1p -> notImplemented - Sin -> withT (emit $ UnOp Sin x) $ bindM2 mul tx (emit $ UnOp Cos (sink x)) - Cos -> withT (emit $ UnOp Cos x) $ bindM2 mul tx (neg =<< emit (UnOp Sin (sink x))) - Tan -> notImplemented - Sqrt -> do - y <- emit $ UnOp Sqrt x - return $ WithTangent y do - denominator <- bindM2 mul (2 `fLitLike` sink y) (sinkM y) - bindM2 div' tx (pure denominator) - Floor -> emitZeroT - Ceil -> emitZeroT - Round -> emitZeroT - LGamma -> notImplemented - Erf -> notImplemented - Erfc -> notImplemented - FNeg -> withT (neg x) (neg =<< tx) - BNot -> emitZeroT +linearizeUnOp op x'' = do + WithTangent x' tx' <- linearizeAtom x'' + ans' <- emit $ UnOp op x' + return $ WithTangent ans' do + ans <- sinkM ans' + x <- sinkM x' + tx <- tx' + let zero = zeroLikeT ans + case op of + Exp -> emitLin $ BinOp FMul tx ans + Exp2 -> notImplemented + Log -> emitLin $ BinOp FDiv tx x + Log2 -> notImplemented + Log10 -> notImplemented + Log1p -> notImplemented + Sin -> do + c <- emit $ UnOp Cos x + emitLin $ BinOp FMul tx c + Cos -> do + c <- emit =<< (UnOp FNeg <$> emit (UnOp Sin x)) + emitLin $ BinOp FMul tx c + Tan -> notImplemented + Sqrt -> do + denominator <- fmul (2 `fLitLike` ans) ans + emitLin $ BinOp FDiv tx denominator + Floor -> zero + Ceil -> zero + Round -> zero + LGamma -> notImplemented + Erf -> notImplemented + Erfc -> notImplemented + FNeg -> emitLin $ UnOp FNeg tx + BNot -> zero linearizeBinOp :: Emits o => BinOp -> SAtom i -> SAtom i -> LinM i o SAtom SAtom -linearizeBinOp op x' y' = do - WithTangent x tx <- linearizeAtom x' - WithTangent y ty <- linearizeAtom y' - let emitZeroT = withZeroT $ emit $ BinOp op x y - case op of - IAdd -> emitZeroT - ISub -> emitZeroT - IMul -> emitZeroT - IDiv -> emitZeroT - IRem -> emitZeroT - ICmp _ -> emitZeroT - FAdd -> withT (add x y) (bindM2 add tx ty) - FSub -> withT (sub x y) (bindM2 sub tx ty) - FMul -> withT (mul x y) - (bindM2 add (bindM2 mul (referToPrimal x) ty) - (bindM2 mul tx (referToPrimal y))) - FDiv -> withT (div' x y) do - tx' <- bindM2 div' tx (referToPrimal y) - ty' <- bindM2 div' (bindM2 mul (referToPrimal x) ty) - (bindM2 mul (referToPrimal y) (referToPrimal y)) - sub tx' ty' - FPow -> withT (emit $ BinOp FPow x y) do - px <- referToPrimal x - py <- referToPrimal y - c <- (1.0 `fLitLike` py) >>= (sub py) >>= fpow px - tx' <- bindM2 mul tx (return py) - ty' <- bindM2 mul (bindM2 mul (return px) ty) (flog px) - mul c =<< add tx' ty' - FCmp _ -> emitZeroT - BAnd -> emitZeroT - BOr -> emitZeroT - BXor -> emitZeroT - BShL -> emitZeroT - BShR -> emitZeroT +linearizeBinOp op x'' y'' = do + WithTangent x' tx' <- linearizeAtom x'' + WithTangent y' ty' <- linearizeAtom y'' + ans' <- emit $ BinOp op x' y' + return $ WithTangent ans' do + ans <- sinkM ans' + x <- referToPrimal x' + y <- referToPrimal y' + tx <- tx' + ty <- ty' + let zero = zeroLikeT ans + case op of + IAdd -> zero + ISub -> zero + IMul -> zero + IDiv -> zero + IRem -> zero + ICmp _ -> zero + FAdd -> emitLin $ BinOp FAdd tx ty + FSub -> emitLin $ BinOp FSub tx ty + FMul -> do + t1 <- emitLin $ BinOp FMul ty x + t2 <- emitLin $ BinOp FMul tx y + emitLin $ BinOp FAdd t1 t2 + FDiv -> do + t1 <- emitLin $ BinOp FDiv tx y + xyy <- fdiv x =<< fmul y y + t2 <- emitLin $ BinOp FMul ty xyy + emitLin $ BinOp FSub t1 t2 + FPow -> do + ym1 <- fsub y (1.0 `fLitLike` y) + xpowym1 <- emit $ BinOp FPow x ym1 + xlogx <- fmul x =<< emit (UnOp Log x) + t1 <- emitLin $ BinOp FMul tx y + t2 <- emitLin $ BinOp FMul ty xlogx + t12 <- emitLin $ BinOp FAdd t1 t2 + emitLin $ BinOp FMul xpowym1 t12 + FCmp _ -> zero + BAnd -> zero + BOr -> zero + BXor -> zero + BShL -> zero + BShR -> zero -- This has the same type as `sinkM` and falls back thereto, but recomputes -- indexing a primal array in the tangent to avoid materializing intermediate @@ -575,12 +593,12 @@ referToPrimal x = do linearizePrimCon :: Emits o => Con SimpIR i -> LinM i o SAtom SAtom linearizePrimCon con = case con of - Lit _ -> emitZeroT + Lit _ -> zero ProdCon xs -> fmapLin (Con . ProdCon . fromComposeE) $ seqLin (fmap linearizeAtom xs) SumCon _ _ _ -> notImplemented - HeapVal -> emitZeroT + HeapVal -> zero DepPair _ _ _ -> notImplemented - where emitZeroT = withZeroT $ renameM $ Con con + where zero = emitZeroT con linearizeHof :: Emits o => Hof SimpIR i -> LinM i o SAtom SAtom linearizeHof hof = case hof of @@ -672,21 +690,6 @@ linearizeEffectFun rws (BinaryLamExpr hB refB body) = do return (BinaryLamExpr h b body', linLam') linearizeEffectFun _ _ = error "expect effect function to be a binary lambda" -withT :: PrimalM i o (e1 o) - -> (forall o'. (Emits o', DExt o o') => TangentM o' (e2 o')) - -> PrimalM i o (WithTangent o e1 e2) -withT p t = do - p' <- p - return $ WithTangent p' t - -withZeroT :: PrimalM i o (Atom SimpIR o) - -> PrimalM i o (WithTangent o SAtom SAtom) -withZeroT p = do - p' <- p - return $ WithTangent p' do - pTy <- return $ getType $ sink p' - zeroAt =<< tangentType pTy - notImplemented :: HasCallStack => a notImplemented = error "Not implemented" diff --git a/src/lib/PPrint.hs b/src/lib/PPrint.hs index 24189448..af1c48f1 100644 --- a/src/lib/PPrint.hs +++ b/src/lib/PPrint.hs @@ -970,6 +970,7 @@ instance Pretty LetAnn where PlainLet -> "" InlineLet -> "%inline" NoInlineLet -> "%noinline" + LinearLet -> "%linear" OccInfoPure u -> p u <> line OccInfoImpure u -> p u <> ", impure" <> line diff --git a/src/lib/Transpose.hs b/src/lib/Transpose.hs index e312de43..302ca9e4 100644 --- a/src/lib/Transpose.hs +++ b/src/lib/Transpose.hs @@ -9,8 +9,6 @@ module Transpose (transpose, transposeTopFun) where import Data.Foldable import Data.Functor import Control.Category ((>>>)) -import Control.Monad.Reader -import qualified Data.Set as S import GHC.Stack import Builder @@ -18,7 +16,6 @@ import Core import Err import Imp import IRVariants -import MTL1 import Name import Subst import QueryType @@ -37,7 +34,7 @@ transpose lam ct = liftEmitBuilder $ runTransposeM do {-# SCC transpose #-} runTransposeM :: TransposeM n n a -> BuilderM SimpIR n a -runTransposeM cont = runReaderT1 (ListE []) $ runSubstReaderT idSubst $ cont +runTransposeM cont = runSubstReaderT idSubst $ cont transposeTopFun :: (MonadFail1 m, EnvReader m) @@ -73,20 +70,15 @@ unpackLinearLamExpr lam@(LamExpr bs body) = do -- === transposition monad === +type AtomTransposeSubstVal = TransposeSubstVal (AtomNameC SimpIR) data TransposeSubstVal c n where RenameNonlin :: Name c n -> TransposeSubstVal c n -- accumulator references corresponding to non-ref linear variables - LinRef :: SAtom n -> TransposeSubstVal (AtomNameC SimpIR) n + LinRef :: SAtom n -> AtomTransposeSubstVal n -- as an optimization, we don't make references for trivial vector spaces - LinTrivial :: TransposeSubstVal (AtomNameC SimpIR) n + LinTrivial :: AtomTransposeSubstVal n -type LinRegions = ListE SAtomVar - -type TransposeM a = SubstReaderT TransposeSubstVal - (ReaderT1 LinRegions (BuilderM SimpIR)) a - -type TransposeM' a = SubstReaderT AtomSubstVal - (ReaderT1 LinRegions (BuilderM SimpIR)) a +type TransposeM a = SubstReaderT TransposeSubstVal (BuilderM SimpIR) a -- TODO: it might make sense to replace substNonlin/isLin -- with a single `trySubtNonlin :: e i -> Maybe (e o)`. @@ -99,30 +91,6 @@ substNonlin e = do RenameNonlin v' -> v' _ -> error "not a nonlinear expression") e --- TODO: Can we generalize onNonLin to accept SubstReaderT Name instead of --- SubstReaderT AtomSubstVal? For that to work, we need another combinator, --- that lifts a SubstReader AtomSubstVal into a SubstReader Name, because --- effectsSubstE is currently typed as SubstReader AtomSubstVal. --- Then we can presumably recode substNonlin as `onNonLin substM`. We may --- be able to do that anyway, except we will then need to restrict the type --- of substNonlin to require `SubstE AtomSubstVal e`; but that may be fine. -onNonLin :: HasCallStack - => TransposeM' i o a -> TransposeM i o a -onNonLin cont = do - subst <- getSubst - let subst' = newSubst (\v -> case subst ! v of - RenameNonlin v' -> Rename v' - _ -> error "not a nonlinear expression") - liftSubstReaderT $ runSubstReaderT subst' cont - -isLin :: HoistableE e => e i -> TransposeM i o Bool -isLin e = do - substVals <- mapM lookupSubstM $ freeAtomVarsList @SimpIR e - return $ flip any substVals \case - LinTrivial -> True - LinRef _ -> True - RenameNonlin _ -> False - withAccumulator :: Emits o => SType o @@ -147,43 +115,42 @@ emitCTToRef ref ct = do baseMonoid <- tangentBaseMonoidFor (getType ct) void $ emit $ RefOp ref $ MExtend baseMonoid ct -getLinRegions :: TransposeM i o [SAtomVar o] -getLinRegions = asks fromListE - -extendLinRegions :: SAtomVar o -> TransposeM i o a -> TransposeM i o a -extendLinRegions v cont = local (\(ListE vs) -> ListE (v:vs)) cont - -- === actual pass === -transposeWithDecls :: Emits o => Nest SDecl i i' -> SExpr i' -> SAtom o -> TransposeM i o () +transposeWithDecls :: forall i i' o. Emits o => Nest SDecl i i' -> SExpr i' -> SAtom o -> TransposeM i o () transposeWithDecls Empty atom ct = transposeExpr atom ct -transposeWithDecls (Nest (Let b (DeclBinding _ expr)) rest) result ct = - substExprIfNonlin expr >>= \case - Nothing -> do - ty' <- substNonlin $ getType expr - ctExpr <- withAccumulator ty' \refSubstVal -> - extendSubst (b @> refSubstVal) $ - transposeWithDecls rest result (sink ct) - transposeExpr expr ctExpr - Just nonlinExpr -> do - v <- emitToVar nonlinExpr - extendSubst (b @> RenameNonlin (atomVarName v)) $ - transposeWithDecls rest result ct - -substExprIfNonlin :: SExpr i -> TransposeM i o (Maybe (SExpr o)) -substExprIfNonlin expr = - isLin expr >>= \case - True -> return Nothing - False -> do - onNonLin (substM $ getEffects expr) >>= isLinEff >>= \case - True -> return Nothing - False -> Just <$> substNonlin expr +transposeWithDecls (Nest (Let b (DeclBinding ann expr)) rest) result ct = case ann of + LinearLet -> do + ty' <- substNonlin $ getType expr + case expr of + Project _ i x -> do + continue =<< projectLinearRef x \ref -> emitLin =<< mkProjRef ref (ProjectProduct i) + TabApp _ x i -> do + continue =<< projectLinearRef x \ref -> do + i' <- substNonlin i + emitLin =<< mkIndexRef ref i' + _ -> do + ctExpr <- withAccumulator ty' \refSubstVal -> continue refSubstVal + transposeExpr expr ctExpr + _ -> do + v <- substNonlin expr >>= emitToVar + continue $ RenameNonlin (atomVarName v) + where + continue :: forall o'. (Emits o', Ext o o') => AtomTransposeSubstVal o' -> TransposeM i o' () + continue substVal = do + ct' <- sinkM ct + extendSubst (b @> substVal) $ transposeWithDecls rest result ct' -isLinEff :: EffectRow SimpIR o -> TransposeM i o Bool -isLinEff effs@(EffectRow _ NoTail) = do - regions <- fmap atomVarName <$> getLinRegions - let effRegions = freeAtomVarsList effs - return $ not $ null $ S.fromList effRegions `S.intersection` S.fromList regions +projectLinearRef + :: Emits o + => SAtom i -> (SAtom o -> TransposeM i o (SAtom o)) + -> TransposeM i o (AtomTransposeSubstVal o) +projectLinearRef x f = do + Stuck _ (Var v) <- return x + lookupSubstM (atomVarName v) >>= \case + RenameNonlin _ -> error "nonlinear" + LinRef ref -> LinRef <$> f ref + LinTrivial -> return LinTrivial getTransposedTopFun :: EnvReader m => TopFunName n -> m n (Maybe (TopFunName n)) getTransposedTopFun f = do @@ -200,44 +167,23 @@ transposeExpr expr ct = case expr of xsNonlin' <- mapM substNonlin xsNonlin ct' <- naryTopApp fT (xsNonlin' ++ [ct]) transposeAtom xLin ct' - -- TODO: Instead, should we handle table application like nonlinear - -- expressions, where we just project the reference? - TabApp _ x i -> do - i' <- substNonlin i - case x of - Stuck _ stuck -> case stuck of - Var v -> do - lookupSubstM (atomVarName v) >>= \case - RenameNonlin _ -> error "shouldn't happen" - LinRef ref -> do - refProj <- indexRef ref i' - emitCTToRef refProj ct - LinTrivial -> return () - StuckProject _ _ -> undefined - StuckTabApp _ _ -> undefined - PtrVar _ _ -> error "not tangent" - RepValAtom _ -> error "not tangent" - _ -> error $ "shouldn't occur: " ++ pprint x PrimOp op -> transposeOp op ct Case e alts _ -> do - linearScrutinee <- isLin e - case linearScrutinee of - True -> notImplemented - False -> do - e' <- substNonlin e - void $ buildCase e' UnitTy \i v -> do - v' <- emitToVar v - Abs b body <- return $ alts !! i - extendSubst (b @> RenameNonlin (atomVarName v')) do - transposeExpr body (sink ct) - return UnitVal + e' <- substNonlin e + void $ buildCase e' UnitTy \i v -> do + v' <- emitToVar v + Abs b body <- return $ alts !! i + extendSubst (b @> RenameNonlin (atomVarName v')) do + transposeExpr body (sink ct) + return UnitVal TabCon _ ty es -> do TabTy d b _ <- return ty idxTy <- substNonlin $ IxType (binderType b) d forM_ (enumerate es) \(ordinalIdx, e) -> do i <- unsafeFromOrdinal idxTy (IdxRepVal $ fromIntegral ordinalIdx) tabApp ct i >>= transposeAtom e - Project _ _ _ -> undefined + TabApp _ _ _ -> error "should have been handled by reference projection" + Project _ _ _ -> error "should have been handled by reference projection" transposeOp :: Emits o => PrimOp SimpIR i -> SAtom o -> TransposeM i o () transposeOp op ct = case op of @@ -262,18 +208,21 @@ transposeOp op ct = case op of ProjRef _ _ -> notImplemented Hof (TypedHof _ hof) -> transposeHof hof ct MiscOp miscOp -> transposeMiscOp miscOp ct - UnOp FNeg x -> transposeAtom x =<< neg ct + UnOp FNeg x -> transposeAtom x =<< (emitLin $ UnOp FNeg ct) UnOp _ _ -> notLinear BinOp FAdd x y -> transposeAtom x ct >> transposeAtom y ct - BinOp FSub x y -> transposeAtom x ct >> (transposeAtom y =<< neg ct) + BinOp FSub x y -> transposeAtom x ct >> (transposeAtom y =<< (emitLin $ UnOp FNeg ct)) + -- XXX: linear argument to FMul is always first BinOp FMul x y -> do - xLin <- isLin x - if xLin - then transposeAtom x =<< mul ct =<< substNonlin y - else transposeAtom y =<< mul ct =<< substNonlin x - BinOp FDiv x y -> transposeAtom x =<< div' ct =<< substNonlin y + y' <- substNonlin y + tx <- emitLin $ BinOp FMul ct y' + transposeAtom x tx + BinOp FDiv x y -> do + y' <- substNonlin y + tx <- emitLin $ BinOp FDiv ct y' + transposeAtom x tx BinOp _ _ _ -> notLinear - MemOp _ -> notLinear + MemOp _ -> notLinear VectorOp _ -> unreachable where notLinear = error $ "Can't transpose a non-linear operation: " ++ pprint op @@ -291,10 +240,9 @@ transposeMiscOp op _ = case op of BitcastOp _ _ -> notImplemented UnsafeCoerce _ _ -> notImplemented GarbageVal _ -> notImplemented - ShowAny _ -> error "Shouldn't have ShowAny in simplified IR" - ShowScalar _ -> error "Shouldn't have ShowScalar in simplified IR" - where - notLinear = error $ "Can't transpose a non-linear operation: " ++ show op + ShowAny _ -> notLinear + ShowScalar _ -> notLinear + where notLinear = error $ "Can't transpose a non-linear operation: " ++ show op transposeAtom :: HasCallStack => Emits o => SAtom i -> SAtom o -> TransposeM i o () transposeAtom atom ct = case atom of @@ -308,16 +256,9 @@ transposeAtom atom ct = case atom of return () LinRef ref -> emitCTToRef ref ct LinTrivial -> return () - StuckProject _ _ -> error "not implemented" - StuckTabApp _ _ -> error "not implemented" - -- let (idxs, v) = asNaryProj i' x' - -- lookupSubstM (atomVarName v) >>= \case - -- RenameNonlin _ -> error "an error, probably" - -- LinRef ref -> do - -- ref' <- applyProjectionsRef (toList idxs) ref - -- emitCTToRef ref' ct - -- LinTrivial -> return () - RepValAtom _ -> error "not implemented" + StuckProject _ _ -> error "not linear" + StuckTabApp _ _ -> error "not linear" + RepValAtom _ -> error "not linear" where notTangent = error $ "Not a tangent atom: " ++ pprint atom transposeHof :: Emits o => Hof SimpIR i -> SAtom o -> TransposeM i o () @@ -333,8 +274,7 @@ transposeHof hof ct = case hof of (ctBody, ctState) <- fromPair ct (_, cts) <- (fromPair =<<) $ emitRunState noHint ctState \h ref -> do extendSubst (hB@>RenameNonlin (atomVarName h)) $ extendSubst (refB@>RenameNonlin (atomVarName ref)) $ - extendLinRegions h $ - transposeExpr body (sink ctBody) + transposeExpr body (sink ctBody) return UnitVal transposeAtom s cts RunReader r (BinaryLamExpr hB refB body) -> do @@ -342,8 +282,7 @@ transposeHof hof ct = case hof of baseMonoid <- tangentBaseMonoidFor accumTy (_, ct') <- (fromPair =<<) $ emitRunWriter noHint accumTy baseMonoid \h ref -> do extendSubst (hB@>RenameNonlin (atomVarName h)) $ extendSubst (refB@>RenameNonlin (atomVarName ref)) $ - extendLinRegions h $ - transposeExpr body (sink ct) + transposeExpr body (sink ct) return UnitVal transposeAtom r ct' RunWriter Nothing _ (BinaryLamExpr hB refB body)-> do @@ -351,8 +290,7 @@ transposeHof hof ct = case hof of (ctBody, ctEff) <- fromPair ct void $ emitRunReader noHint ctEff \h ref -> do extendSubst (hB@>RenameNonlin (atomVarName h)) $ extendSubst (refB@>RenameNonlin (atomVarName ref)) $ - extendLinRegions h $ - transposeExpr body (sink ctBody) + transposeExpr body (sink ctBody) return UnitVal _ -> notImplemented diff --git a/src/lib/Types/Primitives.hs b/src/lib/Types/Primitives.hs index d85d8824..83ba3ffb 100644 --- a/src/lib/Types/Primitives.hs +++ b/src/lib/Types/Primitives.hs @@ -64,6 +64,7 @@ data LetAnn = | InlineLet -- Binding explicitly tagged "do not inline" | NoInlineLet + | LinearLet -- Bound expression is pure, and the binding's occurrences are summarized by -- the UsageInfo | OccInfoPure UsageInfo |