summaryrefslogtreecommitdiff
path: root/lib/prelude.dx
diff options
context:
space:
mode:
authorAlexey Radul <axch@google.com>2023-02-05 09:17:54 -0500
committeraxch <233710+axch@users.noreply.github.com>2023-02-05 10:21:46 -0500
commit3de19605c138ad666fa7d328b3a931922959c5c9 (patch)
treefaf3f7144a4a055e0faa0ab13b176003f6903d58 /lib/prelude.dx
parentfbb56181aa0a766fe462333cdc80969c288b68c8 (diff)
Rewrite general_integer_power with a `for` because Dex can't differentiate through `while`.
There is still a `while` in there, to compute the number of iterations of the squaring loop, but - It doesn't touch the floating-point input, so can't mess up AD, and - We should replace it with a count-leading-zeros intrinsic anyway. Differentiating through general `while` is a problem because the tape becomes a recursive ADT (namely List), and Dex doesn't support those yet. Fixes #1195.
Diffstat (limited to 'lib/prelude.dx')
-rw-r--r--lib/prelude.dx22
1 files changed, 12 insertions, 10 deletions
diff --git a/lib/prelude.dx b/lib/prelude.dx
index aedd1012..94458e01 100644
--- a/lib/prelude.dx
+++ b/lib/prelude.dx
@@ -2281,6 +2281,12 @@ def is_power_of_2 (x:Nat) : Bool =
then False
else 0 == %and x' (%isub x' (1::NatRep))
+-- This computes the integer part of the binary logarithm of the input.
+-- TODO: natlog2 0 should do something other than underflow the answer.
+-- TODO: Use LLVM ctlz intrinsic instead. It needs a slightly new
+-- code path in ImpToLLVM, because it's the first LLVM intrinsic
+-- we have with a fixed-point argument.
+-- https://llvm.org/docs/LangRef.html#llvm-ctlz-intrinsic
def natlog2 (x:Nat) : Nat =
tmp = yield_state 0 \ans.
cmp <- run_state 1
@@ -2295,22 +2301,18 @@ def natlog2 (x:Nat) : Nat =
unsafe_nat_diff tmp 1 -- TODO: something less horrible
def general_integer_power {a} (times:a->a->a) (one:a) (base:a) (power:Nat) : a =
+ iters = if power == 0 then 0 else 1 + natlog2 power
-- Implements exponentiation by squaring.
-- This could be nicer if there were a way to explicitly
-- specify which typelcass instance to use for Mul.
yield_state one \ans.
pow <- with_state power
z <- with_state base
- while do
- if get pow > 0
- then
- if is_odd (get pow)
- then ans := times (get ans) (get z)
- z := times (get z) (get z)
- pow := intdiv2 (get pow)
- True
- else
- False
+ for _:(Fin iters).
+ if is_odd (get pow)
+ then ans := times (get ans) (get z)
+ z := times (get z) (get z)
+ pow := intdiv2 (get pow)
def intpow {a} [Mul a] (base:a) (power:Nat) : a =
general_integer_power (*) one base power