diff options
-rw-r--r-- | lib/prelude.dx | 1583 | ||||
-rw-r--r-- | lib/sort.dx | 42 | ||||
-rw-r--r-- | makefile | 4 | ||||
-rw-r--r-- | src/lib/ConcreteSyntax.hs | 123 | ||||
-rw-r--r-- | src/lib/Lexing.hs | 87 | ||||
-rw-r--r-- | tests/adt-tests.dx | 12 | ||||
-rw-r--r-- | tests/eval-tests.dx | 42 | ||||
-rw-r--r-- | tests/exception-tests.dx | 32 | ||||
-rw-r--r-- | tests/monad-tests.dx | 14 | ||||
-rw-r--r-- | tests/parser-tests.dx | 134 | ||||
-rw-r--r-- | tests/print-tests.dx | 2 | ||||
-rw-r--r-- | tests/shadow-tests.dx | 4 | ||||
-rw-r--r-- | tests/sort-tests.dx | 1 | ||||
-rw-r--r-- | tests/type-tests.dx | 38 | ||||
-rw-r--r-- | tests/uexpr-tests.dx | 36 |
15 files changed, 1116 insertions, 1038 deletions
diff --git a/lib/prelude.dx b/lib/prelude.dx index e0dcbdeb..d9f9def7 100644 --- a/lib/prelude.dx +++ b/lib/prelude.dx @@ -28,455 +28,455 @@ RawPtr : Type = %Word8Ptr() Int = Int32 Float = Float32 -def the (a:Type, x:a) -> a = x +def the(a:Type, x:a) -> a = x -interface Data (a:Type) +interface Data(a:Type) do_not_implement_this_interface_for_the_compiler_relies_on_the_invariant_it_protects : (a) -> a '### Casting -def internal_cast (x:from) -> to given (from, to) = +def internal_cast(x:from) -> to given (from, to) = %cast(to, x) -def unsafe_coerce (x:from) -> to given (from|Data, to|Data) = %unsafeCoerce(to, x) -def uninitialized_value () -> a given (a|Data) = %garbageVal(a) - -def f64_to_f (x : Float64) -> Float = internal_cast x -def f32_to_f (x : Float32) -> Float = internal_cast x -def f_to_f64 (x : Float) -> Float64 = internal_cast x -def f_to_f32 (x : Float) -> Float32 = internal_cast x -def i64_to_i (x : Int64) -> Int = internal_cast x -def i32_to_i (x : Int32) -> Int = internal_cast x -def w8_to_i (x : Word8) -> Int = internal_cast x -def i_to_i64 (x : Int) -> Int64 = internal_cast x -def i_to_i32 (x : Int) -> Int32 = internal_cast x -def i_to_w8 (x : Int) -> Word8 = internal_cast x -def i_to_w32 (x : Int) -> Word32 = internal_cast x -def i_to_w64 (x : Int) -> Word64 = internal_cast x -def w32_to_w64 (x : Word32)-> Word64 = internal_cast x -def i_to_f (x:Int) -> Float = internal_cast x -def f_to_i (x:Float) -> Int = internal_cast x -def raw_ptr_to_i64 (x:RawPtr) -> Int64 = internal_cast x +def unsafe_coerce(x:from) -> to given (from|Data, to|Data) = %unsafeCoerce(to, x) +def uninitialized_value() -> a given (a|Data) = %garbageVal(a) + +def f64_to_f(x: Float64) -> Float = internal_cast x +def f32_to_f(x: Float32) -> Float = internal_cast x +def f_to_f64(x: Float) -> Float64 = internal_cast x +def f_to_f32(x: Float) -> Float32 = internal_cast x +def i64_to_i(x: Int64) -> Int = internal_cast x +def i32_to_i(x: Int32) -> Int = internal_cast x +def w8_to_i(x: Word8) -> Int = internal_cast x +def i_to_i64(x: Int) -> Int64 = internal_cast x +def i_to_i32(x: Int) -> Int32 = internal_cast x +def i_to_w8(x: Int) -> Word8 = internal_cast x +def i_to_w32(x: Int) -> Word32 = internal_cast x +def i_to_w64(x: Int) -> Word64 = internal_cast x +def w32_to_w64(x: Word32)-> Word64 = internal_cast x +def i_to_f(x:Int) -> Float = internal_cast x +def f_to_i(x:Float) -> Int = internal_cast x +def raw_ptr_to_i64(x:RawPtr) -> Int64 = internal_cast x Nat = %Nat() NatRep = Word32 -def nat_to_rep (x : Nat) -> NatRep = %projNewtype(x) -def rep_to_nat (x : NatRep) -> Nat = %NatCon(x) - -def n_to_w8 (x : Nat) -> Word8 = nat_to_rep x | internal_cast -def n_to_w32 (x : Nat) -> Word32 = nat_to_rep x | internal_cast -def n_to_w64 (x : Nat) -> Word64 = nat_to_rep x | internal_cast -def n_to_i32 (x : Nat) -> Int32 = nat_to_rep x | internal_cast -def n_to_i64 (x : Nat) -> Int64 = nat_to_rep x | internal_cast -def n_to_f32 (x : Nat) -> Float32 = nat_to_rep x | internal_cast -def n_to_f64 (x : Nat) -> Float64 = nat_to_rep x | internal_cast -def n_to_f (x : Nat) -> Float = nat_to_rep x | internal_cast - -def w8_to_n (x : Word8) -> Nat = internal_cast x | rep_to_nat -def w32_to_n (x : Word32) -> Nat = internal_cast x | rep_to_nat -def w64_to_n (x : Word64) -> Nat = internal_cast x | rep_to_nat -def i32_to_n (x : Int32) -> Nat = internal_cast x | rep_to_nat -def i64_to_n (x : Int64) -> Nat = internal_cast x | rep_to_nat -def f32_to_n (x : Float32) -> Nat = internal_cast x | rep_to_nat -def f64_to_n (x : Float64) -> Nat = internal_cast x | rep_to_nat -def f_to_n (x : Float) -> Nat = internal_cast x | rep_to_nat - -interface FromUnsignedInteger (a:Type) +def nat_to_rep(x : Nat) -> NatRep = %projNewtype(x) +def rep_to_nat(x : NatRep) -> Nat = %NatCon(x) + +def n_to_w8(x: Nat) -> Word8 = nat_to_rep x | internal_cast +def n_to_w32(x: Nat) -> Word32 = nat_to_rep x | internal_cast +def n_to_w64(x: Nat) -> Word64 = nat_to_rep x | internal_cast +def n_to_i32(x: Nat) -> Int32 = nat_to_rep x | internal_cast +def n_to_i64(x: Nat) -> Int64 = nat_to_rep x | internal_cast +def n_to_f32(x: Nat) -> Float32 = nat_to_rep x | internal_cast +def n_to_f64(x: Nat) -> Float64 = nat_to_rep x | internal_cast +def n_to_f(x: Nat) -> Float = nat_to_rep x | internal_cast + +def w8_to_n(x : Word8) -> Nat = internal_cast x | rep_to_nat +def w32_to_n(x : Word32) -> Nat = internal_cast x | rep_to_nat +def w64_to_n(x : Word64) -> Nat = internal_cast x | rep_to_nat +def i32_to_n(x : Int32) -> Nat = internal_cast x | rep_to_nat +def i64_to_n(x : Int64) -> Nat = internal_cast x | rep_to_nat +def f32_to_n(x : Float32) -> Nat = internal_cast x | rep_to_nat +def f64_to_n(x : Float64) -> Nat = internal_cast x | rep_to_nat +def f_to_n(x : Float) -> Nat = internal_cast x | rep_to_nat + +interface FromUnsignedInteger(a:Type) from_unsigned_integer : (Word64) -> a -instance FromUnsignedInteger Word8 - def from_unsigned_integer (x) = internal_cast x +instance FromUnsignedInteger(Word8) + def from_unsigned_integer(x) = internal_cast x -instance FromUnsignedInteger Word32 - def from_unsigned_integer (x) = internal_cast x +instance FromUnsignedInteger(Word32) + def from_unsigned_integer(x) = internal_cast x -instance FromUnsignedInteger Word64 - def from_unsigned_integer (x) = x +instance FromUnsignedInteger(Word64) + def from_unsigned_integer(x) = x -instance FromUnsignedInteger Int32 - def from_unsigned_integer (x) = internal_cast x +instance FromUnsignedInteger(Int32) + def from_unsigned_integer(x) = internal_cast x -instance FromUnsignedInteger Int64 - def from_unsigned_integer (x) = internal_cast x +instance FromUnsignedInteger(Int64) + def from_unsigned_integer(x) = internal_cast x -instance FromUnsignedInteger Float32 - def from_unsigned_integer (x) = internal_cast x +instance FromUnsignedInteger(Float32) + def from_unsigned_integer(x) = internal_cast x -instance FromUnsignedInteger Float64 - def from_unsigned_integer (x) = internal_cast x +instance FromUnsignedInteger(Float64) + def from_unsigned_integer(x) = internal_cast x -instance FromUnsignedInteger Nat - def from_unsigned_integer (x) = w64_to_n(x) +instance FromUnsignedInteger(Nat) + def from_unsigned_integer(x) = w64_to_n(x) -interface FromInteger (a:Type) +interface FromInteger(a:Type) from_integer : (Int64) -> a -instance FromInteger Float32 - def from_integer (x) = internal_cast x +instance FromInteger(Float32) + def from_integer(x) = internal_cast x -instance FromInteger Int32 - def from_integer (x) = internal_cast x +instance FromInteger(Int32) + def from_integer(x) = internal_cast x -instance FromInteger Float64 - def from_integer (x) = internal_cast x +instance FromInteger(Float64) + def from_integer(x) = internal_cast x -instance FromInteger Int64 - def from_integer (x) = x +instance FromInteger(Int64) + def from_integer(x) = x '## Bitwise operations -interface Bits (a:Type) +interface Bits(a:Type) (.<<.) : (a, Int) -> a (.>>.) : (a, Int) -> a (.|.) : (a, a) -> a (.&.) : (a, a) -> a (.^.) : (a, a) -> a -instance Bits Word8 - def (.<<.) (x, y) = %shl(x, i_to_w8 y) - def (.>>.) (x, y) = %shr(x, i_to_w8 y) - def (.|.) (x, y) = %or( x, y) - def (.&.) (x, y) = %and(x, y) - def (.^.) (x, y) = %xor(x, y) - -instance Bits Word32 - def (.<<.) (x, y) = %shl(x, i_to_w32 y) - def (.>>.) (x, y) = %shr(x, i_to_w32 y) - def (.|.) (x, y) = %or( x, y) - def (.&.) (x, y) = %and(x, y) - def (.^.) (x, y) = %xor(x, y) - -instance Bits Word64 - def (.<<.) (x, y) = %shl(x, i_to_w64 y) - def (.>>.) (x, y) = %shr(x, i_to_w64 y) - def (.|.) (x, y) = %or( x ,y) - def (.&.) (x, y) = %and(x ,y) - def (.^.) (x, y) = %xor(x ,y) - -def low_word (x : Word64) -> Word32 = internal_cast(x .>>. 32) -def high_word (x : Word64) -> Word32 = internal_cast(x) +instance Bits(Word8) + def (.<<.)(x, y) = %shl(x, i_to_w8 y) + def (.>>.)(x, y) = %shr(x, i_to_w8 y) + def (.|.)(x, y) = %or( x, y) + def (.&.)(x, y) = %and(x, y) + def (.^.)(x, y) = %xor(x, y) + +instance Bits(Word32) + def (.<<.)(x, y) = %shl(x, i_to_w32 y) + def (.>>.)(x, y) = %shr(x, i_to_w32 y) + def (.|.)(x, y) = %or( x, y) + def (.&.)(x, y) = %and(x, y) + def (.^.)(x, y) = %xor(x, y) + +instance Bits(Word64) + def (.<<.)(x, y) = %shl(x, i_to_w64 y) + def (.>>.)(x, y) = %shr(x, i_to_w64 y) + def (.|.)(x, y) = %or( x ,y) + def (.&.)(x, y) = %and(x ,y) + def (.^.)(x, y) = %xor(x ,y) + +def low_word( x : Word64) -> Word32 = internal_cast(x .>>. 32) +def high_word(x : Word64) -> Word32 = internal_cast(x) '### Basic Arithmetic #### Add Things that can be added. This defines the `Add` [group](https://en.wikipedia.org/wiki/Group_(mathematics)) and its operators. -interface Add (a|Data) +interface Add(a|Data) (+) : (a, a) -> a zero : a -interface Sub (a|Add) +interface Sub(a|Add) (-) : (a, a) -> a -instance Add Float64 - def (+) (x, y) = %fadd(x, y) +instance Add(Float64) + def (+)(x, y) = %fadd(x, y) zero = 0 -instance Sub Float64 - def (-) (x, y) = %fsub(x, y) +instance Sub(Float64) + def (-)(x, y) = %fsub(x, y) -instance Add Float32 - def (+) (x, y) = %fadd(x, y) +instance Add(Float32) + def (+)(x, y) = %fadd(x, y) zero = 0 -instance Sub Float32 - def (-) (x, y) = %fsub(x, y) +instance Sub(Float32) + def (-)(x, y) = %fsub(x, y) -instance Add Int64 - def (+) (x, y) = %iadd(x, y) +instance Add(Int64) + def (+)(x, y) = %iadd(x, y) zero = 0 -instance Sub Int64 - def (-) (x, y) = %isub(x, y) +instance Sub(Int64) + def (-)(x, y) = %isub(x, y) -instance Add Int32 - def (+) (x, y) = %iadd(x, y) +instance Add(Int32) + def (+)(x, y) = %iadd(x, y) zero = 0 -instance Sub Int32 - def (-) (x, y) = %isub(x, y) +instance Sub(Int32) + def (-)(x, y) = %isub(x, y) -instance Add Word8 - def (+) (x, y) = %iadd(x, y) +instance Add(Word8) + def (+)(x, y) = %iadd(x, y) zero = 0 -instance Sub Word8 - def (-) (x, y) = %isub(x, y) +instance Sub(Word8) + def (-)(x, y) = %isub(x, y) -instance Add Word32 - def (+) (x, y) = %iadd(x, y) +instance Add(Word32) + def (+)(x, y) = %iadd(x, y) zero = 0 -instance Sub Word32 - def (-) (x, y) = %isub(x, y) +instance Sub(Word32) + def (-)(x, y) = %isub(x, y) -instance Add Word64 - def (+) (x, y) = %iadd(x, y) +instance Add(Word64) + def (+)(x, y) = %iadd(x, y) zero = 0 -instance Sub Word64 - def (-) (x, y) = %isub(x, y) +instance Sub(Word64) + def (-)(x, y) = %isub(x, y) -instance Add Nat - def (+) (x, y) = rep_to_nat %iadd(nat_to_rep x, nat_to_rep y) +instance Add(Nat) + def (+)(x, y) = rep_to_nat %iadd(nat_to_rep x, nat_to_rep y) zero = 0 -instance Add () +instance Add(()) def (+)(x, y) = () zero = () -instance Sub () +instance Sub(()) def (-)(x, y) = () '#### Mul Things that can be multiplied. This defines the `Mul` [Monoid](https://en.wikipedia.org/wiki/Monoid), and its operator. -interface Mul (a|Data) +interface Mul(a|Data) (*) : (a, a) -> a one : a -instance Mul Float64 - def (*) (x, y) = %fmul(x, y) +instance Mul(Float64) + def (*)(x, y) = %fmul(x, y) one = f_to_f64 1.0 -instance Mul Float32 - def (*) (x, y) = %fmul(x, y) +instance Mul(Float32) + def (*)(x, y) = %fmul(x, y) one = f_to_f32 1.0 -instance Mul Int64 - def (*) (x, y) = %imul(x, y) +instance Mul(Int64) + def (*)(x, y) = %imul(x, y) one = 1 -instance Mul Int32 - def (*) (x, y) = %imul(x, y) +instance Mul(Int32) + def (*)(x, y) = %imul(x, y) one = 1 -instance Mul Word8 - def (*) (x, y) = %imul(x, y) +instance Mul(Word8) + def (*)(x, y) = %imul(x, y) one = 1 -instance Mul Word32 - def (*) (x, y) = %imul(x, y) +instance Mul(Word32) + def (*)(x, y) = %imul(x, y) one = 1 -instance Mul Word64 - def (*) (x, y) = %imul(x, y) +instance Mul(Word64) + def (*)(x, y) = %imul(x, y) one = 1 -instance Mul Nat - def (*) (x, y) = rep_to_nat %imul(nat_to_rep x, nat_to_rep y) +instance Mul(Nat) + def(*)(x, y) = rep_to_nat %imul(nat_to_rep x, nat_to_rep y) one = 1 -instance Mul () - def (*) (x, y) = () +instance Mul(()) + def (*)(x, y) = () one = () '#### Integral Integer-like things. -interface Integral (a) +interface Integral(a) idiv : (a,a)->a rem : (a,a)->a -instance Integral Int64 - def idiv (x, y) = %idiv(x, y) - def rem (x, y) = %irem(x, y) +instance Integral(Int64) + def idiv(x, y) = %idiv(x, y) + def rem(x, y) = %irem(x, y) -instance Integral Int32 - def idiv (x, y) = %idiv(x, y) - def rem (x, y) = %irem(x, y) +instance Integral(Int32) + def idiv(x, y) = %idiv(x, y) + def rem(x, y) = %irem(x, y) -instance Integral Word8 - def idiv (x, y) = %idiv(x, y) - def rem (x, y) = %irem(x, y) +instance Integral(Word8) + def idiv(x, y) = %idiv(x, y) + def rem(x, y) = %irem(x, y) -instance Integral Word32 - def idiv (x, y) = %idiv(x, y) - def rem (x, y) = %irem(x, y) +instance Integral(Word32) + def idiv(x, y) = %idiv(x, y) + def rem(x, y) = %irem(x, y) -instance Integral Word64 - def idiv (x, y) = %idiv(x, y) - def rem (x, y) = %irem(x, y) +instance Integral(Word64) + def idiv(x, y) = %idiv(x, y) + def rem(x, y) = %irem(x, y) -instance Integral Nat - def idiv (x, y) = rep_to_nat %idiv(nat_to_rep x, (nat_to_rep y)) - def rem (x, y) = rep_to_nat %irem(nat_to_rep x, (nat_to_rep y)) +instance Integral(Nat) + def idiv(x, y) = rep_to_nat %idiv(nat_to_rep x, (nat_to_rep y)) + def rem(x, y) = rep_to_nat %irem(nat_to_rep x, (nat_to_rep y)) '#### Fractional Rational-like things. Includes floating point and two field rational representations. -interface Fractional (a) +interface Fractional(a) divide : (a, a) -> a -instance Fractional Float64 - def divide (x, y) = %fdiv(x, y) +instance Fractional(Float64) + def divide(x, y) = %fdiv(x, y) -instance Fractional Float32 - def divide (x, y) = %fdiv(x, y) +instance Fractional(Float32) + def divide(x, y) = %fdiv(x, y) '## Index set interface and instances -interface Ix (n|Data) +interface Ix(n|Data) size' : () -> Nat ordinal : (n) -> Nat unsafe_from_ordinal : (Nat) -> n -def size (n|Ix) -> Nat = size'(n=n) +def size(n|Ix) -> Nat = size'(n=n) -def Fin (n:Nat) -> Type = %Fin(n) +def Fin(n:Nat) -> Type = %Fin(n) -- version of subtraction on Nats that clamps at zero -def (-|) (x: Nat, y:Nat) -> Nat = +def (-|)(x: Nat, y:Nat) -> Nat = x' = nat_to_rep x y' = nat_to_rep y requires_clamp = %ilt(x', y') rep_to_nat %select(requires_clamp, 0, (%isub(x', y'))) -def unsafe_nat_diff (x:Nat, y:Nat) -> Nat = +def unsafe_nat_diff(x:Nat, y:Nat) -> Nat = x' = nat_to_rep x y' = nat_to_rep y rep_to_nat %isub(x', y') -- `(i..)` parses as `RangeFrom _ i` -- TODO: need to a way to indicate `.new` as private -struct RangeFrom (q:Type, i:q) = val : Nat +struct RangeFrom(q:Type, i:q) = val : Nat -- `(i<..)` parses as `RangeFromExc _ i` -struct RangeFromExc (q:Type, i:q) = val : Nat +struct RangeFromExc(q:Type, i:q) = val : Nat -- `(..i)` parses as `RangeTo _ i` -struct RangeTo (q:Type, i:q) = val : Nat +struct RangeTo(q:Type, i:q) = val : Nat -- `(..<i)` parses as `RangeToExc _ i` -struct RangeToExc (q:Type, i:q) = val : Nat +struct RangeToExc(q:Type, i:q) = val : Nat -instance Ix RangeFrom(q, i) given (q|Ix, i:q) - def size' () = unsafe_nat_diff(size q, ordinal i) +instance Ix(RangeFrom q i) given (q|Ix, i:q) + def size'() = unsafe_nat_diff(size q, ordinal i) def ordinal(j) = j.val def unsafe_from_ordinal(j) = RangeFrom.new(j) -instance Ix RangeFromExc(q, i) given (q|Ix, i:q) - def size' () = unsafe_nat_diff(size q, ordinal i + 1) +instance Ix(RangeFromExc q i) given (q|Ix, i:q) + def size'() = unsafe_nat_diff(size q, ordinal i + 1) def ordinal(j) = j.val def unsafe_from_ordinal(j) = RangeFromExc.new(j) -instance Ix RangeTo(q, i) given (q|Ix, i:q) - def size' () = ordinal i + 1 +instance Ix(RangeTo q i) given (q|Ix, i:q) + def size'() = ordinal i + 1 def ordinal(j) = j.val def unsafe_from_ordinal(j) = RangeTo.new(j) -instance Ix RangeToExc(q, i) given (q|Ix, i:q) - def size' () = ordinal i +instance Ix(RangeToExc q i) given (q|Ix, i:q) + def size'() = ordinal i def ordinal(j) = j.val def unsafe_from_ordinal(j) = RangeToExc.new(j) -instance Ix () - def size' () = 1 +instance Ix(()) + def size'() = 1 def ordinal(_) = 0 def unsafe_from_ordinal(_) = () -def iota (n|Ix) -> n=>Nat = for i. ordinal i +def iota(n|Ix) -> n=>Nat = for i. ordinal i '## Arithmetic instances for table types -instance Add (n=>a) given (a|Add, n|Ix) - def (+) (xs, ys) = for i. xs[i] + ys[i] +instance Add(n=>a) given (a|Add, n|Ix) + def (+)(xs, ys) = for i. xs[i] + ys[i] zero = for _. zero -instance Sub (n=>a) given (a|Sub, n|Ix) - def (-) (xs, ys) = for i. xs[i] - ys[i] +instance Sub(n=>a) given (a|Sub, n|Ix) + def (-)(xs, ys) = for i. xs[i] - ys[i] -instance Add ((i:n) => (i..) => a) given (a|Add, n|Ix) -- Upper triangular tables - def (+) (xs, ys) = for i. xs[i] + ys[i] +instance Add((i:n) => (i..) => a) given (a|Add, n|Ix) -- Upper triangular tables + def (+)(xs, ys) = for i. xs[i] + ys[i] zero = for _. zero -instance Sub ((i:n) => (i..) => a) given (a|Sub, n|Ix) -- Upper triangular tables - def (-) (xs, ys) = for i. xs[i] - ys[i] +instance Sub((i:n) => (i..) => a) given (a|Sub, n|Ix) -- Upper triangular tables + def (-)(xs, ys) = for i. xs[i] - ys[i] -instance Add ((i:n) => (..i) => a) given (a|Add, n|Ix) -- Lower triangular tables - def (+) (xs, ys) = for i. xs[i] + ys[i] +instance Add((i:n) => (..i) => a) given (a|Add, n|Ix) -- Lower triangular tables + def (+)(xs, ys) = for i. xs[i] + ys[i] zero = for _. zero -instance Sub ((i:n) => (..i) => a) given (a|Sub, n|Ix) -- Lower triangular tables - def (-) (xs, ys) = for i. xs[i] - ys[i] +instance Sub((i:n) => (..i) => a) given (a|Sub, n|Ix) -- Lower triangular tables + def (-)(xs, ys) = for i. xs[i] - ys[i] -instance Add ((i:n) => (..<i) => a) given (a|Add, n|Ix) - def (+) (xs, ys) = for i. xs[i] + ys[i] +instance Add((i:n) => (..<i) => a) given (a|Add, n|Ix) + def (+)(xs, ys) = for i. xs[i] + ys[i] zero = for _. zero -instance Sub ((i:n) => (..<i) => a) given (a|Sub, n|Ix) - def (-) (xs, ys) = for i. xs[i] - ys[i] +instance Sub((i:n) => (..<i) => a) given (a|Sub, n|Ix) + def (-)(xs, ys) = for i. xs[i] - ys[i] -instance Add ((i:n) => (i<..) => a) given (a|Add, n|Ix) - def (+) (xs, ys) = for i. xs[i] + ys[i] +instance Add((i:n) => (i<..) => a) given (a|Add, n|Ix) + def (+)(xs, ys) = for i. xs[i] + ys[i] zero = for _. zero -instance Sub ((i:n) => (i<..) => a) given (a|Sub, n|Ix) - def (-) (xs, ys) = for i. xs[i] - ys[i] +instance Sub((i:n) => (i<..) => a) given (a|Sub, n|Ix) + def (-)(xs, ys) = for i. xs[i] - ys[i] -instance Mul (n=>a) given (a|Mul, n|Ix) - def (*) (xs, ys) = for i. xs[i] * ys[i] +instance Mul(n=>a) given (a|Mul, n|Ix) + def (*)(xs, ys) = for i. xs[i] * ys[i] one = for _. one '## Basic polymorphic functions and types -def fst (pair:(a, b)) -> a given (a, b) = +def fst(pair:(a, b)) -> a given (a, b) = (x, _) = pair x -def snd (pair:(a, b)) -> b given (a, b) = +def snd(pair:(a, b)) -> b given (a, b) = (_, y) = pair y -def swap (pair:(a, b)) -> (b, a) given (a, b) = +def swap(pair:(a, b)) -> (b, a) given (a, b) = (x, y) = pair (y, x) -instance Add (a, b) given (a|Add, b|Add) - def (+) (x, y) = +instance Add((a, b)) given (a|Add, b|Add) + def (+)(x, y) = (x1, x2) = x (y1, y2) = y (x1 + y1, x2 + y2) zero = (zero, zero) -instance Sub (a, b) given (a|Sub, b|Sub) - def (-) (x, y) = +instance Sub((a, b)) given (a|Sub, b|Sub) + def(-)(x, y) = (x1, x2) = x (y1, y2) = y (x1 - y1, x2 - y2) -instance Ix (a, b) given (a|Ix, b|Ix) - def size' () = size a * size b - def ordinal (pair) = +instance Ix((a, b)) given (a|Ix, b|Ix) + def size'() = size a * size b + def ordinal(pair) = (i, j) = pair (ordinal i * size b) + ordinal j - def unsafe_from_ordinal (o) = + def unsafe_from_ordinal(o) = bs = size b (unsafe_from_ordinal(n=a, idiv(o, bs)), unsafe_from_ordinal(n=b, rem(o, bs))) '## Vector spaces -interface VSpace (a|Add|Sub) +interface VSpace(a|Add|Sub) (.*) : (Float, a) -> a -def (*.) (v:a, s:Float) -> a given (a|VSpace) = s .* v -def (/) (v:a, s:Float) -> a given (a|VSpace) = divide(1.0, s) .* v -def neg (v:a) -> a given (a|VSpace) = (-1.0) .* v +def (*.)(v:a, s:Float) -> a given (a|VSpace) = s .* v +def (/)( v:a, s:Float) -> a given (a|VSpace) = divide(1.0, s) .* v +def neg( v:a) -> a given (a|VSpace) = (-1.0) .* v -instance VSpace Float - def (.*) (x, y) = x * y +instance VSpace(Float) + def (.*)(x, y) = x * y -instance VSpace (n=>a) given (a|VSpace, n|Ix) - def (.*) (s, xs) = for i. s .* xs[i] +instance VSpace(n=>a) given (a|VSpace, n|Ix) + def (.*)(s, xs) = for i. s .* xs[i] -instance VSpace (a, b) given (a|VSpace, b|VSpace) - def (.*) (s, pair) = +instance VSpace((a, b)) given (a|VSpace, b|VSpace) + def (.*)(s, pair) = (a, b) = pair (s .* a, s .* b) -instance VSpace ((i:n) => (..i) => a) given (n|Ix, a|VSpace) - def (.*) (s, xs) = for i. s .* xs[i] +instance VSpace((i:n) => (..i) => a) given (n|Ix, a|VSpace) + def (.*)(s, xs) = for i. s .* xs[i] -instance VSpace ((i:n) => (i..) => a) given (n|Ix, a|VSpace) - def (.*) (s, xs) = for i. s .* xs[i] +instance VSpace((i:n) => (i..) => a) given (n|Ix, a|VSpace) + def (.*)(s, xs) = for i. s .* xs[i] -instance VSpace ((i:n) => (..<i) => a) given (n|Ix, a|VSpace) - def (.*) (s, xs) = for i. s .* xs[i] +instance VSpace((i:n) => (..<i) => a) given (n|Ix, a|VSpace) + def (.*)(s, xs) = for i. s .* xs[i] -instance VSpace ((i:n) => (i<..) => a) given (n|Ix, a|VSpace) - def (.*) (s, xs) = for i. s .* xs[i] +instance VSpace((i:n) => (i<..) => a) given (n|Ix, a|VSpace) + def (.*)(s, xs) = for i. s .* xs[i] -instance VSpace () - def (.*) (_, _) = () +instance VSpace(()) + def (.*)(_, _) = () '## Boolean type @@ -484,21 +484,21 @@ data Bool = False True -def b_to_w8 (x:Bool) -> Word8 = %dataConTag(x) +def b_to_w8(x:Bool) -> Word8 = %dataConTag(x) -def w8_to_b (x:Word8) -> Bool = %toEnum(Bool, x) +def w8_to_b(x:Word8) -> Bool = %toEnum(Bool, x) -def (&&) (x:Bool, y:Bool) -> Bool = +def (&&)(x:Bool, y:Bool) -> Bool = x' = b_to_w8 x y' = b_to_w8 y w8_to_b $ %and(x', y') -def (||) (x:Bool, y:Bool) -> Bool = +def (||)(x:Bool, y:Bool) -> Bool = x' = b_to_w8 x y' = b_to_w8 y w8_to_b $ %or(x', y') -def not (x:Bool) -> Bool = +def not(x:Bool) -> Bool = x' = b_to_w8 x w8_to_b $ %not(x') @@ -507,14 +507,14 @@ TODO: move these with the others? -- Can't use `%select` because it lowers to `ISelect`, which requires -- `a` to be a `BaseTy`. -def select (p:Bool, x:a, y:a) -> a given (a) = +def select(p:Bool, x:a, y:a) -> a given (a) = case p of True -> x False -> y -def b_to_i (x:Bool) -> Int = w8_to_i(b_to_w8 x) -def b_to_n (x:Bool) -> Nat = w8_to_n(b_to_w8 x) -def b_to_f (x:Bool) -> Float = i_to_f(b_to_i x) +def b_to_i(x:Bool) -> Int = w8_to_i(b_to_w8 x) +def b_to_n(x:Bool) -> Nat = w8_to_n(b_to_w8 x) +def b_to_f(x:Bool) -> Float = i_to_f(b_to_i x) '## Ordering TODO: move this down to with `Ord`? @@ -524,39 +524,39 @@ data Ordering = EQ GT -def o_to_w8 (x:Ordering) -> Word8 = %dataConTag(x) +def o_to_w8(x:Ordering) -> Word8 = %dataConTag(x) '## Sum types A [sum type, or tagged union](https://en.wikipedia.org/wiki/Tagged_union) can hold values from a fixed set of types, distinguished by tags. For those familiar with the C language, they can be though of as a combination of an `enum` with a `union`. Here we define several basic kinds, and some operators on them. -data Maybe (a:Type) = +data Maybe(a:Type) = Nothing Just(a) -def is_nothing (x:Maybe a) -> Bool given (a) = +def is_nothing(x:Maybe a) -> Bool given (a) = case x of Nothing -> True Just(_) -> False -def is_just (x:Maybe a) -> Bool given (a) = not $ is_nothing x +def is_just(x:Maybe a) -> Bool given (a) = not $ is_nothing x -def maybe (d:b, f:(a)->b, x:Maybe a) -> b given (a, b) = +def maybe(d:b, f:(a)->b, x:Maybe a) -> b given (a, b) = case x of Nothing -> d Just(x') -> f x' -data Either (a:Type, b:Type) = +data Either(a:Type, b:Type) = Left(a) Right(b) -instance Ix (Either(a, b)) given (a|Ix, b|Ix) - def size' () = size a + size b - def ordinal (i) = case i of +instance Ix(Either(a, b)) given (a|Ix, b|Ix) + def size'() = size a + size b + def ordinal(i) = case i of Left(ai) -> ordinal ai Right(bi) -> ordinal bi + size a - def unsafe_from_ordinal (o) = + def unsafe_from_ordinal(o) = as = nat_to_rep $ size a o' = nat_to_rep o -- TODO: Reshuffle the prelude to be able to use (<) here @@ -568,43 +568,43 @@ instance Ix (Either(a, b)) given (a|Ix, b|Ix) '## Subtraction on Nats -- TODO: think more about the right API here -def unsafe_i_to_n (x:Int) -> Nat = +def unsafe_i_to_n(x:Int) -> Nat = rep_to_nat $ internal_cast x -def n_to_i (x:Nat) -> Int = +def n_to_i(x:Nat) -> Int = internal_cast (nat_to_rep x) -def i_to_n (x:Int) -> Maybe Nat = +def i_to_n(x:Int) -> Maybe Nat = if w8_to_b $ %ilt(x, 0::Int) then Nothing else Just $ unsafe_i_to_n x '## Fencepost index sets -struct Post (segment:Type) = +struct Post(segment:Type) = val : Nat -instance Ix (Post segment) given (segment|Ix) - def size' () = size segment + 1 - def ordinal (i) = i.val - def unsafe_from_ordinal (i) = Post.new(i) +instance Ix(Post segment) given (segment|Ix) + def size'() = size segment + 1 + def ordinal(i) = i.val + def unsafe_from_ordinal(i) = Post.new(i) -def left_post (i:n) -> Post n given (n|Ix) = +def left_post(i:n) -> Post n given (n|Ix) = unsafe_from_ordinal(n=Post n, ordinal i) -def right_post (i:n) -> Post n given (n|Ix) = +def right_post(i:n) -> Post n given (n|Ix) = unsafe_from_ordinal(n=Post n, ordinal i + 1) -interface NonEmpty (n|Ix) +interface NonEmpty(n|Ix) first_ix : n -def last_ix () ->> n given (n|NonEmpty) = +def last_ix() ->> n given (n|NonEmpty) = unsafe_from_ordinal(unsafe_i_to_n(n_to_i(size n) - 1)) -instance NonEmpty (Post n) given (n|Ix) +instance NonEmpty(Post n) given (n|Ix) first_ix = unsafe_from_ordinal(n=Post n, 0) -instance NonEmpty () +instance NonEmpty(()) first_ix = unsafe_from_ordinal(0) '### Monoid @@ -616,41 +616,41 @@ It includes: - Concatenation of Lists (including strings) Monoids support `fold` operations, and similar. -interface Monoid (a|Data) +interface Monoid(a|Data) mempty : a (<>) : (a, a) -> a -instance Monoid (n=>a) given (a|Monoid, n|Ix) +instance Monoid(n=>a) given (a|Monoid, n|Ix) mempty = for i. mempty - def (<>) (x, y) = for i. x[i] <> y[i] + def (<>)(x, y) = for i. x[i] <> y[i] -named-instance AndMonoid : Monoid Bool +named-instance AndMonoid : Monoid(Bool) mempty = True - def (<>) (x, y) = x && y + def (<>)(x, y) = x && y -named-instance OrMonoid : Monoid Bool +named-instance OrMonoid : Monoid(Bool) mempty = False - def (<>) (x, y) = x || y + def (<>)(x, y) = x || y -named-instance AddMonoid (a|Add) -> Monoid a +named-instance AddMonoid(a|Add) -> Monoid(a) mempty = zero - def (<>) (x, y) = x + y + def (<>)(x, y) = x + y -named-instance MulMonoid (a|Mul) -> Monoid a +named-instance MulMonoid(a|Mul) -> Monoid(a) mempty = one - def (<>) (x, y) = x * y + def (<>)(x, y) = x * y '## Effects -def Ref (r:Heap, a|Data) -> Type = %Ref(r, a) -def get (ref:Ref h s) -> {State h} s given (h, s) = %get(ref) -def (:=) (ref:Ref h s, x:s) -> {State h} () given (h, s) = %put(ref, x) +def Ref(r:Heap, a|Data) -> Type = %Ref(r, a) +def get(ref:Ref h s) -> {State h} s given (h, s) = %get(ref) +def (:=)(ref:Ref h s, x:s) -> {State h} () given (h, s) = %put(ref, x) -def ask (ref:Ref h r) -> {Read h} r given (h, r) = %ask(ref) +def ask(ref:Ref h r) -> {Read h} r given (h, r) = %ask(ref) -data AccumMonoidData(h:Heap, w:Type) = UnsafeMkAccumMonoidData (b:Type, Monoid b) +data AccumMonoidData(h:Heap, w:Type) = UnsafeMkAccumMonoidData(b:Type, Monoid b) -interface AccumMonoid (h:Heap, w) +interface AccumMonoid(h:Heap, w) getAccumMonoidData : AccumMonoidData(h, w) instance AccumMonoid(h, n=>w) given (n|Ix, h, w) (am:AccumMonoid(h, w)) @@ -658,15 +658,15 @@ instance AccumMonoid(h, n=>w) given (n|Ix, h, w) (am:AccumMonoid(h, w)) UnsafeMkAccumMonoidData(b, bm) = %applyMethod0(am) UnsafeMkAccumMonoidData(b, bm) -def (+=) (ref:Ref h w, x:w) -> {Accum h} () +def (+=)(ref:Ref h w, x:w) -> {Accum h} () given (h, w) (am:AccumMonoid(h, w)) = UnsafeMkAccumMonoidData(b, bm) = %applyMethod0(am) empty = %applyMethod0(bm) %mextend(ref, empty, \x:b y:b. %applyMethod1(bm, x, y), x) -def (!) (ref: Ref h (n=>a), i:n) -> Ref h a given (n|Ix, a|Data, h) = %indexRef(ref, i) -def fst_ref (ref: Ref h (a,b)) -> Ref h a given (a|Data, b, h) = %fstRef(ref) -def snd_ref (ref: Ref h (a,b)) -> Ref h b given (a, b|Data, h) = %sndRef(ref) +def (!)(ref: Ref h (n=>a), i:n) -> Ref h a given (n|Ix, a|Data, h) = %indexRef(ref, i) +def fst_ref(ref: Ref h (a,b)) -> Ref h a given (a|Data, b, h) = %fstRef(ref) +def snd_ref(ref: Ref h (a,b)) -> Ref h b given (a, b|Data, h) = %sndRef(ref) def run_reader( init:r, @@ -681,7 +681,7 @@ def with_reader( ) -> {|eff} a given (r|Data, a, eff) = run_reader(init, action) -def MonoidLifter (b:Type, w:Type) -> Type = +def MonoidLifter(b:Type, w:Type) -> Type = (given (h) (AccumMonoid(h, b))) ->> AccumMonoid(h, w) named-instance mk_accum_monoid (given (h, w), d:AccumMonoidData(h, w)) -> AccumMonoid(h, w) @@ -692,7 +692,7 @@ def run_accum( action: (given (h) (AccumMonoid(h, b)), Ref h w) -> {Accum h|eff} a ) -> {|eff} (a, w) given (a, b, w|Data, eff) (MonoidLifter(b,w)) = empty = %applyMethod0(bm) - def explicitAction (h':Heap, ref:Ref h' w) -> {Accum h'|eff} a = + def explicitAction(h':Heap, ref:Ref h' w) -> {Accum h'|eff} a = accumMonoidData : AccumMonoidData h' b = UnsafeMkAccumMonoidData b bm accumBaseMonoid = mk_accum_monoid accumMonoidData %explicitApply(action, h', accumBaseMonoid, ref) @@ -708,7 +708,7 @@ def run_state( init:s, action: (given (h), Ref h s) -> {State h |eff} a ) -> {|eff} (a,s) given (a, s|Data, eff) = - def explicitAction (h':Heap, ref:Ref h' s) -> {State h'|eff} a = action ref + def explicitAction(h':Heap, ref:Ref h' s) -> {State h'|eff} a = action ref %runState(init, explicitAction) def with_state( @@ -723,13 +723,13 @@ def yield_state( ) -> {|eff} s given (a, s|Data, eff) = snd $ run_state(init, action) -def unsafe_io ( +def unsafe_io( f:()->{IO|eff} a ) -> {|eff} a given (a, eff) = f' : (() -> {IO|eff} a) = \. f() %runIO(f') -def unreachable () -> a given (a|Data) = unsafe_io \. %throwError(a) +def unreachable() -> a given (a|Data) = unsafe_io \. %throwError(a) '## Type classes @@ -739,10 +739,10 @@ def unreachable () -> a given (a|Data) = unsafe_io \. %throwError(a) Equatable. Things that we can tell if they are equal or not to other things. -interface Eq (a|Data) +interface Eq(a|Data) (==) : (a, a) -> Bool -def (/=) (x:a, y:a) -> Bool given (a|Eq) = not $ x == y +def (/=)(x:a, y:a) -> Bool given (a|Eq) = not $ x == y '#### Ord Orderable / Comparable. @@ -751,42 +751,42 @@ i.e. things that can be compared to other things to find if larger, smaller or e 'We take the standard false-hood and pretend that this applies to Floats, even though strictly speaking this not true as our floats follow [IEEE754](https://en.wikipedia.org/wiki/IEEE_754), and thus have `NaN < 1.0 == false` and `1.0 < NaN == false`. -interface Ord (a|Eq) +interface Ord(a|Eq) (>) : (a, a) -> Bool (<) : (a, a) -> Bool -def (<=) (x:a, y:a) -> Bool given (a|Ord) = x<y || x==y -def (>=) (x:a, y:a) -> Bool given (a|Ord) = x>y || x==y +def (<=)(x:a, y:a) -> Bool given (a|Ord) = x<y || x==y +def (>=)(x:a, y:a) -> Bool given (a|Ord) = x>y || x==y -instance Eq Float64 - def (==) (x, y) = w8_to_b $ %feq(x, y) +instance Eq(Float64) + def (==)(x, y) = w8_to_b $ %feq(x, y) -instance Eq Float32 - def (==) (x, y) = w8_to_b $ %feq(x, y) +instance Eq(Float32) + def (==)(x, y) = w8_to_b $ %feq(x, y) -instance Eq Int64 - def (==) (x, y) = w8_to_b $ %ieq(x, y) +instance Eq(Int64) + def (==)(x, y) = w8_to_b $ %ieq(x, y) -instance Eq Int32 - def (==) (x, y) = w8_to_b $ %ieq(x, y) +instance Eq(Int32) + def (==)(x, y) = w8_to_b $ %ieq(x, y) -instance Eq Word8 - def (==) (x, y) = w8_to_b $ %ieq(x, y) +instance Eq(Word8) + def (==)(x, y) = w8_to_b $ %ieq(x, y) -instance Eq Word32 - def (==) (x, y) = w8_to_b $ %ieq(x, y) +instance Eq(Word32) + def (==)(x, y) = w8_to_b $ %ieq(x, y) -instance Eq Word64 - def (==) (x, y) = w8_to_b $ %ieq(x, y) +instance Eq(Word64) + def (==)(x, y) = w8_to_b $ %ieq(x, y) -instance Eq Bool - def (==) (x, y) = b_to_w8 x == b_to_w8 y +instance Eq(Bool) + def (==)(x, y) = b_to_w8 x == b_to_w8 y -instance Eq () - def (==) (_, _) = True +instance Eq(()) + def (==)(_, _) = True -instance Eq (Either(a, b)) given (a|Eq, b|Eq) - def (==) (x, y) = case x of +instance Eq(Either(a, b)) given (a|Eq, b|Eq) + def (==)(x, y) = case x of Left(x) -> case y of Left( y) -> x == y Right(y) -> False @@ -794,8 +794,8 @@ instance Eq (Either(a, b)) given (a|Eq, b|Eq) Left( y) -> False Right(y) -> x == y -instance Eq (Maybe a) given (a|Eq) - def (==) (x, y) = case x of +instance Eq(Maybe a) given (a|Eq) + def (==)(x, y) = case x of Just(x) -> case y of Just(y) -> x == y Nothing -> False @@ -803,85 +803,85 @@ instance Eq (Maybe a) given (a|Eq) Just(y) -> False Nothing -> True -instance Eq RawPtr - def (==) (x, y) = raw_ptr_to_i64 x == raw_ptr_to_i64 y +instance Eq(RawPtr) + def (==)(x, y) = raw_ptr_to_i64 x == raw_ptr_to_i64 y -instance Ord Float64 - def (>) (x, y) = w8_to_b $ %fgt(x, y) - def (<) (x, y) = w8_to_b $ %flt(x, y) +instance Ord(Float64) + def (>)(x, y) = w8_to_b $ %fgt(x, y) + def (<)(x, y) = w8_to_b $ %flt(x, y) -instance Ord Float32 - def (>) (x, y) = w8_to_b $ %fgt(x, y) - def (<) (x, y) = w8_to_b $ %flt(x, y) +instance Ord(Float32) + def (>)(x, y) = w8_to_b $ %fgt(x, y) + def (<)(x, y) = w8_to_b $ %flt(x, y) -instance Ord Int64 - def (>) (x, y) = w8_to_b $ %igt(x, y) - def (<) (x, y) = w8_to_b $ %ilt(x, y) +instance Ord(Int64) + def (>)(x, y) = w8_to_b $ %igt(x, y) + def (<)(x, y) = w8_to_b $ %ilt(x, y) -instance Ord Int32 - def (>) (x, y) = w8_to_b $ %igt(x, y) - def (<) (x, y) = w8_to_b $ %ilt(x, y) +instance Ord(Int32) + def (>)(x, y) = w8_to_b $ %igt(x, y) + def (<)(x, y) = w8_to_b $ %ilt(x, y) -instance Ord Word8 - def (>) (x, y) = w8_to_b $ %igt(x, y) - def (<) (x, y) = w8_to_b $ %ilt(x, y) +instance Ord(Word8) + def (>)(x, y) = w8_to_b $ %igt(x, y) + def (<)(x, y) = w8_to_b $ %ilt(x, y) -instance Ord Word32 - def (>) (x, y) = w8_to_b $ %igt(x, y) - def (<) (x, y) = w8_to_b $ %ilt(x, y) +instance Ord(Word32) + def (>)(x, y) = w8_to_b $ %igt(x, y) + def (<)(x, y) = w8_to_b $ %ilt(x, y) -instance Ord Word64 - def (>) (x, y) = w8_to_b $ %igt(x, y) - def (<) (x, y) = w8_to_b $ %ilt(x, y) +instance Ord(Word64) + def (>)(x, y) = w8_to_b $ %igt(x, y) + def (<)(x, y) = w8_to_b $ %ilt(x, y) -instance Ord () - def (>) (x, y) = False - def (<) (x, y) = False +instance Ord(()) + def (>)(x, y) = False + def (<)(x, y) = False -instance Eq (a, b) given (a|Eq, b|Eq) - def (==) (p1, p2) = +instance Eq((a, b)) given (a|Eq, b|Eq) + def (==)(p1, p2) = (x1, y1) = p1 (x2, y2) = p2 x1 == x2 && y1 == y2 -instance Ord (a, b) given (a|Ord, b|Ord) - def (>) (p1, p2) = +instance Ord((a, b)) given (a|Ord, b|Ord) + def (>)(p1, p2) = (x1, y1) = p1 (x2, y2) = p2 x1 > x2 || (x1 == x2 && y1 > y2) - def (<) (p1, p2) = + def (<)(p1, p2) = (x1, y1) = p1 (x2, y2) = p2 x1 < x2 || (x1 == x2 && y1 < y2) -instance Eq Ordering - def (==) (x, y) = o_to_w8 x == o_to_w8 y +instance Eq(Ordering) + def (==)(x, y) = o_to_w8 x == o_to_w8 y -instance Eq Nat - def (==) (x, y) = nat_to_rep x == nat_to_rep y +instance Eq(Nat) + def (==)(x, y) = nat_to_rep x == nat_to_rep y -instance Ord Nat - def (>) (x, y) = nat_to_rep x > nat_to_rep y - def (<) (x, y) = nat_to_rep x < nat_to_rep y +instance Ord(Nat) + def (>)(x, y) = nat_to_rep x > nat_to_rep y + def (<)(x, y) = nat_to_rep x < nat_to_rep y -- TODO: we want Eq and Ord for all index sets, not just `Fin n` -instance Eq (Fin n) given (n) - def (==) (x, y) = ordinal x == ordinal y +instance Eq(Fin n) given (n) + def (==)(x, y) = ordinal x == ordinal y -instance Ord (Fin n) given (n) - def (>) (x, y) = ordinal x > ordinal y - def (<) (x, y) = ordinal x < ordinal y +instance Ord(Fin n) given (n) + def (>)(x, y) = ordinal x > ordinal y + def (<)(x, y) = ordinal x < ordinal y -instance Ix Bool - def size' () = 2 - def ordinal (b) = case b of +instance Ix(Bool) + def size'() = 2 + def ordinal(b) = case b of False -> 0 True -> 1 - def unsafe_from_ordinal (i) = i > 0 + def unsafe_from_ordinal(i) = i > 0 -instance Ix (Maybe a) given (a|Ix) - def size' () = size a + 1 - def ordinal (i) = case i of +instance Ix(Maybe a) given (a|Ix) + def size'() = size a + 1 + def ordinal(i) = case i of Just(ai) -> ordinal ai Nothing -> size a def unsafe_from_ordinal(o) = @@ -889,13 +889,13 @@ instance Ix (Maybe a) given (a|Ix) False -> Just $ unsafe_from_ordinal o True -> Nothing -instance NonEmpty Bool +instance NonEmpty(Bool) first_ix = unsafe_from_ordinal 0 -instance NonEmpty (a,b) given (a|NonEmpty, b|NonEmpty) +instance NonEmpty((a,b)) given (a|NonEmpty, b|NonEmpty) first_ix = unsafe_from_ordinal 0 -instance NonEmpty (Either(a,b)) given (a|NonEmpty, b|Ix) +instance NonEmpty(Either(a,b)) given (a|NonEmpty, b|Ix) first_ix = unsafe_from_ordinal 0 -- The below instance is valid, but causes "multiple candidate dictionaries" @@ -903,7 +903,7 @@ instance NonEmpty (Either(a,b)) given (a|NonEmpty, b|Ix) -- instance NonEmpty (a|b) given {a b} [Ix a, NonEmpty b] -- first_ix = unsafe_from_ordinal _ 0 -instance NonEmpty (Maybe a) given (a|Ix) +instance NonEmpty(Maybe a) given (a|Ix) first_ix = unsafe_from_ordinal 0 def scan( @@ -916,42 +916,42 @@ def scan( s := c' y -def fold (init:a, body:(n,a)->a) -> a given (n|Ix, a|Data) = +def fold(init:a, body:(n,a)->a) -> a given (n|Ix, a|Data) = fst $ scan init \i x. (body(i, x), ()) -def compare (x:a, y:a) -> Ordering given (a|Ord) = +def compare(x:a, y:a) -> Ordering given (a|Ord) = if x < y then LT else if x == y then EQ else GT -instance Monoid Ordering +instance Monoid(Ordering) mempty = EQ - def (<>) (x, y) = + def (<>)(x, y) = case x of LT -> LT GT -> GT EQ -> y -instance Eq (n=>a) given (n|Ix, a|Eq) - def (==) (xs, ys) = +instance Eq(n=>a) given (n|Ix, a|Eq) + def (==)(xs, ys) = yield_accum AndMonoid \ref. for i. ref += xs[i] == ys[i] -instance Ord (n=>a) given (n|Ix, a|Ord) - def (>) (xs, ys) = +instance Ord(n=>a) given (n|Ix, a|Ord) + def (>)(xs, ys) = f: Ordering = fold EQ $ \i c. c <> compare(xs[i], ys[i]) f == GT - def (<) (xs, ys) = + def (<)(xs, ys) = f: Ordering = fold EQ $ \i c. c <> compare(xs[i], ys[i]) f == LT '## Subset class -interface Subset (subset, superset) +interface Subset(subset, superset) inject : (subset) -> superset project : (superset) -> Maybe subset unsafe_project : (superset) -> subset @@ -963,34 +963,34 @@ instance Subset(a, c) given (a, b, c) (Subset(a, b), Subset(b, c)) Just(y)-> project y def unsafe_project(x) = unsafe_project $ unsafe_project(subset=b, x) -def unsafe_project_rangefrom (j:q) -> RangeFrom(q, i) given (q|Ix, i:q) = +def unsafe_project_rangefrom(j:q) -> RangeFrom(q, i) given (q|Ix, i:q) = RangeFrom.new unsafe_nat_diff(ordinal j, ordinal i) instance Subset(RangeFrom(q, i), q) given (q|Ix, i:q) - def inject (j) = + def inject(j) = unsafe_from_ordinal $ j.val + ordinal i - def project (j) = + def project(j) = j' = ordinal j i' = ordinal i if j' < i' then Nothing else Just $ RangeFrom.new $ unsafe_nat_diff(j', i') - def unsafe_project (j) = RangeFrom.new unsafe_nat_diff(ordinal j, ordinal i) + def unsafe_project(j) = RangeFrom.new unsafe_nat_diff(ordinal j, ordinal i) instance Subset(RangeFromExc(q, i), q) given (q|Ix, i:q) - def inject (j) = unsafe_from_ordinal $ j.val + ordinal i + 1 - def project (j) = + def inject(j) = unsafe_from_ordinal $ j.val + ordinal i + 1 + def project(j) = j' = ordinal j i' = ordinal i if j' <= i' then Nothing else Just $ RangeFromExc.new unsafe_nat_diff(j', i' + 1) - def unsafe_project (j) = + def unsafe_project(j) = RangeFromExc.new unsafe_nat_diff(ordinal j, ordinal i + 1) instance Subset(RangeTo(q, i), q) given (q|Ix, i:q) - def inject (j) = unsafe_from_ordinal j.val - def project (j) = + def inject(j) = unsafe_from_ordinal j.val + def project(j) = j' = ordinal j i' = ordinal i if j' > i' @@ -999,24 +999,24 @@ instance Subset(RangeTo(q, i), q) given (q|Ix, i:q) def unsafe_project(j) = RangeTo.new (ordinal j) instance Subset(RangeToExc(q, i), q) given (q|Ix, i:q) - def inject (j) = unsafe_from_ordinal j.val - def project (j) = + def inject(j) = unsafe_from_ordinal j.val + def project(j) = j' = ordinal j i' = ordinal i if j' >= i' then Nothing else Just $ RangeToExc.new j' - def unsafe_project (j) = RangeToExc.new (ordinal j) + def unsafe_project(j) = RangeToExc.new (ordinal j) instance Subset(RangeToExc(q, i), RangeTo(q, i)) given (q|Ix, i:q) - def inject (j) = unsafe_from_ordinal j.val - def project (j) = + def inject(j) = unsafe_from_ordinal j.val + def project(j) = j' = ordinal j i' = ordinal i if j' >= i' then Nothing else Just $ RangeToExc.new j' - def unsafe_project (j) = RangeToExc.new (ordinal j) + def unsafe_project(j) = RangeToExc.new (ordinal j) '## Elementary/Special Functions This is more or less the standard [LibM fare](https://en.wikipedia.org/wiki/C_mathematical_functions). @@ -1024,7 +1024,7 @@ Roughly it lines up with some definitions of the set of [Elementary](https://en. In truth, nothing is elementary or special except that we humans have decided it is. Many, but not all of these functions are [Transcendental](https://en.wikipedia.org/wiki/Transcendental_function). -interface Floating (a:Type) +interface Floating(a:Type) exp : (a) -> a exp2 : (a) -> a log : (a) -> a @@ -1046,124 +1046,124 @@ interface Floating (a:Type) erf : (a) -> a erfc : (a) -> a -def lbeta (x:a, y:a) -> a given (a|Sub|Floating) = lgamma x + lgamma y - lgamma (x + y) +def lbeta(x:a, y:a) -> a given (a|Sub|Floating) = lgamma x + lgamma y - lgamma (x + y) -- Todo: better numerics for very large and small values. -- Using %exp here to avoid circular definition problems. -def float32_sinh (x:Float32) -> Float32 = %fdiv(%fsub(%exp(x), %exp(%fsub(0.0,x))), 2.0) -def float32_cosh (x:Float32) -> Float32 = %fdiv(%fadd(%exp(x), %exp(%fsub(0.0,x))), 2.0) -def float32_tanh (x:Float32) -> Float32 = %fdiv(%fsub(%exp(x), %exp(%fsub(0.0,x))) +def float32_sinh(x:Float32) -> Float32 = %fdiv(%fsub(%exp(x), %exp(%fsub(0.0,x))), 2.0) +def float32_cosh(x:Float32) -> Float32 = %fdiv(%fadd(%exp(x), %exp(%fsub(0.0,x))), 2.0) +def float32_tanh(x:Float32) -> Float32 = %fdiv(%fsub(%exp(x), %exp(%fsub(0.0,x))) ,%fadd(%exp(x), %exp(%fsub(0.0,x)))) -- Todo: unify this with float32 functions. -def float64_sinh (x:Float64) -> Float64 = %fdiv(%fsub(%exp(x), %exp(%fsub(f_to_f64 0.0, x))), f_to_f64 2.0) -def float64_cosh (x:Float64) -> Float64 = %fdiv(%fadd(%exp(x), %exp(%fsub(f_to_f64 0.0, x))), f_to_f64 2.0) -def float64_tanh (x:Float64) -> Float64 = %fdiv(%fsub(%exp(x), %exp(%fsub(f_to_f64 0.0, x))) +def float64_sinh(x:Float64) -> Float64 = %fdiv(%fsub(%exp(x), %exp(%fsub(f_to_f64 0.0, x))), f_to_f64 2.0) +def float64_cosh(x:Float64) -> Float64 = %fdiv(%fadd(%exp(x), %exp(%fsub(f_to_f64 0.0, x))), f_to_f64 2.0) +def float64_tanh(x:Float64) -> Float64 = %fdiv(%fsub(%exp(x), %exp(%fsub(f_to_f64 0.0, x))) ,%fadd(%exp(x), %exp(%fsub(f_to_f64 0.0, x)))) -instance Floating Float64 - def exp (x) = %exp (x) - def exp2 (x) = %exp2 (x) - def log (x) = %log (x) - def log2 (x) = %log2 (x) - def log10 (x) = %log10 (x) - def log1p (x) = %log1p (x) - def sin (x) = %sin (x) - def cos (x) = %cos (x) - def tan (x) = %tan (x) - def sinh (x) = float64_sinh x - def cosh (x) = float64_cosh x - def tanh (x) = float64_tanh x - def floor (x) = %floor (x) - def ceil (x) = %ceil (x) - def round (x) = %round (x) - def sqrt (x) = %sqrt (x) - def pow (x, y) = %fpow(x, y) - def lgamma (x) = %lgamma (x) - def erf (x) = %erf (x) - def erfc (x) = %erfc (x) - -instance Floating Float32 - def exp (x) = %exp (x) - def exp2 (x) = %exp2 (x) - def log (x) = %log (x) - def log2 (x) = %log2 (x) - def log10 (x) = %log10 (x) - def log1p (x) = %log1p (x) - def sin (x) = %sin (x) - def cos (x) = %cos (x) - def tan (x) = %tan (x) - def sinh (x) = float32_sinh x - def cosh (x) = float32_cosh x - def tanh (x) = float32_tanh x - def floor (x) = %floor (x) - def ceil (x) = %ceil (x) - def round (x) = %round (x) - def sqrt (x) = %sqrt (x) - def pow (x, y) = %fpow(x, y) - def lgamma (x) = %lgamma (x) - def erf (x) = %erf (x) - def erfc (x) = %erfc (x) +instance Floating(Float64) + def exp(x) = %exp(x) + def exp2(x) = %exp2(x) + def log(x) = %log(x) + def log2(x) = %log2(x) + def log10(x) = %log10(x) + def log1p(x) = %log1p(x) + def sin(x) = %sin(x) + def cos(x) = %cos(x) + def tan(x) = %tan( x) + def sinh(x) = float64_sinh(x) + def cosh(x) = float64_cosh(x) + def tanh(x) = float64_tanh(x) + def floor(x) = %floor(x) + def ceil(x) = %ceil(x) + def round(x) = %round(x) + def sqrt(x) = %sqrt(x) + def pow(x,y) = %fpow(x,y) + def lgamma(x)= %lgamma(x) + def erf(x) = %erf(x) + def erfc(x) = %erfc(x) + +instance Floating(Float32) + def exp(x) = %exp(x) + def exp2(x) = %exp2(x) + def log(x) = %log(x) + def log2(x) = %log2(x) + def log10(x) = %log10(x) + def log1p(x) = %log1p(x) + def sin(x) = %sin(x) + def cos(x) = %cos(x) + def tan(x) = %tan(x) + def sinh(x) = float32_sinh(x) + def cosh(x) = float32_cosh(x) + def tanh(x) = float32_tanh(x) + def floor(x) = %floor(x) + def ceil(x) = %ceil(x) + def round(x) = %round(x) + def sqrt(x) = %sqrt(x) + def pow(x,y) = %fpow(x, y) + def lgamma(x)= %lgamma(x) + def erf(x) = %erf(x) + def erfc(x) = %erfc(x) '## Raw pointer operations -struct Ptr (a:Type) = +struct Ptr(a:Type) = val : RawPtr -def cast_ptr (ptr: Ptr a) -> Ptr b given (a, b) = Ptr.new(ptr.val) +def cast_ptr(ptr: Ptr a) -> Ptr b given (a, b) = Ptr.new(ptr.val) -interface Storable (a|Data) +interface Storable(a|Data) store : (Ptr a, a) -> {IO} () load : (Ptr a) -> {IO} a storage_size : () -> Nat -instance Storable Word8 - def store (ptr, x) = %ptrStore(ptr.val, x) - def load (ptr) = %ptrLoad (ptr.val) - def storage_size () = 1 +instance Storable(Word8) + def store(ptr, x) = %ptrStore(ptr.val, x) + def load(ptr) = %ptrLoad(ptr.val) + def storage_size() = 1 -instance Storable Int32 - def store (ptr, x) = %ptrStore(internal_cast(to=%Int32Ptr(), ptr.val), x) - def load (ptr) = %ptrLoad (internal_cast(to=%Int32Ptr(), ptr.val)) - def storage_size () = 4 +instance Storable(Int32) + def store(ptr, x) = %ptrStore(internal_cast(to=%Int32Ptr(), ptr.val), x) + def load(ptr) = %ptrLoad(internal_cast(to=%Int32Ptr(), ptr.val)) + def storage_size() = 4 -instance Storable Word32 - def store (ptr, x) = %ptrStore(internal_cast(to=%Word32Ptr(), ptr), x) - def load (ptr) = %ptrLoad (internal_cast(to=%Word32Ptr(), ptr)) - def storage_size () = 4 +instance Storable(Word32) + def store(ptr, x) = %ptrStore(internal_cast(to=%Word32Ptr(), ptr), x) + def load(ptr) = %ptrLoad(internal_cast(to=%Word32Ptr(), ptr)) + def storage_size() = 4 -instance Storable Float32 - def store (ptr, x) = %ptrStore (internal_cast(to=%Float32Ptr(), ptr.val), x) - def load (ptr) = %ptrLoad (internal_cast(to=%Float32Ptr(), ptr.val)) - def storage_size () = 4 +instance Storable(Float32) + def store(ptr, x) = %ptrStore(internal_cast(to=%Float32Ptr(), ptr.val), x) + def load(ptr) = %ptrLoad(internal_cast(to=%Float32Ptr(), ptr.val)) + def storage_size() = 4 -instance Storable Nat - def store (ptr, x) = store(Ptr.new(ptr.val), nat_to_rep x) - def load (ptr) = rep_to_nat $ load(Ptr.new(ptr.val)) - def storage_size () = storage_size(a=NatRep) +instance Storable(Nat) + def store(ptr, x) = store(Ptr.new(ptr.val), nat_to_rep x) + def load(ptr) = rep_to_nat $ load(Ptr.new(ptr.val)) + def storage_size() = storage_size(a=NatRep) -instance Storable (Ptr a) given (a) - def store (ptr, x) = %ptrStore (internal_cast(to=%PtrPtr(), ptr.val), x.val) - def load (ptr) = Ptr.new(%ptrLoad(internal_cast(to=%PtrPtr(), ptr))) - def storage_size () = 8 -- TODO: something more portable? +instance Storable(Ptr a) given (a) + def store(ptr, x) = %ptrStore(internal_cast(to=%PtrPtr(), ptr.val), x.val) + def load(ptr) = Ptr.new(%ptrLoad(internal_cast(to=%PtrPtr(), ptr))) + def storage_size() = 8 -- TODO: something more portable? -- TODO: Storable instances for other types -def malloc (n:Nat) -> {IO} (Ptr a) given (a|Storable) = +def malloc(n:Nat) -> {IO} (Ptr a) given (a|Storable) = numBytes = storage_size(a=a) * n - Ptr.new(%alloc (nat_to_rep numBytes)) + Ptr.new(%alloc(nat_to_rep numBytes)) -def free (ptr:Ptr a) -> {IO} () given (a) = %free(ptr.val) +def free(ptr:Ptr a) -> {IO} () given (a) = %free(ptr.val) -def (+>>) (ptr:Ptr a, i:Nat) -> Ptr a given (a|Storable) = +def (+>>)(ptr:Ptr a, i:Nat) -> Ptr a given (a|Storable) = i' = nat_to_rep $ i * storage_size(a=a) Ptr.new(%ptrOffset(ptr.val, i')) -- TODO: consider making a Storable instance for tables instead -def store_table (ptr: Ptr a, tab:n=>a) -> {IO} () given (a|Storable, n|Ix) = +def store_table(ptr: Ptr a, tab:n=>a) -> {IO} () given (a|Storable, n|Ix) = for_ i. store(ptr +>> ordinal i, tab[i]) -def memcpy (dest:Ptr a, src:Ptr a, n:Nat) -> {IO} () given (a|Storable) = +def memcpy(dest:Ptr a, src:Ptr a, n:Nat) -> {IO} () given (a|Storable) = for_ i:(Fin n). i' = ordinal i store(dest +>> i', load $ src +>> i') @@ -1187,74 +1187,75 @@ def with_table_ptr( for i. store(ptr +>> ordinal i, xs[i]) action ptr -def table_from_ptr (ptr:Ptr a) -> {IO} n=>a given (a|Storable, n|Ix) = +def table_from_ptr(ptr:Ptr a) -> {IO} n=>a given (a|Storable, n|Ix) = for i. load $ ptr +>> ordinal i '## Miscellaneous common utilities pi : Float = 3.141592653589793 -def id (x:a) -> a given (a) = x -def dup (x:a) -> (a, a) given (a) = (x, x) -def map (f:(a)->{|eff} b, xs: n=>a) -> {|eff} (n=>b) given (a, b, n|Ix, eff) = +def id(x:a) -> a given (a) = x +def dup(x:a) -> (a, a) given (a) = (x, x) +def map(f:(a)->{|eff} b, xs: n=>a) -> {|eff} (n=>b) given (a, b, n|Ix, eff) = for i. f xs[i] -- map, flipped so that the function goes last -def each (xs: n=>a, f:(a)->{|eff} b) -> {|eff} (n=>b) given (a, b, n|Ix, eff) = +def each(xs: n=>a, f:(a)->{|eff} b) -> {|eff} (n=>b) given (a, b, n|Ix, eff) = for i. f xs[i] -def zip (xs:n=>a, ys:n=>b) -> (n=>(a,b)) given (a, b, n|Ix) = for i. (xs[i], ys[i]) -def unzip (xys:n=>(a,b)) -> (n=>a , n=>b) given (a, b, n|Ix)= (each xys fst, each xys snd) -def fanout (x:a) -> n=>a given (n|Ix, a) = for i. x -def sq (x:a) -> a given (a|Mul) = x * x -def abs (x:a) -> a given (a|Sub|Ord) = select(x > zero, x, zero - x) -def mod (x:a, y:a) -> a given (a|Add|Integral) = rem(y + rem(x, y), y) +def zip(xs:n=>a, ys:n=>b) -> (n=>(a,b)) given (a, b, n|Ix) = for i. (xs[i], ys[i]) +def unzip(xys:n=>(a,b)) -> (n=>a , n=>b) given (a, b, n|Ix)= (each xys fst, each xys snd) +def fanout(x:a) -> n=>a given (n|Ix, a) = for i. x +def sq(x:a) -> a given (a|Mul) = x * x +def abs(x:a) -> a given (a|Sub|Ord) = select(x > zero, x, zero - x) +def mod(x:a, y:a) -> a given (a|Add|Integral) = rem(y + rem(x, y), y) '## Table Operations -instance Floating (n=>a) given (a|Floating, n|Ix) - def exp (x) = each x exp - def exp2 (x) = each x exp2 - def log (x) = each x log - def log2 (x) = each x log2 - def log10 (x) = each x log10 - def log1p (x) = each x log1p - def sin (x) = each x sin - def cos (x) = each x cos - def tan (x) = each x tan - def sinh (x) = each x sinh - def cosh (x) = each x cosh - def tanh (x) = each x tanh - def floor (x) = each x floor - def ceil (x) = each x ceil - def round (x) = each x round - def sqrt (x) = each x sqrt - def pow (x, y) = for i. pow(x[i], y[i]) - def lgamma (x) = each x lgamma - def erf (x) = each x erf - def erfc (x) = each x erfc +instance Floating(n=>a) given (a|Floating, n|Ix) + def exp(x) = each x exp + def exp2(x) = each x exp2 + def log(x) = each x log + def log2(x) = each x log2 + def log10(x) = each x log10 + def log1p(x) = each x log1p + def sin(x) = each x sin + def cos(x) = each x cos + def tan(x) = each x tan + def sinh(x) = each x sinh + def cosh(x) = each x cosh + def tanh(x) = each x tanh + def floor(x) = each x floor + def ceil(x) = each x ceil + def round(x) = each x round + def sqrt(x) = each x sqrt + def pow(x, y) = for i. pow(x[i], y[i]) + def lgamma(x) = each x lgamma + def erf(x) = each x erf + def erfc(x) = each x erfc '### Reductions -- `combine` should be a commutative and associative, and form a -- commutative monoid with `identity` -def reduce (identity:a, combine:(a,a)->a, xs:n=>a) -> a given (a|Data, n|Ix) = +def reduce(identity:a, combine:(a,a)->a, xs:n=>a) -> a given (a|Data, n|Ix) = -- TODO: implement with the accumulator effect fold identity \i c. combine(c, xs[i]) -- TODO: call this `scan` and call the current `scan` something else -def scan' (init:a, body:(n,a)->a) -> n=>a given (a|Data, n|Ix) = +def scan'(init:a, body:(n,a)->a) -> n=>a given (a|Data, n|Ix) = snd $ scan init \i x. dup(body(i, x)) -def fsum (xs:n=>Float) -> Float given (n|Ix) = yield_accum(AddMonoid Float) \ref. for i. ref += xs[i] -def sum (xs:n=>v) -> v given (n|Ix, v|Add) = reduce(zero, (+), xs) -def prod (xs:n=>v) -> v given (n|Ix, v|Mul) = reduce(one , (*), xs) -def mean (xs:n=>v) -> v given (n|Ix, v|VSpace) = sum xs / n_to_f (size n) -def std (xs:n=>v) -> v given (n|Ix, v|Mul|Sub|VSpace|Floating) = sqrt $ mean (each xs sq) - sq (mean xs) -def any (xs:n=>Bool) -> Bool given (n|Ix) = reduce(False, (||), xs) -def all (xs:n=>Bool) -> Bool given (n|Ix) = reduce(True , (&&), xs) +def fsum(xs:n=>Float) -> Float given (n|Ix) = + yield_accum(AddMonoid Float) \ref. for i. ref += xs[i] +def sum(xs:n=>v) -> v given (n|Ix, v|Add) = reduce(zero, (+), xs) +def prod(xs:n=>v) -> v given (n|Ix, v|Mul) = reduce(one , (*), xs) +def mean(xs:n=>v) -> v given (n|Ix, v|VSpace) = sum xs / n_to_f (size n) +def std(xs:n=>v) -> v given (n|Ix, v|Mul|Sub|VSpace|Floating) = sqrt $ mean (each xs sq) - sq (mean xs) +def any(xs:n=>Bool) -> Bool given (n|Ix) = reduce(False, (||), xs) +def all(xs:n=>Bool) -> Bool given (n|Ix) = reduce(True , (&&), xs) '### apply_n -def apply_n (n:Nat, x:a, f:(a) -> a) -> a given (a|Data) = +def apply_n(n:Nat, x:a, f:(a) -> a) -> a given (a|Data) = yield_state x \ref. for _:(Fin n). ref := f (get ref) @@ -1262,14 +1263,14 @@ def apply_n (n:Nat, x:a, f:(a) -> a) -> a given (a|Data) = TODO: Move this to be with reductions? It's a kind of `scan`. -def cumsum (xs: n=>a) -> n=>a given (n|Ix, a|Add) = +def cumsum(xs: n=>a) -> n=>a given (n|Ix, a|Add) = total <- with_state zero for i. newTotal = get total + xs[i] total := newTotal newTotal -def cumsum_low (xs: n=>a) -> n=>a given (n|Ix, a|Add) = +def cumsum_low(xs: n=>a) -> n=>a given (n|Ix, a|Add) = total <- with_state zero for i. oldTotal = get total @@ -1281,14 +1282,14 @@ def cumsum_low (xs: n=>a) -> n=>a given (n|Ix, a|Add) = '### AD operations -- TODO: add vector space constraints -def linearize (f:(a)->b, x:a) -> (b, (a)->b) given (a, b) = +def linearize(f:(a)->b, x:a) -> (b, (a)->b) given (a, b) = %linearize(\x. f x, x) -def jvp (f:(a)->b, x:a, t:a) -> b given (a, b) = (snd $ linearize(f, x))(t) -def transpose_linear (f:(a)->b) -> (b)->a given (a, b) = \ct. +def jvp(f:(a)->b, x:a, t:a) -> b given (a, b) = (snd $ linearize(f, x))(t) +def transpose_linear(f:(a)->b) -> (b)->a given (a, b) = \ct. %linearTranspose(\x. f x, ct) -def vjp (f:(a)->b, x:a) -> (b, (b)->a) given (a, b) = +def vjp(f:(a)->b, x:a) -> (b, (b)->a) given (a, b) = (y, df) = linearize(f, x) (y, transpose_linear df) @@ -1319,61 +1320,61 @@ interface HasDefaultTolerance(a) default_atol : a default_rtol : a -def (~~) (x:a, y:a) -> Bool given (a|HasAllClose|HasDefaultTolerance) = +def (~~)(x:a, y:a) -> Bool given (a|HasAllClose|HasDefaultTolerance) = allclose(default_atol, default_rtol, x, y) -instance HasAllClose Float32 - def allclose (atol, rtol, x, y) = abs (x - y) <= (atol + rtol * abs y) +instance HasAllClose(Float32) + def allclose(atol, rtol, x, y) = abs (x - y) <= (atol + rtol * abs y) -instance HasAllClose Float64 - def allclose (atol, rtol, x, y) = abs (x - y) <= (atol + rtol * abs y) +instance HasAllClose(Float64) + def allclose(atol, rtol, x, y) = abs (x - y) <= (atol + rtol * abs y) -instance HasDefaultTolerance Float32 +instance HasDefaultTolerance(Float32) default_atol = f_to_f32 0.00001 default_rtol = f_to_f32 0.0001 -instance HasDefaultTolerance Float64 +instance HasDefaultTolerance(Float64) default_atol = f_to_f64 0.00000001 default_rtol = f_to_f64 0.00001 -instance HasAllClose (a, b) given ( a|HasDefaultTolerance|HasAllClose - , b|HasDefaultTolerance|HasAllClose) - def allclose (atol, rtol, pair1, pair2) = +instance HasAllClose((a, b)) given ( a|HasDefaultTolerance|HasAllClose + , b|HasDefaultTolerance|HasAllClose) + def allclose(atol, rtol, pair1, pair2) = (x1, x2) = pair1 (y1, y2) = pair2 (x1 ~~ y1) && (x2 ~~ y2) -instance HasDefaultTolerance (a, b) given (a|HasDefaultTolerance,b|HasDefaultTolerance) +instance HasDefaultTolerance((a, b)) given (a|HasDefaultTolerance,b|HasDefaultTolerance) default_atol = (default_atol, default_atol) default_rtol = (default_rtol, default_rtol) -instance HasAllClose (n=>t) given (n|Ix, t|HasAllClose) - def allclose (atol, rtol, a, b) = +instance HasAllClose(n=>t) given (n|Ix, t|HasAllClose) + def allclose(atol, rtol, a, b) = all for i:n. allclose(atol[i], rtol[i], a[i], b[i]) -instance HasDefaultTolerance (n=>t) given (n|Ix, t|HasDefaultTolerance) +instance HasDefaultTolerance(n=>t) given (n|Ix, t|HasDefaultTolerance) default_atol = for i. default_atol default_rtol = for i. default_rtol '### AD Checking tools -def check_deriv_base (f:(Float)->Float, x:Float) -> Bool = +def check_deriv_base(f:(Float)->Float, x:Float) -> Bool = eps = 0.01 ansFwd = deriv( f, x) ansRev = deriv_rev( f, x) ansNumeric = (f (x + eps) - f (x - eps)) / (2. * eps) ansFwd ~~ ansNumeric && ansRev ~~ ansNumeric -def check_deriv (f:(Float)->Float, x:Float) -> Bool = +def check_deriv(f:(Float)->Float, x:Float) -> Bool = check_deriv_base(f, x) && check_deriv_base(\x. deriv(f, x), x) '## Length-erased lists data List(a)= - AsList (n:Nat, elements:(Fin n => a)) + AsList(n:Nat, elements:(Fin n => a)) -instance Eq (List a) given (a|Eq) - def (==) (xsList, ysList) = +instance Eq(List a) given (a|Eq) + def (==)(xsList, ysList) = AsList(nx,xs) = xsList AsList(ny,ys) = ysList if nx /= ny @@ -1381,16 +1382,16 @@ instance Eq (List a) given (a|Eq) else all for i:(Fin nx). xs[i] == ys[unsafe_from_ordinal (ordinal i)] -def unsafe_cast_table (xs:from=>a) -> to=>a given (to|Ix, from|Ix, a) = +def unsafe_cast_table(xs:from=>a) -> to=>a given (to|Ix, from|Ix, a) = for i. xs[unsafe_from_ordinal (ordinal i)] -def to_list (xs:n=>a) -> List a given (n|Ix, a) = +def to_list(xs:n=>a) -> List a given (n|Ix, a) = n' = size n AsList(_, unsafe_cast_table(to=Fin n', xs)) -instance Monoid (List a) given (a|Data) +instance Monoid(List a) given (a|Data) mempty = AsList(_, []) - def (<>) (x, y) = + def (<>)(x, y) = AsList(nx,xs) = x AsList(ny,ys) = y nz = nx + ny @@ -1400,30 +1401,30 @@ instance Monoid (List a) given (a|Data) True -> xs[unsafe_from_ordinal i'] False -> ys[unsafe_from_ordinal $ unsafe_nat_diff(i', nx)] -named-instance ListMonoid (a|Data) -> Monoid (List a) +named-instance ListMonoid (a|Data) -> Monoid(List a) mempty = mempty - def (<>) (x, y) = x <> y + def (<>)(x, y) = x <> y -- TODO Eliminate or reimplement this operation, since it costs O(n) -- where n is the length of the list held in the reference. -def append (list: Ref(h, List a), x:a) -> {Accum h} () +def append(list: Ref(h, List a), x:a) -> {Accum h} () given (a|Data, h) (AccumMonoid(h, List a)) = list += to_list [x] -- TODO: replace `slice` with this? -def post_slice (xs:n=>a, start:Post n, end:Post n) -> List a given (n|Ix, a) = +def post_slice(xs:n=>a, start:Post n, end:Post n) -> List a given (n|Ix, a) = slice_size = unsafe_nat_diff(ordinal end, ordinal start) to_list for i:(Fin slice_size). xs[unsafe_from_ordinal(n=n, ordinal i + ordinal start)] '## Dynamic buffer -struct DynBuffer (a) = +struct DynBuffer(a) = size : Ptr Nat max_size : Ptr Nat buffer : Ptr (Ptr a) -def with_dynamic_buffer (action: (DynBuffer a) -> {IO} b) -> {IO} b given (a|Storable, b) = +def with_dynamic_buffer(action: (DynBuffer a) -> {IO} b) -> {IO} b given (a|Storable, b) = initMaxSize = 256 sizePtr <- with_alloc 1 store(sizePtr, 0) @@ -1435,7 +1436,7 @@ def with_dynamic_buffer (action: (DynBuffer a) -> {IO} b) -> {IO} b given (a|Sto free $ load bufferPtr result -def maybe_increase_buffer_size (db: DynBuffer a, sizeDelta:Nat) -> {IO} () given (a|Storable) = +def maybe_increase_buffer_size(db: DynBuffer a, sizeDelta:Nat) -> {IO} () given (a|Storable) = size = load db.size max_size = load db.max_size bufPtr = load db.buffer @@ -1449,10 +1450,10 @@ def maybe_increase_buffer_size (db: DynBuffer a, sizeDelta:Nat) -> {IO} () given store(db.max_size, newMaxSize) store(db.buffer , newBufPtr) -def add_at_nat_ptr (ptr: Ptr Nat, n:Nat) -> {IO} () = +def add_at_nat_ptr(ptr: Ptr Nat, n:Nat) -> {IO} () = store(ptr, load ptr + n) -def extend_dynamic_buffer (buf: DynBuffer a, new:List a) -> {IO} () given (a|Storable) = +def extend_dynamic_buffer(buf: DynBuffer a, new:List a) -> {IO} () given (a|Storable) = AsList(n, xs) = new maybe_increase_buffer_size(buf, n) bufPtr = load buf.buffer @@ -1460,23 +1461,23 @@ def extend_dynamic_buffer (buf: DynBuffer a, new:List a) -> {IO} () given (a|St store_table(bufPtr +>> size, xs) add_at_nat_ptr(buf.size, n) -def load_dynamic_buffer (buf: DynBuffer a) -> {IO} (List a) given (a|Storable) = +def load_dynamic_buffer(buf: DynBuffer a) -> {IO} (List a) given (a|Storable) = bufPtr = load buf.buffer size = load buf.size AsList(size, table_from_ptr bufPtr) -def push_dynamic_buffer (buf: DynBuffer a, x:a) -> {IO} () given (a|Storable) = +def push_dynamic_buffer(buf: DynBuffer a, x:a) -> {IO} () given (a|Storable) = extend_dynamic_buffer(buf, to_list [x]) '## Strings and Characters String : Type = List Char -def string_from_char_ptr (n:Word32, ptr:Ptr Char) -> {IO} String = +def string_from_char_ptr(n:Word32, ptr:Ptr Char) -> {IO} String = AsList(rep_to_nat n, table_from_ptr ptr) -- TODO. This is ASCII code point. It really should be Int32 for Unicode codepoint -def codepoint (c:Char) -> Int = w8_to_i c +def codepoint(c:Char) -> Int = w8_to_i c struct CString = ptr : RawPtr @@ -1488,6 +1489,7 @@ def with_c_string( ) -> {IO} a given (a) = AsList(n, s') = s <> "\NUL" with_table_ptr s' \ptr. action $ CString.new(ptr.val) + '### Show interface For things that can be shown. `show` gives a string representation of its input. @@ -1495,57 +1497,57 @@ No particular promises are made to exactly what that representation will contain In particular it is **not** promised to be parseable. Nor does it promise a particular level of precision for numeric values. -interface Show (a) +interface Show(a) show : (a) -> String -instance Show String +instance Show(String) def show(x) = x foreign "showInt32" showInt32 : (Int32) -> {IO} (Word32, RawPtr) -instance Show Int32 - def show (x) = unsafe_io \. +instance Show(Int32) + def show(x) = unsafe_io \. (n, ptr) = showInt32 x string_from_char_ptr(n, Ptr.new ptr) foreign "showInt64" showInt64 : (Int64) -> {IO} (Word32, RawPtr) -instance Show Int64 - def show (x) = unsafe_io \. +instance Show(Int64) + def show(x) = unsafe_io \. (n, ptr) = showInt64 x string_from_char_ptr(n, Ptr.new ptr) -instance Show Nat - def show (x) = show $ n_to_i64 x +instance Show(Nat) + def show(x) = show $ n_to_i64 x foreign "showFloat32" showFloat32 : (Float32) -> {IO} (Word32, RawPtr) -instance Show Float32 - def show (x) = unsafe_io \. +instance Show(Float32) + def show(x) = unsafe_io \. (n, ptr) = showFloat32 x string_from_char_ptr(n, Ptr.new ptr) foreign "showFloat64" showFloat64 : (Float64) -> {IO} (Word32, RawPtr) -instance Show Float64 - def show (x) = unsafe_io \. +instance Show(Float64) + def show(x) = unsafe_io \. (n, ptr) = showFloat64 x string_from_char_ptr(n, Ptr.new ptr) -instance Show () +instance Show(()) def show(_) = "()" -instance Show (a, b) given (a|Show, b|Show) +instance Show((a, b)) given (a|Show, b|Show) def show(x) = (a, b) = x "(" <> show a <> ", " <> show b <> ")" -instance Show (a, b, c) given (a|Show, b|Show, c|Show) +instance Show((a, b, c)) given (a|Show, b|Show, c|Show) def show(x) = (a, b, c) = x "(" <> show a <> ", " <> show b <> ", " <> show c <> ")" -instance Show (a, b, c, d) given (a|Show, b|Show, c|Show, d|Show) +instance Show((a, b, c, d)) given (a|Show, b|Show, c|Show, d|Show) def show(x) = (a, b, c, d) = x "(" <> show a <> ", " <> show b <> ", " <> show c <> ", " <> show d <> ")" @@ -1553,13 +1555,13 @@ instance Show (a, b, c, d) given (a|Show, b|Show, c|Show, d|Show) '### Parse interface For types that can be parsed from a `String`. -interface Parse (a) +interface Parse(a) parseString : (String) -> Maybe a foreign "strtof" strtofFFI : (RawPtr, RawPtr) -> {IO} Float -instance Parse Float - def parseString (str) = unsafe_io \. +instance Parse(Float) + def parseString(str) = unsafe_io \. AsList(str_len, _) = str with_c_string str \cStr. with_alloc 1 \end_ptr:(Ptr (Ptr Char)). @@ -1571,14 +1573,14 @@ instance Parse Float '## Floating-point helper functions TODO: Move these to be with Elementary/Special functions. Or move those to be here. -def sign (x:Float) -> Float = +def sign(x:Float) -> Float = case x > 0.0 of True -> 1.0 False -> case x < 0.0 of True -> -1.0 False -> x -def copysign (a:Float, b:Float) -> Float = +def copysign(a:Float, b:Float) -> Float = case b > 0.0 of True -> a False -> case b < 0.0 of @@ -1590,31 +1592,31 @@ infinity = 1.0 / 0.0 nan = 0.0 / 0.0 -- Todo: use IEEE floating-point builtins. -def isinf (x:Float) -> Bool = (x == infinity) || (x == -infinity) -def isnan (x:Float) -> Bool = not (x >= x && x <= x) +def isinf(x:Float) -> Bool = (x == infinity) || (x == -infinity) +def isnan(x:Float) -> Bool = not (x >= x && x <= x) -- Todo: use IEEE-754R 5.11: Floating Point Comparison Relation cmpUnordered. -def either_is_nan (x:Float, y:Float) -> Bool = (isnan x) || (isnan y) +def either_is_nan(x:Float, y:Float) -> Bool = (isnan x) || (isnan y) '## File system operations FilePath : Type = String -def is_null_raw_ptr (ptr:RawPtr) -> Bool = +def is_null_raw_ptr(ptr:RawPtr) -> Bool = raw_ptr_to_i64 ptr == 0 -def from_nullable_raw_ptr (ptr:RawPtr) -> Maybe (Ptr a) given (a) = +def from_nullable_raw_ptr(ptr:RawPtr) -> Maybe (Ptr a) given (a) = if is_null_raw_ptr ptr then Nothing else Just $ Ptr.new ptr -def c_string_ptr (s:CString) -> Maybe (Ptr Char) = from_nullable_raw_ptr s.ptr +def c_string_ptr(s:CString) -> Maybe (Ptr Char) = from_nullable_raw_ptr s.ptr data StreamMode = ReadMode WriteMode -struct Stream (mode:StreamMode) = +struct Stream(mode:StreamMode) = ptr : RawPtr '### Stream IO @@ -1625,7 +1627,7 @@ foreign "fwrite" fwriteFFI : (RawPtr, Int64, Int64, RawPtr) -> {IO} Int64 foreign "fread" freadFFI : (RawPtr, Int64, Int64, RawPtr) -> {IO} Int64 foreign "fflush" fflushFFI : (RawPtr) -> {IO} Int64 -def fopen (path:String, mode:StreamMode) -> {IO} (Stream mode) = +def fopen(path:String, mode:StreamMode) -> {IO} (Stream mode) = modeStr = case mode of ReadMode -> "r" WriteMode -> "w" @@ -1633,11 +1635,11 @@ def fopen (path:String, mode:StreamMode) -> {IO} (Stream mode) = with_c_string modeStr \cMode. Stream.new $ fopenFFI(cPath.ptr, cMode.ptr) -def fclose (stream:Stream mode) -> {IO} () given (mode) = +def fclose(stream:Stream mode) -> {IO} () given (mode) = fcloseFFI stream.ptr () -def fwrite (stream:Stream WriteMode, s:String) -> {IO} () = +def fwrite(stream:Stream WriteMode, s:String) -> {IO} () = AsList(n, s') = s with_table_ptr s' \ptr. fwriteFFI(ptr.val, i_to_i64 1, n_to_i64 n, stream.ptr) @@ -1647,21 +1649,21 @@ def fwrite (stream:Stream WriteMode, s:String) -> {IO} () = '### Iteration TODO: move this out of the file-system section -def while (body: () -> {|eff} Bool) -> {|eff} () given (eff) = +def while(body: () -> {|eff} Bool) -> {|eff} () given (eff) = body' : () -> {|eff} Word8 = \. b_to_w8 $ body() - %while body' + %while(body') -data IterResult (a|Data) = +data IterResult(a|Data) = Continue Done(a) -- TODO: can we improve effect inference so we don't need this? -def lift_state (ref: Ref(h, c), f:(a) -> {|eff} b, x:a) -> {State h|eff} b +def lift_state(ref: Ref(h, c), f:(a) -> {|eff} b, x:a) -> {State h|eff} b given (a, b, c, h, eff) = f x -- A little iteration combinator -def iter (body: (Nat) -> {|eff} IterResult a) -> {|eff} a given (a|Data, eff) = +def iter(body: (Nat) -> {|eff} IterResult a) -> {|eff} a given (a|Data, eff) = result = yield_state Nothing \resultRef. i <- with_state 0 while \. @@ -1675,18 +1677,18 @@ def iter (body: (Nat) -> {|eff} IterResult a) -> {|eff} a given (a|Data, eff) = Just(ans) -> ans Nothing -> unreachable() -def bounded_iter - (maxIters:Nat, fallback:a) - -> ((Nat) -> {|eff} IterResult a) - -> {|eff} a - given (a|Data, eff) = \body. iter \i. +def bounded_iter( + maxIters:Nat, + fallback:a, + body:(Nat) -> {|eff} IterResult a + ) -> {|eff} a given (a|Data, eff) = iter \i. if i >= maxIters then Done fallback else body i '### Environment Variables -def from_c_string (s:CString) -> {IO} (Maybe String) = +def from_c_string(s:CString) -> {IO} (Maybe String) = case c_string_ptr s of Nothing -> Nothing Just(ptr) -> @@ -1702,16 +1704,16 @@ def from_c_string (s:CString) -> {IO} (Maybe String) = foreign "getenv" getenvFFI : (RawPtr) -> {IO} RawPtr -def get_env (name:String) -> {IO} Maybe String = +def get_env(name:String) -> {IO} Maybe String = cStr <- with_c_string name getenvFFI cStr.ptr | CString.new | from_c_string -def check_env (name:String) -> {IO} Bool = +def check_env(name:String) -> {IO} Bool = is_just $ get_env name '### More Stream IO -def fread (stream:Stream ReadMode) -> {IO} String = +def fread(stream:Stream ReadMode) -> {IO} String = -- TODO: allow reading longer files! n = 4096 ptr:(Ptr Char) <- with_alloc n @@ -1726,11 +1728,11 @@ def fread (stream:Stream ReadMode) -> {IO} String = '### Print -def get_output_stream () -> {IO} Stream WriteMode = +def get_output_stream() -> {IO} Stream WriteMode = Stream.new $ %outputStream() @noinline -def print (s:String) -> {IO} () = +def print(s:String) -> {IO} () = stream = get_output_stream() fwrite(stream, s) fwrite(stream, "\n") @@ -1742,7 +1744,7 @@ foreign "remove" removeFFI : (RawPtr) -> {IO} Int64 foreign "mkstemp" mkstempFFI : (RawPtr) -> {IO} Int32 foreign "close" closeFFI : (Int32) -> {IO} Int32 -def shell_out (command:String) -> {IO} String = +def shell_out(command:String) -> {IO} String = modeStr = "r" with_c_string command \command'. with_c_string modeStr \modeStr'. @@ -1757,15 +1759,15 @@ Not to be confused with a partially applied function '### Error throwing @noinline -def error (s:String) -> a given (a|Data) = unsafe_io \. +def error(s:String) -> a given (a|Data) = unsafe_io \. print s - %throwError a + %throwError(a) -def todo () ->> a given (a|Data) = error "TODO: implement it!" +def todo() ->> a given (a|Data) = error "TODO: implement it!" '### File Operations -def delete_file (f:FilePath) -> {IO} () = +def delete_file(f:FilePath) -> {IO} () = s <- with_c_string(f) removeFFI s.ptr () @@ -1784,13 +1786,13 @@ def with_file( fclose stream result -def write_file (f:FilePath, s:String) -> {IO} () = +def write_file(f:FilePath, s:String) -> {IO} () = with_file(f, WriteMode) \stream. fwrite(stream, s) -def read_file (f:FilePath) -> {IO} String = +def read_file(f:FilePath) -> {IO} String = with_file(f, ReadMode) \stream. fread stream -def has_file (f:FilePath) -> {IO} Bool = +def has_file(f:FilePath) -> {IO} Bool = stream = fopen(f, ReadMode) result = not (is_null_raw_ptr stream.ptr) if result then fclose stream @@ -1798,19 +1800,19 @@ def has_file (f:FilePath) -> {IO} Bool = '### Temporary Files -def new_temp_file () -> {IO} FilePath = +def new_temp_file() -> {IO} FilePath = s <- with_c_string "/tmp/dex-XXXXXX" fd = mkstempFFI s.ptr closeFFI fd string_from_char_ptr(15, (Ptr.new s.ptr)) -def with_temp_file (action: (FilePath) -> {IO} a) -> {IO} a given (a) = +def with_temp_file(action: (FilePath) -> {IO} a) -> {IO} a given (a) = tmpFile = new_temp_file() result = action tmpFile delete_file tmpFile result -def with_temp_files (action: (n=>FilePath) -> {IO} a) -> {IO} a given (n|Ix, a) = +def with_temp_files(action: (n=>FilePath) -> {IO} a) -> {IO} a given (n|Ix, a) = tmpFiles = for i. new_temp_file() result = action tmpFiles for i. delete_file tmpFiles[i] @@ -1819,37 +1821,37 @@ def with_temp_files (action: (n=>FilePath) -> {IO} a) -> {IO} a given (n|Ix, a) '### Table operations @noinline -def from_ordinal_error (i:Nat, upper:Nat) -> String = +def from_ordinal_error(i:Nat, upper:Nat) -> String = "Ordinal index out of range:" <> show i <> " >= " <> show upper -def from_ordinal (i:Nat) -> n given (n|Ix) = +def from_ordinal(i:Nat) -> n given (n|Ix) = case i < size n of True -> unsafe_from_ordinal i False -> error $ from_ordinal_error(i, size n) -- TODO: should this be called `from_ordinal`? -def to_ix (i:Nat) -> Maybe n given (n|Ix) = +def to_ix(i:Nat) -> Maybe n given (n|Ix) = case i < size n of True -> Just $ unsafe_from_ordinal i False -> Nothing -- TODO: could make an `unsafeCastIndex` and this could avoid the runtime copy -- TODO: safe (runtime-checked) and unsafe versions -def cast_table (xs:to=>a) -> from=>a given (from|Ix, to|Ix, a|Data) = +def cast_table(xs:to=>a) -> from=>a given (from|Ix, to|Ix, a|Data) = case size from == size to of True -> unsafe_cast_table xs False -> error $ "Table size mismatch in cast: " <> show (size from) <> " vs " <> show (size to) -def asidx (i:Nat) -> n given (n|Ix) = from_ordinal i -def (@) (i:Nat, n|Ix) -> n = from_ordinal i +def asidx(i:Nat) -> n given (n|Ix) = from_ordinal i +def (@)(i:Nat, n|Ix) -> n = from_ordinal i -def slice (xs:n=>a, start:Nat, m|Ix) -> m=>a given (n|Ix, a) = +def slice(xs:n=>a, start:Nat, m|Ix) -> m=>a given (n|Ix, a) = for i. xs[from_ordinal (ordinal i + start)] -def head (xs:n=>a) -> a given (n|Ix, a) = xs[0@_] +def head(xs:n=>a) -> a given (n|Ix, a) = xs[0@_] -def tail (xs:n=>a, start:Nat) -> List a given (n|Ix, a) = +def tail(xs:n=>a, start:Nat) -> List a given (n|Ix, a) = numElts = size n -| start to_list $ slice(xs, start, Fin numElts) @@ -1864,7 +1866,7 @@ Dex's PRNG system is modelled directly after [JAX's](https://github.com/google/j Key = Word64 @noinline -def threefry_2x32 (k:Word64, count:Word64) -> Word64 = +def threefry_2x32(k:Word64, count:Word64) -> Word64 = -- Based on jax's threefry_2x32 by Matt Johnson and Peter Hawkins rotations1 = [13, 15, 26, 6] rotations2 = [17, 29, 16, 24] @@ -1897,31 +1899,31 @@ def threefry_2x32 (k:Word64, count:Word64) -> Word64 = (w32_to_w64 x .<<. 32) .|. (w32_to_w64 y) -def hash (x:Key, y:Nat) -> Key = +def hash(x:Key, y:Nat) -> Key = y64 = n_to_w64 y threefry_2x32(x, y64) -def new_key (x:Nat) -> Key = hash(0, x) -def many (f:(Key)->a, k:Key, i:n) -> a given (a, n|Ix) = f hash(k, ordinal i) -def ixkey (k:Key, i:n) -> Key given (n|Ix) = hash(k, ordinal i) -def split_key (k:Key) -> Fin n => Key given (n:Nat) = for i. ixkey(k, i) +def new_key(x:Nat) -> Key = hash(0, x) +def many(f:(Key)->a, k:Key, i:n) -> a given (a, n|Ix) = f hash(k, ordinal i) +def ixkey(k:Key, i:n) -> Key given (n|Ix) = hash(k, ordinal i) +def split_key(k:Key) -> Fin n => Key given (n:Nat) = for i. ixkey(k, i) '### Sample Generators These functions generate samples taken from, different distributions. Such as `rand_mat` with samples from the distribution of floating point matrices where each element is taken from a i.i.d. uniform distribution. Note that additional standard distributions are provided by the `stats` library. -def rand (k:Key) -> Float = +def rand(k:Key) -> Float = exponent_bits = 1065353216 -- 1065353216 = 127 << 23 mantissa_bits = (high_word k .&. 8388607) -- 8388607 == (1 << 23) - 1 bits = exponent_bits .|. mantissa_bits %bitcast(Float, bits) - 1.0 -def rand_vec (n:Nat, f: (Key) -> a, k: Key) -> Fin n => a given (n|Ix, a) = +def rand_vec(n:Nat, f: (Key) -> a, k: Key) -> Fin n => a given (n|Ix, a) = for i:(Fin n). f ixkey(k, i) -def rand_mat (n:Nat, m:Nat, f: (Key) -> a, k: Key) -> Fin n => Fin m => a given (a) = +def rand_mat(n:Nat, m:Nat, f: (Key) -> a, k: Key) -> Fin n => Fin m => a given (a) = for i j. f ixkey(k, (i, j)) -def randn (k:Key) -> Float = +def randn(k:Key) -> Float = [k1, k2] = split_key k -- rand is uniform between 0 and 1, but implemented such that it rounds to 0 -- (in float32) once every few million draws, but never rounds to 1. @@ -1930,14 +1932,14 @@ def randn (k:Key) -> Float = sqrt ((-2.0) * log u1) * cos (2.0 * pi * u2) -- TODO: Make this better... -def rand_int (k:Key) -> Nat = w64_to_n k `mod` 2147483647 +def rand_int(k:Key) -> Nat = w64_to_n k `mod` 2147483647 -def bern (p:Float, k:Key) -> Bool = rand k < p +def bern(p:Float, k:Key) -> Bool = rand k < p -def randn_vec (k:Key) -> n=>Float given (n|Ix) = +def randn_vec(k:Key) -> n=>Float given (n|Ix) = for i. randn (ixkey(k, i)) -def rand_idx (k:Key) -> n given (n|Ix) = +def rand_idx(k:Key) -> n given (n|Ix) = rand k * n_to_f (size n) | floor | f_to_n | unsafe_from_ordinal '## Inner product typeclass @@ -1945,10 +1947,10 @@ def rand_idx (k:Key) -> n given (n|Ix) = interface InnerProd(v|VSpace) inner_prod : (v, v) -> Float -instance InnerProd Float +instance InnerProd(Float) def inner_prod(x, y) = x * y -instance InnerProd (n=>a) given (a|InnerProd, n|Ix) +instance InnerProd(n=>a) given (a|InnerProd, n|Ix) def inner_prod(x, y) =sum for i. inner_prod(x[i], y[i]) '## Arbitrary @@ -1957,40 +1959,40 @@ Type class for generating example values interface Arbitrary(a) arb : (Key) -> a -instance Arbitrary Bool - def arb (key) = key .&. 1 == 0 +instance Arbitrary(Bool) + def arb(key) = key .&. 1 == 0 -instance Arbitrary Float32 - def arb (key) = randn key +instance Arbitrary(Float32) + def arb(key) = randn key -instance Arbitrary Int32 - def arb (key) = f_to_i $ randn key * 5.0 +instance Arbitrary(Int32) + def arb(key) = f_to_i $ randn key * 5.0 -instance Arbitrary Nat - def arb (key) = f_to_n $ randn key * 5.0 +instance Arbitrary(Nat) + def arb(key) = f_to_n $ randn key * 5.0 -instance Arbitrary (n=>a) given (n|Ix, a|Arbitrary) - def arb (key) = for i. arb $ ixkey(key, i) +instance Arbitrary(n=>a) given (n|Ix, a|Arbitrary) + def arb(key) = for i. arb $ ixkey(key, i) -instance Arbitrary ((i:n)=>(..<i) => a) given (n|Ix, a|Arbitrary) - def arb (key) = for i. arb $ ixkey(key, i) +instance Arbitrary((i:n)=>(..<i) => a) given (n|Ix, a|Arbitrary) + def arb(key) = for i. arb $ ixkey(key, i) -instance Arbitrary ((i:n)=>(..i) => a) given (n|Ix, a|Arbitrary) - def arb (key) = for i. arb $ ixkey(key, i) +instance Arbitrary((i:n)=>(..i) => a) given (n|Ix, a|Arbitrary) + def arb(key) = for i. arb $ ixkey(key, i) -instance Arbitrary ((i:n)=>(i..) => a) given (n|Ix, a|Arbitrary) - def arb (key) = for i. arb $ ixkey(key, i) +instance Arbitrary((i:n)=>(i..) => a) given (n|Ix, a|Arbitrary) + def arb(key) = for i. arb $ ixkey(key, i) -instance Arbitrary ((i:n)=>(i<..) => a) given (n|Ix, a|Arbitrary) - def arb (key) = for i. arb $ ixkey(key, i) +instance Arbitrary((i:n)=>(i<..) => a) given (n|Ix, a|Arbitrary) + def arb(key) = for i. arb $ ixkey(key, i) -instance Arbitrary (a, b) given (a|Arbitrary, b|Arbitrary) - def arb (key) = +instance Arbitrary((a, b)) given (a|Arbitrary, b|Arbitrary) + def arb(key) = [k1, k2] = split_key key (arb k1, arb k2) -instance Arbitrary (Fin n) given (n) - def arb (key) = rand_idx key +instance Arbitrary(Fin n) given (n) + def arb(key) = rand_idx key '## Ord on Arrays @@ -1998,7 +2000,7 @@ instance Arbitrary (Fin n) given (n) 'returns the highest index `i` such that `xs.i <= x` -def search_sorted (xs:n=>a, x:a) -> Maybe n given (n|Ix, a|Ord) = +def search_sorted(xs:n=>a, x:a) -> Maybe n given (n|Ix, a|Ord) = if size n == 0 then Nothing else if x < xs[from_ordinal 0] @@ -2019,24 +2021,24 @@ def search_sorted (xs:n=>a, x:a) -> Maybe n given (n|Ix, a|Ord) = '### min / max etc -def min_by (f:(a)->o, x:a, y:a) -> a given (o|Ord, a) = select(f x < f y, x, y) -def max_by (f:(a)->o, x:a, y:a) -> a given (o|Ord, a) = select(f x > f y, x, y) +def min_by(f:(a)->o, x:a, y:a) -> a given (o|Ord, a) = select(f x < f y, x, y) +def max_by(f:(a)->o, x:a, y:a) -> a given (o|Ord, a) = select(f x > f y, x, y) -def min (x1: o, x2: o) -> o given (o|Ord) = min_by(id, x1, x2) -def max (x1: o, x2: o) -> o given (o|Ord) = max_by(id, x1, x2) +def min(x1: o, x2: o) -> o given (o|Ord) = min_by(id, x1, x2) +def max(x1: o, x2: o) -> o given (o|Ord) = max_by(id, x1, x2) -def minimum_by (f:(a)->o, xs:n=>a) -> a given (a|Data, o|Ord, n|Ix) = +def minimum_by(f:(a)->o, xs:n=>a) -> a given (a|Data, o|Ord, n|Ix) = reduce(xs[0@_], \x y. min_by(f, x, y), xs) -def maximum_by (f:(a)->o, xs:n=>a) -> a given (a|Data, o|Ord, n|Ix) = +def maximum_by(f:(a)->o, xs:n=>a) -> a given (a|Data, o|Ord, n|Ix) = reduce(xs[0@_], \x y. max_by(f, x, y), xs) -def minimum (xs:n=>o) -> o given (n|Ix, o|Ord) = minimum_by(id, xs) -def maximum (xs:n=>o) -> o given (n|Ix, o|Ord) = maximum_by(id, xs) +def minimum(xs:n=>o) -> o given (n|Ix, o|Ord) = minimum_by(id, xs) +def maximum(xs:n=>o) -> o given (n|Ix, o|Ord) = maximum_by(id, xs) '### argmin/argmax -- TODO: put in same section as `searchsorted` -def argscan (comp:(o,o)->Bool, xs:n=>o) -> n given (o|Ord, n|Ix) = +def argscan(comp:(o,o)->Bool, xs:n=>o) -> n given (o|Ord, n|Ix) = zeroth = (0@_, xs[0@_]) compare = \p1 p2. (idx1, x1) = p1 @@ -2045,8 +2047,8 @@ def argscan (comp:(o,o)->Bool, xs:n=>o) -> n given (o|Ord, n|Ix) = zipped = for i. (i, xs[i]) fst $ reduce(zeroth, compare, zipped) -def argmin (xs:n=>o) -> n given (n|Ix, o|Ord) = argscan((<), xs) -def argmax (xs:n=>o) -> n given (n|Ix, o|Ord) = argscan((>), xs) +def argmin(xs:n=>o) -> n given (n|Ix, o|Ord) = argscan((<), xs) +def argmax(xs:n=>o) -> n given (n|Ix, o|Ord) = argscan((>), xs) def lexical_order( compareElements:(n,n)->Bool, @@ -2080,13 +2082,13 @@ def lexical_order( True -> Continue False -> Done False -instance Ord (List n) given (n|Ord) - def (>) (xs, ys) = lexical_order((>), (<), xs, ys) - def (<) (xs, ys) = lexical_order((>), (<), xs, ys) +instance Ord(List n) given (n|Ord) + def (>)(xs, ys) = lexical_order((>), (>), xs, ys) + def (<)(xs, ys) = lexical_order((<), (<), xs, ys) '### clip -def clip (bounds:(a,a), x:a) -> a given (a|Ord) = +def clip(bounds:(a,a), x:a) -> a given (a|Ord) = (low,high) = bounds min(high, max(low, x)) @@ -2094,7 +2096,7 @@ def clip (bounds:(a,a), x:a) -> a given (a|Ord) = TODO: these should be with the other Elementary/Special Functions ### atan/atan2 -def atan_inner (x:Float) -> Float = +def atan_inner(x:Float) -> Float = -- From "Computing accurate Horner form approximations to -- special functions in finite precision arithmetic" -- https://arxiv.org/abs/1508.03211 @@ -2112,10 +2114,10 @@ def atan_inner (x:Float) -> Float = r * x + x -def min_and_max (x:a, y:a) -> (a, a) given (a|Ord) = +def min_and_max(x:a, y:a) -> (a, a) given (a|Ord) = select(x < y, (x, y), (y, x)) -- get both with one comparison. -def atan2 (y:Float, x:Float) -> Float = +def atan2(y:Float, x:Float) -> Float = -- Based off of the Tensorflow implementation at -- github.com/tensorflow/mlir-hlo/blob/master/lib/ -- Dialect/mhlo/transforms/legalize_trigonometric_to_approximation.cc#L147 @@ -2133,21 +2135,21 @@ def atan2 (y:Float, x:Float) -> Float = a = copysign(a, y) select(either_is_nan(x, y), nan, a) -- Propagate NaNs. -def atan (x:Float) -> Float = atan2(x, 1.0) +def atan(x:Float) -> Float = atan2(x, 1.0) '## Miscellaneous utilities TODO: all of these should be in some other section -def reflect (i:n) -> n given (n|Ix) = +def reflect(i:n) -> n given (n|Ix) = unsafe_from_ordinal $ unsafe_nat_diff(size n, ordinal i + 1) -def reverse (x:n=>a) -> n=>a given (n|Ix, a) = +def reverse(x:n=>a) -> n=>a given (n|Ix, a) = for i. x[reflect i] -def wrap_periodic (n|Ix, i:Nat) -> n = +def wrap_periodic(n|Ix, i:Nat) -> n = unsafe_from_ordinal(n=n, i `mod` size n) -def pad_to (m|Ix, x:a, xs:n=>a) -> m=>a given (n|Ix, a) = +def pad_to(m|Ix, x:a, xs:n=>a) -> m=>a given (n|Ix, a) = n' = size n for i. i' = ordinal i @@ -2155,13 +2157,13 @@ def pad_to (m|Ix, x:a, xs:n=>a) -> m=>a given (n|Ix, a) = True -> xs[i'@_] False -> x -def idiv_ceil (x:Nat, y:Nat) -> Nat = x `idiv` y + b_to_n (x `rem` y /= 0) -def intdiv2 (x:Nat) -> Nat = rep_to_nat $ %shr(nat_to_rep x, 1 :: NatRep) -def intpow2 (power:Nat) -> Nat = rep_to_nat $ %shl(1 :: NatRep, nat_to_rep power) -def is_odd (x:Nat) -> Bool = rem(x, 2) == 1 -def is_even (x:Nat) -> Bool = rem(x, 2) == 0 +def idiv_ceil(x:Nat, y:Nat) -> Nat = x `idiv` y + b_to_n (x `rem` y /= 0) +def intdiv2(x:Nat) -> Nat = rep_to_nat $ %shr(nat_to_rep x, 1 :: NatRep) +def intpow2(power:Nat) -> Nat = rep_to_nat $ %shl(1 :: NatRep, nat_to_rep power) +def is_odd(x:Nat) -> Bool = rem(x, 2) == 1 +def is_even(x:Nat) -> Bool = rem(x, 2) == 0 -def is_power_of_2 (x:Nat) -> Bool = +def is_power_of_2(x:Nat) -> Bool = -- A fast trick based on bitwise AND. -- This works on integer types larger than 8 bits. -- Note: The bitwise and operator (.&.) @@ -2178,7 +2180,7 @@ def is_power_of_2 (x:Nat) -> Bool = -- code path in ImpToLLVM, because it's the first LLVM intrinsic -- we have with a fixed-point argument. -- https://llvm.org/docs/LangRef.html#llvm-ctlz-intrinsic -def natlog2 (x:Nat) -> Nat = +def natlog2(x:Nat) -> Nat = tmp = yield_state 0 \ans. cmp <- run_state 1 while \. @@ -2191,8 +2193,11 @@ def natlog2 (x:Nat) -> Nat = False unsafe_nat_diff(tmp, 1) -- TODO: something less horrible -def general_integer_power - (times:(a,a)->a, one:a, base:a, power:Nat) -> a given (a|Data) = +def general_integer_power( + times:(a,a)->a, + one:a, base:a, + power:Nat + ) -> a given (a|Data) = iters = if power == 0 then 0 else 1 + natlog2 power -- Implements exponentiation by squaring. -- This could be nicer if there were a way to explicitly @@ -2206,33 +2211,33 @@ def general_integer_power z := times(get z, get z) pow := intdiv2 (get pow) -def intpow (base:a, power:Nat) -> a given (a|Mul) = +def intpow(base:a, power:Nat) -> a given (a|Mul) = general_integer_power((*), one, base, power) def from_just(x:Maybe a) -> a given (a) = case x of Just(x') -> x' -def any_sat (f:(a)->Bool, xs:n=>a) -> Bool given (a, n|Ix) = any(each xs f) +def any_sat(f:(a)->Bool, xs:n=>a) -> Bool given (a, n|Ix) = any(each xs f) -def seq_maybes (xs: n=>Maybe a) -> Maybe (n => a) given (n|Ix, a) = +def seq_maybes(xs: n=>Maybe a) -> Maybe (n => a) given (n|Ix, a) = -- is it possible to implement this safely? (i.e. without using partial -- functions) case any_sat(is_nothing, xs) of True -> Nothing False -> Just $ each xs from_just -def linear_search (xs:n=>a, query:a) -> Maybe n given (n|Ix, a|Eq) = +def linear_search(xs:n=>a, query:a) -> Maybe n given (n|Ix, a|Eq) = yield_state Nothing \ref. for i. case xs[i] == query of True -> ref := Just i False -> () -def list_length (l:List a) -> Nat given (a) = +def list_length(l:List a) -> Nat given (a) = AsList(n, _) = l n -- This is for efficiency (rather than using `<>` repeatedly) -- TODO: we want this for any monoid but this implementation won't work. -def concat (lists:n=>(List a)) -> List a given (a, n|Ix) = +def concat(lists:n=>(List a)) -> List a given (a, n|Ix) = totalSize = sum for i. list_length lists[i] to_list $ with_state 0 \listIdx. eltIdx <- with_state 0 @@ -2250,7 +2255,7 @@ def concat (lists:n=>(List a)) -> List a given (a, n|Ix) = eltIdx := eltIdxVal + 1 xs[eltIdxVal@_] -def cat_maybes (xs:n=>Maybe a) -> List a given (n|Ix, a|Data) = +def cat_maybes(xs:n=>Maybe a) -> List a given (n|Ix, a|Data) = (num_res, res_inds) = yield_state (0::Nat, for i:n. Nothing) \ref. for i. case xs[i] of Just(_) -> @@ -2265,19 +2270,19 @@ def cat_maybes (xs:n=>Maybe a) -> List a given (n|Ix, a|Data) = Nothing -> todo -- Impossible Nothing -> todo -- Impossible -def filter (xs:n=>a, condition:(a)->Bool) -> List a given (a|Data, n|Ix) = +def filter(xs:n=>a, condition:(a)->Bool) -> List a given (a|Data, n|Ix) = cat_maybes $ for i. if condition xs[i] then Just xs[i] else Nothing -def arg_filter (xs:n=>a, condition:(a)->Bool) -> List n given (a|Data, n|Ix) = +def arg_filter(xs:n=>a, condition:(a)->Bool) -> List n given (a|Data, n|Ix) = cat_maybes $ for i. if condition xs[i] then Just i else Nothing -- TODO: use `ix_offset : [Ix n] -> n -> Int -> Maybe n` instead -def prev_ix (i:n) -> Maybe n given (n|Ix) = +def prev_ix(i:n) -> Maybe n given (n|Ix) = case i_to_n (n_to_i (ordinal i) - 1) of Nothing -> Nothing Just(i_prev) -> unsafe_from_ordinal(i_prev) | Just -def lines (source:String) -> List String = +def lines(source:String) -> List String = AsList(_, s) = source AsList(num_lines, newline_ixs) = cat_maybes for i_char. if s[i_char] == '\n' @@ -2293,35 +2298,35 @@ def lines (source:String) -> List String = '## Probability -- cdf should include 0.0 but not 1.0 -def categorical_from_cdf (cdf: n=>Float, key: Key) -> n given (n|Ix) = +def categorical_from_cdf(cdf: n=>Float, key: Key) -> n given (n|Ix) = r = rand key case search_sorted(cdf, r) of Just(i) -> i -def normalize_pdf (xs: d=>Float) -> d=>Float given (d|Ix) = xs / sum xs +def normalize_pdf(xs: d=>Float) -> d=>Float given (d|Ix) = xs / sum xs -def cdf_for_categorical (logprobs: n=>Float) -> n=>Float given (n|Ix) = +def cdf_for_categorical(logprobs: n=>Float) -> n=>Float given (n|Ix) = maxLogProb = maximum logprobs cumsum_low $ normalize_pdf $ for i. exp(logprobs[i] - maxLogProb) -def categorical (logprobs: n=>Float, key: Key) -> n given (n|Ix) = +def categorical(logprobs: n=>Float, key: Key) -> n given (n|Ix) = categorical_from_cdf(cdf_for_categorical logprobs, key) -- batch variant to share the work of forming the cumsum -- (alternatively we could rely on hoisting of loop constants) -def categorical_batch (logprobs: n=>Float, key: Key) -> m=>n given (n|Ix, m|Ix) = +def categorical_batch(logprobs: n=>Float, key: Key) -> m=>n given (n|Ix, m|Ix) = cdf = cdf_for_categorical logprobs for i. categorical_from_cdf(cdf, ixkey(key, i)) -def logsumexp (x: n=>Float) -> Float given (n|Ix) = +def logsumexp(x: n=>Float) -> Float given (n|Ix) = m = maximum x m + (log $ sum for i. exp (x[i] - m)) -def logsoftmax (x: n=>Float) -> n=>Float given (n|Ix) = +def logsoftmax(x: n=>Float) -> n=>Float given (n|Ix) = lse = logsumexp x for i. x[i] - lse -def softmax (x: n=>Float) -> n=>Float given (n|Ix) = +def softmax(x: n=>Float) -> n=>Float given (n|Ix) = m = maximum x e = for i. exp (x[i] - m) s = sum e @@ -2330,45 +2335,45 @@ def softmax (x: n=>Float) -> n=>Float given (n|Ix) = '## Polynomials TODO: Move this somewhere else -def evalpoly (coefficients:n=>v, x:Float) -> v given (n|Ix, v|VSpace) = +def evalpoly(coefficients:n=>v, x:Float) -> v given (n|Ix, v|VSpace) = -- Evaluate a polynomial at x. Same as Numpy's polyval. fold zero \i c. coefficients[i] + x .* c '## TestMode -- TODO: move this to be in Testing Helpers -def dex_test_mode () -> Bool = unsafe_io \. check_env "DEX_TEST_MODE" +def dex_test_mode() -> Bool = unsafe_io \. check_env "DEX_TEST_MODE" '## Exception effect -- TODO: move `error` and `todo` to here. -def catch (f:() -> {Except|eff} a) -> {|eff} Maybe a given (a, eff) = +def catch(f:() -> {Except|eff} a) -> {|eff} Maybe a given (a, eff) = f' : (() -> {Except|eff} a) = \. f() - %catchException f' + %catchException(f') -def throw () -> {Except} a given (a) = - %throwException a +def throw() -> {Except} a given (a) = + %throwException(a) -def assert (b:Bool) -> {Except} () = +def assert(b:Bool) -> {Except} () = if not b then throw() '### Misc instances that require `error` instance Subset(a, Either(a,b)) given (a|Data, b|Data) - def inject (x) = Left x - def project (x) = case x of + def inject(x) = Left x + def project(x) = case x of Left( y) -> Just y Right(x) -> Nothing - def unsafe_project (x) = case x of + def unsafe_project(x) = case x of Left( x) -> x Right(x) -> error "Can't project Right branch to Left branch" instance Subset(b, Either(a,b)) given (a|Data, b|Data) - def inject (x) = Right x - def project (x) = case x of + def inject(x) = Right x + def project(x) = case x of Left( x) -> Nothing Right(y) -> Just y - def unsafe_project (x) = case x of + def unsafe_project(x) = case x of Left( x) -> error "Can't project Left branch to Right branch" Right(x) -> x @@ -2383,14 +2388,14 @@ instance Subset(b, Either(a,b)) given (a|Data, b|Data) '### Index set for tables -def int_to_reversed_digits (k:Nat) -> a=>b given (a|Ix, b|Ix) = +def int_to_reversed_digits(k:Nat) -> a=>b given (a|Ix, b|Ix) = base = size b snd $ scan k \_ cur_k. next_k = cur_k `idiv` base digit = cur_k `mod` base (next_k, unsafe_from_ordinal(n=b, digit)) -def reversed_digits_to_int (digits: a=>b) -> Nat given (a|Ix, b|Ix) = +def reversed_digits_to_int(digits: a=>b) -> Nat given (a|Ix, b|Ix) = base = size b fst $ fold (0, 1) \j pair. (cur_k, cur_base) = pair @@ -2398,30 +2403,30 @@ def reversed_digits_to_int (digits: a=>b) -> Nat given (a|Ix, b|Ix) = next_base = cur_base * base (next_k, next_base) -instance Ix (a=>b) given (a|Ix, b|Ix) +instance Ix(a=>b) given (a|Ix, b|Ix) -- 0@a is the least significant digit, -- while (size a - 1)@a is the most significant digit. - def size' () = size b `intpow` size a - def ordinal (i) = reversed_digits_to_int i - def unsafe_from_ordinal (i) = int_to_reversed_digits i + def size'() = size b `intpow` size a + def ordinal(i) = reversed_digits_to_int i + def unsafe_from_ordinal(i) = int_to_reversed_digits i -instance NonEmpty (a=>b) given (a|Ix, b|NonEmpty) +instance NonEmpty(a=>b) given (a|Ix, b|NonEmpty) first_ix = unsafe_from_ordinal 0 '### stack -- TODO: replace `DynBuffer` with this? -def Stack (h:Heap, a|Data) = Ref(h, (Nat, List a)) +def Stack(h:Heap, a|Data) = Ref(h, (Nat, List a)) -def stack_size (stack:Stack(h, a)) -> {State h} Nat given (h, a) = get $ fst_ref stack +def stack_size(stack:Stack(h, a)) -> {State h} Nat given (h, a) = get $ fst_ref stack -def unsafe_get_stack_buffer (stack:Stack(h, a)) -> {State h} (Ref(h, Fin 0 => a)) given (h, a|Data) = +def unsafe_get_stack_buffer(stack:Stack(h, a)) -> {State h} (Ref(h, Fin 0 => a)) given (h, a|Data) = get $ snd_ref $ unsafe_coerce(to=Ref h (Nat, Ref h (Fin 0 => a)), snd_ref stack) -def stack_buf_size (stack:Stack(h, a)) -> {State h} Nat given (h, a|Data) = +def stack_buf_size(stack:Stack(h, a)) -> {State h} Nat given (h, a|Data) = get $ fst_ref $ unsafe_coerce(to=Ref h (Nat, Ref h (Fin 0 => a)), snd_ref stack) -def ensure_size_at_least (stack:Stack(h, a), req_size:Nat) -> {State h} () given (h, a|Data) = +def ensure_size_at_least(stack:Stack(h, a), req_size:Nat) -> {State h} () given (h, a|Data) = if req_size > stack_buf_size stack then -- TODO: maybe this should use integer arithmetic? new_buf_size = f_to_n $ 2.0 `pow` (ceil $ log2 $ n_to_f req_size) @@ -2433,13 +2438,13 @@ def ensure_size_at_least (stack:Stack(h, a), req_size:Nat) -> {State h} () given Just(i') -> cur_data[i'] Nothing -> uninitialized_value() -def read_stack (stack:Stack(h, a)) -> {State h} (List a) given (h, a|Data) = +def read_stack(stack:Stack(h, a)) -> {State h} (List a) given (h, a|Data) = n = stack_size stack buf = unsafe_coerce(to=Ref(h, Fin n => a), unsafe_get_stack_buffer stack) AsList(n, get buf) @noinline -def stack_push (stack:Stack(h, a), x:a) -> {State h} () given (a|Data, h) = +def stack_push(stack:Stack(h, a), x:a) -> {State h} () given (a|Data, h) = n_old = stack_size stack n_new = n_old + 1 ensure_size_at_least(stack, n_new) @@ -2448,7 +2453,7 @@ def stack_push (stack:Stack(h, a), x:a) -> {State h} () given (a|Data, h) = fst_ref stack := n_new @noinline -def stack_extend (stack:Stack(h, a), x:n=>a) -> {State h} () given (a|Data, n|Ix, h) = +def stack_extend(stack:Stack(h, a), x:n=>a) -> {State h} () given (a|Data, n|Ix, h) = n_old = stack_size stack n_new = n_old + size n ensure_size_at_least(stack, n_new) @@ -2457,7 +2462,7 @@ def stack_extend (stack:Stack(h, a), x:n=>a) -> {State h} () given (a|Data, n|Ix buf_slice := x fst_ref stack := n_new -def stack_pop (stack:Stack(h, a)) -> {State h} Maybe a given (a|Data, h) = +def stack_pop(stack:Stack(h, a)) -> {State h} Maybe a given (a|Data, h) = n_old = stack_size stack case n_old == 0 of True -> Nothing @@ -2475,52 +2480,52 @@ def with_stack( init_stack = to_list for i:(Fin stack_init_size). uninitialized_value() with_state (0, init_stack) \ref . action ref -def stack_extend_internal (stack:Stack(h, Char), x:Fin n=>Char) -> {State h} () given (n, h) = +def stack_extend_internal(stack:Stack(h, Char), x:Fin n=>Char) -> {State h} () given (n, h) = stack_extend(stack, x) -def stack_push_internal (stack:Stack(h, Char), x:Char) -> {State h} () given (h) = +def stack_push_internal(stack:Stack(h, Char), x:Char) -> {State h} () given (h) = stack_push(stack, x) -def with_stack_internal (f:(given (h:Heap), Stack(h, Char)) -> {State h} ()) -> List Char = +def with_stack_internal(f:(given (h:Heap), Stack(h, Char)) -> {State h} ()) -> List Char = with_stack Char \stack. f stack read_stack stack -def show_any (x:a) -> String given (a) = unsafe_coerce(to=String, %showAny(x)) +def show_any(x:a) -> String given (a) = unsafe_coerce(to=String, %showAny(x)) -def coerce_table (m|Ix, x:n=>a) -> m => a given (n|Ix, a|Data) = +def coerce_table(m|Ix, x:n=>a) -> m => a given (n|Ix, a|Data) = if size m == size n then unsafe_coerce(to=m=>a, x) else error "mismatched sizes in table coercion" '### Linear Algebra -def linspace (n|Ix, low:Float, high:Float) -> n=>Float = +def linspace(n|Ix, low:Float, high:Float) -> n=>Float = dx = (high - low) / n_to_f (size n) for i:n. low + n_to_f (ordinal i) * dx -def transpose (x:n=>m=>a) -> m=>n=>a given (n|Ix, m|Ix, a) = for i j. x[j,i] -def vdot (x:n=>Float, y:n=>Float) -> Float given (n|Ix) = fsum for i. x[i] * y[i] -def dot (s:n=>Float, vs:n=>v) -> v given (n|Ix, v|VSpace) = sum for j. s[j] .* vs[j] +def transpose(x:n=>m=>a) -> m=>n=>a given (n|Ix, m|Ix, a) = for i j. x[j,i] +def vdot(x:n=>Float, y:n=>Float) -> Float given (n|Ix) = fsum for i. x[i] * y[i] +def dot(s:n=>Float, vs:n=>v) -> v given (n|Ix, v|VSpace) = sum for j. s[j] .* vs[j] -def naive_matmul (x: l=>m=>Float, y: m=>n=>Float) -> (l=>n=>Float) given (l|Ix, m|Ix, n|Ix) = +def naive_matmul(x: l=>m=>Float, y: m=>n=>Float) -> (l=>n=>Float) given (l|Ix, m|Ix, n|Ix) = for i k. fsum for j. x[i,j] * y[j,k] -- A `FullTileIx` type represents `tile_ix`th full tile (of size -- `tile_size`) iterating over the index set `n`. -- This type is only well formed when tile_ix * tile_size < size n. -struct FullTileIx (n|Ix, tile_size:Nat, tile_ix:Nat) = +struct FullTileIx(n|Ix, tile_size:Nat, tile_ix:Nat) = unwrap : Fin tile_size -instance Ix (FullTileIx(n, tile_size, tile_ix)) given (n|Ix, tile_size:Nat, tile_ix:Nat) - def size' () = tile_size - def ordinal (i) = ordinal i.unwrap - def unsafe_from_ordinal (i) =FullTileIx.new $ unsafe_from_ordinal i +instance Ix(FullTileIx(n, tile_size, tile_ix)) given (n|Ix, tile_size:Nat, tile_ix:Nat) + def size'() = tile_size + def ordinal(i) = ordinal i.unwrap + def unsafe_from_ordinal(i) =FullTileIx.new $ unsafe_from_ordinal i instance Subset(FullTileIx(n, tile_size, tile_ix), n) given (n|Ix, tile_size:Nat, tile_ix:Nat) - def inject (i) = unsafe_from_ordinal $ tile_size * tile_ix + ordinal i.unwrap - def project (i) = todo - def unsafe_project (i) = todo + def inject(i) = unsafe_from_ordinal $ tile_size * tile_ix + ordinal i.unwrap + def project(i) = todo + def unsafe_project(i) = todo -- A `CodaIx` type represents the last few elements of the index set `n`, -- as might be left over after iterating by tiles. @@ -2528,13 +2533,13 @@ instance Subset(FullTileIx(n, tile_size, tile_ix), n) given (n|Ix, tile_size:Nat struct CodaIx(n|Ix, coda_offset:Nat, coda_size:Nat) = unwrap : Fin coda_size -instance Ix (CodaIx(n, coda_offset, coda_size)) given (n|Ix, coda_offset:Nat, coda_size:Nat) - def size' () = coda_size - def ordinal (i) = ordinal i.unwrap - def unsafe_from_ordinal (i) = CodaIx.new $ unsafe_from_ordinal i +instance Ix(CodaIx(n, coda_offset, coda_size)) given (n|Ix, coda_offset:Nat, coda_size:Nat) + def size'() = coda_size + def ordinal(i) = ordinal i.unwrap + def unsafe_from_ordinal(i) = CodaIx.new $ unsafe_from_ordinal i instance Subset(CodaIx(n, coda_offset, coda_size), n) given (n|Ix, coda_offset:Nat, coda_size:Nat) - def inject (i) = unsafe_from_ordinal $ coda_offset + ordinal i.unwrap + def inject(i) = unsafe_from_ordinal $ coda_offset + ordinal i.unwrap def project(i) = todo def unsafe_project(i) = todo @@ -2572,15 +2577,15 @@ def (**)( m_ix = inject m_offset result!l_ix!n_ix += x[l_ix,m_ix] * y[m_ix,n_ix] -def (**.) (mat: n=>m=>Float, v: m=>Float) -> (n=>Float) given (n|Ix, m|Ix) = +def (**.)(mat: n=>m=>Float, v: m=>Float) -> (n=>Float) given (n|Ix, m|Ix) = for i. vdot(mat[i], v) -def (.**) (v: m=>Float, mat: n=>m=>Float) -> (n=>Float) given (n|Ix, m|Ix) = +def(.**)(v: m=>Float, mat: n=>m=>Float) -> (n=>Float) given (n|Ix, m|Ix) = mat **. v -def inner (x:n=>Float, mat:n=>m=>Float, y:m=>Float) -> Float given (n|Ix, m|Ix) = +def inner(x:n=>Float, mat:n=>m=>Float, y:m=>Float) -> Float given (n|Ix, m|Ix) = fsum for p. (i,j) = p x[i] * mat[i,j] * y[j] -def eye () ->> n=>n=>a given (n|Ix, a|Add|Mul) = +def eye() ->> n=>n=>a given (n|Ix, a|Add|Mul) = for i j. select(ordinal i == ordinal j, one, zero) diff --git a/lib/sort.dx b/lib/sort.dx index e6a7c77b..1178dd0c 100644 --- a/lib/sort.dx +++ b/lib/sort.dx @@ -12,58 +12,60 @@ it's doing bubble / insertion sort with quadratic time cost. However, if it breaks the list in half recursively, it'll be doing parallel mergesort. Currently the Dex compiler will do the quadratic-time version. -def concat_table {a b v} (leftin: a=>v) (rightin: b=>v) : ((a|b)=>v) = +def concat_table(leftin: a=>v, rightin: b=>v) -> (Either a b=>v) given (a|Ix, b|Ix, v) = for idx. case idx of - Left i -> leftin.i - Right i -> rightin.i + Left i -> leftin[i] + Right i -> rightin[i] -def merge_sorted_tables {a m n} [Ord a] (xs:m=>a) (ys:n=>a) : ((m|n)=>a) = +def merge_sorted_tables(xs:m=>a, ys:n=>a) -> (Either m n=>a) given (a|Ord, m|Ix, n|Ix) = -- Possible improvements: -- 1) Using a SortedTable type. -- 2) Avoid needlessly initializing the return array. init = concat_table xs ys -- Initialize array of correct size. yield_state init \buf. with_state (0, 0) \countrefs. - for i:(m|n). + for i:(Either m n). (cur_x, cur_y) = get countrefs if cur_y >= size n -- no ys left then countrefs := (cur_x + 1, cur_y) - buf!i := xs.(unsafe_from_ordinal _ cur_x) + buf!i := xs[unsafe_from_ordinal cur_x] else if cur_x < size m -- still xs left - then - if xs.(unsafe_from_ordinal _ cur_x) <= ys.(unsafe_from_ordinal _ cur_y) + then + if xs[unsafe_from_ordinal cur_x] <= ys[unsafe_from_ordinal cur_y] then countrefs := (cur_x + 1, cur_y) - buf!i := xs.(unsafe_from_ordinal _ cur_x) + buf!i := xs[unsafe_from_ordinal cur_x] else countrefs := (cur_x, cur_y + 1) - buf!i := ys.(unsafe_from_ordinal _ cur_y) + buf!i := ys[unsafe_from_ordinal cur_y] -def merge_sorted_lists {a} [Ord a] (AsList nx xs: List a) (AsList ny ys: List a) : List a = +def merge_sorted_lists(lx: List a, ly: List a) -> List a given (a|Ord) = -- Need this wrapper because Dex can't automatically weaken -- (a | b)=>c to ((Fin d)=>c) + AsList(nx, xs) = lx + AsList(ny, ys) = ly sorted = merge_sorted_tables xs ys newsize = nx + ny - AsList _ $ unsafe_cast_table (Fin newsize) sorted + AsList _ $ unsafe_cast_table(to=Fin newsize, sorted) -def sort {a n} [Ord a] (xs: n=>a) : n=>a = +def sort(xs: n=>a) -> n=>a given (n|Ix, a|Ord) = -- Warning: Has quadratic runtime cost for now. - xlists = for i:n. (AsList 1 [xs.i]) + xlists = for i:n. (AsList 1 [xs[i]]) -- Merge sort monoid: mempty = AsList 0 [] mcombine = merge_sorted_lists -- reduce might someday recursively subdivide the problem. - (AsList _ r) = reduce mempty mcombine xlists - unsafe_cast_table n r + AsList(_, r) = reduce mempty mcombine xlists + unsafe_cast_table(to=n, r) -def (+|) {n} [Ix n] (i:n) (delta:Nat) : n = +def (+|)(i:n, delta:Nat) -> n given (n|Ix) = i' = ordinal i + delta - from_ordinal _ $ select (i' >= size n) (size n -| 1) i' + from_ordinal $ select (i' >= size n) (size n -| 1) i' -def is_sorted {a n} [Ord a] (xs:n=>a) : Bool = - all for i. xs.i <= xs.(i +| 1) +def is_sorted(xs:n=>a) -> Bool given (a|Ord, n|Ix) = + all for i. xs[i] <= xs[i +| 1] @@ -221,8 +221,8 @@ endif test-names = uexpr-tests print-tests adt-tests type-tests struct-tests cast-tests eval-tests show-tests \ read-tests shadow-tests monad-tests io-tests exception-tests sort-tests \ - standalone-function-tests \ - ad-tests parser-tests serialize-tests parser-combinator-tests \ + parser-tests standalone-function-tests \ + ad-tests serialize-tests parser-combinator-tests \ record-variant-tests typeclass-tests complex-tests trig-tests \ linalg-tests set-tests fft-tests stats-tests stack-tests diff --git a/src/lib/ConcreteSyntax.hs b/src/lib/ConcreteSyntax.hs index 8a7c7533..c5fecf4c 100644 --- a/src/lib/ConcreteSyntax.hs +++ b/src/lib/ConcreteSyntax.hs @@ -238,8 +238,7 @@ structDef :: Parser CTopDecl structDef = withSrc do keyWord StructKW tyName <- anyName - params <- explicitParams - sym "=" + params <- typeParams fields <- onePerLine nameAndType return $ CStruct tyName params fields @@ -247,11 +246,10 @@ dataDef :: Parser CTopDecl dataDef = withSrc do keyWord DataKW tyName <- anyName - params <- explicitParams - sym "=" + params <- typeParams dataCons <- onePerLine do dataConName <- anyName - dataConArgs <- explicitParams + dataConArgs <- optExplicitParams return (dataConName, dataConArgs) return $ CData tyName params dataCons @@ -273,9 +271,6 @@ interfaceDef = withSrc do return (methodName, ty) return $ CInterface className params methodDecls --- superclassConstraints :: Parser [Group] --- superclassConstraints = optionalMonoid $ brackets $ cNames `sepBy` sym "," - effectDef :: Parser CTopDecl effectDef = withSrc do keyWord EffectKW @@ -364,7 +359,7 @@ instanceDef isNamed = do args <- (sym ":" >> return Nothing) <|> ((Just <$> parens (commaSep cParenGroup)) <* sym "->") return $ Just (name, args) - className <- anyNameNoSC + className <- anyName args <- argList givens <- optional givenClause methods <- realBlock @@ -375,7 +370,7 @@ funDefLet = label "function definition" do keyWord DefKW mayBreak do name <- anyName - params <- parens (commaSep cParenGroup) + params <- explicitParams rhs <- optional do expl <- explicitness effs <- optional cEffs @@ -391,20 +386,39 @@ explicitness :: Parser AppExplicitness explicitness = (sym "->" $> ExplicitApp) <|> (sym "->>" $> ImplicitApp) -anyNameNoSC :: Parser SourceName -anyNameNoSC = withConsumption ConsumeNothing $ anyName - -- Intended for occurrences, like `foo(x, y, z)` (cf. defParamsList). --- Previous lexeme shouldn't consume trailing whitespace. argList :: Parser [Group] -argList = nextChar >>= \case - '(' -> parens (commaSep cGroup) - _ -> do - arg <- sc >> leafGroup - return [arg] +argList = immediateParens (commaSep cParenGroup) + +immediateLParen :: Parser () +immediateLParen = label "'(' (without preceding whitespace)" do + nextChar >>= \case + '(' -> precededByWhitespace >>= \case + True -> empty + False -> charLexeme '(' + _ -> empty + +immediateParens :: Parser a -> Parser a +immediateParens p = bracketed immediateLParen rParen p + +-- Putting `sym =` inside the cases gives better errors. +typeParams :: Parser ExplicitParams +typeParams = + (explicitParams <* sym "=") + <|> (return [] <* sym "=") + +optExplicitParams :: Parser ExplicitParams +optExplicitParams = label "optional parameter list" $ + explicitParams <|> return [] explicitParams :: Parser ExplicitParams -explicitParams = label "parameter list" $ parens (commaSep cGroup) <|> return [] +explicitParams = label "parameter list in parentheses (without preceding whitespace)" $ + immediateParens $ commaSep cParenGroup + +noGap :: Parser () +noGap = precededByWhitespace >>= \case + True -> fail "Unexpected whitespace" + False -> return () givenClause :: Parser GivenClause givenClause = keyWord GivenKW >> do @@ -571,22 +585,20 @@ noElse msg = (optional $ try $ sc >> keyWord ElseKW) >>= \case leafGroup :: Parser Group leafGroup = do - leaf <- leafGroupNoSC - postOps <- many postfixGroupNoSC <* sc + leaf <- leafGroup' + postOps <- many postfixGroup return $ foldl (\accum (op, opLhs) -> joinSrc accum opLhs $ CBin (WithSrc Nothing op) accum opLhs) leaf postOps where - -- These "noSC" functions don't consume trailing whitespace. We want to parse - -- things like `f(x,y).foo(z)[q]`. - leafGroupNoSC :: Parser Group - leafGroupNoSC = withSrc do + leafGroup' :: Parser Group + leafGroup' = withSrc do next <- nextChar case next of - '_' -> noSC $ underscore $> CHole + '_' -> underscore $> CHole '(' -> (CIdentifier <$> symName) - <|> cParensNoSC - '[' -> cBracketsNoSC - '{' -> CBraces <$> bracketNoSC '{' '}' (commaSep cGroup) + <|> cParens + '[' -> cBrackets + '{' -> cBraces '\"' -> CString <$> strLit '\'' -> CChar <$> charLit '%' -> do @@ -598,33 +610,30 @@ leafGroup = do <|> CFloat <$> doubleLit) '\\' -> cNullaryLam <|> cLam -- For exprs include for, rof, for_, rof_ - 'f' -> cFor <|> cIdentifierNoSC - 'd' -> cDo <|> cIdentifierNoSC - 'r' -> cFor <|> cIdentifierNoSC - 'c' -> cCase <|> cIdentifierNoSC - 'i' -> cIf <|> cIdentifierNoSC - _ -> cIdentifierNoSC - - noSC :: Parser a -> Parser a - noSC p = withConsumption ConsumeNothing p - - postfixGroupNoSC :: Parser (Bin', Group) - postfixGroupNoSC = - ((JuxtaposeNoSpace,) <$> withSrc cParensNoSC) - <|> ((JuxtaposeNoSpace,) <$> withSrc cBracketsNoSC) - <|> ((Dot,) <$> (try $ char '.' >> withSrc cIdentifierNoSC)) - - bracketNoSC :: Char -> Char -> Parser a -> Parser a - bracketNoSC l r p = charLexeme l >> mayBreak (sc >> p) <* char r - - cIdentifierNoSC :: Parser Group' - cIdentifierNoSC = noSC $ CIdentifier <$> anyName - - cParensNoSC :: Parser Group' - cParensNoSC = CParens <$> bracketNoSC '(' ')' (cParenGroup `sepBy` sym ",") - - cBracketsNoSC :: Parser Group' - cBracketsNoSC = CBrackets <$> bracketNoSC '[' ']' (commaSep cGroup) + 'f' -> cFor <|> cIdentifier + 'd' -> cDo <|> cIdentifier + 'r' -> cFor <|> cIdentifier + 'c' -> cCase <|> cIdentifier + 'i' -> cIf <|> cIdentifier + _ -> cIdentifier + + postfixGroup :: Parser (Bin', Group) + postfixGroup = noGap >> + ((JuxtaposeNoSpace,) <$> withSrc cParens) + <|> ((JuxtaposeNoSpace,) <$> withSrc cBrackets) + <|> ((Dot,) <$> (try $ char '.' >> withSrc cIdentifier)) + +cIdentifier :: Parser Group' +cIdentifier = CIdentifier <$> anyName + +cParens :: Parser Group' +cParens = CParens <$> parens (commaSep cParenGroup) + +cBrackets :: Parser Group' +cBrackets = CBrackets <$> brackets (commaSep cGroup) + +cBraces :: Parser Group' +cBraces = CBrackets <$> braces (commaSep cGroup) -- A `PrecTable` is enough information to (i) remove or replace -- operators for special contexts, and (ii) build the input structure diff --git a/src/lib/Lexing.hs b/src/lib/Lexing.hs index 4320b8c0..f49d3d1d 100644 --- a/src/lib/Lexing.hs +++ b/src/lib/Lexing.hs @@ -1,6 +1,12 @@ +-- Copyright 2022 Google LLC +-- +-- Use of this source code is governed by a BSD-style +-- license that can be found in the LICENSE file or at +-- https://developers.google.com/open-source/licenses/bsd + module Lexing where -import Control.Monad.Reader +import Control.Monad.State.Strict import Data.Char import Data.HashSet qualified as HS import qualified Data.Scientific as Scientific @@ -20,17 +26,19 @@ import Err import LabeledItems import Types.Primitives -data ParseCtx = ParseCtx { curIndent :: Int - , whitespaceConsumption :: WhitespaceConsumption } -type Parser = ReaderT ParseCtx (Parsec Void Text) +data ParseCtx = ParseCtx + { curIndent :: Int -- used Reader-style (i.e. ask/local) + , canBreak :: Bool -- used Reader-style (i.e. ask/local) + , prevWhitespace :: Bool -- tracks whether we just consumed whitespace + } + +initParseCtx :: ParseCtx +initParseCtx = ParseCtx 0 False False -data WhitespaceConsumption = - ConsumeNothing - | ConsumeAllExceptLineBreaks - | ConsumeAll +type Parser = StateT ParseCtx (Parsec Void Text) parseit :: Text -> Parser a -> Except a -parseit s p = case parse (runReaderT p (ParseCtx 0 ConsumeAllExceptLineBreaks)) "" s of +parseit s p = case parse (fst <$> runStateT p initParseCtx) "" s of Left e -> throw ParseErr $ errorBundlePretty e Right x -> return x @@ -41,8 +49,8 @@ mustParseit s p = case parseit s p of debug :: (Show a) => String -> Parser a -> Parser a debug lbl action = do - ctx <- ask - lift $ dbg lbl $ runReaderT action ctx + ctx <- get + lift $ dbg lbl $ fst <$> runStateT action ctx -- === Lexemes === @@ -160,11 +168,7 @@ knownSymStrs = HS.fromList -- string must be in `knownSymStrs` sym :: Text -> Lexer () -sym s = lexeme $ try $ string s >> notFollowedBy continuation - -- This awful hack is because "|}" should be a single lexeme for - -- variants, but we can't make "}" be a symChar because that would - -- allow user-defined symbols like --}, which is horrible. - where continuation = if s == "|" then (symChar <|> char '}') else symChar +sym s = lexeme $ try $ string s >> notFollowedBy symChar anySym :: Lexer String anySym = lexeme $ try $ do @@ -209,7 +213,8 @@ symChars = HS.fromList ".,!$^&*:-~+/=<>|?\\@#" -- === Util === sc :: Parser () -sc = skipMany $ hidden space <|> hidden lineComment +sc = (skipSome s >> recordWhitespace) <|> return () + where s = hidden space <|> hidden lineComment lineComment :: Parser () lineComment = do @@ -220,33 +225,29 @@ outputLines :: Parser () outputLines = void $ many (symbol ">" >> takeWhileP Nothing (/= '\n') >> ((eol >> return ()) <|> eof)) space :: Parser () -space = asks whitespaceConsumption >>= \case - ConsumeNothing -> fail "" - ConsumeAllExceptLineBreaks -> void $ takeWhile1P (Just "white space") (`elem` (" \t" :: String)) - ConsumeAll -> space1 - -withConsumption :: WhitespaceConsumption -> Parser a -> Parser a -withConsumption c p = localConsumption p \_ -> c -{-# INLINE withConsumption #-} - -localConsumption :: Parser a -> (WhitespaceConsumption -> WhitespaceConsumption) -> Parser a -localConsumption p update = - local (\ctx -> ctx { whitespaceConsumption = update $ whitespaceConsumption ctx }) p -{-# INLINE localConsumption #-} +space = gets canBreak >>= \case + True -> space1 + False -> void $ takeWhile1P (Just "white space") (`elem` (" \t" :: String)) mayBreak :: Parser a -> Parser a -mayBreak = withConsumption ConsumeAll +mayBreak p = pLocal (\ctx -> ctx { canBreak = True }) p {-# INLINE mayBreak #-} mayNotBreak :: Parser a -> Parser a -mayNotBreak p = localConsumption p \case - ConsumeAll -> ConsumeAllExceptLineBreaks - c -> c +mayNotBreak p = pLocal (\ctx -> ctx { canBreak = False }) p {-# INLINE mayNotBreak #-} -optionalMonoid :: Monoid a => Parser a -> Parser a -optionalMonoid p = p <|> return mempty -{-# INLINE optionalMonoid #-} +precededByWhitespace :: Parser Bool +precededByWhitespace = gets prevWhitespace +{-# INLINE precededByWhitespace #-} + +recordWhitespace :: Parser () +recordWhitespace = modify \ctx -> ctx { prevWhitespace = True } +{-# INLINE recordWhitespace #-} + +recordNonWhitespace :: Parser () +recordNonWhitespace = modify \ctx -> ctx { prevWhitespace = False } +{-# INLINE recordNonWhitespace #-} nameString :: Parser String nameString = lexeme . try $ (:) <$> lowerChar <*> many alphaNumChar @@ -281,7 +282,7 @@ withPos p = do nextLine :: Parser () nextLine = do eol - n <- asks curIndent + n <- curIndent <$> get void $ mayNotBreak $ many $ try (sc >> eol) void $ replicateM n (char ' ') @@ -297,9 +298,15 @@ withIndent p = do nextLine indent <- T.length <$> takeWhileP (Just "space") (==' ') when (indent <= 0) empty - local (\ctx -> ctx { curIndent = curIndent ctx + indent }) $ p + pLocal (\ctx -> ctx { curIndent = curIndent ctx + indent }) $ p {-# INLINE withIndent #-} +pLocal :: (ParseCtx -> ParseCtx) -> Parser a -> Parser a +pLocal f p = do + s <- get + put (f s) >> p <* put s +{-# INLINE pLocal #-} + eol :: Parser () eol = void MC.eol @@ -311,7 +318,7 @@ failIf True s = fail s failIf False _ = return () lexeme :: Parser a -> Parser a -lexeme = L.lexeme sc +lexeme p = L.lexeme sc (p <* recordNonWhitespace) {-# INLINE lexeme #-} symbol :: Text -> Parser () diff --git a/tests/adt-tests.dx b/tests/adt-tests.dx index 4a817de7..e9a5aff9 100644 --- a/tests/adt-tests.dx +++ b/tests/adt-tests.dx @@ -51,7 +51,7 @@ MkMyPair(z1, MkMyPair(z2, z3)) = zz > [(1, 1), (3, 2), (5, 3)] data MyEither(a:Type, b:Type) = - MyLeft (a) + MyLeft(a) MyRight(b) x : MyEither Int Float = MyLeft 1 @@ -189,7 +189,7 @@ AsList(_, xsTab) = xsList (x, y) > (3, 4) -def catLists (xs:List a, ys:List a) -> List a given (a) = +def catLists(xs:List a, ys:List a) -> List a given (a) = AsList(nx, xs') = xs AsList(ny, ys') = ys nz = nx + ny @@ -215,7 +215,7 @@ def catLists (xs:List a, ys:List a) -> List a given (a) = -def listToTable2 (l: List a) -> (Fin (list_length l))=>a given (a) = +def listToTable2(l: List a) -> (Fin (list_length l))=>a given (a) = AsList(_, xs) = l xs @@ -231,7 +231,7 @@ l2 = AsList _ [1, 2, 3] :p sum $ listToTable2 l2 > 6 -def zerosLikeList (l : List a) -> (Fin (list_length l))=>Float given (a) = +def zerosLikeList(l: List a) -> (Fin (list_length l))=>Float given (a) = for i:(Fin $ list_length l). 0.0 :p zerosLikeList l2 @@ -240,7 +240,7 @@ def zerosLikeList (l : List a) -> (Fin (list_length l))=>Float given (a) = data Graph(n|Ix, a:Type) = MkGraph(nodes:(n=>a), edges:(List (n, n))) -def graphToAdjacencyMatrix (g:Graph n a) -> n=>n=>Bool given (n|Ix, a) = +def graphToAdjacencyMatrix(g:Graph n a) -> n=>n=>Bool given (n|Ix, a) = MkGraph(nodes, AsList(_, edges)) = g init = for i j. False yield_state init \mRef. @@ -271,7 +271,7 @@ data MySum = data MySum2 = Foo2 - Bar2 (Fin 3 => Int) + Bar2(Fin 3 => Int) -- bug #348 :p concat for i:(Fin 4). AsList _ [(Foo2, Foo2)] diff --git a/tests/eval-tests.dx b/tests/eval-tests.dx index 474c79bc..73d1a541 100644 --- a/tests/eval-tests.dx +++ b/tests/eval-tests.dx @@ -136,7 +136,7 @@ s = 1.0 for i. xs[i] + 10 > [12, 11, 10] -def cumsumplus (xs:n=>Float) -> n=>Float given (n|Ix) = +def cumsumplus(xs:n=>Float) -> n=>Float given (n|Ix) = snd $ scan 0.0 \i c. ans = c + xs[i] (ans, 1.0 + ans) @@ -144,7 +144,7 @@ def cumsumplus (xs:n=>Float) -> n=>Float given (n|Ix) = :p cumsumplus [1.0, 2.0, 3.0] > [2., 4., 7.] -def cumsum2d (xs:n=>m=>a) -> n=>m=>a given (n|Ix, m|Ix, a|Add) = +def cumsum2d(xs:n=>m=>a) -> n=>m=>a given (n|Ix, m|Ix, a|Add) = cumsum for i. cumsum xs[i] xs : (Fin 4)=>(Fin 3)=>Float = arb $ new_key 0 @@ -520,7 +520,7 @@ count = (w32_to_w64 (i_to_w32 608135816) .<<. 32) .|. (w32_to_w64 (i_to_w32 2242 c > ([3., 2., 1., 0.], 4.) -def eitherFloor (x:(Int `Either` Float)) -> Int = case x of +def eitherFloor(x:(Int `Either` Float)) -> Int = case x of Left i -> i Right f -> f_to_i $ floor f @@ -569,7 +569,7 @@ Just 1 == Just 0 -- Userspace-Bool breaks these -- -- Needed to avoid ambiguous type variables if both sides use the same constructor --- def cmpEither (x:Int|Int) (y:Int|Int) : Bool = x == y +-- def cmpEither(x:Int|Int) (y:Int|Int) : Bool = x == y -- :p cmpEither (Left 1) (Left 1) -- > True @@ -708,7 +708,7 @@ def newtonSolve(tol:Float, f: (Float)->Float, x0:Float) -> Float = (f 5, w) > (7, 2.) --- def add (n : Type) ?-> (a : n=>Float) (b : n=>Float) : n=>Float = +-- def add(n : Type) ?-> (a : n=>Float) (b : n=>Float) : n=>Float = -- (tile (\t:(Tile n (Fin VectorWidth)). storeVector $ loadTile t a + loadTile t b) -- (\i:n. a.i + b.i)) -- toAdd = for _:(Fin 10). 1.0 @@ -809,7 +809,7 @@ s1 = "hello world" -- > [[0], [1, 2], [3, 4, 5], [6, 7, 8, 9]] -- @noinline -def fromLeftFloat (x:(Float `Either` Int)) -> Float = +def fromLeftFloat(x:(Float `Either` Int)) -> Float = case x of Left x' -> x' Right _ -> error "this is an error" @@ -822,27 +822,27 @@ def fromLeftFloat (x:(Float `Either` Int)) -> Float = > 1.2 @noinline -def f1 (x:Int) -> Int = x + 1 +def f1(x:Int) -> Int = x + 1 @noinline -def f2 (x:Int) -> Int = f1 $ f1 $ f1 $ f1 $ f1 $ f1 $ f1 $ f1 $ f1 $ f1 $ x +def f2(x:Int) -> Int = f1 $ f1 $ f1 $ f1 $ f1 $ f1 $ f1 $ f1 $ f1 $ f1 $ x @noinline -def f3 (x:Int) -> Int = f2 $ f2 $ f2 $ f2 $ f2 $ f2 $ f2 $ f2 $ f2 $ f2 $ x +def f3(x:Int) -> Int = f2 $ f2 $ f2 $ f2 $ f2 $ f2 $ f2 $ f2 $ f2 $ f2 $ x @noinline -def f4 (x:Int) -> Int = f3 $ f3 $ f3 $ f3 $ f3 $ f3 $ f3 $ f3 $ f3 $ f3 $ x +def f4(x:Int) -> Int = f3 $ f3 $ f3 $ f3 $ f3 $ f3 $ f3 $ f3 $ f3 $ f3 $ x @noinline -def f5 (x:Int) -> Int = f4 $ f4 $ f4 $ f4 $ f4 $ f4 $ f4 $ f4 $ f4 $ f4 $ x +def f5(x:Int) -> Int = f4 $ f4 $ f4 $ f4 $ f4 $ f4 $ f4 $ f4 $ f4 $ f4 $ x @noinline -def f6 (x:Int) -> Int = f5 $ f5 $ f5 $ f5 $ f5 $ f5 $ f5 $ f5 $ f5 $ f5 $ x +def f6(x:Int) -> Int = f5 $ f5 $ f5 $ f5 $ f5 $ f5 $ f5 $ f5 $ f5 $ f5 $ x -- regression test for #1229 - checks that constraints and data args can be interleaved @noinline -def interleave_args (x:Float, xs:n=>Float) -> Float given (n|Ix) = +def interleave_args(x:Float, xs:n=>Float) -> Float given (n|Ix) = x + sum xs interleave_args(1.0, [2.0, 3.0]) @@ -872,7 +872,7 @@ interleave_args(1.0, [2.0, 3.0]) :p intpow 2 17 > 131072 -def f (x:Float) -> Float = +def f(x:Float) -> Float = intpow x 2 x2 = 3.1 @@ -908,7 +908,7 @@ interface AssociatedWithTwo(a:Type, c:Type) value2 : () -> Int instance AssociatedWithTwo(Int, Float) - def value2 () = 8 + def value2() = 8 -- TODO: the source location for this error message is a bit off since we added -- n-ary applications. @@ -959,7 +959,7 @@ all $ for i:(Maybe (Fin 2)). i == (unsafe_from_ordinal (ordinal i)) > True -def all_prefixes (s:String) -> List String = +def all_prefixes(s:String) -> List String = AsList(_, s') = s to_list for i. post_slice s' first_ix i @@ -970,7 +970,7 @@ def all_prefixes (s:String) -> List String = -- -- Regression test for https://github.com/google-research/dex-lang/issues/694 -- data Thing = T {a: Float} --- def scan_over (xs:n=>Float) -> n=>Float given (n|Ix) = +-- def scan_over(xs:n=>Float) -> n=>Float given (n|Ix) = -- snd $ scan (T {a=1.0}) \i t. -- (T {a}) = t -- (T {a=a + xs[i]}, a) @@ -1005,19 +1005,19 @@ first_ix :: (Fin 0) > Type error:Couldn't synthesize a class dictionary for: (NonEmpty (Fin 0)) > > first_ix :: (Fin 0) -> ^^^^^^^^ +> ^^^^^^^^^ first_ix :: (Fin 0 `Either` Fin 0) > Type error:Couldn't synthesize a class dictionary for: (NonEmpty (Either (Fin 0) (Fin 0))) > > first_ix :: (Fin 0 `Either` Fin 0) -> ^^^^^^^^ +> ^^^^^^^^^ first_ix :: (Fin 0, ()) > Type error:Couldn't synthesize a class dictionary for: (NonEmpty (Fin 0, ())) > > first_ix :: (Fin 0, ()) -> ^^^^^^^^ +> ^^^^^^^^^ -- Subset tests @@ -1076,7 +1076,7 @@ Just f64_to_f -- from a case (this tickles ACase colliding with the code generated -- to print). The @noinline defeats case-of-known-constructor. @noinline -def id_bool (x:Bool) -> Bool = x +def id_bool(x:Bool) -> Bool = x case (id_bool True) of True -> Just f64_to_f diff --git a/tests/exception-tests.dx b/tests/exception-tests.dx index cc1effcd..d35aaebb 100644 --- a/tests/exception-tests.dx +++ b/tests/exception-tests.dx @@ -1,57 +1,57 @@ -def checkFloatInUnitInterval (x:Float) : {Except} Float = +def checkFloatInUnitInterval(x:Float) -> {Except} Float = assert $ x >= 0.0 assert $ x <= 1.0 x -:p catch do assert False +:p catch \. assert False > Nothing -:p catch do assert True +:p catch \. assert True > (Just ()) -:p catch do checkFloatInUnitInterval 1.2 +:p catch \. checkFloatInUnitInterval 1.2 > Nothing -:p catch do checkFloatInUnitInterval (-1.2) +:p catch \. checkFloatInUnitInterval (-1.2) > Nothing -:p catch do checkFloatInUnitInterval 0.2 +:p catch \. checkFloatInUnitInterval 0.2 > (Just 0.2) :p yield_state 0 \ref. - catch do + catch \. ref := 1 assert False ref := 2 > 1 -:p catch do +:p catch \. for i:(Fin 5). if ordinal i > 3 - then throw () + then throw() else 23 > Nothing -:p catch do +:p catch \. for i:(Fin 3). if ordinal i > 3 - then throw () + then throw() else 23 > (Just [23, 23, 23]) -- Is this the result we want? :p yield_state zero \ref. - catch do + catch \. for i:(Fin 6). if (ordinal i `rem` 2) == 0 - then throw () + then throw() else () ref!i := 1 > [0, 1, 0, 1, 0, 1] -:p catch do +:p catch \. run_state 0 \ref. ref := 1 assert False @@ -59,9 +59,9 @@ def checkFloatInUnitInterval (x:Float) : {Except} Float = > Nothing -- https://github.com/google-research/dex-lang/issues/612 -def sashabug (h: Unit) : {Except} List Int = +def sashabug(h: ()) -> {Except} List Int = yield_state mempty \results. results := (get results) <> AsList 1 [2] -catch do (catch do sashabug ()) +catch \. (catch \. sashabug ()) > (Just (Just (AsList 1 [2]))) diff --git a/tests/monad-tests.dx b/tests/monad-tests.dx index 92868294..6641f79e 100644 --- a/tests/monad-tests.dx +++ b/tests/monad-tests.dx @@ -1,21 +1,21 @@ :p - def m (ref:Ref h Int) -> {State h} Int given (h:Heap) = get ref + def m(ref:Ref h Int) -> {State h} Int given (h:Heap) = get ref run_state 2 m > (2, 2) :p - def m (ref:Ref h Int) -> {State h} () given (h:Heap) = ref := 3 + def m(ref:Ref h Int) -> {State h} () given (h:Heap) = ref := 3 run_state 0 m > ((), 3) :p - def m (ref:Ref h Int) -> {Read h} Int given (h:Heap) = ask ref + def m(ref:Ref h Int) -> {Read h} Int given (h:Heap) = ask ref with_reader 5 m > 5 :p - def stateAction (ref:Ref h Float) -> {State h} () given (h:Heap) = + def stateAction(ref:Ref h Float) -> {State h} () given (h:Heap) = x = get ref ref := (x + 2.0) z = get ref @@ -87,7 +87,7 @@ def myAction(w:Ref hw Float, r:Ref hr Float) run_accum (AddMonoid Float) \w1. run_accum (AddMonoid Float) \w2. m w1 w2 > (((), 3.), 2.) -def foom (s:Ref h ((Fin 3)=>Int)) -> {State h} () given (h:Heap) = +def foom(s:Ref h ((Fin 3)=>Int)) -> {State h} () given (h:Heap) = s!(from_ordinal 0) := 1 s!(from_ordinal 2) := 2 @@ -96,7 +96,7 @@ def foom (s:Ref h ((Fin 3)=>Int)) -> {State h} () given (h:Heap) = -- TODO: handle effects returning functions -- :p --- def foo (x:Float) : Float = +-- def foo(x:Float) : Float = -- f = withReader x \r. -- y = ask r -- \z. 100.0 * x + 10.0 * y + z @@ -198,7 +198,7 @@ def effectsAtZero(f: (Int)->{|eff} ()) -> {|eff} () given (eff:Effects) = > False -- Test custom list monoid with accum -def adjacencyMatrixToEdgeList (mat: n=>n=>Bool) -> List (n, n) given (n|Ix) = +def adjacencyMatrixToEdgeList(mat: n=>n=>Bool) -> List (n, n) given (n|Ix) = yield_accum(ListMonoid((n,n))) \list. for idxs. (i, j) = idxs diff --git a/tests/parser-tests.dx b/tests/parser-tests.dx index 446f666c..cb5b312e 100644 --- a/tests/parser-tests.dx +++ b/tests/parser-tests.dx @@ -30,7 +30,7 @@ f = \x. x + 10. :p f -1.0 -- parses as (-) f (-1.0) > Type error: -> Expected: (Float32 -> Float32) +> Expected: ((x:Float32) -> Float32) > Actual: Float32 > > :p f -1.0 -- parses as (-) f (-1.0) @@ -39,16 +39,6 @@ f = \x. x + 10. :p f (-1.0) > 9. -'Lambdas can have specific arrow annotations. - -lam1 = \{n}. \x:(Fin n). ordinal x -:t lam1 -> ((n:Nat) ?-> (Fin n) -> Nat) - -lam4 = \{n m}. (Fin n, Fin m) -:t lam4 -> (Nat ?-> Nat ?-> (Type & Type)) - :p ( 1 + @@ -59,7 +49,7 @@ lam4 = \{n m}. (Fin n, Fin m) :p xs = [1,2,3] for i. - if xs.i > 1 + if xs[i] > 1 then 0 else 1 > [1, 0, 0] @@ -71,24 +61,14 @@ lam4 = \{n m}. (Fin n, Fin m) ref := get ref + 1 > ((), 10) -def myInt : Int = 1 -:p myInt -> 1 - -def myInt2 : {State Int} Int = 1 -> Syntax error: Nullary def can't have effects -> -> def myInt2 : {State Int} Int = 1 -> ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ - -- Check that a syntax error in a funDefLet doesn't try to reparse the -- whole definition as something it's not. -def frob (x:Int) (y:Int) : + = +def frob(x:Int, y:Int) -> + = -> Parse error:86:30: +> Parse error:66:29: > | -> 86 | def frob (x:Int) (y:Int) : + = -> | ^ +> 66 | def frob(x:Int, y:Int) -> + = +> | ^ > unexpected '=' > expecting name or symbol name @@ -96,35 +76,31 @@ def frob (x:Int) (y:Int) : + = for i. i -> Parse error:97:1: +> Parse error:77:1: > | -> 97 | i +> 77 | i > | ^ > expecting end of line or space def (foo + bar) : Int = 6 -> Parse error:105:6: -> | -> 105 | def (foo + bar) : Int = 6 -> | ^ +> Parse error:85:6: +> | +> 85 | def (foo + bar) : Int = 6 +> | ^ > unexpected 'f' > expecting symbol name 'Data definitions allow but do not require type / kind annotations -data MyPair1 a b = MkPair1 a b - -data MyPair2 a:Type b:Type = MkPair2 x:a y:b - -'which may be grouped with parentheses +data MyPair1(a, b) = MkPair1(a, b) -data MyPair3 (a:Type) (b:Type) = MkPair3 (x:a) (y:b) +data MyPair2(a:Type, b:Type) = MkPair2(x:a, y:b) 'Data definitions allow interleaving arguments and class constraints (regression test for Issue 1015) -data TableInType n a [Ix n] table:(n=>a) = +data TableInType(n|Ix, a, table:(n=>a)) = MkTableInType -- Doesn't store any data except in the type! 'Left arrow <- desugars to a continuation lambda @@ -135,3 +111,83 @@ data TableInType n a [Ix n] table:(n=>a) = x := 4 get x > 4 + +'Check that we get reasonably helpful error messages if we try to write param lists +without parens or with whitespace before the parens. + + +data MyMaybe a = + MyNothing + +> Parse error:119:14: +> | +> 119 | data MyMaybe a = +> | ^ +> unexpected 'a' +> expecting '=' or parameter list in parentheses (without preceding whitespace) +data MyMaybe (a) = + MyNothing + +> Parse error:122:14: +> | +> 122 | data MyMaybe (a) = +> | ^ +> unexpected '(' +> expecting '=' or parameter list in parentheses (without preceding whitespace) +data MyMaybe(a) = + MyNothing + MyJust a + +> Parse error:127:10: +> | +> 127 | MyJust a +> | ^^ +> unexpected "a<newline>" +> expecting end of input, end of line, or optional parameter list +data MyMaybe(a) = + MyNothing + MyJust (a) + +> Parse error:131:10: +> | +> 131 | MyJust (a) +> | ^^ +> unexpected "(a" +> expecting end of input, end of line, or optional parameter list +interface MyClass a + pass + +> Parse error:133:19: +> | +> 133 | interface MyClass a +> | ^ +> expecting parameter list in parentheses (without preceding whitespace) +instance MyClass Int + pass + +> Parse error:136:18: +> | +> 136 | instance MyClass Int +> | ^ +> expecting '(' (without preceding whitespace) +instance MyClass (Int) + pass + +> Parse error:139:18: +> | +> 139 | instance MyClass (Int) +> | ^ +> expecting '(' (without preceding whitespace) +def myFunction (x) = () + +> Parse error:142:16: +> | +> 142 | def myFunction (x) = () +> | ^ +> expecting parameter list in parentheses (without preceding whitespace) +def myFunction x = () +> Parse error:144:16: +> | +> 144 | def myFunction x = () +> | ^ +> expecting parameter list in parentheses (without preceding whitespace) diff --git a/tests/print-tests.dx b/tests/print-tests.dx index d3cdaec3..e1cdbe8d 100644 --- a/tests/print-tests.dx +++ b/tests/print-tests.dx @@ -14,7 +14,7 @@ :pcodegen [Just (Just 1.0), Just Nothing, Nothing] > [(Just (Just 1.)), (Just Nothing), Nothing] -data MyType = MyValue (Nat) +data MyType = MyValue(Nat) :pcodegen MyValue 1 > (MyValue 1) diff --git a/tests/shadow-tests.dx b/tests/shadow-tests.dx index 9592aba6..4a059804 100644 --- a/tests/shadow-tests.dx +++ b/tests/shadow-tests.dx @@ -36,7 +36,7 @@ arr = 20 > Error: variable already defined: arr > > arr = 20 -> ^^^ +> ^^^^ :p arr > Error: ambiguous variable: arr is defined: @@ -89,4 +89,4 @@ ShadowCon' = 1 > Error: variable already defined: ShadowCon' > > ShadowCon' = 1 -> ^^^^^^^^^^ +> ^^^^^^^^^^^ diff --git a/tests/sort-tests.dx b/tests/sort-tests.dx index e5940dfb..e4b6876b 100644 --- a/tests/sort-tests.dx +++ b/tests/sort-tests.dx @@ -5,7 +5,6 @@ import sort :p is_sorted $ sort [9, 3, 7, 4, 6, 1, 9, 1, 9, -1, 10, 10, 100, 0] > True - '### Lexical Sorting Tests :p "aaa" < "bbb" diff --git a/tests/type-tests.dx b/tests/type-tests.dx index e6d16946..2e6aef63 100644 --- a/tests/type-tests.dx +++ b/tests/type-tests.dx @@ -153,13 +153,13 @@ MyPair : (Type) -> Type = -- TODO: put source annotation on effect for a better message here -def fEff () -> {|a} a given (a) = todo +def fEff() -> {|a} a given (a) = todo > Type error: > Expected: Type > Actual: EffKind > -> def fEff () -> {|a} a given (a) = todo -> ^ +> def fEff() -> {|a} a given (a) = todo +> ^^ :p for i:(Fin 7). sum for j:(Fin unboundName). 1.0 @@ -328,25 +328,25 @@ newPair : NewPair Int Float = MkNewPair 1 2.0 > :p (\x:Int. x) == (\x:Int. x) > ^^^^^^^^^^^^^^^^^^^^^^^^^^ -def getFst1 (xs:n=>b) -> b given (n|Ix, b) = +def getFst1(xs:n=>b) -> b given (n|Ix, b) = xs[from_ordinal 0] :p getFst1 [1,2,3] > 1 -def getFst2 (xs:n=>b) -> b given (n|Ix, b) = +def getFst2(xs:n=>b) -> b given (n|Ix, b) = xs[from_ordinal 0] :p getFst2 [1,2,3] > 1 -def getFst3 (xs:n=>b) -> b given (b, n|Ix) = +def getFst3(xs:n=>b) -> b given (b, n|Ix) = xs[from_ordinal 0] :p getFst3 [1,2,3] > 1 -def triRefIndex (ref:Ref h ((i':n)=>(..i')=>Float), i:n) -> Ref h ((..i)=>Float) +def triRefIndex(ref:Ref h ((i':n)=>(..i')=>Float), i:n) -> Ref h ((..i)=>Float) given (h, n|Ix) (Data ((i':n)=>(..i')=>Float)) = %indexRef(ref, i) @@ -369,7 +369,7 @@ def triRefIndex (ref:Ref h ((i':n)=>(..i')=>Float), i:n) -> Ref h ((..i)=>Float) id (for i:(Fin 2). for j:(..i). 1.0) > [[1.], [1., 1.]] -def weakerInferenceReduction (l: (i:n)=>(..i)=>Float, j:n) -> () given (n|Ix) = +def weakerInferenceReduction(l: (i:n)=>(..i)=>Float, j:n) -> () given (n|Ix) = for i:(..j). i' = inject(superset=n, i) for k:(..i'). @@ -408,7 +408,7 @@ c = [1, 2] -- Tests for type inference of table literals -def mkEmpty (a|Data) -> (Fin 0)=>a = [] +def mkEmpty(a|Data) -> (Fin 0)=>a = [] :t [0.0, 1.0] > ((Fin 2) => Float32) @@ -419,7 +419,7 @@ def mkEmpty (a|Data) -> (Fin 0)=>a = [] :t (coerce_table _ [0.0, 1.0]) :: (Fin 1, Fin 2)=>Float > ((Fin 1, Fin 2) => Float32) -def uncurryTable (x: (Fin 2, Fin 2)=>a) -> (Fin 2)=>(Fin 2)=>a given (a) = +def uncurryTable(x: (Fin 2, Fin 2)=>a) -> (Fin 2)=>(Fin 2)=>a given (a) = for i j. x[(i, j)] -- We should be able to infer the tuple type here @@ -431,8 +431,8 @@ def uncurryTable (x: (Fin 2, Fin 2)=>a) -> (Fin 2)=>(Fin 2)=>a given (a) = > ((Fin 2) => (Fin 2) => Nat) -- Make sure that the local type alias is unifiable with Int -def GetInt (n: Int) -> Type = Int -def ff (n : Int) -> Int = +def GetInt(n: Int) -> Type = Int +def ff(n : Int) -> Int = i = GetInt n the i 2 @@ -440,7 +440,7 @@ ff 0 > 2 -- The two local aliases for Fin n should be unifiable with each other and Fin n -def q (n: Nat) -> (Fin n)=>Nat = +def q(n: Nat) -> (Fin n)=>Nat = ix1 = Fin n x1 = for i:ix1. ordinal i ix2 = Fin n @@ -478,7 +478,7 @@ q 5 -- Make sure we fail gracefully when the annotated index set doesn't -- have a static size. -def frob (_:()) -> () given (n) = +def frob(_:()) -> () given (n) = [0.0, 1.0]::((Fin n)=>Float) () > Type error: @@ -494,7 +494,7 @@ def frob (_:()) -> () given (n) = -- foo is a function with all-implicit arguments (whether that's a -- good idea or not). -def foo () ->> a=>Float given (a|Ix) = +def foo() ->> a=>Float given (a|Ix) = for z:a. 1.0 :t foo @@ -517,7 +517,7 @@ foo : Fin 3 => Float = [1.0, 2.0, 3.0] > Error: variable already defined: foo > > foo : Fin 3 => Float = [1.0, 2.0, 3.0] -> ^^^ +> ^^^^ '### Tests for dependent pair syntax @@ -539,10 +539,10 @@ if False '### Tests for dependent pair pattern match -def LowerTriMat (n|Ix, v:Type) -> Type = (i:n)=>(..i)=>v -def UpperTriMat (n|Ix, v:Type) -> Type = (i:n)=>(i..)=>v +def LowerTriMat(n|Ix, v:Type) -> Type = (i:n)=>(..i)=>v +def UpperTriMat(n|Ix, v:Type) -> Type = (i:n)=>(i..)=>v -def transpose_ix (i:n, j:(i..)) -> (i:n &> ..i) given (n|Ix) = +def transpose_ix(i:n, j:(i..)) -> (i:n &> ..i) given (n|Ix) = j' = inject(superset=n, j) i' = unsafe_project i (j' ,> i') diff --git a/tests/uexpr-tests.dx b/tests/uexpr-tests.dx index 1fa396e2..5524ef79 100644 --- a/tests/uexpr-tests.dx +++ b/tests/uexpr-tests.dx @@ -2,12 +2,12 @@ :p 3 + (4 + 5) > 12 -def depId (a:Type, x:a) -> a = x +def depId(a:Type, x:a) -> a = x :p depId Int 1 > 1 -def returnFirstArg (a:Type, b:Type, x:a, y:b) -> a = x +def returnFirstArg(a:Type, b:Type, x:a, y:b) -> a = x :p returnFirstArg Int Float 1 2.0 > 1 @@ -15,17 +15,17 @@ def returnFirstArg (a:Type, b:Type, x:a, y:b) -> a = x :p 1.0 + 2.0 > 3. -def triple (x:Float) -> Float = x + x + x +def triple(x:Float) -> Float = x + x + x :p triple 1.0 > 3. -def idExplicit (a:Type, x:a) -> a = x +def idExplicit(a:Type, x:a) -> a = x :p idExplicit Int 1 > 1 -def idImplicit (x:a) -> a given (a:Type) = x +def idImplicit(x:a) -> a given (a:Type) = x :p idImplicit 1 > 1 @@ -54,27 +54,27 @@ idImplicit2 : (given (a:Type), a) -> a = \x. x > Error: variable not in scope: x > > :p x + x -> ^ +> ^^ idiv = 1 > Error: variable already defined: idiv > > idiv = 1 -> ^^^^ +> ^^^^^ -def TyId (a:Type) -> Type = a +def TyId(a:Type) -> Type = a :p x:(TyId Int) = 1 x > 1 :p - def TyId2 (a:Type) -> Type = a + def TyId2(a:Type) -> Type = a x:(TyId2 Int) = 1 x > 1 -def tabId (x:n=>Int) -> n=>Int given (n|Ix) = for i. x[i] +def tabId(x:n=>Int) -> n=>Int given (n|Ix) = for i. x[i] -- bug: this doesn't work if we split it across top-level decls :p @@ -176,7 +176,7 @@ myPair = (1, 2.3) xsRef!i := ordinal i > [0, 1, 2] -def passthrough (f:(a)->{|eff} b, x:a) -> {|eff} b +def passthrough(f:(a)->{|eff} b, x:a) -> {|eff} b given (a, b, eff:Effects) = f x @@ -208,13 +208,13 @@ def passthrough (f:(a)->{|eff} b, x:a) -> {|eff} b > f : (aa:Type, aa) -> aa = \bb x. myId x > ^^^^^^^^^^^^^ -def myFst (p:(a,b)) -> a given (a, b) = +def myFst(p:(a,b)) -> a given (a, b) = (x, _) = p x :p myFst (1,2) > 1 -def myOtherFst (pair:(a,b)) -> a given (a, b) = +def myOtherFst(pair:(a,b)) -> a given (a, b) = (x, _) = pair x :p myOtherFst (1,2) @@ -253,7 +253,7 @@ def myOtherFst (pair:(a,b)) -> a given (a, b) = id'' : (given (b:Type), b) -> b = id -def eitherFloor (x:(Either Int Float)) -> Int = case x of +def eitherFloor(x:(Either Int Float)) -> Int = case x of Left i -> i Right f -> f_to_i f @@ -285,7 +285,7 @@ def eitherFloor (x:(Either Int Float)) -> Int = case x of > (Solving for: [a, b]) > > [a, b] = [1, 2, 3] -> ^^^^^^ +> ^^^^^^^ :p [[a, _], [_, d]] = [[1, 2], [3, 4]] @@ -302,7 +302,7 @@ def eitherFloor (x:(Either Int Float)) -> Int = case x of > (Solving for: [a, b]) > > [a, b, c, d] = coerce_table (Fin 2, Fin 2) [1, 2, 3, 4] -> ^^^^^^^^^^^^ +> ^^^^^^^^^^^^^ -- Needs delayed inference (can't verify Ix and reduce the type before we infer the hole). -- :p @@ -310,7 +310,7 @@ def eitherFloor (x:(Either Int Float)) -> Int = case x of -- (a, b, c, d) -- > (0, (0, (0, 0))) -def bug (n|Data) -> () = +def bug(n|Data) -> () = for w':n. w : n = todo for i:(w..). () @@ -336,7 +336,7 @@ badDefinition = 4 > Error: variable already defined: badDefinition > > badDefinition = 4 -> ^^^^^^^^^^^^^ +> ^^^^^^^^^^^^^^ :p badDefinition > Error: ambiguous variable: badDefinition is defined: |