diff options
author | Alexey Radul <axch@google.com> | 2023-02-05 09:17:54 -0500 |
---|---|---|
committer | axch <233710+axch@users.noreply.github.com> | 2023-02-05 10:21:46 -0500 |
commit | 3de19605c138ad666fa7d328b3a931922959c5c9 (patch) | |
tree | faf3f7144a4a055e0faa0ab13b176003f6903d58 /lib/prelude.dx | |
parent | fbb56181aa0a766fe462333cdc80969c288b68c8 (diff) |
Rewrite general_integer_power with a `for` because Dex can't differentiate through `while`.
There is still a `while` in there, to compute the number of iterations
of the squaring loop, but
- It doesn't touch the floating-point input, so can't mess up AD, and
- We should replace it with a count-leading-zeros intrinsic anyway.
Differentiating through general `while` is a problem because the tape
becomes a recursive ADT (namely List), and Dex doesn't support those yet.
Fixes #1195.
Diffstat (limited to 'lib/prelude.dx')
-rw-r--r-- | lib/prelude.dx | 22 |
1 files changed, 12 insertions, 10 deletions
diff --git a/lib/prelude.dx b/lib/prelude.dx index aedd1012..94458e01 100644 --- a/lib/prelude.dx +++ b/lib/prelude.dx @@ -2281,6 +2281,12 @@ def is_power_of_2 (x:Nat) : Bool = then False else 0 == %and x' (%isub x' (1::NatRep)) +-- This computes the integer part of the binary logarithm of the input. +-- TODO: natlog2 0 should do something other than underflow the answer. +-- TODO: Use LLVM ctlz intrinsic instead. It needs a slightly new +-- code path in ImpToLLVM, because it's the first LLVM intrinsic +-- we have with a fixed-point argument. +-- https://llvm.org/docs/LangRef.html#llvm-ctlz-intrinsic def natlog2 (x:Nat) : Nat = tmp = yield_state 0 \ans. cmp <- run_state 1 @@ -2295,22 +2301,18 @@ def natlog2 (x:Nat) : Nat = unsafe_nat_diff tmp 1 -- TODO: something less horrible def general_integer_power {a} (times:a->a->a) (one:a) (base:a) (power:Nat) : a = + iters = if power == 0 then 0 else 1 + natlog2 power -- Implements exponentiation by squaring. -- This could be nicer if there were a way to explicitly -- specify which typelcass instance to use for Mul. yield_state one \ans. pow <- with_state power z <- with_state base - while do - if get pow > 0 - then - if is_odd (get pow) - then ans := times (get ans) (get z) - z := times (get z) (get z) - pow := intdiv2 (get pow) - True - else - False + for _:(Fin iters). + if is_odd (get pow) + then ans := times (get ans) (get z) + z := times (get z) (get z) + pow := intdiv2 (get pow) def intpow {a} [Mul a] (base:a) (power:Nat) : a = general_integer_power (*) one base power |