-- TODO: use prelude sum instead once we can differentiate state effect def sum'(xs:n=>Float) -> Float given (n|Ix) = yield_accum (AddMonoid Float) \ref. for i. ref += xs[i] :p f : (Float) -> Float = \x. x jvp f 3.0 2.0 > 2. :p f = \x. x * x jvp f 3.0 1.5 > 9. :p f = \x. x + x jvp f 3.0 2.0 > 4. :p f = \x. x * x * x jvp f 2.0 1.5 > 18. :p f : (Float) -> Float = \x. x transpose_linear(f)(2.0) > 2. :p f : (Float) -> Float = \x. x + x transpose_linear(f)(1.0) > 2. :p f : (Float) -> Float = \x. x + (x + x) * 2.0 transpose_linear(f)(1.0) > 5. :p f : (Float) -> Float = \x. x * 2.0 transpose_linear(f)(1.0) > 2. :p f : (Float) -> Float = \x. 2.0 * x transpose_linear(f)(1.0) > 2. :p grad (\x. x * x) 1.0 > 2. :p deriv (\x. 3.0 / x) 2.0 > -0.75 :p deriv (\x. x / 2.0) 3.0 > 0.5 :p f : (given (n|Ix), n=>Float) -> n=>Float = \xs. for i. xs[i] * xs[i] jvp f [1.,2.] [3.,4.] > [6., 16.] :p jvp transpose [[1.,2.], [3.,4.]] [[10.,20.], [30.,40.]] > [[10., 30.], [20., 40.]] :p jvp sum' [1., 2.] [10.0, 20.0] > 30. :p f : (Float) -> Float = \x. yield_accum (AddMonoid Float) \ref. ref += x jvp f 1.0 1.0 > 1. :p f = \x. x * x * x jvp (\x. jvp f x 1.0) 2.0 1.0 > 12. :p f = \x. 4.0 * x * x * x deriv (\x1. deriv (\x2. deriv f x2) x1) 1.234 > 24. :p f : (Float) -> (Float, Float) = \x. (x, 2.0 * x) transpose_linear(f)((1.0, 3.0)) > 7. :p f : ((Float, Float)) -> Float = \p. (x, y) = p x + 2.0 * y transpose_linear(f)(1.0) > (1., 2.) :p deriv cos 0.0 > 0. :p deriv sin 0.0 > 1. :p (sin 1.0, deriv (\x. deriv sin x) 1.0) > (0.841471, -0.841471) :p (cos 1.0, deriv (\x'. deriv (\x. deriv sin x) x') 1.0) > (0.5403023, -0.5403023) :p check_deriv sin 1.0 > True :p check_deriv cos 1.0 > True :p check_deriv exp 2.0 > True :p check_deriv log 2.0 > True :p check_deriv (\x. exp (sin (cos x))) 2.0 > True :p check_deriv (\x. deriv sin x) 1.0 > True :p check_deriv (\x. deriv cos x) 1.0 > True :p check_deriv sqrt 4.0 > True -- badDerivFun : Float -> Float -- badDerivFun x = x -- badDerivFun#lin : Float -> (Float, Float -> Float) -- badDerivFun#lin x = (x, llam t. 2. * t) -- :p checkDeriv badDerivFun 1.0 -- > False -- Perturbation confusion test suggested by Barak Pearlmutter -- https://github.com/HIPS/autograd/issues/4 :p deriv (\x. x * deriv (\y. x * y) 2.0) 1.0 > 2. tripleit : (Float) -> Float = \x. x + x + x :p tripleit 1.0 > 3. :p transpose_linear(tripleit)(1.0) > 3. :p transpose_linear(transpose_linear tripleit)(1.0) > 3. :p f : (given (n|Ix), Float) -> n=>Float = \x. for i. x transpose_linear(f)([1.0, 2.0]) > 3. :p f : (given (n|Ix), n=>Float) -> n=>Float = \x. for i. x[i] * 2.0 transpose_linear(f)([1.0, 2.0]) > [2., 4.] myOtherSquare : (Float) -> Float = \x. yield_accum (AddMonoid Float) \w. w += x * x :p check_deriv myOtherSquare 3.0 > True :p f : (Float) -> Float = \x. fst (x * x, 2 + 1) jvp f 1.0 3.0 > 6. :p f : (Float) -> Float = \x. x * i_to_f (1 + 1) jvp f 1.0 2.0 > 4. :p f : (Fin 2=>Float) -> Float = \xs. xs[0 @ Fin 2] * xs[1 @ Fin 2] jvp f [1., 2.] [3.0, 4.0] > 10. :p f : ((Float, Float)) -> Float = \p. (x,y) = p x * y jvp f (1., 2.) (3.0, 4.0) > 10. :p f : (given (n|Ix), n=>Float) -> n=>Float = \xs. for i. xs[i] * xs[i] jvp f [1.,2.] [3.,4.] > [6., 16.] :p jvp sum' [1., 2.] [3.0, 4.0] > 7. :p grad sum' [1.,2.] > [1., 1.] vec = [1.] :p jvp (\x. vec) [1.] [1.] > [0.] :p grad (\p. (x, y) = p vdot x y) ([1.,2.], [3.,4.]) > ([3., 4.], [1., 2.]) :p f : (Float) -> Float = \x. y = x * 2.0 yield_accum (AddMonoid Float) \a. a += x * 2.0 a += y grad f 1.0 > 4. :p f : (Float) -> Float = \x. x2 = x * x with_reader x \xr. with_reader x2 \x2r. 5.0 * (ask x2r) + 4.0 * (ask xr) + 2.0 check_deriv f 2.0 > True :p f : (Float) -> Float = \x. yield_state x \xr. for i:(Fin 2). xr := get xr * get xr check_deriv f 2.0 > True -- :p -- f = \rec. -- ({x=x, y=y, z=z}) = rec -- x * y * i_to_f z -- (check_deriv (\x. f {x=x, y=4.0, z=5}) 2.0, check_deriv (\y. f {x=2.0, y=y, z=5}) 4.0) -- > (True, True) -- TODO: Re-enable once the big PR is merged -- :p -- f = \x. -- y = for i:(Fin 10). { x=x * (IToF $ ordinal i) } -- z = for i. -- ({ x=x, ... }) = y.i -- x -- sum' z -- checkDeriv f 1.0 -- > True -- :p -- f = \x. for i:(Fin 4). { x=x * x * (n_to_f $ ordinal i) } -- jvp f 2.0 1.0 -- > [{x = 0.}, {x = 4.}, {x = 8.}, {x = 12.}] :p f = \x. max(0.0, x) (check_deriv_base f 1.0, check_deriv_base f (-1.0)) > (True, True) :p xs = for i:(Fin 2). 2.0 f = \x. sum xs check_deriv f 1.0 > True :p f = \c. v = for i:(Fin 2). 2.0 (c, v) = yield_state (c, v) \r. for i:(Fin 2). (c, v) = get r r := (c + sum v, v) c check_deriv f 1.0 > True -- Test reference indexing :p f = \x. i = unsafe_from_ordinal(n=Fin 3, 0) mat = yield_state zero \m. m!i!i := x mat[i,i] check_deriv f 1.0 > True -- Regression test for bug #841, linearization through triangular table func = \x:Float. table = for i:(Fin 2). for j:(..i). x tmp = (0 @ (Fin 2)) sum table[tmp] snd(linearize func 1.0)(2.0) > 2. -- Nested AD, examples from #713 def func2(x:m=>Float) -> Float given (m|Ix) = exp (0.5 * sum for i. sq x[i]) def hvpf(x:m=>Float, v:m=>Float) -> m=>Float given (m|Ix) = (dot x v) .* (grad func2 x) + func2(x) .* v x = [0.1, 0.2, 0.3] v = [0.2, 0.3, 0.4] func2 x ~~ 1.072508 > True -- analytic result = hvpf x v result ~~ [0.235952, 0.364653, 0.493354] > True -- finite diff over reverse eps = 0.0001 (grad func2 (x + eps .* v) - grad func2 x) / eps ~~ [0.235885, 0.364631, 0.493228] > True -- reverse over forward grad (\x. jvp func2 x v) x ~~ result > True -- reverse over reverse grad (\x. dot (grad func2 x) v) x ~~ result > True -- forward over reverse jvp (\x. grad func2 x) x v ~~ result > True -- Regression test for bug #848, AD through state effect over case def min'(a:Float, b:Float) -> Float = yield_state a \s. best = get s new_best = select (best < b) best 2.0 s := new_best grad (\x . (min' x (x+1))) 1.0 > 1. grad (\x . [x][argmin [x]]) 1.0 > 1. -------------------- Custom linearization -------------------- @noinline f = \x:Float. x * 2 fLin = \x:Float. (f x, \xt:Float. xt * 4) custom-linearization f fLin deriv (\x. f (x * 2) * 5) 1.0 > 40. deriv f 1.0 > 4. -- Custom derivative is preserved even when we defunctionalize for flip deriv 1.0 \x:Float. farr = for i:(Fin 5). f farr[unsafe_from_ordinal 1] x > Error: variable not in scope: flip > > flip deriv 1.0 \x:Float. > ^^^^^ -- Custom derivative is preserved even when we defunctionalize case (\x f. deriv f x) 1.0 \x:Float. total = sum $ for i:(Fin 10). ordinal i f2 = case total > 10 of True -> f False -> \x:Float. x f2 x > 4. @noinline g = \x:(Float, Float). 4 .* x custom-linearization g \x. (g x, g x) > Type error:Expected the custom linearization to have type: > > ((x.1:(Float32, Float32)) -> ((Float32, Float32) > , ((v#0:(Float32, Float32)) -> (Float32 > , Float32)))) > > but it has type: > > ((x.1:(Float32, Float32)) -> ((Float32, Float32), (Float32, Float32))) data IHaveNoTangentType = MkIHaveNoTangentType @noinline noInputTangent = \x:IHaveNoTangentType. 2.0 custom-linearization noInputTangent \x. (noInputTangent x, 0.0) > Type error:No tangent type for: IHaveNoTangentType @noinline noOutputTangent = \x:Float. MkIHaveNoTangentType custom-linearization noOutputTangent \x. (noOutputTangent x, 0.0) > Type error:No tangent type for: IHaveNoTangentType -- Functions of multiple arguments should be supported @noinline h = \x:Float y:Float. x * 2 + y hLin = \x y. (h x y, \xt:Float yt:Float. xt + yt) custom-linearization h hLin deriv (\x. h x x) 1.0 == deriv (\x. x + x) 1.0 > True -- Index-polymorphic custom linearization @noinline def w(x:n=>Float) -> n=>Float given (n|Ix) = 2 .* x def wLin(x:n=>Float) -> (n=>Float, (n=>Float) -> n=>Float) given (n|Ix) = (w x, \xt:(n=>Float). 4 .* xt) custom-linearization w wLin jvp (\x. w x) [1.0, 1.0] [1.0, 0.0] > [4., 0.] jvp w [1.0, 1.0] [1.0, 0.0] > [4., 0.] ------ standalone functions ------ @noinline def xy2(x:Float, y:Float) -> Float = x * y * y :p check_deriv (\x. xy2 x x) 2.0 > True :p check_deriv (\x. xy2 x 3.0) 2.0 > True :p check_deriv (\x. xy2 2.0 3.0) 2.0 > True :p check_deriv (\x. xy2 5.0 x) 2.0 > True ------ Symbolic tangents ------ def ST(a) = SymbolicTangent(a) @noinline def q(x:Float, y:Float) -> Float = x * 2 + y def qLin(x:Float, y:Float) -> _ = def lin(xt:ST Float, yt: ST Float) -> Float = case yt of ZeroTangent -> someTangent xt * 4 SomeTangent yt' -> someTangent xt * 2 + yt' (q x y, lin) custom-linearization q qLin > Type error:Expected the custom linearization to have type: > > ((x.1:Float32,y:Float32) -> (Float32, ((v#0:Float32,v#1:Float32) -> Float32))) > > but it has type: > > ((x.1:Float32,y:Float32) -> (Float32 > , ((xt:(SymbolicTangent Float32),yt:(SymbolicTangent > Float32)) -> Float32))) custom-linearization-symbolic q qLin -- Incorrect derivative, based on the ZeroTangent branch deriv (\x. q x 1.0) 1.0 > 4. -- Correct derivative, based on the SomeTangent branch deriv (\x. q x x) 1.0 > 3. ------ Check custom linearization of matmul ------ amat = for i:(Fin 100) j:(Fin 100). n_to_f $ ordinal (i, j) -- The derivative of matmul should give the same answers as a direct -- matmul (this checks that the custom derivative is not too busted). def mmp'(m1:l=>m=>Float, m2:m=>n=>Float) -> l=>n=>Float given (l|Ix, m|Ix, n|Ix) = jvp (\m. m1 ** m) m2 m2 :p mmp'(amat, amat) ~~ naive_matmul(amat, amat) > True -- Let's check the other orientation too def mmp''(m1:l=>m=>Float, m2:m=>n=>Float) -> l=>n=>Float given (l|Ix, m|Ix, n|Ix) = jvp (\m. m ** m2) m1 m1 :p mmp''(amat, amat) ~~ naive_matmul(amat, amat) > True