'# Dex prelude 'Runs before every Dex program unless an alternative is provided with `--prelude`. '## Essentials ### Primitive Types Type = %TyKind() Heap = %HeapType() Effects = %EffKind() Fields = %LabeledRowKind() Int64 = %Int64() Int32 = %Int32() Float64 = %Float64() Float32 = %Float32() Word8 = %Word8() Word32 = %Word32() Word64 = %Word64() Byte = Word8 Char = Byte Label = %Label() RawPtr : Type = %Word8Ptr() Int = Int32 Float = Float32 def the(a:Type, x:a) -> a = x interface Data(a:Type) do_not_implement_this_interface_for_the_compiler_relies_on_the_invariant_it_protects : (a) -> a '### Casting def internal_cast(x:from) -> to given (from, to) = %cast(to, x) def unsafe_coerce(x:from) -> to given (from|Data, to|Data) = %unsafeCoerce(to, x) def uninitialized_value() -> a given (a|Data) = %garbageVal(a) def f64_to_f(x: Float64) -> Float = internal_cast x def f32_to_f(x: Float32) -> Float = internal_cast x def f_to_f64(x: Float) -> Float64 = internal_cast x def f_to_f32(x: Float) -> Float32 = internal_cast x def i64_to_i(x: Int64) -> Int = internal_cast x def i32_to_i(x: Int32) -> Int = internal_cast x def w8_to_i(x: Word8) -> Int = internal_cast x def i_to_i64(x: Int) -> Int64 = internal_cast x def i_to_i32(x: Int) -> Int32 = internal_cast x def i_to_w8(x: Int) -> Word8 = internal_cast x def i_to_w32(x: Int) -> Word32 = internal_cast x def i_to_w64(x: Int) -> Word64 = internal_cast x def w32_to_w64(x: Word32)-> Word64 = internal_cast x def i_to_f(x:Int) -> Float = internal_cast x def f_to_i(x:Float) -> Int = internal_cast x def raw_ptr_to_i64(x:RawPtr) -> Int64 = internal_cast x Nat = %Nat() NatRep = Word32 def nat_to_rep(x : Nat) -> NatRep = %projNewtype(x) def rep_to_nat(x : NatRep) -> Nat = %NatCon(x) def n_to_w8(x: Nat) -> Word8 = nat_to_rep x | internal_cast def n_to_w32(x: Nat) -> Word32 = nat_to_rep x | internal_cast def n_to_w64(x: Nat) -> Word64 = nat_to_rep x | internal_cast def n_to_i32(x: Nat) -> Int32 = nat_to_rep x | internal_cast def n_to_i64(x: Nat) -> Int64 = nat_to_rep x | internal_cast def n_to_f32(x: Nat) -> Float32 = nat_to_rep x | internal_cast def n_to_f64(x: Nat) -> Float64 = nat_to_rep x | internal_cast def n_to_f(x: Nat) -> Float = nat_to_rep x | internal_cast def w8_to_n(x : Word8) -> Nat = internal_cast x | rep_to_nat def w32_to_n(x : Word32) -> Nat = internal_cast x | rep_to_nat def w64_to_n(x : Word64) -> Nat = internal_cast x | rep_to_nat def i32_to_n(x : Int32) -> Nat = internal_cast x | rep_to_nat def i64_to_n(x : Int64) -> Nat = internal_cast x | rep_to_nat def f32_to_n(x : Float32) -> Nat = internal_cast x | rep_to_nat def f64_to_n(x : Float64) -> Nat = internal_cast x | rep_to_nat def f_to_n(x : Float) -> Nat = internal_cast x | rep_to_nat interface FromUnsignedInteger(a:Type) from_unsigned_integer : (Word64) -> a instance FromUnsignedInteger(Word8) def from_unsigned_integer(x) = internal_cast x instance FromUnsignedInteger(Word32) def from_unsigned_integer(x) = internal_cast x instance FromUnsignedInteger(Word64) def from_unsigned_integer(x) = x instance FromUnsignedInteger(Int32) def from_unsigned_integer(x) = internal_cast x instance FromUnsignedInteger(Int64) def from_unsigned_integer(x) = internal_cast x instance FromUnsignedInteger(Float32) def from_unsigned_integer(x) = internal_cast x instance FromUnsignedInteger(Float64) def from_unsigned_integer(x) = internal_cast x instance FromUnsignedInteger(Nat) def from_unsigned_integer(x) = w64_to_n(x) interface FromInteger(a:Type) from_integer : (Int64) -> a instance FromInteger(Float32) def from_integer(x) = internal_cast x instance FromInteger(Int32) def from_integer(x) = internal_cast x instance FromInteger(Float64) def from_integer(x) = internal_cast x instance FromInteger(Int64) def from_integer(x) = x '## Bitwise operations interface Bits(a:Type) (.<<.) : (a, Int) -> a (.>>.) : (a, Int) -> a (.|.) : (a, a) -> a (.&.) : (a, a) -> a (.^.) : (a, a) -> a instance Bits(Word8) def (.<<.)(x, y) = %shl(x, i_to_w8 y) def (.>>.)(x, y) = %shr(x, i_to_w8 y) def (.|.)(x, y) = %or( x, y) def (.&.)(x, y) = %and(x, y) def (.^.)(x, y) = %xor(x, y) instance Bits(Word32) def (.<<.)(x, y) = %shl(x, i_to_w32 y) def (.>>.)(x, y) = %shr(x, i_to_w32 y) def (.|.)(x, y) = %or( x, y) def (.&.)(x, y) = %and(x, y) def (.^.)(x, y) = %xor(x, y) instance Bits(Word64) def (.<<.)(x, y) = %shl(x, i_to_w64 y) def (.>>.)(x, y) = %shr(x, i_to_w64 y) def (.|.)(x, y) = %or( x ,y) def (.&.)(x, y) = %and(x ,y) def (.^.)(x, y) = %xor(x ,y) def low_word( x : Word64) -> Word32 = internal_cast(x .>>. 32) def high_word(x : Word64) -> Word32 = internal_cast(x) '### Basic Arithmetic #### Add Things that can be added. This defines the `Add` [group](https://en.wikipedia.org/wiki/Group_(mathematics)) and its operators. interface Add(a|Data) (+) : (a, a) -> a zero : a interface Sub(a|Add) (-) : (a, a) -> a instance Add(Float64) def (+)(x, y) = %fadd(x, y) zero = 0 instance Sub(Float64) def (-)(x, y) = %fsub(x, y) instance Add(Float32) def (+)(x, y) = %fadd(x, y) zero = 0 instance Sub(Float32) def (-)(x, y) = %fsub(x, y) instance Add(Int64) def (+)(x, y) = %iadd(x, y) zero = 0 instance Sub(Int64) def (-)(x, y) = %isub(x, y) instance Add(Int32) def (+)(x, y) = %iadd(x, y) zero = 0 instance Sub(Int32) def (-)(x, y) = %isub(x, y) instance Add(Word8) def (+)(x, y) = %iadd(x, y) zero = 0 instance Sub(Word8) def (-)(x, y) = %isub(x, y) instance Add(Word32) def (+)(x, y) = %iadd(x, y) zero = 0 instance Sub(Word32) def (-)(x, y) = %isub(x, y) instance Add(Word64) def (+)(x, y) = %iadd(x, y) zero = 0 instance Sub(Word64) def (-)(x, y) = %isub(x, y) instance Add(Nat) def (+)(x, y) = rep_to_nat %iadd(nat_to_rep x, nat_to_rep y) zero = 0 instance Add(()) def (+)(x, y) = () zero = () instance Sub(()) def (-)(x, y) = () '#### Mul Things that can be multiplied. This defines the `Mul` [Monoid](https://en.wikipedia.org/wiki/Monoid), and its operator. interface Mul(a|Data) (*) : (a, a) -> a one : a instance Mul(Float64) def (*)(x, y) = %fmul(x, y) one = f_to_f64 1.0 instance Mul(Float32) def (*)(x, y) = %fmul(x, y) one = f_to_f32 1.0 instance Mul(Int64) def (*)(x, y) = %imul(x, y) one = 1 instance Mul(Int32) def (*)(x, y) = %imul(x, y) one = 1 instance Mul(Word8) def (*)(x, y) = %imul(x, y) one = 1 instance Mul(Word32) def (*)(x, y) = %imul(x, y) one = 1 instance Mul(Word64) def (*)(x, y) = %imul(x, y) one = 1 instance Mul(Nat) def(*)(x, y) = rep_to_nat %imul(nat_to_rep x, nat_to_rep y) one = 1 instance Mul(()) def (*)(x, y) = () one = () '#### Integral Integer-like things. interface Integral(a) idiv : (a,a)->a rem : (a,a)->a instance Integral(Int64) def idiv(x, y) = %idiv(x, y) def rem(x, y) = %irem(x, y) instance Integral(Int32) def idiv(x, y) = %idiv(x, y) def rem(x, y) = %irem(x, y) instance Integral(Word8) def idiv(x, y) = %idiv(x, y) def rem(x, y) = %irem(x, y) instance Integral(Word32) def idiv(x, y) = %idiv(x, y) def rem(x, y) = %irem(x, y) instance Integral(Word64) def idiv(x, y) = %idiv(x, y) def rem(x, y) = %irem(x, y) instance Integral(Nat) def idiv(x, y) = rep_to_nat %idiv(nat_to_rep x, (nat_to_rep y)) def rem(x, y) = rep_to_nat %irem(nat_to_rep x, (nat_to_rep y)) '#### Fractional Rational-like things. Includes floating point and two field rational representations. interface Fractional(a) divide : (a, a) -> a instance Fractional(Float64) def divide(x, y) = %fdiv(x, y) instance Fractional(Float32) def divide(x, y) = %fdiv(x, y) '## Index set interface and instances interface Ix(n|Data) size' : () -> Nat ordinal : (n) -> Nat unsafe_from_ordinal : (Nat) -> n def size(n|Ix) -> Nat = size'(n=n) def Fin(n:Nat) -> Type = %Fin(n) -- version of subtraction on Nats that clamps at zero def (-|)(x: Nat, y:Nat) -> Nat = x' = nat_to_rep x y' = nat_to_rep y requires_clamp = %ilt(x', y') rep_to_nat %select(requires_clamp, 0, (%isub(x', y'))) def unsafe_nat_diff(x:Nat, y:Nat) -> Nat = x' = nat_to_rep x y' = nat_to_rep y rep_to_nat %isub(x', y') -- `(i..)` parses as `RangeFrom _ i` -- TODO: need to a way to indicate constructor as private struct RangeFrom(q:Type, i:q) = val : Nat -- `(i<..)` parses as `RangeFromExc _ i` struct RangeFromExc(q:Type, i:q) = val : Nat -- `(..i)` parses as `RangeTo _ i` struct RangeTo(q:Type, i:q) = val : Nat -- `(.. n=>Nat = for i. ordinal i '## Arithmetic instances for table types instance Add(n=>a) given (a|Add, n|Ix) def (+)(xs, ys) = for i. xs[i] + ys[i] zero = for _. zero instance Sub(n=>a) given (a|Sub, n|Ix) def (-)(xs, ys) = for i. xs[i] - ys[i] instance Add((i:n) => (i..) => a) given (a|Add, n|Ix) -- Upper triangular tables def (+)(xs, ys) = for i. xs[i] + ys[i] zero = for _. zero instance Sub((i:n) => (i..) => a) given (a|Sub, n|Ix) -- Upper triangular tables def (-)(xs, ys) = for i. xs[i] - ys[i] instance Add((i:n) => (..i) => a) given (a|Add, n|Ix) -- Lower triangular tables def (+)(xs, ys) = for i. xs[i] + ys[i] zero = for _. zero instance Sub((i:n) => (..i) => a) given (a|Sub, n|Ix) -- Lower triangular tables def (-)(xs, ys) = for i. xs[i] - ys[i] instance Add((i:n) => (.. a) given (a|Add, n|Ix) def (+)(xs, ys) = for i. xs[i] + ys[i] zero = for _. zero instance Sub((i:n) => (.. a) given (a|Sub, n|Ix) def (-)(xs, ys) = for i. xs[i] - ys[i] instance Add((i:n) => (i<..) => a) given (a|Add, n|Ix) def (+)(xs, ys) = for i. xs[i] + ys[i] zero = for _. zero instance Sub((i:n) => (i<..) => a) given (a|Sub, n|Ix) def (-)(xs, ys) = for i. xs[i] - ys[i] instance Mul(n=>a) given (a|Mul, n|Ix) def (*)(xs, ys) = for i. xs[i] * ys[i] one = for _. one '## Basic polymorphic functions and types def fst(pair:(a, b)) -> a given (a, b) = pair.0 def snd(pair:(a, b)) -> b given (a, b) = pair.1 def swap(pair:(a, b)) -> (b, a) given (a, b) = (x, y) = pair (y, x) instance Add((a, b)) given (a|Add, b|Add) def (+)(x, y) = (x1, x2) = x (y1, y2) = y (x1 + y1, x2 + y2) zero = (zero, zero) instance Sub((a, b)) given (a|Sub, b|Sub) def(-)(x, y) = (x1, x2) = x (y1, y2) = y (x1 - y1, x2 - y2) instance Ix((a, b)) given (a|Ix, b|Ix) def size'() = size a * size b def ordinal(pair) = (i, j) = pair (ordinal i * size b) + ordinal j def unsafe_from_ordinal(o) = bs = size b (unsafe_from_ordinal(n=a, idiv(o, bs)), unsafe_from_ordinal(n=b, rem(o, bs))) instance Ix((a, b, c)) given (a|Ix, b|Ix, c|Ix) def size'() = size a * size b * size c def ordinal(tup) = (i, j, k) = tup ordinal((i,(j,k))) def unsafe_from_ordinal(o) = (i, (j, k)) = unsafe_from_ordinal o (i, j, k) '## Vector spaces interface VSpace(a|Add|Sub) (.*) : (Float, a) -> a def (*.)(v:a, s:Float) -> a given (a|VSpace) = s .* v def (/)( v:a, s:Float) -> a given (a|VSpace) = divide(1.0, s) .* v def neg( v:a) -> a given (a|VSpace) = (-1.0) .* v instance VSpace(Float) def (.*)(x, y) = x * y instance VSpace(n=>a) given (a|VSpace, n|Ix) def (.*)(s, xs) = for i. s .* xs[i] instance VSpace((a, b)) given (a|VSpace, b|VSpace) def (.*)(s, pair) = (a, b) = pair (s .* a, s .* b) instance VSpace((i:n) => (..i) => a) given (n|Ix, a|VSpace) def (.*)(s, xs) = for i. s .* xs[i] instance VSpace((i:n) => (i..) => a) given (n|Ix, a|VSpace) def (.*)(s, xs) = for i. s .* xs[i] instance VSpace((i:n) => (.. a) given (n|Ix, a|VSpace) def (.*)(s, xs) = for i. s .* xs[i] instance VSpace((i:n) => (i<..) => a) given (n|Ix, a|VSpace) def (.*)(s, xs) = for i. s .* xs[i] instance VSpace(()) def (.*)(_, _) = () '## Boolean type data Bool = False True def b_to_w8(x:Bool) -> Word8 = %dataConTag(x) def w8_to_b(x:Word8) -> Bool = %toEnum(Bool, x) def (&&)(x:Bool, y:Bool) -> Bool = x' = b_to_w8 x y' = b_to_w8 y w8_to_b $ %and(x', y') def (||)(x:Bool, y:Bool) -> Bool = x' = b_to_w8 x y' = b_to_w8 y w8_to_b $ %or(x', y') def not(x:Bool) -> Bool = x' = b_to_w8 x w8_to_b $ %not(x') '## More Boolean operations TODO: move these with the others? -- Can't use `%select` because it lowers to `ISelect`, which requires -- `a` to be a `BaseTy`. def select(p:Bool, x:a, y:a) -> a given (a) = case p of True -> x False -> y def b_to_i(x:Bool) -> Int = w8_to_i(b_to_w8 x) def b_to_n(x:Bool) -> Nat = w8_to_n(b_to_w8 x) def b_to_f(x:Bool) -> Float = i_to_f(b_to_i x) '## Ordering TODO: move this down to with `Ord`? data Ordering = LT EQ GT def o_to_w8(x:Ordering) -> Word8 = %dataConTag(x) '## Sum types A [sum type, or tagged union](https://en.wikipedia.org/wiki/Tagged_union) can hold values from a fixed set of types, distinguished by tags. For those familiar with the C language, they can be though of as a combination of an `enum` with a `union`. Here we define several basic kinds, and some operators on them. data Maybe(a:Type) = Nothing Just(a) def is_nothing(x:Maybe a) -> Bool given (a) = case x of Nothing -> True Just(_) -> False def is_just(x:Maybe a) -> Bool given (a) = not $ is_nothing x def maybe(d:b, f:(a)->b, x:Maybe a) -> b given (a, b) = case x of Nothing -> d Just(x') -> f x' data Either(a:Type, b:Type) = Left(a) Right(b) instance Ix(Either(a, b)) given (a|Ix, b|Ix) def size'() = size a + size b def ordinal(i) = case i of Left(ai) -> ordinal ai Right(bi) -> ordinal bi + size a def unsafe_from_ordinal(o) = as = nat_to_rep $ size a o' = nat_to_rep o -- TODO: Reshuffle the prelude to be able to use (<) here case w8_to_b $ %ilt(o', as) of True -> Left $ unsafe_from_ordinal(n=a, o) -- TODO: Reshuffle the prelude to be able to use `diff_nat` here False -> Right $ unsafe_from_ordinal(n=b, rep_to_nat (%isub(o', as))) '## Subtraction on Nats -- TODO: think more about the right API here def unsafe_i_to_n(x:Int) -> Nat = rep_to_nat $ internal_cast x def n_to_i(x:Nat) -> Int = internal_cast (nat_to_rep x) def i_to_n(x:Int) -> Maybe Nat = if w8_to_b $ %ilt(x, 0::Int) then Nothing else Just $ unsafe_i_to_n x '## Fencepost index sets struct Post(segment:Type) = val : Nat instance Ix(Post segment) given (segment|Ix) def size'() = size segment + 1 def ordinal(i) = i.val def unsafe_from_ordinal(i) = Post(i) def left_post(i:n) -> Post n given (n|Ix) = unsafe_from_ordinal(n=Post n, ordinal i) def right_post(i:n) -> Post n given (n|Ix) = unsafe_from_ordinal(n=Post n, ordinal i + 1) interface NonEmpty(n|Ix) first_ix : n def last_ix() ->> n given (n|NonEmpty) = unsafe_from_ordinal(unsafe_i_to_n(n_to_i(size n) - 1)) instance NonEmpty(Post n) given (n|Ix) first_ix = unsafe_from_ordinal(n=Post n, 0) instance NonEmpty(()) first_ix = unsafe_from_ordinal(0) '### Monoid A [monoid](https://en.wikipedia.org/wiki/Monoid) is a things that have an associative binary operator and an identity element. This is a very useful and general calls of things. It includes: - Addition and Multiplication of Numbers - Boolean Logic - Concatenation of Lists (including strings) Monoids support `fold` operations, and similar. interface Monoid(a|Data) mempty : a (<>) : (a, a) -> a instance Monoid(n=>a) given (a|Monoid, n|Ix) mempty = for i. mempty def (<>)(x, y) = for i. x[i] <> y[i] named-instance AndMonoid : Monoid(Bool) mempty = True def (<>)(x, y) = x && y named-instance OrMonoid : Monoid(Bool) mempty = False def (<>)(x, y) = x || y named-instance AddMonoid(a|Add) -> Monoid(a) mempty = zero def (<>)(x, y) = x + y named-instance MulMonoid(a|Mul) -> Monoid(a) mempty = one def (<>)(x, y) = x * y '## Effects def Ref(r:Heap, a|Data) -> Type = %Ref(r, a) def get(ref:Ref h s) -> {State h} s given (h, s) = %get(ref) def (:=)(ref:Ref h s, x:s) -> {State h} () given (h, s) = %put(ref, x) def ask(ref:Ref h r) -> {Read h} r given (h, r) = %ask(ref) data AccumMonoidData(h:Heap, w:Type) = UnsafeMkAccumMonoidData(b:Type, Monoid b) interface AccumMonoid(h:Heap, w) getAccumMonoidData : AccumMonoidData(h, w) instance AccumMonoid(h, n=>w) given (n|Ix, h, w) (am:AccumMonoid(h, w)) getAccumMonoidData = UnsafeMkAccumMonoidData(b, bm) = %applyMethod0(am) UnsafeMkAccumMonoidData(b, bm) def (+=)(ref:Ref h w, x:w) -> {Accum h} () given (h, w) (am:AccumMonoid(h, w)) = UnsafeMkAccumMonoidData(b, bm) = %applyMethod0(am) empty = %applyMethod0(bm) %mextend(ref, empty, \x:b y:b. %applyMethod1(bm, x, y), x) def (!)(ref: Ref h (n=>a), i:n) -> Ref h a given (n|Ix, a|Data, h) = %indexRef(ref, i) def fst_ref(ref: Ref h (a,b)) -> Ref h a given (b, a|Data, h) = ref.0 def snd_ref(ref: Ref h (a,b)) -> Ref h b given (a, b|Data, h) = ref.1 def run_reader( init:r, action:(given (h), Ref h r) -> {Read h|eff} a ) -> {|eff} a given (r|Data, a, eff) = def explicitAction(h':Heap, ref:Ref h' r) -> {Read h'|eff} a = action ref %runReader(init, explicitAction) def with_reader( init:r, action: (given (h), Ref(h,r)) -> {Read h|eff} a ) -> {|eff} a given (r|Data, a, eff) = run_reader(init, action) def MonoidLifter(b:Type, w:Type) -> Type = (given (h) (AccumMonoid(h, b))) ->> AccumMonoid(h, w) named-instance mk_accum_monoid (given (h, w), d:AccumMonoidData(h, w)) -> AccumMonoid(h, w) getAccumMonoidData = d def run_accum( bm:Monoid b, action: (given (h) (AccumMonoid(h, b)), Ref h w) -> {Accum h|eff} a ) -> {|eff} (a, w) given (a, b, w|Data, eff) (MonoidLifter(b,w)) = empty = %applyMethod0(bm) def explicitAction(h':Heap, ref:Ref h' w) -> {Accum h'|eff} a = accumMonoidData : AccumMonoidData h' b = UnsafeMkAccumMonoidData b bm accumBaseMonoid = mk_accum_monoid accumMonoidData %explicitApply(action, h', accumBaseMonoid, ref) %runWriter(empty, \x:b y:b. %applyMethod1(bm, x, y), explicitAction) def yield_accum( m:Monoid b, action: (given (h) (AccumMonoid(h, b)), Ref h w) -> {Accum h|eff} a ) -> {|eff} w given (a, b, w|Data, eff) (MonoidLifter b w) = snd $ run_accum(m, action) def run_state( init:s, action: (given (h), Ref h s) -> {State h |eff} a ) -> {|eff} (a,s) given (a, s|Data, eff) = def explicitAction(h':Heap, ref:Ref h' s) -> {State h'|eff} a = action ref %runState(init, explicitAction) def with_state( init:s, action: (given (h), Ref h s) -> {State h |eff} a ) -> {|eff} a given (a, s|Data, eff) = fst $ run_state(init, action) def yield_state( init:s, action: (given (h), Ref h s) -> {State h |eff} a ) -> {|eff} s given (a, s|Data, eff) = snd $ run_state(init, action) def unsafe_io( f:()->{IO|eff} a ) -> {|eff} a given (a, eff) = f' : (() -> {IO|eff} a) = \. f() %runIO(f') def unreachable() -> a given (a|Data) = unsafe_io \. %throwError(a) '## Type classes '### Eq and Ord '#### Eq Equatable. Things that we can tell if they are equal or not to other things. interface Eq(a|Data) (==) : (a, a) -> Bool def (/=)(x:a, y:a) -> Bool given (a|Eq) = not $ x == y '#### Ord Orderable / Comparable. Things that can be place in a total order. i.e. things that can be compared to other things to find if larger, smaller or equal in value. 'We take the standard false-hood and pretend that this applies to Floats, even though strictly speaking this not true as our floats follow [IEEE754](https://en.wikipedia.org/wiki/IEEE_754), and thus have `NaN < 1.0 == false` and `1.0 < NaN == false`. interface Ord(a|Eq) (>) : (a, a) -> Bool (<) : (a, a) -> Bool def (<=)(x:a, y:a) -> Bool given (a|Ord) = x=)(x:a, y:a) -> Bool given (a|Ord) = x>y || x==y instance Eq(Float64) def (==)(x, y) = w8_to_b $ %feq(x, y) instance Eq(Float32) def (==)(x, y) = w8_to_b $ %feq(x, y) instance Eq(Int64) def (==)(x, y) = w8_to_b $ %ieq(x, y) instance Eq(Int32) def (==)(x, y) = w8_to_b $ %ieq(x, y) instance Eq(Word8) def (==)(x, y) = w8_to_b $ %ieq(x, y) instance Eq(Word32) def (==)(x, y) = w8_to_b $ %ieq(x, y) instance Eq(Word64) def (==)(x, y) = w8_to_b $ %ieq(x, y) instance Eq(Bool) def (==)(x, y) = b_to_w8 x == b_to_w8 y instance Eq(()) def (==)(_, _) = True instance Eq(Either(a, b)) given (a|Eq, b|Eq) def (==)(x, y) = case x of Left(x) -> case y of Left( y) -> x == y Right(y) -> False Right(x) -> case y of Left( y) -> False Right(y) -> x == y instance Eq(Maybe a) given (a|Eq) def (==)(x, y) = case x of Just(x) -> case y of Just(y) -> x == y Nothing -> False Nothing -> case y of Just(y) -> False Nothing -> True instance Eq(RawPtr) def (==)(x, y) = raw_ptr_to_i64 x == raw_ptr_to_i64 y instance Ord(Float64) def (>)(x, y) = w8_to_b $ %fgt(x, y) def (<)(x, y) = w8_to_b $ %flt(x, y) instance Ord(Float32) def (>)(x, y) = w8_to_b $ %fgt(x, y) def (<)(x, y) = w8_to_b $ %flt(x, y) instance Ord(Int64) def (>)(x, y) = w8_to_b $ %igt(x, y) def (<)(x, y) = w8_to_b $ %ilt(x, y) instance Ord(Int32) def (>)(x, y) = w8_to_b $ %igt(x, y) def (<)(x, y) = w8_to_b $ %ilt(x, y) instance Ord(Word8) def (>)(x, y) = w8_to_b $ %igt(x, y) def (<)(x, y) = w8_to_b $ %ilt(x, y) instance Ord(Word32) def (>)(x, y) = w8_to_b $ %igt(x, y) def (<)(x, y) = w8_to_b $ %ilt(x, y) instance Ord(Word64) def (>)(x, y) = w8_to_b $ %igt(x, y) def (<)(x, y) = w8_to_b $ %ilt(x, y) instance Ord(()) def (>)(x, y) = False def (<)(x, y) = False instance Eq((a, b)) given (a|Eq, b|Eq) def (==)(p1, p2) = (x1, y1) = p1 (x2, y2) = p2 x1 == x2 && y1 == y2 instance Ord((a, b)) given (a|Ord, b|Ord) def (>)(p1, p2) = (x1, y1) = p1 (x2, y2) = p2 x1 > x2 || (x1 == x2 && y1 > y2) def (<)(p1, p2) = (x1, y1) = p1 (x2, y2) = p2 x1 < x2 || (x1 == x2 && y1 < y2) instance Eq(Ordering) def (==)(x, y) = o_to_w8 x == o_to_w8 y instance Eq(Nat) def (==)(x, y) = nat_to_rep x == nat_to_rep y instance Ord(Nat) def (>)(x, y) = nat_to_rep x > nat_to_rep y def (<)(x, y) = nat_to_rep x < nat_to_rep y -- TODO: we want Eq and Ord for all index sets, not just `Fin n` instance Eq(Fin n) given (n) def (==)(x, y) = ordinal x == ordinal y instance Ord(Fin n) given (n) def (>)(x, y) = ordinal x > ordinal y def (<)(x, y) = ordinal x < ordinal y instance Ix(Bool) def size'() = 2 def ordinal(b) = case b of False -> 0 True -> 1 def unsafe_from_ordinal(i) = i > 0 instance Ix(Maybe a) given (a|Ix) def size'() = size a + 1 def ordinal(i) = case i of Just(ai) -> ordinal ai Nothing -> size a def unsafe_from_ordinal(o) = case o == size a of False -> Just $ unsafe_from_ordinal o True -> Nothing instance NonEmpty(Bool) first_ix = unsafe_from_ordinal 0 instance NonEmpty((a,b)) given (a|NonEmpty, b|NonEmpty) first_ix = unsafe_from_ordinal 0 instance NonEmpty(Either(a,b)) given (a|NonEmpty, b|Ix) first_ix = unsafe_from_ordinal 0 -- The below instance is valid, but causes "multiple candidate dictionaries" -- errors if both Left and Right are NonEmpty. -- instance NonEmpty (a|b) given {a b} [Ix a, NonEmpty b] -- first_ix = unsafe_from_ordinal _ 0 instance NonEmpty(Maybe a) given (a|Ix) first_ix = unsafe_from_ordinal 0 def scan( init:a, body:(n, a)->(a,b) ) -> (a, n=>b) given (a|Data, b, n|Ix) = swap $ run_state(init) \s. for i. c = get s (c', y) = body(i, c) s := c' y def fold(init:a, body:(n,a)->a) -> a given (n|Ix, a|Data) = fst $ scan init \i x. (body(i, x), ()) def compare(x:a, y:a) -> Ordering given (a|Ord) = if x < y then LT else if x == y then EQ else GT instance Monoid(Ordering) mempty = EQ def (<>)(x, y) = case x of LT -> LT GT -> GT EQ -> y instance Eq(n=>a) given (n|Ix, a|Eq) def (==)(xs, ys) = yield_accum AndMonoid \ref. for i. ref += xs[i] == ys[i] instance Ord(n=>a) given (n|Ix, a|Ord) def (>)(xs, ys) = f: Ordering = fold EQ $ \i c. c <> compare(xs[i], ys[i]) f == GT def (<)(xs, ys) = f: Ordering = fold EQ $ \i c. c <> compare(xs[i], ys[i]) f == LT '## Subset class interface Subset(subset, superset) inject' : (subset) -> superset project' : (superset) -> Maybe subset unsafe_project' : (superset) -> subset -- wrappers with more helpful implicit arg names def inject(x:from) -> to given (to, from) (Subset(from, to)) = inject'(x) def project(x:from) -> Maybe to given (to, from) (Subset(to, from)) = project'(x) def unsafe_project(x:from) -> to given (to, from) (Subset(to, from)) = unsafe_project'(x) instance Subset(a, c) given (a, b, c) (Subset(a, b), Subset(b, c)) def inject'(x) = inject $ inject(to=b, x) def project'(x) = case project(to=b, x) of Nothing -> Nothing Just(y)-> project y def unsafe_project'(x) = unsafe_project $ unsafe_project(to=b, x) def unsafe_project_rangefrom(j:q) -> RangeFrom(q, i) given (q|Ix, i:q) = RangeFrom unsafe_nat_diff(ordinal j, ordinal i) instance Subset(RangeFrom(q, i), q) given (q|Ix, i:q) def inject'(j) = unsafe_from_ordinal $ j.val + ordinal i def project'(j) = j' = ordinal j i' = ordinal i if j' < i' then Nothing else Just $ RangeFrom $ unsafe_nat_diff(j', i') def unsafe_project'(j) = RangeFrom unsafe_nat_diff(ordinal j, ordinal i) instance Subset(RangeFromExc(q, i), q) given (q|Ix, i:q) def inject'(j) = unsafe_from_ordinal $ j.val + ordinal i + 1 def project'(j) = j' = ordinal j i' = ordinal i if j' <= i' then Nothing else Just $ RangeFromExc unsafe_nat_diff(j', i' + 1) def unsafe_project'(j) = RangeFromExc unsafe_nat_diff(ordinal j, ordinal i + 1) instance Subset(RangeTo(q, i), q) given (q|Ix, i:q) def inject'(j) = unsafe_from_ordinal j.val def project'(j) = j' = ordinal j i' = ordinal i if j' > i' then Nothing else Just $ RangeTo j' def unsafe_project'(j) = RangeTo (ordinal j) instance Subset(RangeToExc(q, i), q) given (q|Ix, i:q) def inject'(j) = unsafe_from_ordinal j.val def project'(j) = j' = ordinal j i' = ordinal i if j' >= i' then Nothing else Just $ RangeToExc j' def unsafe_project'(j) = RangeToExc (ordinal j) instance Subset(RangeToExc(q, i), RangeTo(q, i)) given (q|Ix, i:q) def inject'(j) = unsafe_from_ordinal j.val def project'(j) = j' = ordinal j i' = ordinal i if j' >= i' then Nothing else Just $ RangeToExc j' def unsafe_project'(j) = RangeToExc (ordinal j) '## Elementary/Special Functions This is more or less the standard [LibM fare](https://en.wikipedia.org/wiki/C_mathematical_functions). Roughly it lines up with some definitions of the set of [Elementary](https://en.wikipedia.org/wiki/Elementary_function) and/or [Special](https://en.wikipedia.org/wiki/Special_functions). In truth, nothing is elementary or special except that we humans have decided it is. Many, but not all of these functions are [Transcendental](https://en.wikipedia.org/wiki/Transcendental_function). interface Floating(a:Type) exp : (a) -> a exp2 : (a) -> a log : (a) -> a log2 : (a) -> a log10 : (a) -> a log1p : (a) -> a sin : (a) -> a cos : (a) -> a tan : (a) -> a sinh : (a) -> a cosh : (a) -> a tanh : (a) -> a floor : (a) -> a ceil : (a) -> a round : (a) -> a sqrt : (a) -> a pow : (a, a) -> a lgamma : (a) -> a erf : (a) -> a erfc : (a) -> a def lbeta(x:a, y:a) -> a given (a|Sub|Floating) = lgamma x + lgamma y - lgamma (x + y) -- Todo: better numerics for very large and small values. -- Using %exp here to avoid circular definition problems. def float32_sinh(x:Float32) -> Float32 = %fdiv(%fsub(%exp(x), %exp(%fsub(0.0,x))), 2.0) def float32_cosh(x:Float32) -> Float32 = %fdiv(%fadd(%exp(x), %exp(%fsub(0.0,x))), 2.0) def float32_tanh(x:Float32) -> Float32 = %fdiv(%fsub(%exp(x), %exp(%fsub(0.0,x))) ,%fadd(%exp(x), %exp(%fsub(0.0,x)))) -- Todo: unify this with float32 functions. def float64_sinh(x:Float64) -> Float64 = %fdiv(%fsub(%exp(x), %exp(%fsub(f_to_f64 0.0, x))), f_to_f64 2.0) def float64_cosh(x:Float64) -> Float64 = %fdiv(%fadd(%exp(x), %exp(%fsub(f_to_f64 0.0, x))), f_to_f64 2.0) def float64_tanh(x:Float64) -> Float64 = %fdiv(%fsub(%exp(x), %exp(%fsub(f_to_f64 0.0, x))) ,%fadd(%exp(x), %exp(%fsub(f_to_f64 0.0, x)))) instance Floating(Float64) def exp(x) = %exp(x) def exp2(x) = %exp2(x) def log(x) = %log(x) def log2(x) = %log2(x) def log10(x) = %log10(x) def log1p(x) = %log1p(x) def sin(x) = %sin(x) def cos(x) = %cos(x) def tan(x) = %tan( x) def sinh(x) = float64_sinh(x) def cosh(x) = float64_cosh(x) def tanh(x) = float64_tanh(x) def floor(x) = %floor(x) def ceil(x) = %ceil(x) def round(x) = %round(x) def sqrt(x) = %sqrt(x) def pow(x,y) = %fpow(x,y) def lgamma(x)= %lgamma(x) def erf(x) = %erf(x) def erfc(x) = %erfc(x) instance Floating(Float32) def exp(x) = %exp(x) def exp2(x) = %exp2(x) def log(x) = %log(x) def log2(x) = %log2(x) def log10(x) = %log10(x) def log1p(x) = %log1p(x) def sin(x) = %sin(x) def cos(x) = %cos(x) def tan(x) = %tan(x) def sinh(x) = float32_sinh(x) def cosh(x) = float32_cosh(x) def tanh(x) = float32_tanh(x) def floor(x) = %floor(x) def ceil(x) = %ceil(x) def round(x) = %round(x) def sqrt(x) = %sqrt(x) def pow(x,y) = %fpow(x, y) def lgamma(x)= %lgamma(x) def erf(x) = %erf(x) def erfc(x) = %erfc(x) '## Raw pointer operations struct Ptr(a:Type) = val : RawPtr def cast_ptr(ptr: Ptr a) -> Ptr b given (a, b) = Ptr(ptr.val) interface Storable(a|Data) store : (Ptr a, a) -> {IO} () load : (Ptr a) -> {IO} a storage_size : () -> Nat instance Storable(Word8) def store(ptr, x) = %ptrStore(ptr.val, x) def load(ptr) = %ptrLoad(ptr.val) def storage_size() = 1 instance Storable(Int32) def store(ptr, x) = %ptrStore(internal_cast(to=%Int32Ptr(), ptr.val), x) def load(ptr) = %ptrLoad(internal_cast(to=%Int32Ptr(), ptr.val)) def storage_size() = 4 instance Storable(Word32) def store(ptr, x) = %ptrStore(internal_cast(to=%Word32Ptr(), ptr), x) def load(ptr) = %ptrLoad(internal_cast(to=%Word32Ptr(), ptr)) def storage_size() = 4 instance Storable(Float32) def store(ptr, x) = %ptrStore(internal_cast(to=%Float32Ptr(), ptr.val), x) def load(ptr) = %ptrLoad(internal_cast(to=%Float32Ptr(), ptr.val)) def storage_size() = 4 instance Storable(Nat) def store(ptr, x) = store(Ptr(ptr.val), nat_to_rep x) def load(ptr) = rep_to_nat $ load(Ptr(ptr.val)) def storage_size() = storage_size(a=NatRep) instance Storable(Ptr a) given (a) def store(ptr, x) = %ptrStore(internal_cast(to=%PtrPtr(), ptr.val), x.val) def load(ptr) = Ptr(%ptrLoad(internal_cast(to=%PtrPtr(), ptr))) def storage_size() = 8 -- TODO: something more portable? -- TODO: Storable instances for other types def malloc(n:Nat) -> {IO} (Ptr a) given (a|Storable) = numBytes = storage_size(a=a) * n Ptr(%alloc(nat_to_rep numBytes)) def free(ptr:Ptr a) -> {IO} () given (a) = %free(ptr.val) def (+>>)(ptr:Ptr a, i:Nat) -> Ptr a given (a|Storable) = i' = nat_to_rep $ i * storage_size(a=a) Ptr(%ptrOffset(ptr.val, i')) -- TODO: consider making a Storable instance for tables instead def store_table(ptr: Ptr a, tab:n=>a) -> {IO} () given (a|Storable, n|Ix) = for_ i. store(ptr +>> ordinal i, tab[i]) def memcpy(dest:Ptr a, src:Ptr a, n:Nat) -> {IO} () given (a|Storable) = for_ i:(Fin n). i' = ordinal i store(dest +>> i', load $ src +>> i') -- TODO: generalize these brackets to allow other effects -- TODO: make sure that freeing happens even if there are run-time errors def with_alloc( n:Nat, action: (Ptr a) -> {IO} b ) -> {IO} b given (a|Storable, b) = ptr = malloc n result = action ptr free ptr result def with_table_ptr( xs:n=>a, action: (Ptr a) -> {IO} b ) -> {IO} b given (a|Storable, b, n|Ix) = ptr <- with_alloc(size n) for i. store(ptr +>> ordinal i, xs[i]) action ptr def table_from_ptr(ptr:Ptr a) -> {IO} n=>a given (a|Storable, n|Ix) = for i. load $ ptr +>> ordinal i '## Miscellaneous common utilities pi : Float = 3.141592653589793 def id(x:a) -> a given (a) = x def dup(x:a) -> (a, a) given (a) = (x, x) def map(f:(a)->{|eff} b, xs: n=>a) -> {|eff} (n=>b) given (a, b, n|Ix, eff) = for i. f xs[i] -- map, flipped so that the function goes last def each(xs: n=>a, f:(a)->{|eff} b) -> {|eff} (n=>b) given (a, b, n|Ix, eff) = for i. f xs[i] def zip(xs:n=>a, ys:n=>b) -> (n=>(a,b)) given (a, b, n|Ix) = for i. (xs[i], ys[i]) def unzip(xys:n=>(a,b)) -> (n=>a , n=>b) given (a, b, n|Ix)= (each xys fst, each xys snd) def fanout(x:a) -> n=>a given (n|Ix, a) = for i. x def sq(x:a) -> a given (a|Mul) = x * x def abs(x:a) -> a given (a|Sub|Ord) = select(x > zero, x, zero - x) def mod(x:a, y:a) -> a given (a|Add|Integral) = rem(y + rem(x, y), y) def (>>>)(f:(a) -> b, g:(b) -> c) -> (a) -> c given (a, b, c) = \x. g(f(x)) def (<<<)(f:(b) -> c, g:(a) -> b) -> (a) -> c given (a, b, c) = \x. f(g(x)) '## Table Operations instance Floating(n=>a) given (a|Floating, n|Ix) def exp(x) = each x exp def exp2(x) = each x exp2 def log(x) = each x log def log2(x) = each x log2 def log10(x) = each x log10 def log1p(x) = each x log1p def sin(x) = each x sin def cos(x) = each x cos def tan(x) = each x tan def sinh(x) = each x sinh def cosh(x) = each x cosh def tanh(x) = each x tanh def floor(x) = each x floor def ceil(x) = each x ceil def round(x) = each x round def sqrt(x) = each x sqrt def pow(x, y) = for i. pow(x[i], y[i]) def lgamma(x) = each x lgamma def erf(x) = each x erf def erfc(x) = each x erfc '### Reductions -- `combine` should be a commutative and associative, and form a -- commutative monoid with `identity` def reduce(identity:a, combine:(a,a)->a, xs:n=>a) -> a given (a|Data, n|Ix) = -- TODO: implement with the accumulator effect fold identity \i c. combine(c, xs[i]) -- TODO: call this `scan` and call the current `scan` something else def scan'(init:a, body:(n,a)->a) -> n=>a given (a|Data, n|Ix) = snd $ scan init \i x. dup(body(i, x)) def fsum(xs:n=>Float) -> Float given (n|Ix) = yield_accum(AddMonoid Float) \ref. for i. ref += xs[i] def sum(xs:n=>v) -> v given (n|Ix, v|Add) = reduce(zero, (+), xs) def prod(xs:n=>v) -> v given (n|Ix, v|Mul) = reduce(one , (*), xs) def mean(xs:n=>v) -> v given (n|Ix, v|VSpace) = sum xs / n_to_f (size n) def std(xs:n=>v) -> v given (n|Ix, v|Mul|Sub|VSpace|Floating) = sqrt $ mean (each xs sq) - sq (mean xs) def any(xs:n=>Bool) -> Bool given (n|Ix) = reduce(False, (||), xs) def all(xs:n=>Bool) -> Bool given (n|Ix) = reduce(True , (&&), xs) '### apply_n def apply_n(n:Nat, x:a, f:(a) -> a) -> a given (a|Data) = yield_state x \ref. for _:(Fin n). ref := f (get ref) '## cumulative sum TODO: Move this to be with reductions? It's a kind of `scan`. def cumsum(xs: n=>a) -> n=>a given (n|Ix, a|Add) = total <- with_state zero for i. newTotal = get total + xs[i] total := newTotal newTotal def cumsum_low(xs: n=>a) -> n=>a given (n|Ix, a|Add) = total <- with_state zero for i. oldTotal = get total total := oldTotal + xs[i] oldTotal '## Automatic differentiation '### AD operations -- TODO: add vector space constraints def linearize(f:(a)->b, x:a) -> (b, (a)->b) given (a, b) = %linearize(\x. f x, x) def jvp(f:(a)->b, x:a, t:a) -> b given (a, b) = (snd $ linearize(f, x))(t) def transpose_linear(f:(a)->b) -> (b)->a given (a, b) = \ct. %linearTranspose(\x. f x, ct) def vjp(f:(a)->b, x:a) -> (b, (b)->a) given (a, b) = (y, df) = linearize(f, x) (y, transpose_linear df) def grad(f:(a)->Float, x:a) -> a given (a) = (snd vjp(f, x))(1.0) def deriv(f:(Float)->Float, x:Float) -> Float = jvp(f, x, 1.0) def deriv_rev(f:(Float)->Float, x:Float) -> Float = (snd vjp(f, x))(1.0) -- XXX: Watch out when editing this data type! We depend on its structure -- deep inside the compiler (mostly in linearization and during rule registration). data SymbolicTangent(a) = ZeroTangent SomeTangent(a) def someTangent(x:SymbolicTangent a) -> a given (a|VSpace) = case x of ZeroTangent -> zero SomeTangent(x') -> x' '### Approximate Equality TODO: move this outside the AD section to be with equality? interface HasAllClose(a) allclose : (a, a, a, a) -> Bool interface HasDefaultTolerance(a) default_atol : a default_rtol : a def (~~)(x:a, y:a) -> Bool given (a|HasAllClose|HasDefaultTolerance) = allclose(default_atol, default_rtol, x, y) instance HasAllClose(Float32) def allclose(atol, rtol, x, y) = abs (x - y) <= (atol + rtol * abs y) instance HasAllClose(Float64) def allclose(atol, rtol, x, y) = abs (x - y) <= (atol + rtol * abs y) instance HasDefaultTolerance(Float32) default_atol = f_to_f32 0.00001 default_rtol = f_to_f32 0.0001 instance HasDefaultTolerance(Float64) default_atol = f_to_f64 0.00000001 default_rtol = f_to_f64 0.00001 instance HasAllClose((a, b)) given ( a|HasDefaultTolerance|HasAllClose , b|HasDefaultTolerance|HasAllClose) def allclose(atol, rtol, pair1, pair2) = (x1, x2) = pair1 (y1, y2) = pair2 (x1 ~~ y1) && (x2 ~~ y2) instance HasDefaultTolerance((a, b)) given (a|HasDefaultTolerance,b|HasDefaultTolerance) default_atol = (default_atol, default_atol) default_rtol = (default_rtol, default_rtol) instance HasAllClose(n=>t) given (n|Ix, t|HasAllClose) def allclose(atol, rtol, a, b) = all for i:n. allclose(atol[i], rtol[i], a[i], b[i]) instance HasDefaultTolerance(n=>t) given (n|Ix, t|HasDefaultTolerance) default_atol = for i. default_atol default_rtol = for i. default_rtol '### AD Checking tools def check_deriv_base(f:(Float)->Float, x:Float) -> Bool = eps = 0.01 ansFwd = deriv( f, x) ansRev = deriv_rev( f, x) ansNumeric = (f (x + eps) - f (x - eps)) / (2. * eps) ansFwd ~~ ansNumeric && ansRev ~~ ansNumeric def check_deriv(f:(Float)->Float, x:Float) -> Bool = check_deriv_base(f, x) && check_deriv_base(\x. deriv(f, x), x) '## Length-erased lists data List(a)= AsList(n:Nat, elements:(Fin n => a)) instance Eq(List a) given (a|Eq) def (==)(xsList, ysList) = AsList(nx,xs) = xsList AsList(ny,ys) = ysList if nx /= ny then False else all for i:(Fin nx). xs[i] == ys[unsafe_from_ordinal (ordinal i)] def unsafe_cast_table(xs:from=>a) -> to=>a given (to|Ix, from|Ix, a) = for i. xs[unsafe_from_ordinal (ordinal i)] def to_list(xs:n=>a) -> List a given (n|Ix, a) = n' = size n AsList(_, unsafe_cast_table(to=Fin n', xs)) instance Monoid(List a) given (a|Data) mempty = AsList(_, []) def (<>)(x, y) = AsList(nx,xs) = x AsList(ny,ys) = y nz = nx + ny to_list for i:(Fin nz). i' = ordinal i case i' < nx of True -> xs[unsafe_from_ordinal i'] False -> ys[unsafe_from_ordinal $ unsafe_nat_diff(i', nx)] named-instance ListMonoid (a|Data) -> Monoid(List a) mempty = mempty def (<>)(x, y) = x <> y -- TODO Eliminate or reimplement this operation, since it costs O(n) -- where n is the length of the list held in the reference. def append(list: Ref(h, List a), x:a) -> {Accum h} () given (a|Data, h) (AccumMonoid(h, List a)) = list += to_list [x] -- TODO: replace `slice` with this? def post_slice(xs:n=>a, start:Post n, end:Post n) -> List a given (n|Ix, a) = slice_size = unsafe_nat_diff(ordinal end, ordinal start) to_list for i:(Fin slice_size). xs[unsafe_from_ordinal(n=n, ordinal i + ordinal start)] '## Dynamic buffer struct DynBuffer(a) = size : Ptr Nat max_size : Ptr Nat buffer : Ptr (Ptr a) def with_dynamic_buffer(action: (DynBuffer a) -> {IO} b) -> {IO} b given (a|Storable, b) = initMaxSize = 256 sizePtr <- with_alloc 1 store(sizePtr, 0) maxSizePtr <- with_alloc 1 store(maxSizePtr, initMaxSize) bufferPtr <- with_alloc 1 store(bufferPtr, malloc initMaxSize) result = action $ DynBuffer(sizePtr, maxSizePtr, bufferPtr) free $ load bufferPtr result def maybe_increase_buffer_size(db: DynBuffer a, sizeDelta:Nat) -> {IO} () given (a|Storable) = size = load db.size max_size = load db.max_size bufPtr = load db.buffer newSize = sizeDelta + size if newSize > max_size then -- TODO: maybe this should use integer arithmetic? newMaxSize = f_to_n $ 2.0 `pow` (ceil $ log2 $ n_to_f newSize) newBufPtr = malloc newMaxSize memcpy(newBufPtr, bufPtr, size) free bufPtr store(db.max_size, newMaxSize) store(db.buffer , newBufPtr) def add_at_nat_ptr(ptr: Ptr Nat, n:Nat) -> {IO} () = store(ptr, load ptr + n) def extend_dynamic_buffer(buf: DynBuffer a, new:List a) -> {IO} () given (a|Storable) = AsList(n, xs) = new maybe_increase_buffer_size(buf, n) bufPtr = load buf.buffer size = load buf.size store_table(bufPtr +>> size, xs) add_at_nat_ptr(buf.size, n) def load_dynamic_buffer(buf: DynBuffer a) -> {IO} (List a) given (a|Storable) = bufPtr = load buf.buffer size = load buf.size AsList(size, table_from_ptr bufPtr) def push_dynamic_buffer(buf: DynBuffer a, x:a) -> {IO} () given (a|Storable) = extend_dynamic_buffer(buf, to_list [x]) '## Strings and Characters String : Type = List Char def string_from_char_ptr(n:Word32, ptr:Ptr Char) -> {IO} String = AsList(rep_to_nat n, table_from_ptr ptr) -- TODO. This is ASCII code point. It really should be Int32 for Unicode codepoint def codepoint(c:Char) -> Int = w8_to_i c struct CString = ptr : RawPtr -- TODO: check the string contains no nulls def with_c_string( s:String, action: (CString) -> {IO} a ) -> {IO} a given (a) = AsList(n, s') = s <> "\NUL" with_table_ptr s' \ptr. action CString(ptr.val) '### Show interface For things that can be shown. `show` gives a string representation of its input. No particular promises are made to exactly what that representation will contain. In particular it is **not** promised to be parseable. Nor does it promise a particular level of precision for numeric values. interface Show(a) show : (a) -> String instance Show(String) def show(x) = x foreign "showInt32" showInt32 : (Int32) -> {IO} (Word32, RawPtr) instance Show(Int32) def show(x) = unsafe_io \. (n, ptr) = showInt32 x string_from_char_ptr(n, Ptr ptr) foreign "showInt64" showInt64 : (Int64) -> {IO} (Word32, RawPtr) instance Show(Int64) def show(x) = unsafe_io \. (n, ptr) = showInt64 x string_from_char_ptr(n, Ptr ptr) instance Show(Nat) def show(x) = show $ n_to_i64 x foreign "showFloat32" showFloat32 : (Float32) -> {IO} (Word32, RawPtr) instance Show(Float32) def show(x) = unsafe_io \. (n, ptr) = showFloat32 x string_from_char_ptr(n, Ptr ptr) foreign "showFloat64" showFloat64 : (Float64) -> {IO} (Word32, RawPtr) instance Show(Float64) def show(x) = unsafe_io \. (n, ptr) = showFloat64 x string_from_char_ptr(n, Ptr ptr) instance Show(()) def show(_) = "()" instance Show((a, b)) given (a|Show, b|Show) def show(x) = (a, b) = x "(" <> show a <> ", " <> show b <> ")" instance Show((a, b, c)) given (a|Show, b|Show, c|Show) def show(x) = (a, b, c) = x "(" <> show a <> ", " <> show b <> ", " <> show c <> ")" instance Show((a, b, c, d)) given (a|Show, b|Show, c|Show, d|Show) def show(x) = (a, b, c, d) = x "(" <> show a <> ", " <> show b <> ", " <> show c <> ", " <> show d <> ")" '### Parse interface For types that can be parsed from a `String`. interface Parse(a) parseString : (String) -> Maybe a foreign "strtof" strtofFFI : (RawPtr, RawPtr) -> {IO} Float instance Parse(Float) def parseString(str) = unsafe_io \. AsList(str_len, _) = str with_c_string str \cStr. with_alloc 1 \end_ptr:(Ptr (Ptr Char)). result = strtofFFI(cStr.ptr, end_ptr.val) str_end_ptr = load end_ptr consumed = raw_ptr_to_i64 str_end_ptr.val - raw_ptr_to_i64 cStr.ptr if consumed == (n_to_i64 str_len) then Just result else Nothing '## Floating-point helper functions TODO: Move these to be with Elementary/Special functions. Or move those to be here. def sign(x:Float) -> Float = case x > 0.0 of True -> 1.0 False -> case x < 0.0 of True -> -1.0 False -> x def copysign(a:Float, b:Float) -> Float = case b > 0.0 of True -> a False -> case b < 0.0 of True -> (-a) False -> 0.0 -- Todo: use IEEE floating-point builtins. infinity = 1.0 / 0.0 nan = 0.0 / 0.0 -- Todo: use IEEE floating-point builtins. def isinf(x:Float) -> Bool = (x == infinity) || (x == -infinity) def isnan(x:Float) -> Bool = not (x >= x && x <= x) -- Todo: use IEEE-754R 5.11: Floating Point Comparison Relation cmpUnordered. def either_is_nan(x:Float, y:Float) -> Bool = (isnan x) || (isnan y) '## File system operations FilePath : Type = String def is_null_raw_ptr(ptr:RawPtr) -> Bool = raw_ptr_to_i64 ptr == 0 def from_nullable_raw_ptr(ptr:RawPtr) -> Maybe (Ptr a) given (a) = if is_null_raw_ptr ptr then Nothing else Just $ Ptr ptr def c_string_ptr(s:CString) -> Maybe (Ptr Char) = from_nullable_raw_ptr s.ptr data StreamMode = ReadMode WriteMode struct Stream(mode:StreamMode) = ptr : RawPtr '### Stream IO foreign "fopen" fopenFFI : (RawPtr, RawPtr) -> {IO} RawPtr foreign "fclose" fcloseFFI : (RawPtr) -> {IO} Int64 foreign "fwrite" fwriteFFI : (RawPtr, Int64, Int64, RawPtr) -> {IO} Int64 foreign "fread" freadFFI : (RawPtr, Int64, Int64, RawPtr) -> {IO} Int64 foreign "fflush" fflushFFI : (RawPtr) -> {IO} Int64 def fopen(path:String, mode:StreamMode) -> {IO} (Stream mode) = modeStr = case mode of ReadMode -> "r" WriteMode -> "w" with_c_string path \cPath. with_c_string modeStr \cMode. Stream $ fopenFFI(cPath.ptr, cMode.ptr) def fclose(stream:Stream mode) -> {IO} () given (mode) = fcloseFFI stream.ptr () def fwrite(stream:Stream WriteMode, s:String) -> {IO} () = AsList(n, s') = s with_table_ptr s' \ptr. fwriteFFI(ptr.val, i_to_i64 1, n_to_i64 n, stream.ptr) fflushFFI stream.ptr () '### Iteration TODO: move this out of the file-system section def while(body: () -> {|eff} Bool) -> {|eff} () given (eff) = body' : () -> {|eff} Word8 = \. b_to_w8 $ body() %while(body') data IterResult(a|Data) = Continue Done(a) -- TODO: can we improve effect inference so we don't need this? def lift_state(ref: Ref(h, c), f:(a) -> {|eff} b, x:a) -> {State h|eff} b given (a, b, c, h, eff) = f x -- A little iteration combinator def iter(body: (Nat) -> {|eff} IterResult a) -> {|eff} a given (a|Data, eff) = result = yield_state Nothing \resultRef. i <- with_state 0 while \. continue = is_nothing $ get resultRef if continue then case lift_state(resultRef, (\x. lift_state(i, body, x)), get i) of Continue -> i := get i + 1 Done(result) -> resultRef := Just result continue case result of Just(ans) -> ans Nothing -> unreachable() def bounded_iter( maxIters:Nat, fallback:a, body:(Nat) -> {|eff} IterResult a ) -> {|eff} a given (a|Data, eff) = iter \i. if i >= maxIters then Done fallback else body i '### Environment Variables def from_c_string(s:CString) -> {IO} (Maybe String) = case c_string_ptr s of Nothing -> Nothing Just(ptr) -> Just do buf <- with_dynamic_buffer i <- iter c = load $ ptr +>> i if c == '\NUL' then Done $ load_dynamic_buffer buf else push_dynamic_buffer(buf, c) Continue foreign "getenv" getenvFFI : (RawPtr) -> {IO} RawPtr def get_env(name:String) -> {IO} Maybe String = cStr <- with_c_string name getenvFFI cStr.ptr | CString | from_c_string def check_env(name:String) -> {IO} Bool = is_just $ get_env name '### More Stream IO def fread(stream:Stream ReadMode) -> {IO} String = -- TODO: allow reading longer files! n = 4096 ptr:(Ptr Char) <- with_alloc n buf <- with_dynamic_buffer iter \_. numRead = i_to_w32 $ i64_to_i $ freadFFI(ptr.val, 1, n_to_i64 n, stream.ptr) extend_dynamic_buffer(buf, string_from_char_ptr(numRead, ptr)) if numRead == n_to_w32 n then Continue else Done () load_dynamic_buffer buf '### Print def get_output_stream() -> {IO} Stream WriteMode = Stream $ %outputStream() @noinline def print(s:String) -> {IO} () = stream = get_output_stream() fwrite(stream, s) fwrite(stream, "\n") '### Shelling Out foreign "popen" popenFFI : (RawPtr, RawPtr) -> {IO} RawPtr foreign "remove" removeFFI : (RawPtr) -> {IO} Int64 foreign "mkstemp" mkstempFFI : (RawPtr) -> {IO} Int32 foreign "close" closeFFI : (Int32) -> {IO} Int32 def shell_out(command:String) -> {IO} String = modeStr = "r" with_c_string command \command'. with_c_string modeStr \modeStr'. pipe = Stream $ popenFFI(command'.ptr, modeStr'.ptr) fread pipe '## Partial functions A partial function in this context is a function that can error. i.e. a function that is not actually defined for all of its supposed domain. Not to be confused with a partially applied function '### Error throwing @noinline def error(s:String) -> a given (a|Data) = unsafe_io \. print s %throwError(a) def todo() ->> a given (a|Data) = error "TODO: implement it!" '### File Operations def delete_file(f:FilePath) -> {IO} () = s <- with_c_string(f) removeFFI s.ptr () def with_file( f:FilePath, mode:StreamMode, action:(Stream mode) -> {IO} a ) -> {IO} a given (a|Data) = stream = fopen(f, mode) if is_null_raw_ptr stream.ptr then error $ "Unable to open file: " <> f else result = action stream fclose stream result def write_file(f:FilePath, s:String) -> {IO} () = with_file(f, WriteMode) \stream. fwrite(stream, s) def read_file(f:FilePath) -> {IO} String = with_file(f, ReadMode) \stream. fread stream def has_file(f:FilePath) -> {IO} Bool = stream = fopen(f, ReadMode) result = not (is_null_raw_ptr stream.ptr) if result then fclose stream result '### Temporary Files def new_temp_file() -> {IO} FilePath = s <- with_c_string "/tmp/dex-XXXXXX" fd = mkstempFFI s.ptr closeFFI fd string_from_char_ptr(15, (Ptr s.ptr)) def with_temp_file(action: (FilePath) -> {IO} a) -> {IO} a given (a) = tmpFile = new_temp_file() result = action tmpFile delete_file tmpFile result def with_temp_files(action: (n=>FilePath) -> {IO} a) -> {IO} a given (n|Ix, a) = tmpFiles = for i. new_temp_file() result = action tmpFiles for i. delete_file tmpFiles[i] result '### Table operations @noinline def from_ordinal_error(i:Nat, upper:Nat) -> String = "Ordinal index out of range:" <> show i <> " >= " <> show upper def from_ordinal(i:Nat) -> n given (n|Ix) = case i < size n of True -> unsafe_from_ordinal i False -> error $ from_ordinal_error(i, size n) -- TODO: should this be called `from_ordinal`? def to_ix(i:Nat) -> Maybe n given (n|Ix) = case i < size n of True -> Just $ unsafe_from_ordinal i False -> Nothing -- TODO: could make an `unsafeCastIndex` and this could avoid the runtime copy -- TODO: safe (runtime-checked) and unsafe versions def cast_table(xs:to=>a) -> from=>a given (from|Ix, to|Ix, a|Data) = case size from == size to of True -> unsafe_cast_table xs False -> error $ "Table size mismatch in cast: " <> show (size from) <> " vs " <> show (size to) def asidx(i:Nat) -> n given (n|Ix) = from_ordinal i def (@)(i:Nat, n|Ix) -> n = from_ordinal i def slice(xs:n=>a, start:Nat, m|Ix) -> m=>a given (n|Ix, a) = for i. xs[from_ordinal (ordinal i + start)] def head(xs:n=>a) -> a given (n|Ix, a) = xs[0@_] def tail(xs:n=>a, start:Nat) -> List a given (n|Ix, a) = numElts = size n -| start to_list $ slice(xs, start, Fin numElts) '## Pseudorandom number generator utilities Dex does not use a stateful random number generator. Rather it uses what is known as a split-able random number generator, which is based on a hash function. Dex's PRNG system is modelled directly after [JAX's](https://github.com/google/jax/blob/master/design_notes/prng.md), which is based on a well established but shockingly underused idea from the functional programming community: the splittable PRNG. It's a good idea for many reasons, but it's especially helpful in a parallel setting. If you want to read more, [Splittable pseudorandom number generators using cryptographic hashing](http://publications.lib.chalmers.se/records/fulltext/183348/local_183348.pdf) describes the splitting model itself and [D.E. Shaw Research's counter-based PRNG](http://www.thesalmons.org/john/random123/papers/random123sc11.pdf) proposes the particular hash function we use. '### Key functions -- TODO: newtype Key = Word64 @noinline def threefry_2x32(k:Word64, count:Word64) -> Word64 = -- Based on jax's threefry_2x32 by Matt Johnson and Peter Hawkins rotations1 = [13, 15, 26, 6] rotations2 = [17, 29, 16, 24] k0 = low_word k k1 = high_word k -- TODO: add a fromHex k2 = k0 .^. k1 .^. (n_to_w32 466688986) -- 0x1BD11BDA x = low_word count y = high_word count x = x + k0 y = y + k1 rotations = [rotations1, rotations2] ks = [k1, k2, k0] (x, y) = yield_state (x, y) \ref. for i:(Fin 5). for j. (x, y) = get ref rotationIndex = unsafe_from_ordinal (ordinal i `mod` 2) rot = rotations[rotationIndex, j] x = x + y y = (y .<<. rot) .|. (y .>>. (32 - rot)) y = x .^. y ref := (x, y) (x, y) = get ref x = x + ks[unsafe_from_ordinal (ordinal i `mod` 3)] y = y + ks[unsafe_from_ordinal (((ordinal i)+1) `mod` 3)] + n_to_w32 ((ordinal i)+1) ref := (x, y) (w32_to_w64 x .<<. 32) .|. (w32_to_w64 y) def hash(x:Key, y:Nat) -> Key = y64 = n_to_w64 y threefry_2x32(x, y64) def new_key(x:Nat) -> Key = hash(0, x) def many(f:(Key)->a, k:Key, i:n) -> a given (a, n|Ix) = f hash(k, ordinal i) def ixkey(k:Key, i:n) -> Key given (n|Ix) = hash(k, ordinal i) def split_key(k:Key) -> Fin n => Key given (n:Nat) = for i. ixkey(k, i) '### Sample Generators These functions generate samples taken from, different distributions. Such as `rand_mat` with samples from the distribution of floating point matrices where each element is taken from a i.i.d. uniform distribution. Note that additional standard distributions are provided by the `stats` library. def rand(k:Key) -> Float = exponent_bits = 1065353216 -- 1065353216 = 127 << 23 mantissa_bits = (high_word k .&. 8388607) -- 8388607 == (1 << 23) - 1 bits = exponent_bits .|. mantissa_bits %bitcast(Float, bits) - 1.0 def rand_vec(n:Nat, f: (Key) -> a, k: Key) -> Fin n => a given (a) = for i:(Fin n). f ixkey(k, i) def rand_mat(n:Nat, m:Nat, f: (Key) -> a, k: Key) -> Fin n => Fin m => a given (a) = for i j. f ixkey(k, (i, j)) def randn(k:Key) -> Float = [k1, k2] = split_key k -- rand is uniform between 0 and 1, but implemented such that it rounds to 0 -- (in float32) once every few million draws, but never rounds to 1. u1 = 1.0 - (rand k1) u2 = rand k2 sqrt ((-2.0) * log u1) * cos (2.0 * pi * u2) -- TODO: Make this better... def rand_int(k:Key) -> Nat = w64_to_n k `mod` 2147483647 def randn_vec(k:Key) -> n=>Float given (n|Ix) = for i. randn (ixkey(k, i)) def rand_idx(k:Key) -> n given (n|Ix) = rand k * n_to_f (size n) | floor | f_to_n | unsafe_from_ordinal '## Inner product typeclass interface InnerProd(v|VSpace) inner_prod : (v, v) -> Float instance InnerProd(Float) def inner_prod(x, y) = x * y instance InnerProd(n=>a) given (a|InnerProd, n|Ix) def inner_prod(x, y) =sum for i. inner_prod(x[i], y[i]) '## Arbitrary Type class for generating example values interface Arbitrary(a) arb : (Key) -> a instance Arbitrary(Bool) def arb(key) = key .&. 1 == 0 instance Arbitrary(Float32) def arb(key) = randn key instance Arbitrary(Int32) def arb(key) = f_to_i $ randn key * 5.0 instance Arbitrary(Nat) def arb(key) = f_to_n $ randn key * 5.0 instance Arbitrary(n=>a) given (n|Ix, a|Arbitrary) def arb(key) = for i. arb $ ixkey(key, i) instance Arbitrary((i:n)=>(.. a) given (n|Ix, a|Arbitrary) def arb(key) = for i. arb $ ixkey(key, i) instance Arbitrary((i:n)=>(..i) => a) given (n|Ix, a|Arbitrary) def arb(key) = for i. arb $ ixkey(key, i) instance Arbitrary((i:n)=>(i..) => a) given (n|Ix, a|Arbitrary) def arb(key) = for i. arb $ ixkey(key, i) instance Arbitrary((i:n)=>(i<..) => a) given (n|Ix, a|Arbitrary) def arb(key) = for i. arb $ ixkey(key, i) instance Arbitrary((a, b)) given (a|Arbitrary, b|Arbitrary) def arb(key) = [k1, k2] = split_key key (arb k1, arb k2) instance Arbitrary(Fin n) given (n) def arb(key) = rand_idx key '## Ord on Arrays '### Searching 'returns the highest index `i` such that `xs.i <= x` def search_sorted(xs:n=>a, x:a) -> Maybe n given (n|Ix, a|Ord) = if size n == 0 then Nothing else if x < xs[from_ordinal 0] then Nothing else low <- with_state(0) high <- with_state(size n) _ <- iter numLeft = n_to_i (get high) - n_to_i (get low) if numLeft == 1 then Done $ Just $ from_ordinal $ get low else centerIx = get low + unsafe_i_to_n (numLeft `idiv` 2) if x < xs[from_ordinal centerIx] then high := centerIx else low := centerIx Continue '### min / max etc def min_by(f:(a)->o, x:a, y:a) -> a given (o|Ord, a) = select(f x < f y, x, y) def max_by(f:(a)->o, x:a, y:a) -> a given (o|Ord, a) = select(f x > f y, x, y) def min(x1: o, x2: o) -> o given (o|Ord) = min_by(id, x1, x2) def max(x1: o, x2: o) -> o given (o|Ord) = max_by(id, x1, x2) def minimum_by(f:(a)->o, xs:n=>a) -> a given (a|Data, o|Ord, n|Ix) = reduce(xs[0@_], \x y. min_by(f, x, y), xs) def maximum_by(f:(a)->o, xs:n=>a) -> a given (a|Data, o|Ord, n|Ix) = reduce(xs[0@_], \x y. max_by(f, x, y), xs) def minimum(xs:n=>o) -> o given (n|Ix, o|Ord) = minimum_by(id, xs) def maximum(xs:n=>o) -> o given (n|Ix, o|Ord) = maximum_by(id, xs) '### argmin/argmax -- TODO: put in same section as `searchsorted` def argscan(comp:(o,o)->Bool, xs:n=>o) -> n given (o|Ord, n|Ix) = zeroth = (0@_, xs[0@_]) compare = \p1 p2. (idx1, x1) = p1 (idx2, x2) = p2 select(comp(x1, x2), (idx1, x1), (idx2, x2)) zipped = for i. (i, xs[i]) fst $ reduce(zeroth, compare, zipped) def argmin(xs:n=>o) -> n given (n|Ix, o|Ord) = argscan((<), xs) def argmax(xs:n=>o) -> n given (n|Ix, o|Ord) = argscan((>), xs) def lexical_order( compareElements:(n,n)->Bool, compareLengths: (Nat,Nat)->Bool, xList:List n, yList:List n ) -> Bool given (n|Ord) = -- Orders Lists according to the order of their elements, -- in the same way a dictionary does. -- For example, this lets us sort Strings. -- -- More precisely, it returns True iff compareElements xs.i ys.i is true -- at the first location they differ. -- -- This function operates serially and short-circuits -- at the first difference. One could also write this -- function as a parallel reduction, but it would be -- wasteful in the case where there is an early difference, -- because we can't short circuit. AsList(nx, xs) = xList AsList(ny, ys) = yList iter \i. case i == min(nx, ny) of True -> Done $ compareLengths(nx, ny) False -> xi = xs[unsafe_from_ordinal i] yi = ys[unsafe_from_ordinal i] case compareElements(xi, yi) of True -> Done True False -> case xi == yi of True -> Continue False -> Done False instance Ord(List n) given (n|Ord) def (>)(xs, ys) = lexical_order((>), (>), xs, ys) def (<)(xs, ys) = lexical_order((<), (<), xs, ys) '### clip def clip(bounds:(a,a), x:a) -> a given (a|Ord) = (low,high) = bounds min(high, max(low, x)) '## Trigonometric functions TODO: these should be with the other Elementary/Special Functions ### atan/atan2 def atan_inner(x:Float) -> Float = -- From "Computing accurate Horner form approximations to -- special functions in finite precision arithmetic" -- https://arxiv.org/abs/1508.03211 -- Only accurate in the range [-1, 1] s = x * x r = 0.0027856871 r = r * s - 0.0158660002 r = r * s + 0.042472221 r = r * s - 0.0749753043 r = r * s + 0.106448799 r = r * s - 0.142070308 r = r * s + 0.199934542 r = r * s - 0.333331466 r = r * s r * x + x def min_and_max(x:a, y:a) -> (a, a) given (a|Ord) = select(x < y, (x, y), (y, x)) -- get both with one comparison. def atan2(y:Float, x:Float) -> Float = -- Based off of the Tensorflow implementation at -- github.com/tensorflow/mlir-hlo/blob/master/lib/ -- Dialect/mhlo/transforms/legalize_trigonometric_to_approximation.cc#L147 -- With a fix to the nan propagation. abs_x = abs x abs_y = abs y (min_abs_x_y, max_abs_x_y) = min_and_max(abs_x, abs_y) a = atan_inner (min_abs_x_y / max_abs_x_y) a = select(abs_x <= abs_y, (pi / 2.0) -a, a) a = select(x < 0.0, pi - a, a) t = select(x < 0.0, pi, 0.0) a = select(y == 0.0, t, a) t = select(x < 0.0, 3.0 * pi / 4.0, pi / 4.0) a = select(isinf x && isinf y, t, a) -- Handle infinite inputs. a = copysign(a, y) select(either_is_nan(x, y), nan, a) -- Propagate NaNs. def atan(x:Float) -> Float = atan2(x, 1.0) '## Miscellaneous utilities TODO: all of these should be in some other section def reflect(i:n) -> n given (n|Ix) = unsafe_from_ordinal $ unsafe_nat_diff(size n, ordinal i + 1) def reverse(x:n=>a) -> n=>a given (n|Ix, a) = for i. x[reflect i] def wrap_periodic(n|Ix, i:Nat) -> n = unsafe_from_ordinal(n=n, i `mod` size n) def pad_to(m|Ix, x:a, xs:n=>a) -> m=>a given (n|Ix, a) = n' = size n for i. i' = ordinal i case i' < n' of True -> xs[i'@_] False -> x def idiv_ceil(x:Nat, y:Nat) -> Nat = x `idiv` y + b_to_n (x `rem` y /= 0) def intdiv2(x:Nat) -> Nat = rep_to_nat $ %shr(nat_to_rep x, 1 :: NatRep) def intpow2(power:Nat) -> Nat = rep_to_nat $ %shl(1 :: NatRep, nat_to_rep power) def is_odd(x:Nat) -> Bool = rem(x, 2) == 1 def is_even(x:Nat) -> Bool = rem(x, 2) == 0 def is_power_of_2(x:Nat) -> Bool = -- A fast trick based on bitwise AND. -- This works on integer types larger than 8 bits. -- Note: The bitwise and operator (.&.) -- is only defined for Byte, which is why -- we use %and here. TODO: Make (.&.) polymorphic. x' = nat_to_rep x if x' == 0 then False else 0 == %and(x', (%isub(x', 1::NatRep))) -- This computes the integer part of the binary logarithm of the input. -- TODO: natlog2 0 should do something other than underflow the answer. -- TODO: Use LLVM ctlz intrinsic instead. It needs a slightly new -- code path in ImpToLLVM, because it's the first LLVM intrinsic -- we have with a fixed-point argument. -- https://llvm.org/docs/LangRef.html#llvm-ctlz-intrinsic def natlog2(x:Nat) -> Nat = tmp = yield_state 0 \ans. cmp <- run_state 1 while \. if x >= (get cmp) then ans := (get ans) + 1 cmp := rep_to_nat $ %shl(nat_to_rep $ get cmp, 1 :: NatRep) True else False unsafe_nat_diff(tmp, 1) -- TODO: something less horrible def general_integer_power( times:(a,a)->a, one:a, base:a, power:Nat ) -> a given (a|Data) = iters = if power == 0 then 0 else 1 + natlog2 power -- Implements exponentiation by squaring. -- This could be nicer if there were a way to explicitly -- specify which typelcass instance to use for Mul. yield_state one \ans. pow <- with_state power z <- with_state base for _:(Fin iters). if is_odd (get pow) then ans := times(get ans, get z) z := times(get z, get z) pow := intdiv2 (get pow) def intpow(base:a, power:Nat) -> a given (a|Mul) = general_integer_power((*), one, base, power) def from_just(x:Maybe a) -> a given (a) = case x of Just(x') -> x' def any_sat(f:(a)->Bool, xs:n=>a) -> Bool given (a, n|Ix) = any(each xs f) def seq_maybes(xs: n=>Maybe a) -> Maybe (n => a) given (n|Ix, a) = -- is it possible to implement this safely? (i.e. without using partial -- functions) case any_sat(is_nothing, xs) of True -> Nothing False -> Just $ each xs from_just def linear_search(xs:n=>a, query:a) -> Maybe n given (n|Ix, a|Eq) = yield_state Nothing \ref. for i. case xs[i] == query of True -> ref := Just i False -> () def list_length(l:List a) -> Nat given (a) = AsList(n, _) = l n -- This is for efficiency (rather than using `<>` repeatedly) -- TODO: we want this for any monoid but this implementation won't work. def concat(lists:n=>(List a)) -> List a given (a, n|Ix) = totalSize = sum for i. list_length lists[i] to_list $ with_state 0 \listIdx. eltIdx <- with_state 0 for i:(Fin totalSize). while \. continue = get eltIdx >= list_length (lists[(get listIdx)@_]) if continue then eltIdx := 0 listIdx := get listIdx + 1 else () continue AsList(_, xs) = lists[(get listIdx)@_] eltIdxVal = get eltIdx eltIdx := eltIdxVal + 1 xs[eltIdxVal@_] def cat_maybes(xs:n=>Maybe a) -> List a given (n|Ix, a|Data) = (num_res, res_inds) = yield_state (0::Nat, for i:n. Nothing) \ref. for i. case xs[i] of Just(_) -> ix = get ref.0 ref.1 ! (unsafe_from_ordinal ix) := Just i ref.0 := ix + 1 Nothing -> () to_list $ for i:(Fin num_res). case res_inds[unsafe_from_ordinal $ ordinal i] of Just(j) -> case xs[j] of Just(x) -> x Nothing -> todo -- Impossible Nothing -> todo -- Impossible def filter(xs:n=>a, condition:(a)->Bool) -> List a given (a|Data, n|Ix) = cat_maybes $ for i. if condition xs[i] then Just xs[i] else Nothing def arg_filter(xs:n=>a, condition:(a)->Bool) -> List n given (a|Data, n|Ix) = cat_maybes $ for i. if condition xs[i] then Just i else Nothing -- TODO: use `ix_offset : [Ix n] -> n -> Int -> Maybe n` instead def prev_ix(i:n) -> Maybe n given (n|Ix) = case i_to_n (n_to_i (ordinal i) - 1) of Nothing -> Nothing Just(i_prev) -> unsafe_from_ordinal(i_prev) | Just def lines(source:String) -> List String = AsList(_, s) = source AsList(num_lines, newline_ixs) = cat_maybes for i_char. if s[i_char] == '\n' then Just(i_char) else Nothing to_list for i_line:(Fin num_lines). start = case prev_ix i_line of Nothing -> first_ix Just(i) -> right_post newline_ixs[i] end = left_post newline_ixs[i_line] post_slice(s, start, end) '## Probability -- cdf should include 0.0 but not 1.0 def categorical_from_cdf(cdf: n=>Float, key: Key) -> n given (n|Ix) = r = rand key case search_sorted(cdf, r) of Just(i) -> i def normalize_pdf(xs: d=>Float) -> d=>Float given (d|Ix) = xs / sum xs def cdf_for_categorical(logprobs: n=>Float) -> n=>Float given (n|Ix) = maxLogProb = maximum logprobs cumsum_low $ normalize_pdf $ for i. exp(logprobs[i] - maxLogProb) def categorical(logprobs: n=>Float, key: Key) -> n given (n|Ix) = categorical_from_cdf(cdf_for_categorical logprobs, key) -- batch variant to share the work of forming the cumsum -- (alternatively we could rely on hoisting of loop constants) def categorical_batch(logprobs: n=>Float, key: Key) -> m=>n given (n|Ix, m|Ix) = cdf = cdf_for_categorical logprobs for i. categorical_from_cdf(cdf, ixkey(key, i)) def logsumexp(x: n=>Float) -> Float given (n|Ix) = m = maximum x m + (log $ sum for i. exp (x[i] - m)) def logsoftmax(x: n=>Float) -> n=>Float given (n|Ix) = lse = logsumexp x for i. x[i] - lse def softmax(x: n=>Float) -> n=>Float given (n|Ix) = m = maximum x e = for i. exp (x[i] - m) s = sum e for i. e[i] / s '## Polynomials TODO: Move this somewhere else def evalpoly(coefficients:n=>v, x:Float) -> v given (n|Ix, v|VSpace) = -- Evaluate a polynomial at x. Same as Numpy's polyval. fold zero \i c. coefficients[i] + x .* c '## TestMode -- TODO: move this to be in Testing Helpers def dex_test_mode() -> Bool = unsafe_io \. check_env "DEX_TEST_MODE" '## Exception effect -- TODO: move `error` and `todo` to here. def catch(f:() -> {Except|eff} a) -> {|eff} Maybe a given (a, eff) = f' : (() -> {Except|eff} a) = \. f() %catchException(f') def throw() -> {Except} a given (a) = %throwException(a) def assert(b:Bool) -> {Except} () = if not b then throw() '### Misc instances that require `error` instance Subset(a, Either(a,b)) given (a|Data, b|Data) def inject'(x) = Left x def project'(x) = case x of Left( y) -> Just y Right(x) -> Nothing def unsafe_project'(x) = case x of Left( x) -> x Right(x) -> error "Can't project Right branch to Left branch" instance Subset(b, Either(a,b)) given (a|Data, b|Data) def inject'(x) = Right x def project'(x) = case x of Left( x) -> Nothing Right(y) -> Just y def unsafe_project'(x) = case x of Left( x) -> error "Can't project Left branch to Right branch" Right(x) -> x '## Testing Helpers -- -- Reliably causes a segfault if pointers aren't initialized to zero. -- -- TODO: add this test when we cache modules -- justSomeDataToTestCaching = toList for i:(Fin 100). -- if ordinal i == 0 -- then Left (toList [1,2,3]) -- else Right 1 '### Index set for tables def int_to_reversed_digits(k:Nat) -> a=>b given (a|Ix, b|Ix) = base = size b snd $ scan k \_ cur_k. next_k = cur_k `idiv` base digit = cur_k `mod` base (next_k, unsafe_from_ordinal(n=b, digit)) def reversed_digits_to_int(digits: a=>b) -> Nat given (a|Ix, b|Ix) = base = size b fst $ fold (0, 1) \j pair. (cur_k, cur_base) = pair next_k = cur_k + ordinal digits[j] * cur_base next_base = cur_base * base (next_k, next_base) instance Ix(a=>b) given (a|Ix, b|Ix) -- 0@a is the least significant digit, -- while (size a - 1)@a is the most significant digit. def size'() = size b `intpow` size a def ordinal(i) = reversed_digits_to_int i def unsafe_from_ordinal(i) = int_to_reversed_digits i instance NonEmpty(a=>b) given (a|Ix, b|NonEmpty) first_ix = unsafe_from_ordinal 0 '### stack -- TODO: replace `DynBuffer` with this? struct Stack(h:Heap, a|Data) = size_ref : Ref h Nat buf_ref : Ref h (List a) def size() -> {State h} Nat = get self.size_ref def unsafe_get_buffer() -> {State h} (Ref(h, Fin 0 => a)) = get $ snd_ref $ unsafe_coerce(to=Ref h (Nat, Ref h (Fin 0 => a)), self.buf_ref) def buf_size() -> {State h} Nat = get $ fst_ref $ unsafe_coerce(to=Ref h (Nat, Ref h (Fin 0 => a)), self.buf_ref) def ensure_size_at_least(req_size:Nat) -> {State h} () = if req_size > self.buf_size() then -- TODO: maybe this should use integer arithmetic? new_buf_size = f_to_n $ 2.0 `pow` (ceil $ log2 $ n_to_f req_size) buf = self.unsafe_get_buffer() logical_size = self.size() cur_data = get $ unsafe_coerce(to=Ref(h, Fin logical_size => a), buf) self.buf_ref := to_list for i:(Fin new_buf_size). case to_ix(n=Fin logical_size, ordinal i) of Just(i') -> cur_data[i'] Nothing -> uninitialized_value() def read() -> {State h} (List a) = n = self.size() buf = unsafe_coerce(to=Ref(h, Fin n => a), self.unsafe_get_buffer()) AsList(n, get buf) @noinline def push(x:a) -> {State h} () = n_old = self.size() n_new = n_old + 1 self.ensure_size_at_least(n_new) buf = self.unsafe_get_buffer() buf ! (unsafe_from_ordinal n_old) := x self.size_ref := n_new @noinline def extend(x:n=>a) -> {State h} () given (n|Ix) = n_old = self.size() n_new = n_old + size n self.ensure_size_at_least(n_new) buf = self.unsafe_get_buffer() buf_slice = unsafe_coerce(to=Ref(h,n=>a), buf ! (unsafe_from_ordinal n_old)) buf_slice := x self.size_ref := n_new def pop() -> {State h} Maybe a = n_old = self.size() case n_old == 0 of True -> Nothing False -> n_new = unsafe_nat_diff(n_old, 1) buf = self.unsafe_get_buffer() self.size_ref := n_new Just $ get buf!(unsafe_from_ordinal n_new) stack_init_size = 16 def with_stack( a|Data, action:(given (h:Heap), Stack(h, a)) -> {State h|eff} r ) -> {|eff} r given (eff, r) = init_stack = to_list for i:(Fin stack_init_size). uninitialized_value() with_state (0, init_stack) \ref . action(Stack(ref.0, ref.1)) def stack_extend_internal(stack:Stack(h, Char), x:Fin n=>Char) -> {State h} () given (n, h) = stack.extend(x) def stack_push_internal(stack:Stack(h, Char), x:Char) -> {State h} () given (h) = stack.push(x) def with_stack_internal(f:(given (h:Heap), Stack(h, Char)) -> {State h} ()) -> List Char = with_stack Char \stack. f stack stack.read() def show_any(x:a) -> String given (a) = unsafe_coerce(to=String, %showAny(x)) def coerce_table(m|Ix, x:n=>a) -> m => a given (n|Ix, a|Data) = if size m == size n then unsafe_coerce(to=m=>a, x) else error "mismatched sizes in table coercion" '### Linear Algebra def linspace(n|Ix, low:Float, high:Float) -> n=>Float = dx = (high - low) / n_to_f (size n) for i:n. low + n_to_f (ordinal i) * dx def transpose(x:n=>m=>a) -> m=>n=>a given (n|Ix, m|Ix, a) = for i j. x[j,i] def vdot(x:n=>Float, y:n=>Float) -> Float given (n|Ix) = fsum for i. x[i] * y[i] def dot(s:n=>Float, vs:n=>v) -> v given (n|Ix, v|VSpace) = sum for j. s[j] .* vs[j] def naive_matmul(x: l=>m=>Float, y: m=>n=>Float) -> (l=>n=>Float) given (l|Ix, m|Ix, n|Ix) = for i k. fsum for j. x[i,j] * y[j,k] -- A `FullTileIx` type represents `tile_ix`th full tile (of size -- `tile_size`) iterating over the index set `n`. -- This type is only well formed when tile_ix * tile_size < size n. struct FullTileIx(n|Ix, tile_size:Nat, tile_ix:Nat) = unwrap : Fin tile_size instance Ix(FullTileIx(n, tile_size, tile_ix)) given (n|Ix, tile_size:Nat, tile_ix:Nat) def size'() = tile_size def ordinal(i) = ordinal i.unwrap def unsafe_from_ordinal(i) = FullTileIx $ unsafe_from_ordinal i instance Subset(FullTileIx(n, tile_size, tile_ix), n) given (n|Ix, tile_size:Nat, tile_ix:Nat) def inject'(i) = unsafe_from_ordinal $ tile_size * tile_ix + ordinal i.unwrap def project'(i) = todo def unsafe_project'(i) = todo -- A `CodaIx` type represents the last few elements of the index set `n`, -- as might be left over after iterating by tiles. -- This type is only well formed when size n == coda_offset + coda_size struct CodaIx(n|Ix, coda_offset:Nat, coda_size:Nat) = unwrap : Fin coda_size instance Ix(CodaIx(n, coda_offset, coda_size)) given (n|Ix, coda_offset:Nat, coda_size:Nat) def size'() = coda_size def ordinal(i) = ordinal i.unwrap def unsafe_from_ordinal(i) = CodaIx $ unsafe_from_ordinal i instance Subset(CodaIx(n, coda_offset, coda_size), n) given (n|Ix, coda_offset:Nat, coda_size:Nat) def inject'(i) = unsafe_from_ordinal $ coda_offset + ordinal i.unwrap def project'(i) = todo def unsafe_project'(i) = todo def tile( n|Ix, tile_size: Nat, body:(m:Type, given () (Ix m, Subset(m, n))) -> {|eff} () ) -> {|eff} () given (eff) = num_tiles = size n `idiv` tile_size coda_size = size n `rem` tile_size coda_offset = num_tiles * tile_size for_ tile_ix:(Fin num_tiles). tile_ix' = ordinal tile_ix body (FullTileIx(n, tile_size, tile_ix')) body (CodaIx(n, coda_offset, coda_size)) -- matmul. Better symbol to use? `@`? def (**)( x: l=>m=>Float, y: m=>n=>Float ) -> l=>n=>Float given (l|Ix, m|Ix, n|Ix) = -- Tile sizes picked for axch's laptop l_tile_size = 32 n_tile_size = 128 m_tile_size = 8 yield_accum (AddMonoid Float) \result. tile(l, l_tile_size) \l_set. tile(n, n_tile_size) \n_set. tile(m, m_tile_size) \m_set. for_ l_offset:l_set. l_ix = inject(to=l, l_offset) for_ n_offset:n_set. n_ix = inject n_offset for_ m_offset:m_set. m_ix = inject m_offset result!l_ix!n_ix += x[l_ix,m_ix] * y[m_ix,n_ix] def (**.)(mat: n=>m=>Float, v: m=>Float) -> (n=>Float) given (n|Ix, m|Ix) = for i. vdot(mat[i], v) def(.**)(v: n=>Float, mat: n=>m=>Float) -> (m=>Float) given (n|Ix, m|Ix) = transpose mat **. v def inner(x:n=>Float, mat:n=>m=>Float, y:m=>Float) -> Float given (n|Ix, m|Ix) = fsum for p. (i,j) = p x[i] * mat[i,j] * y[j] def eye() ->> n=>n=>a given (n|Ix, a|Add|Mul) = for i j. select(ordinal i == ordinal j, one, zero)