summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--lib/prelude.dx1583
-rw-r--r--lib/sort.dx42
-rw-r--r--makefile4
-rw-r--r--src/lib/ConcreteSyntax.hs123
-rw-r--r--src/lib/Lexing.hs87
-rw-r--r--tests/adt-tests.dx12
-rw-r--r--tests/eval-tests.dx42
-rw-r--r--tests/exception-tests.dx32
-rw-r--r--tests/monad-tests.dx14
-rw-r--r--tests/parser-tests.dx134
-rw-r--r--tests/print-tests.dx2
-rw-r--r--tests/shadow-tests.dx4
-rw-r--r--tests/sort-tests.dx1
-rw-r--r--tests/type-tests.dx38
-rw-r--r--tests/uexpr-tests.dx36
15 files changed, 1116 insertions, 1038 deletions
diff --git a/lib/prelude.dx b/lib/prelude.dx
index e0dcbdeb..d9f9def7 100644
--- a/lib/prelude.dx
+++ b/lib/prelude.dx
@@ -28,455 +28,455 @@ RawPtr : Type = %Word8Ptr()
Int = Int32
Float = Float32
-def the (a:Type, x:a) -> a = x
+def the(a:Type, x:a) -> a = x
-interface Data (a:Type)
+interface Data(a:Type)
do_not_implement_this_interface_for_the_compiler_relies_on_the_invariant_it_protects : (a) -> a
'### Casting
-def internal_cast (x:from) -> to given (from, to) =
+def internal_cast(x:from) -> to given (from, to) =
%cast(to, x)
-def unsafe_coerce (x:from) -> to given (from|Data, to|Data) = %unsafeCoerce(to, x)
-def uninitialized_value () -> a given (a|Data) = %garbageVal(a)
-
-def f64_to_f (x : Float64) -> Float = internal_cast x
-def f32_to_f (x : Float32) -> Float = internal_cast x
-def f_to_f64 (x : Float) -> Float64 = internal_cast x
-def f_to_f32 (x : Float) -> Float32 = internal_cast x
-def i64_to_i (x : Int64) -> Int = internal_cast x
-def i32_to_i (x : Int32) -> Int = internal_cast x
-def w8_to_i (x : Word8) -> Int = internal_cast x
-def i_to_i64 (x : Int) -> Int64 = internal_cast x
-def i_to_i32 (x : Int) -> Int32 = internal_cast x
-def i_to_w8 (x : Int) -> Word8 = internal_cast x
-def i_to_w32 (x : Int) -> Word32 = internal_cast x
-def i_to_w64 (x : Int) -> Word64 = internal_cast x
-def w32_to_w64 (x : Word32)-> Word64 = internal_cast x
-def i_to_f (x:Int) -> Float = internal_cast x
-def f_to_i (x:Float) -> Int = internal_cast x
-def raw_ptr_to_i64 (x:RawPtr) -> Int64 = internal_cast x
+def unsafe_coerce(x:from) -> to given (from|Data, to|Data) = %unsafeCoerce(to, x)
+def uninitialized_value() -> a given (a|Data) = %garbageVal(a)
+
+def f64_to_f(x: Float64) -> Float = internal_cast x
+def f32_to_f(x: Float32) -> Float = internal_cast x
+def f_to_f64(x: Float) -> Float64 = internal_cast x
+def f_to_f32(x: Float) -> Float32 = internal_cast x
+def i64_to_i(x: Int64) -> Int = internal_cast x
+def i32_to_i(x: Int32) -> Int = internal_cast x
+def w8_to_i(x: Word8) -> Int = internal_cast x
+def i_to_i64(x: Int) -> Int64 = internal_cast x
+def i_to_i32(x: Int) -> Int32 = internal_cast x
+def i_to_w8(x: Int) -> Word8 = internal_cast x
+def i_to_w32(x: Int) -> Word32 = internal_cast x
+def i_to_w64(x: Int) -> Word64 = internal_cast x
+def w32_to_w64(x: Word32)-> Word64 = internal_cast x
+def i_to_f(x:Int) -> Float = internal_cast x
+def f_to_i(x:Float) -> Int = internal_cast x
+def raw_ptr_to_i64(x:RawPtr) -> Int64 = internal_cast x
Nat = %Nat()
NatRep = Word32
-def nat_to_rep (x : Nat) -> NatRep = %projNewtype(x)
-def rep_to_nat (x : NatRep) -> Nat = %NatCon(x)
-
-def n_to_w8 (x : Nat) -> Word8 = nat_to_rep x | internal_cast
-def n_to_w32 (x : Nat) -> Word32 = nat_to_rep x | internal_cast
-def n_to_w64 (x : Nat) -> Word64 = nat_to_rep x | internal_cast
-def n_to_i32 (x : Nat) -> Int32 = nat_to_rep x | internal_cast
-def n_to_i64 (x : Nat) -> Int64 = nat_to_rep x | internal_cast
-def n_to_f32 (x : Nat) -> Float32 = nat_to_rep x | internal_cast
-def n_to_f64 (x : Nat) -> Float64 = nat_to_rep x | internal_cast
-def n_to_f (x : Nat) -> Float = nat_to_rep x | internal_cast
-
-def w8_to_n (x : Word8) -> Nat = internal_cast x | rep_to_nat
-def w32_to_n (x : Word32) -> Nat = internal_cast x | rep_to_nat
-def w64_to_n (x : Word64) -> Nat = internal_cast x | rep_to_nat
-def i32_to_n (x : Int32) -> Nat = internal_cast x | rep_to_nat
-def i64_to_n (x : Int64) -> Nat = internal_cast x | rep_to_nat
-def f32_to_n (x : Float32) -> Nat = internal_cast x | rep_to_nat
-def f64_to_n (x : Float64) -> Nat = internal_cast x | rep_to_nat
-def f_to_n (x : Float) -> Nat = internal_cast x | rep_to_nat
-
-interface FromUnsignedInteger (a:Type)
+def nat_to_rep(x : Nat) -> NatRep = %projNewtype(x)
+def rep_to_nat(x : NatRep) -> Nat = %NatCon(x)
+
+def n_to_w8(x: Nat) -> Word8 = nat_to_rep x | internal_cast
+def n_to_w32(x: Nat) -> Word32 = nat_to_rep x | internal_cast
+def n_to_w64(x: Nat) -> Word64 = nat_to_rep x | internal_cast
+def n_to_i32(x: Nat) -> Int32 = nat_to_rep x | internal_cast
+def n_to_i64(x: Nat) -> Int64 = nat_to_rep x | internal_cast
+def n_to_f32(x: Nat) -> Float32 = nat_to_rep x | internal_cast
+def n_to_f64(x: Nat) -> Float64 = nat_to_rep x | internal_cast
+def n_to_f(x: Nat) -> Float = nat_to_rep x | internal_cast
+
+def w8_to_n(x : Word8) -> Nat = internal_cast x | rep_to_nat
+def w32_to_n(x : Word32) -> Nat = internal_cast x | rep_to_nat
+def w64_to_n(x : Word64) -> Nat = internal_cast x | rep_to_nat
+def i32_to_n(x : Int32) -> Nat = internal_cast x | rep_to_nat
+def i64_to_n(x : Int64) -> Nat = internal_cast x | rep_to_nat
+def f32_to_n(x : Float32) -> Nat = internal_cast x | rep_to_nat
+def f64_to_n(x : Float64) -> Nat = internal_cast x | rep_to_nat
+def f_to_n(x : Float) -> Nat = internal_cast x | rep_to_nat
+
+interface FromUnsignedInteger(a:Type)
from_unsigned_integer : (Word64) -> a
-instance FromUnsignedInteger Word8
- def from_unsigned_integer (x) = internal_cast x
+instance FromUnsignedInteger(Word8)
+ def from_unsigned_integer(x) = internal_cast x
-instance FromUnsignedInteger Word32
- def from_unsigned_integer (x) = internal_cast x
+instance FromUnsignedInteger(Word32)
+ def from_unsigned_integer(x) = internal_cast x
-instance FromUnsignedInteger Word64
- def from_unsigned_integer (x) = x
+instance FromUnsignedInteger(Word64)
+ def from_unsigned_integer(x) = x
-instance FromUnsignedInteger Int32
- def from_unsigned_integer (x) = internal_cast x
+instance FromUnsignedInteger(Int32)
+ def from_unsigned_integer(x) = internal_cast x
-instance FromUnsignedInteger Int64
- def from_unsigned_integer (x) = internal_cast x
+instance FromUnsignedInteger(Int64)
+ def from_unsigned_integer(x) = internal_cast x
-instance FromUnsignedInteger Float32
- def from_unsigned_integer (x) = internal_cast x
+instance FromUnsignedInteger(Float32)
+ def from_unsigned_integer(x) = internal_cast x
-instance FromUnsignedInteger Float64
- def from_unsigned_integer (x) = internal_cast x
+instance FromUnsignedInteger(Float64)
+ def from_unsigned_integer(x) = internal_cast x
-instance FromUnsignedInteger Nat
- def from_unsigned_integer (x) = w64_to_n(x)
+instance FromUnsignedInteger(Nat)
+ def from_unsigned_integer(x) = w64_to_n(x)
-interface FromInteger (a:Type)
+interface FromInteger(a:Type)
from_integer : (Int64) -> a
-instance FromInteger Float32
- def from_integer (x) = internal_cast x
+instance FromInteger(Float32)
+ def from_integer(x) = internal_cast x
-instance FromInteger Int32
- def from_integer (x) = internal_cast x
+instance FromInteger(Int32)
+ def from_integer(x) = internal_cast x
-instance FromInteger Float64
- def from_integer (x) = internal_cast x
+instance FromInteger(Float64)
+ def from_integer(x) = internal_cast x
-instance FromInteger Int64
- def from_integer (x) = x
+instance FromInteger(Int64)
+ def from_integer(x) = x
'## Bitwise operations
-interface Bits (a:Type)
+interface Bits(a:Type)
(.<<.) : (a, Int) -> a
(.>>.) : (a, Int) -> a
(.|.) : (a, a) -> a
(.&.) : (a, a) -> a
(.^.) : (a, a) -> a
-instance Bits Word8
- def (.<<.) (x, y) = %shl(x, i_to_w8 y)
- def (.>>.) (x, y) = %shr(x, i_to_w8 y)
- def (.|.) (x, y) = %or( x, y)
- def (.&.) (x, y) = %and(x, y)
- def (.^.) (x, y) = %xor(x, y)
-
-instance Bits Word32
- def (.<<.) (x, y) = %shl(x, i_to_w32 y)
- def (.>>.) (x, y) = %shr(x, i_to_w32 y)
- def (.|.) (x, y) = %or( x, y)
- def (.&.) (x, y) = %and(x, y)
- def (.^.) (x, y) = %xor(x, y)
-
-instance Bits Word64
- def (.<<.) (x, y) = %shl(x, i_to_w64 y)
- def (.>>.) (x, y) = %shr(x, i_to_w64 y)
- def (.|.) (x, y) = %or( x ,y)
- def (.&.) (x, y) = %and(x ,y)
- def (.^.) (x, y) = %xor(x ,y)
-
-def low_word (x : Word64) -> Word32 = internal_cast(x .>>. 32)
-def high_word (x : Word64) -> Word32 = internal_cast(x)
+instance Bits(Word8)
+ def (.<<.)(x, y) = %shl(x, i_to_w8 y)
+ def (.>>.)(x, y) = %shr(x, i_to_w8 y)
+ def (.|.)(x, y) = %or( x, y)
+ def (.&.)(x, y) = %and(x, y)
+ def (.^.)(x, y) = %xor(x, y)
+
+instance Bits(Word32)
+ def (.<<.)(x, y) = %shl(x, i_to_w32 y)
+ def (.>>.)(x, y) = %shr(x, i_to_w32 y)
+ def (.|.)(x, y) = %or( x, y)
+ def (.&.)(x, y) = %and(x, y)
+ def (.^.)(x, y) = %xor(x, y)
+
+instance Bits(Word64)
+ def (.<<.)(x, y) = %shl(x, i_to_w64 y)
+ def (.>>.)(x, y) = %shr(x, i_to_w64 y)
+ def (.|.)(x, y) = %or( x ,y)
+ def (.&.)(x, y) = %and(x ,y)
+ def (.^.)(x, y) = %xor(x ,y)
+
+def low_word( x : Word64) -> Word32 = internal_cast(x .>>. 32)
+def high_word(x : Word64) -> Word32 = internal_cast(x)
'### Basic Arithmetic
#### Add
Things that can be added.
This defines the `Add` [group](https://en.wikipedia.org/wiki/Group_(mathematics)) and its operators.
-interface Add (a|Data)
+interface Add(a|Data)
(+) : (a, a) -> a
zero : a
-interface Sub (a|Add)
+interface Sub(a|Add)
(-) : (a, a) -> a
-instance Add Float64
- def (+) (x, y) = %fadd(x, y)
+instance Add(Float64)
+ def (+)(x, y) = %fadd(x, y)
zero = 0
-instance Sub Float64
- def (-) (x, y) = %fsub(x, y)
+instance Sub(Float64)
+ def (-)(x, y) = %fsub(x, y)
-instance Add Float32
- def (+) (x, y) = %fadd(x, y)
+instance Add(Float32)
+ def (+)(x, y) = %fadd(x, y)
zero = 0
-instance Sub Float32
- def (-) (x, y) = %fsub(x, y)
+instance Sub(Float32)
+ def (-)(x, y) = %fsub(x, y)
-instance Add Int64
- def (+) (x, y) = %iadd(x, y)
+instance Add(Int64)
+ def (+)(x, y) = %iadd(x, y)
zero = 0
-instance Sub Int64
- def (-) (x, y) = %isub(x, y)
+instance Sub(Int64)
+ def (-)(x, y) = %isub(x, y)
-instance Add Int32
- def (+) (x, y) = %iadd(x, y)
+instance Add(Int32)
+ def (+)(x, y) = %iadd(x, y)
zero = 0
-instance Sub Int32
- def (-) (x, y) = %isub(x, y)
+instance Sub(Int32)
+ def (-)(x, y) = %isub(x, y)
-instance Add Word8
- def (+) (x, y) = %iadd(x, y)
+instance Add(Word8)
+ def (+)(x, y) = %iadd(x, y)
zero = 0
-instance Sub Word8
- def (-) (x, y) = %isub(x, y)
+instance Sub(Word8)
+ def (-)(x, y) = %isub(x, y)
-instance Add Word32
- def (+) (x, y) = %iadd(x, y)
+instance Add(Word32)
+ def (+)(x, y) = %iadd(x, y)
zero = 0
-instance Sub Word32
- def (-) (x, y) = %isub(x, y)
+instance Sub(Word32)
+ def (-)(x, y) = %isub(x, y)
-instance Add Word64
- def (+) (x, y) = %iadd(x, y)
+instance Add(Word64)
+ def (+)(x, y) = %iadd(x, y)
zero = 0
-instance Sub Word64
- def (-) (x, y) = %isub(x, y)
+instance Sub(Word64)
+ def (-)(x, y) = %isub(x, y)
-instance Add Nat
- def (+) (x, y) = rep_to_nat %iadd(nat_to_rep x, nat_to_rep y)
+instance Add(Nat)
+ def (+)(x, y) = rep_to_nat %iadd(nat_to_rep x, nat_to_rep y)
zero = 0
-instance Add ()
+instance Add(())
def (+)(x, y) = ()
zero = ()
-instance Sub ()
+instance Sub(())
def (-)(x, y) = ()
'#### Mul
Things that can be multiplied.
This defines the `Mul` [Monoid](https://en.wikipedia.org/wiki/Monoid), and its operator.
-interface Mul (a|Data)
+interface Mul(a|Data)
(*) : (a, a) -> a
one : a
-instance Mul Float64
- def (*) (x, y) = %fmul(x, y)
+instance Mul(Float64)
+ def (*)(x, y) = %fmul(x, y)
one = f_to_f64 1.0
-instance Mul Float32
- def (*) (x, y) = %fmul(x, y)
+instance Mul(Float32)
+ def (*)(x, y) = %fmul(x, y)
one = f_to_f32 1.0
-instance Mul Int64
- def (*) (x, y) = %imul(x, y)
+instance Mul(Int64)
+ def (*)(x, y) = %imul(x, y)
one = 1
-instance Mul Int32
- def (*) (x, y) = %imul(x, y)
+instance Mul(Int32)
+ def (*)(x, y) = %imul(x, y)
one = 1
-instance Mul Word8
- def (*) (x, y) = %imul(x, y)
+instance Mul(Word8)
+ def (*)(x, y) = %imul(x, y)
one = 1
-instance Mul Word32
- def (*) (x, y) = %imul(x, y)
+instance Mul(Word32)
+ def (*)(x, y) = %imul(x, y)
one = 1
-instance Mul Word64
- def (*) (x, y) = %imul(x, y)
+instance Mul(Word64)
+ def (*)(x, y) = %imul(x, y)
one = 1
-instance Mul Nat
- def (*) (x, y) = rep_to_nat %imul(nat_to_rep x, nat_to_rep y)
+instance Mul(Nat)
+ def(*)(x, y) = rep_to_nat %imul(nat_to_rep x, nat_to_rep y)
one = 1
-instance Mul ()
- def (*) (x, y) = ()
+instance Mul(())
+ def (*)(x, y) = ()
one = ()
'#### Integral
Integer-like things.
-interface Integral (a)
+interface Integral(a)
idiv : (a,a)->a
rem : (a,a)->a
-instance Integral Int64
- def idiv (x, y) = %idiv(x, y)
- def rem (x, y) = %irem(x, y)
+instance Integral(Int64)
+ def idiv(x, y) = %idiv(x, y)
+ def rem(x, y) = %irem(x, y)
-instance Integral Int32
- def idiv (x, y) = %idiv(x, y)
- def rem (x, y) = %irem(x, y)
+instance Integral(Int32)
+ def idiv(x, y) = %idiv(x, y)
+ def rem(x, y) = %irem(x, y)
-instance Integral Word8
- def idiv (x, y) = %idiv(x, y)
- def rem (x, y) = %irem(x, y)
+instance Integral(Word8)
+ def idiv(x, y) = %idiv(x, y)
+ def rem(x, y) = %irem(x, y)
-instance Integral Word32
- def idiv (x, y) = %idiv(x, y)
- def rem (x, y) = %irem(x, y)
+instance Integral(Word32)
+ def idiv(x, y) = %idiv(x, y)
+ def rem(x, y) = %irem(x, y)
-instance Integral Word64
- def idiv (x, y) = %idiv(x, y)
- def rem (x, y) = %irem(x, y)
+instance Integral(Word64)
+ def idiv(x, y) = %idiv(x, y)
+ def rem(x, y) = %irem(x, y)
-instance Integral Nat
- def idiv (x, y) = rep_to_nat %idiv(nat_to_rep x, (nat_to_rep y))
- def rem (x, y) = rep_to_nat %irem(nat_to_rep x, (nat_to_rep y))
+instance Integral(Nat)
+ def idiv(x, y) = rep_to_nat %idiv(nat_to_rep x, (nat_to_rep y))
+ def rem(x, y) = rep_to_nat %irem(nat_to_rep x, (nat_to_rep y))
'#### Fractional
Rational-like things.
Includes floating point and two field rational representations.
-interface Fractional (a)
+interface Fractional(a)
divide : (a, a) -> a
-instance Fractional Float64
- def divide (x, y) = %fdiv(x, y)
+instance Fractional(Float64)
+ def divide(x, y) = %fdiv(x, y)
-instance Fractional Float32
- def divide (x, y) = %fdiv(x, y)
+instance Fractional(Float32)
+ def divide(x, y) = %fdiv(x, y)
'## Index set interface and instances
-interface Ix (n|Data)
+interface Ix(n|Data)
size' : () -> Nat
ordinal : (n) -> Nat
unsafe_from_ordinal : (Nat) -> n
-def size (n|Ix) -> Nat = size'(n=n)
+def size(n|Ix) -> Nat = size'(n=n)
-def Fin (n:Nat) -> Type = %Fin(n)
+def Fin(n:Nat) -> Type = %Fin(n)
-- version of subtraction on Nats that clamps at zero
-def (-|) (x: Nat, y:Nat) -> Nat =
+def (-|)(x: Nat, y:Nat) -> Nat =
x' = nat_to_rep x
y' = nat_to_rep y
requires_clamp = %ilt(x', y')
rep_to_nat %select(requires_clamp, 0, (%isub(x', y')))
-def unsafe_nat_diff (x:Nat, y:Nat) -> Nat =
+def unsafe_nat_diff(x:Nat, y:Nat) -> Nat =
x' = nat_to_rep x
y' = nat_to_rep y
rep_to_nat %isub(x', y')
-- `(i..)` parses as `RangeFrom _ i`
-- TODO: need to a way to indicate `.new` as private
-struct RangeFrom (q:Type, i:q) = val : Nat
+struct RangeFrom(q:Type, i:q) = val : Nat
-- `(i<..)` parses as `RangeFromExc _ i`
-struct RangeFromExc (q:Type, i:q) = val : Nat
+struct RangeFromExc(q:Type, i:q) = val : Nat
-- `(..i)` parses as `RangeTo _ i`
-struct RangeTo (q:Type, i:q) = val : Nat
+struct RangeTo(q:Type, i:q) = val : Nat
-- `(..<i)` parses as `RangeToExc _ i`
-struct RangeToExc (q:Type, i:q) = val : Nat
+struct RangeToExc(q:Type, i:q) = val : Nat
-instance Ix RangeFrom(q, i) given (q|Ix, i:q)
- def size' () = unsafe_nat_diff(size q, ordinal i)
+instance Ix(RangeFrom q i) given (q|Ix, i:q)
+ def size'() = unsafe_nat_diff(size q, ordinal i)
def ordinal(j) = j.val
def unsafe_from_ordinal(j) = RangeFrom.new(j)
-instance Ix RangeFromExc(q, i) given (q|Ix, i:q)
- def size' () = unsafe_nat_diff(size q, ordinal i + 1)
+instance Ix(RangeFromExc q i) given (q|Ix, i:q)
+ def size'() = unsafe_nat_diff(size q, ordinal i + 1)
def ordinal(j) = j.val
def unsafe_from_ordinal(j) = RangeFromExc.new(j)
-instance Ix RangeTo(q, i) given (q|Ix, i:q)
- def size' () = ordinal i + 1
+instance Ix(RangeTo q i) given (q|Ix, i:q)
+ def size'() = ordinal i + 1
def ordinal(j) = j.val
def unsafe_from_ordinal(j) = RangeTo.new(j)
-instance Ix RangeToExc(q, i) given (q|Ix, i:q)
- def size' () = ordinal i
+instance Ix(RangeToExc q i) given (q|Ix, i:q)
+ def size'() = ordinal i
def ordinal(j) = j.val
def unsafe_from_ordinal(j) = RangeToExc.new(j)
-instance Ix ()
- def size' () = 1
+instance Ix(())
+ def size'() = 1
def ordinal(_) = 0
def unsafe_from_ordinal(_) = ()
-def iota (n|Ix) -> n=>Nat = for i. ordinal i
+def iota(n|Ix) -> n=>Nat = for i. ordinal i
'## Arithmetic instances for table types
-instance Add (n=>a) given (a|Add, n|Ix)
- def (+) (xs, ys) = for i. xs[i] + ys[i]
+instance Add(n=>a) given (a|Add, n|Ix)
+ def (+)(xs, ys) = for i. xs[i] + ys[i]
zero = for _. zero
-instance Sub (n=>a) given (a|Sub, n|Ix)
- def (-) (xs, ys) = for i. xs[i] - ys[i]
+instance Sub(n=>a) given (a|Sub, n|Ix)
+ def (-)(xs, ys) = for i. xs[i] - ys[i]
-instance Add ((i:n) => (i..) => a) given (a|Add, n|Ix) -- Upper triangular tables
- def (+) (xs, ys) = for i. xs[i] + ys[i]
+instance Add((i:n) => (i..) => a) given (a|Add, n|Ix) -- Upper triangular tables
+ def (+)(xs, ys) = for i. xs[i] + ys[i]
zero = for _. zero
-instance Sub ((i:n) => (i..) => a) given (a|Sub, n|Ix) -- Upper triangular tables
- def (-) (xs, ys) = for i. xs[i] - ys[i]
+instance Sub((i:n) => (i..) => a) given (a|Sub, n|Ix) -- Upper triangular tables
+ def (-)(xs, ys) = for i. xs[i] - ys[i]
-instance Add ((i:n) => (..i) => a) given (a|Add, n|Ix) -- Lower triangular tables
- def (+) (xs, ys) = for i. xs[i] + ys[i]
+instance Add((i:n) => (..i) => a) given (a|Add, n|Ix) -- Lower triangular tables
+ def (+)(xs, ys) = for i. xs[i] + ys[i]
zero = for _. zero
-instance Sub ((i:n) => (..i) => a) given (a|Sub, n|Ix) -- Lower triangular tables
- def (-) (xs, ys) = for i. xs[i] - ys[i]
+instance Sub((i:n) => (..i) => a) given (a|Sub, n|Ix) -- Lower triangular tables
+ def (-)(xs, ys) = for i. xs[i] - ys[i]
-instance Add ((i:n) => (..<i) => a) given (a|Add, n|Ix)
- def (+) (xs, ys) = for i. xs[i] + ys[i]
+instance Add((i:n) => (..<i) => a) given (a|Add, n|Ix)
+ def (+)(xs, ys) = for i. xs[i] + ys[i]
zero = for _. zero
-instance Sub ((i:n) => (..<i) => a) given (a|Sub, n|Ix)
- def (-) (xs, ys) = for i. xs[i] - ys[i]
+instance Sub((i:n) => (..<i) => a) given (a|Sub, n|Ix)
+ def (-)(xs, ys) = for i. xs[i] - ys[i]
-instance Add ((i:n) => (i<..) => a) given (a|Add, n|Ix)
- def (+) (xs, ys) = for i. xs[i] + ys[i]
+instance Add((i:n) => (i<..) => a) given (a|Add, n|Ix)
+ def (+)(xs, ys) = for i. xs[i] + ys[i]
zero = for _. zero
-instance Sub ((i:n) => (i<..) => a) given (a|Sub, n|Ix)
- def (-) (xs, ys) = for i. xs[i] - ys[i]
+instance Sub((i:n) => (i<..) => a) given (a|Sub, n|Ix)
+ def (-)(xs, ys) = for i. xs[i] - ys[i]
-instance Mul (n=>a) given (a|Mul, n|Ix)
- def (*) (xs, ys) = for i. xs[i] * ys[i]
+instance Mul(n=>a) given (a|Mul, n|Ix)
+ def (*)(xs, ys) = for i. xs[i] * ys[i]
one = for _. one
'## Basic polymorphic functions and types
-def fst (pair:(a, b)) -> a given (a, b) =
+def fst(pair:(a, b)) -> a given (a, b) =
(x, _) = pair
x
-def snd (pair:(a, b)) -> b given (a, b) =
+def snd(pair:(a, b)) -> b given (a, b) =
(_, y) = pair
y
-def swap (pair:(a, b)) -> (b, a) given (a, b) =
+def swap(pair:(a, b)) -> (b, a) given (a, b) =
(x, y) = pair
(y, x)
-instance Add (a, b) given (a|Add, b|Add)
- def (+) (x, y) =
+instance Add((a, b)) given (a|Add, b|Add)
+ def (+)(x, y) =
(x1, x2) = x
(y1, y2) = y
(x1 + y1, x2 + y2)
zero = (zero, zero)
-instance Sub (a, b) given (a|Sub, b|Sub)
- def (-) (x, y) =
+instance Sub((a, b)) given (a|Sub, b|Sub)
+ def(-)(x, y) =
(x1, x2) = x
(y1, y2) = y
(x1 - y1, x2 - y2)
-instance Ix (a, b) given (a|Ix, b|Ix)
- def size' () = size a * size b
- def ordinal (pair) =
+instance Ix((a, b)) given (a|Ix, b|Ix)
+ def size'() = size a * size b
+ def ordinal(pair) =
(i, j) = pair
(ordinal i * size b) + ordinal j
- def unsafe_from_ordinal (o) =
+ def unsafe_from_ordinal(o) =
bs = size b
(unsafe_from_ordinal(n=a, idiv(o, bs)), unsafe_from_ordinal(n=b, rem(o, bs)))
'## Vector spaces
-interface VSpace (a|Add|Sub)
+interface VSpace(a|Add|Sub)
(.*) : (Float, a) -> a
-def (*.) (v:a, s:Float) -> a given (a|VSpace) = s .* v
-def (/) (v:a, s:Float) -> a given (a|VSpace) = divide(1.0, s) .* v
-def neg (v:a) -> a given (a|VSpace) = (-1.0) .* v
+def (*.)(v:a, s:Float) -> a given (a|VSpace) = s .* v
+def (/)( v:a, s:Float) -> a given (a|VSpace) = divide(1.0, s) .* v
+def neg( v:a) -> a given (a|VSpace) = (-1.0) .* v
-instance VSpace Float
- def (.*) (x, y) = x * y
+instance VSpace(Float)
+ def (.*)(x, y) = x * y
-instance VSpace (n=>a) given (a|VSpace, n|Ix)
- def (.*) (s, xs) = for i. s .* xs[i]
+instance VSpace(n=>a) given (a|VSpace, n|Ix)
+ def (.*)(s, xs) = for i. s .* xs[i]
-instance VSpace (a, b) given (a|VSpace, b|VSpace)
- def (.*) (s, pair) =
+instance VSpace((a, b)) given (a|VSpace, b|VSpace)
+ def (.*)(s, pair) =
(a, b) = pair
(s .* a, s .* b)
-instance VSpace ((i:n) => (..i) => a) given (n|Ix, a|VSpace)
- def (.*) (s, xs) = for i. s .* xs[i]
+instance VSpace((i:n) => (..i) => a) given (n|Ix, a|VSpace)
+ def (.*)(s, xs) = for i. s .* xs[i]
-instance VSpace ((i:n) => (i..) => a) given (n|Ix, a|VSpace)
- def (.*) (s, xs) = for i. s .* xs[i]
+instance VSpace((i:n) => (i..) => a) given (n|Ix, a|VSpace)
+ def (.*)(s, xs) = for i. s .* xs[i]
-instance VSpace ((i:n) => (..<i) => a) given (n|Ix, a|VSpace)
- def (.*) (s, xs) = for i. s .* xs[i]
+instance VSpace((i:n) => (..<i) => a) given (n|Ix, a|VSpace)
+ def (.*)(s, xs) = for i. s .* xs[i]
-instance VSpace ((i:n) => (i<..) => a) given (n|Ix, a|VSpace)
- def (.*) (s, xs) = for i. s .* xs[i]
+instance VSpace((i:n) => (i<..) => a) given (n|Ix, a|VSpace)
+ def (.*)(s, xs) = for i. s .* xs[i]
-instance VSpace ()
- def (.*) (_, _) = ()
+instance VSpace(())
+ def (.*)(_, _) = ()
'## Boolean type
@@ -484,21 +484,21 @@ data Bool =
False
True
-def b_to_w8 (x:Bool) -> Word8 = %dataConTag(x)
+def b_to_w8(x:Bool) -> Word8 = %dataConTag(x)
-def w8_to_b (x:Word8) -> Bool = %toEnum(Bool, x)
+def w8_to_b(x:Word8) -> Bool = %toEnum(Bool, x)
-def (&&) (x:Bool, y:Bool) -> Bool =
+def (&&)(x:Bool, y:Bool) -> Bool =
x' = b_to_w8 x
y' = b_to_w8 y
w8_to_b $ %and(x', y')
-def (||) (x:Bool, y:Bool) -> Bool =
+def (||)(x:Bool, y:Bool) -> Bool =
x' = b_to_w8 x
y' = b_to_w8 y
w8_to_b $ %or(x', y')
-def not (x:Bool) -> Bool =
+def not(x:Bool) -> Bool =
x' = b_to_w8 x
w8_to_b $ %not(x')
@@ -507,14 +507,14 @@ TODO: move these with the others?
-- Can't use `%select` because it lowers to `ISelect`, which requires
-- `a` to be a `BaseTy`.
-def select (p:Bool, x:a, y:a) -> a given (a) =
+def select(p:Bool, x:a, y:a) -> a given (a) =
case p of
True -> x
False -> y
-def b_to_i (x:Bool) -> Int = w8_to_i(b_to_w8 x)
-def b_to_n (x:Bool) -> Nat = w8_to_n(b_to_w8 x)
-def b_to_f (x:Bool) -> Float = i_to_f(b_to_i x)
+def b_to_i(x:Bool) -> Int = w8_to_i(b_to_w8 x)
+def b_to_n(x:Bool) -> Nat = w8_to_n(b_to_w8 x)
+def b_to_f(x:Bool) -> Float = i_to_f(b_to_i x)
'## Ordering
TODO: move this down to with `Ord`?
@@ -524,39 +524,39 @@ data Ordering =
EQ
GT
-def o_to_w8 (x:Ordering) -> Word8 = %dataConTag(x)
+def o_to_w8(x:Ordering) -> Word8 = %dataConTag(x)
'## Sum types
A [sum type, or tagged union](https://en.wikipedia.org/wiki/Tagged_union) can hold values from a fixed set of types, distinguished by tags.
For those familiar with the C language, they can be though of as a combination of an `enum` with a `union`.
Here we define several basic kinds, and some operators on them.
-data Maybe (a:Type) =
+data Maybe(a:Type) =
Nothing
Just(a)
-def is_nothing (x:Maybe a) -> Bool given (a) =
+def is_nothing(x:Maybe a) -> Bool given (a) =
case x of
Nothing -> True
Just(_) -> False
-def is_just (x:Maybe a) -> Bool given (a) = not $ is_nothing x
+def is_just(x:Maybe a) -> Bool given (a) = not $ is_nothing x
-def maybe (d:b, f:(a)->b, x:Maybe a) -> b given (a, b) =
+def maybe(d:b, f:(a)->b, x:Maybe a) -> b given (a, b) =
case x of
Nothing -> d
Just(x') -> f x'
-data Either (a:Type, b:Type) =
+data Either(a:Type, b:Type) =
Left(a)
Right(b)
-instance Ix (Either(a, b)) given (a|Ix, b|Ix)
- def size' () = size a + size b
- def ordinal (i) = case i of
+instance Ix(Either(a, b)) given (a|Ix, b|Ix)
+ def size'() = size a + size b
+ def ordinal(i) = case i of
Left(ai) -> ordinal ai
Right(bi) -> ordinal bi + size a
- def unsafe_from_ordinal (o) =
+ def unsafe_from_ordinal(o) =
as = nat_to_rep $ size a
o' = nat_to_rep o
-- TODO: Reshuffle the prelude to be able to use (<) here
@@ -568,43 +568,43 @@ instance Ix (Either(a, b)) given (a|Ix, b|Ix)
'## Subtraction on Nats
-- TODO: think more about the right API here
-def unsafe_i_to_n (x:Int) -> Nat =
+def unsafe_i_to_n(x:Int) -> Nat =
rep_to_nat $ internal_cast x
-def n_to_i (x:Nat) -> Int =
+def n_to_i(x:Nat) -> Int =
internal_cast (nat_to_rep x)
-def i_to_n (x:Int) -> Maybe Nat =
+def i_to_n(x:Int) -> Maybe Nat =
if w8_to_b $ %ilt(x, 0::Int)
then Nothing
else Just $ unsafe_i_to_n x
'## Fencepost index sets
-struct Post (segment:Type) =
+struct Post(segment:Type) =
val : Nat
-instance Ix (Post segment) given (segment|Ix)
- def size' () = size segment + 1
- def ordinal (i) = i.val
- def unsafe_from_ordinal (i) = Post.new(i)
+instance Ix(Post segment) given (segment|Ix)
+ def size'() = size segment + 1
+ def ordinal(i) = i.val
+ def unsafe_from_ordinal(i) = Post.new(i)
-def left_post (i:n) -> Post n given (n|Ix) =
+def left_post(i:n) -> Post n given (n|Ix) =
unsafe_from_ordinal(n=Post n, ordinal i)
-def right_post (i:n) -> Post n given (n|Ix) =
+def right_post(i:n) -> Post n given (n|Ix) =
unsafe_from_ordinal(n=Post n, ordinal i + 1)
-interface NonEmpty (n|Ix)
+interface NonEmpty(n|Ix)
first_ix : n
-def last_ix () ->> n given (n|NonEmpty) =
+def last_ix() ->> n given (n|NonEmpty) =
unsafe_from_ordinal(unsafe_i_to_n(n_to_i(size n) - 1))
-instance NonEmpty (Post n) given (n|Ix)
+instance NonEmpty(Post n) given (n|Ix)
first_ix = unsafe_from_ordinal(n=Post n, 0)
-instance NonEmpty ()
+instance NonEmpty(())
first_ix = unsafe_from_ordinal(0)
'### Monoid
@@ -616,41 +616,41 @@ It includes:
- Concatenation of Lists (including strings)
Monoids support `fold` operations, and similar.
-interface Monoid (a|Data)
+interface Monoid(a|Data)
mempty : a
(<>) : (a, a) -> a
-instance Monoid (n=>a) given (a|Monoid, n|Ix)
+instance Monoid(n=>a) given (a|Monoid, n|Ix)
mempty = for i. mempty
- def (<>) (x, y) = for i. x[i] <> y[i]
+ def (<>)(x, y) = for i. x[i] <> y[i]
-named-instance AndMonoid : Monoid Bool
+named-instance AndMonoid : Monoid(Bool)
mempty = True
- def (<>) (x, y) = x && y
+ def (<>)(x, y) = x && y
-named-instance OrMonoid : Monoid Bool
+named-instance OrMonoid : Monoid(Bool)
mempty = False
- def (<>) (x, y) = x || y
+ def (<>)(x, y) = x || y
-named-instance AddMonoid (a|Add) -> Monoid a
+named-instance AddMonoid(a|Add) -> Monoid(a)
mempty = zero
- def (<>) (x, y) = x + y
+ def (<>)(x, y) = x + y
-named-instance MulMonoid (a|Mul) -> Monoid a
+named-instance MulMonoid(a|Mul) -> Monoid(a)
mempty = one
- def (<>) (x, y) = x * y
+ def (<>)(x, y) = x * y
'## Effects
-def Ref (r:Heap, a|Data) -> Type = %Ref(r, a)
-def get (ref:Ref h s) -> {State h} s given (h, s) = %get(ref)
-def (:=) (ref:Ref h s, x:s) -> {State h} () given (h, s) = %put(ref, x)
+def Ref(r:Heap, a|Data) -> Type = %Ref(r, a)
+def get(ref:Ref h s) -> {State h} s given (h, s) = %get(ref)
+def (:=)(ref:Ref h s, x:s) -> {State h} () given (h, s) = %put(ref, x)
-def ask (ref:Ref h r) -> {Read h} r given (h, r) = %ask(ref)
+def ask(ref:Ref h r) -> {Read h} r given (h, r) = %ask(ref)
-data AccumMonoidData(h:Heap, w:Type) = UnsafeMkAccumMonoidData (b:Type, Monoid b)
+data AccumMonoidData(h:Heap, w:Type) = UnsafeMkAccumMonoidData(b:Type, Monoid b)
-interface AccumMonoid (h:Heap, w)
+interface AccumMonoid(h:Heap, w)
getAccumMonoidData : AccumMonoidData(h, w)
instance AccumMonoid(h, n=>w) given (n|Ix, h, w) (am:AccumMonoid(h, w))
@@ -658,15 +658,15 @@ instance AccumMonoid(h, n=>w) given (n|Ix, h, w) (am:AccumMonoid(h, w))
UnsafeMkAccumMonoidData(b, bm) = %applyMethod0(am)
UnsafeMkAccumMonoidData(b, bm)
-def (+=) (ref:Ref h w, x:w) -> {Accum h} ()
+def (+=)(ref:Ref h w, x:w) -> {Accum h} ()
given (h, w) (am:AccumMonoid(h, w)) =
UnsafeMkAccumMonoidData(b, bm) = %applyMethod0(am)
empty = %applyMethod0(bm)
%mextend(ref, empty, \x:b y:b. %applyMethod1(bm, x, y), x)
-def (!) (ref: Ref h (n=>a), i:n) -> Ref h a given (n|Ix, a|Data, h) = %indexRef(ref, i)
-def fst_ref (ref: Ref h (a,b)) -> Ref h a given (a|Data, b, h) = %fstRef(ref)
-def snd_ref (ref: Ref h (a,b)) -> Ref h b given (a, b|Data, h) = %sndRef(ref)
+def (!)(ref: Ref h (n=>a), i:n) -> Ref h a given (n|Ix, a|Data, h) = %indexRef(ref, i)
+def fst_ref(ref: Ref h (a,b)) -> Ref h a given (a|Data, b, h) = %fstRef(ref)
+def snd_ref(ref: Ref h (a,b)) -> Ref h b given (a, b|Data, h) = %sndRef(ref)
def run_reader(
init:r,
@@ -681,7 +681,7 @@ def with_reader(
) -> {|eff} a given (r|Data, a, eff) =
run_reader(init, action)
-def MonoidLifter (b:Type, w:Type) -> Type =
+def MonoidLifter(b:Type, w:Type) -> Type =
(given (h) (AccumMonoid(h, b))) ->> AccumMonoid(h, w)
named-instance mk_accum_monoid (given (h, w), d:AccumMonoidData(h, w)) -> AccumMonoid(h, w)
@@ -692,7 +692,7 @@ def run_accum(
action: (given (h) (AccumMonoid(h, b)), Ref h w) -> {Accum h|eff} a
) -> {|eff} (a, w) given (a, b, w|Data, eff) (MonoidLifter(b,w)) =
empty = %applyMethod0(bm)
- def explicitAction (h':Heap, ref:Ref h' w) -> {Accum h'|eff} a =
+ def explicitAction(h':Heap, ref:Ref h' w) -> {Accum h'|eff} a =
accumMonoidData : AccumMonoidData h' b = UnsafeMkAccumMonoidData b bm
accumBaseMonoid = mk_accum_monoid accumMonoidData
%explicitApply(action, h', accumBaseMonoid, ref)
@@ -708,7 +708,7 @@ def run_state(
init:s,
action: (given (h), Ref h s) -> {State h |eff} a
) -> {|eff} (a,s) given (a, s|Data, eff) =
- def explicitAction (h':Heap, ref:Ref h' s) -> {State h'|eff} a = action ref
+ def explicitAction(h':Heap, ref:Ref h' s) -> {State h'|eff} a = action ref
%runState(init, explicitAction)
def with_state(
@@ -723,13 +723,13 @@ def yield_state(
) -> {|eff} s given (a, s|Data, eff) =
snd $ run_state(init, action)
-def unsafe_io (
+def unsafe_io(
f:()->{IO|eff} a
) -> {|eff} a given (a, eff) =
f' : (() -> {IO|eff} a) = \. f()
%runIO(f')
-def unreachable () -> a given (a|Data) = unsafe_io \. %throwError(a)
+def unreachable() -> a given (a|Data) = unsafe_io \. %throwError(a)
'## Type classes
@@ -739,10 +739,10 @@ def unreachable () -> a given (a|Data) = unsafe_io \. %throwError(a)
Equatable.
Things that we can tell if they are equal or not to other things.
-interface Eq (a|Data)
+interface Eq(a|Data)
(==) : (a, a) -> Bool
-def (/=) (x:a, y:a) -> Bool given (a|Eq) = not $ x == y
+def (/=)(x:a, y:a) -> Bool given (a|Eq) = not $ x == y
'#### Ord
Orderable / Comparable.
@@ -751,42 +751,42 @@ i.e. things that can be compared to other things to find if larger, smaller or e
'We take the standard false-hood and pretend that this applies to Floats, even though strictly speaking this not true as our floats follow [IEEE754](https://en.wikipedia.org/wiki/IEEE_754), and thus have `NaN < 1.0 == false` and `1.0 < NaN == false`.
-interface Ord (a|Eq)
+interface Ord(a|Eq)
(>) : (a, a) -> Bool
(<) : (a, a) -> Bool
-def (<=) (x:a, y:a) -> Bool given (a|Ord) = x<y || x==y
-def (>=) (x:a, y:a) -> Bool given (a|Ord) = x>y || x==y
+def (<=)(x:a, y:a) -> Bool given (a|Ord) = x<y || x==y
+def (>=)(x:a, y:a) -> Bool given (a|Ord) = x>y || x==y
-instance Eq Float64
- def (==) (x, y) = w8_to_b $ %feq(x, y)
+instance Eq(Float64)
+ def (==)(x, y) = w8_to_b $ %feq(x, y)
-instance Eq Float32
- def (==) (x, y) = w8_to_b $ %feq(x, y)
+instance Eq(Float32)
+ def (==)(x, y) = w8_to_b $ %feq(x, y)
-instance Eq Int64
- def (==) (x, y) = w8_to_b $ %ieq(x, y)
+instance Eq(Int64)
+ def (==)(x, y) = w8_to_b $ %ieq(x, y)
-instance Eq Int32
- def (==) (x, y) = w8_to_b $ %ieq(x, y)
+instance Eq(Int32)
+ def (==)(x, y) = w8_to_b $ %ieq(x, y)
-instance Eq Word8
- def (==) (x, y) = w8_to_b $ %ieq(x, y)
+instance Eq(Word8)
+ def (==)(x, y) = w8_to_b $ %ieq(x, y)
-instance Eq Word32
- def (==) (x, y) = w8_to_b $ %ieq(x, y)
+instance Eq(Word32)
+ def (==)(x, y) = w8_to_b $ %ieq(x, y)
-instance Eq Word64
- def (==) (x, y) = w8_to_b $ %ieq(x, y)
+instance Eq(Word64)
+ def (==)(x, y) = w8_to_b $ %ieq(x, y)
-instance Eq Bool
- def (==) (x, y) = b_to_w8 x == b_to_w8 y
+instance Eq(Bool)
+ def (==)(x, y) = b_to_w8 x == b_to_w8 y
-instance Eq ()
- def (==) (_, _) = True
+instance Eq(())
+ def (==)(_, _) = True
-instance Eq (Either(a, b)) given (a|Eq, b|Eq)
- def (==) (x, y) = case x of
+instance Eq(Either(a, b)) given (a|Eq, b|Eq)
+ def (==)(x, y) = case x of
Left(x) -> case y of
Left( y) -> x == y
Right(y) -> False
@@ -794,8 +794,8 @@ instance Eq (Either(a, b)) given (a|Eq, b|Eq)
Left( y) -> False
Right(y) -> x == y
-instance Eq (Maybe a) given (a|Eq)
- def (==) (x, y) = case x of
+instance Eq(Maybe a) given (a|Eq)
+ def (==)(x, y) = case x of
Just(x) -> case y of
Just(y) -> x == y
Nothing -> False
@@ -803,85 +803,85 @@ instance Eq (Maybe a) given (a|Eq)
Just(y) -> False
Nothing -> True
-instance Eq RawPtr
- def (==) (x, y) = raw_ptr_to_i64 x == raw_ptr_to_i64 y
+instance Eq(RawPtr)
+ def (==)(x, y) = raw_ptr_to_i64 x == raw_ptr_to_i64 y
-instance Ord Float64
- def (>) (x, y) = w8_to_b $ %fgt(x, y)
- def (<) (x, y) = w8_to_b $ %flt(x, y)
+instance Ord(Float64)
+ def (>)(x, y) = w8_to_b $ %fgt(x, y)
+ def (<)(x, y) = w8_to_b $ %flt(x, y)
-instance Ord Float32
- def (>) (x, y) = w8_to_b $ %fgt(x, y)
- def (<) (x, y) = w8_to_b $ %flt(x, y)
+instance Ord(Float32)
+ def (>)(x, y) = w8_to_b $ %fgt(x, y)
+ def (<)(x, y) = w8_to_b $ %flt(x, y)
-instance Ord Int64
- def (>) (x, y) = w8_to_b $ %igt(x, y)
- def (<) (x, y) = w8_to_b $ %ilt(x, y)
+instance Ord(Int64)
+ def (>)(x, y) = w8_to_b $ %igt(x, y)
+ def (<)(x, y) = w8_to_b $ %ilt(x, y)
-instance Ord Int32
- def (>) (x, y) = w8_to_b $ %igt(x, y)
- def (<) (x, y) = w8_to_b $ %ilt(x, y)
+instance Ord(Int32)
+ def (>)(x, y) = w8_to_b $ %igt(x, y)
+ def (<)(x, y) = w8_to_b $ %ilt(x, y)
-instance Ord Word8
- def (>) (x, y) = w8_to_b $ %igt(x, y)
- def (<) (x, y) = w8_to_b $ %ilt(x, y)
+instance Ord(Word8)
+ def (>)(x, y) = w8_to_b $ %igt(x, y)
+ def (<)(x, y) = w8_to_b $ %ilt(x, y)
-instance Ord Word32
- def (>) (x, y) = w8_to_b $ %igt(x, y)
- def (<) (x, y) = w8_to_b $ %ilt(x, y)
+instance Ord(Word32)
+ def (>)(x, y) = w8_to_b $ %igt(x, y)
+ def (<)(x, y) = w8_to_b $ %ilt(x, y)
-instance Ord Word64
- def (>) (x, y) = w8_to_b $ %igt(x, y)
- def (<) (x, y) = w8_to_b $ %ilt(x, y)
+instance Ord(Word64)
+ def (>)(x, y) = w8_to_b $ %igt(x, y)
+ def (<)(x, y) = w8_to_b $ %ilt(x, y)
-instance Ord ()
- def (>) (x, y) = False
- def (<) (x, y) = False
+instance Ord(())
+ def (>)(x, y) = False
+ def (<)(x, y) = False
-instance Eq (a, b) given (a|Eq, b|Eq)
- def (==) (p1, p2) =
+instance Eq((a, b)) given (a|Eq, b|Eq)
+ def (==)(p1, p2) =
(x1, y1) = p1
(x2, y2) = p2
x1 == x2 && y1 == y2
-instance Ord (a, b) given (a|Ord, b|Ord)
- def (>) (p1, p2) =
+instance Ord((a, b)) given (a|Ord, b|Ord)
+ def (>)(p1, p2) =
(x1, y1) = p1
(x2, y2) = p2
x1 > x2 || (x1 == x2 && y1 > y2)
- def (<) (p1, p2) =
+ def (<)(p1, p2) =
(x1, y1) = p1
(x2, y2) = p2
x1 < x2 || (x1 == x2 && y1 < y2)
-instance Eq Ordering
- def (==) (x, y) = o_to_w8 x == o_to_w8 y
+instance Eq(Ordering)
+ def (==)(x, y) = o_to_w8 x == o_to_w8 y
-instance Eq Nat
- def (==) (x, y) = nat_to_rep x == nat_to_rep y
+instance Eq(Nat)
+ def (==)(x, y) = nat_to_rep x == nat_to_rep y
-instance Ord Nat
- def (>) (x, y) = nat_to_rep x > nat_to_rep y
- def (<) (x, y) = nat_to_rep x < nat_to_rep y
+instance Ord(Nat)
+ def (>)(x, y) = nat_to_rep x > nat_to_rep y
+ def (<)(x, y) = nat_to_rep x < nat_to_rep y
-- TODO: we want Eq and Ord for all index sets, not just `Fin n`
-instance Eq (Fin n) given (n)
- def (==) (x, y) = ordinal x == ordinal y
+instance Eq(Fin n) given (n)
+ def (==)(x, y) = ordinal x == ordinal y
-instance Ord (Fin n) given (n)
- def (>) (x, y) = ordinal x > ordinal y
- def (<) (x, y) = ordinal x < ordinal y
+instance Ord(Fin n) given (n)
+ def (>)(x, y) = ordinal x > ordinal y
+ def (<)(x, y) = ordinal x < ordinal y
-instance Ix Bool
- def size' () = 2
- def ordinal (b) = case b of
+instance Ix(Bool)
+ def size'() = 2
+ def ordinal(b) = case b of
False -> 0
True -> 1
- def unsafe_from_ordinal (i) = i > 0
+ def unsafe_from_ordinal(i) = i > 0
-instance Ix (Maybe a) given (a|Ix)
- def size' () = size a + 1
- def ordinal (i) = case i of
+instance Ix(Maybe a) given (a|Ix)
+ def size'() = size a + 1
+ def ordinal(i) = case i of
Just(ai) -> ordinal ai
Nothing -> size a
def unsafe_from_ordinal(o) =
@@ -889,13 +889,13 @@ instance Ix (Maybe a) given (a|Ix)
False -> Just $ unsafe_from_ordinal o
True -> Nothing
-instance NonEmpty Bool
+instance NonEmpty(Bool)
first_ix = unsafe_from_ordinal 0
-instance NonEmpty (a,b) given (a|NonEmpty, b|NonEmpty)
+instance NonEmpty((a,b)) given (a|NonEmpty, b|NonEmpty)
first_ix = unsafe_from_ordinal 0
-instance NonEmpty (Either(a,b)) given (a|NonEmpty, b|Ix)
+instance NonEmpty(Either(a,b)) given (a|NonEmpty, b|Ix)
first_ix = unsafe_from_ordinal 0
-- The below instance is valid, but causes "multiple candidate dictionaries"
@@ -903,7 +903,7 @@ instance NonEmpty (Either(a,b)) given (a|NonEmpty, b|Ix)
-- instance NonEmpty (a|b) given {a b} [Ix a, NonEmpty b]
-- first_ix = unsafe_from_ordinal _ 0
-instance NonEmpty (Maybe a) given (a|Ix)
+instance NonEmpty(Maybe a) given (a|Ix)
first_ix = unsafe_from_ordinal 0
def scan(
@@ -916,42 +916,42 @@ def scan(
s := c'
y
-def fold (init:a, body:(n,a)->a) -> a given (n|Ix, a|Data) =
+def fold(init:a, body:(n,a)->a) -> a given (n|Ix, a|Data) =
fst $ scan init \i x. (body(i, x), ())
-def compare (x:a, y:a) -> Ordering given (a|Ord) =
+def compare(x:a, y:a) -> Ordering given (a|Ord) =
if x < y
then LT
else if x == y
then EQ
else GT
-instance Monoid Ordering
+instance Monoid(Ordering)
mempty = EQ
- def (<>) (x, y) =
+ def (<>)(x, y) =
case x of
LT -> LT
GT -> GT
EQ -> y
-instance Eq (n=>a) given (n|Ix, a|Eq)
- def (==) (xs, ys) =
+instance Eq(n=>a) given (n|Ix, a|Eq)
+ def (==)(xs, ys) =
yield_accum AndMonoid \ref.
for i. ref += xs[i] == ys[i]
-instance Ord (n=>a) given (n|Ix, a|Ord)
- def (>) (xs, ys) =
+instance Ord(n=>a) given (n|Ix, a|Ord)
+ def (>)(xs, ys) =
f: Ordering =
fold EQ $ \i c. c <> compare(xs[i], ys[i])
f == GT
- def (<) (xs, ys) =
+ def (<)(xs, ys) =
f: Ordering =
fold EQ $ \i c. c <> compare(xs[i], ys[i])
f == LT
'## Subset class
-interface Subset (subset, superset)
+interface Subset(subset, superset)
inject : (subset) -> superset
project : (superset) -> Maybe subset
unsafe_project : (superset) -> subset
@@ -963,34 +963,34 @@ instance Subset(a, c) given (a, b, c) (Subset(a, b), Subset(b, c))
Just(y)-> project y
def unsafe_project(x) = unsafe_project $ unsafe_project(subset=b, x)
-def unsafe_project_rangefrom (j:q) -> RangeFrom(q, i) given (q|Ix, i:q) =
+def unsafe_project_rangefrom(j:q) -> RangeFrom(q, i) given (q|Ix, i:q) =
RangeFrom.new unsafe_nat_diff(ordinal j, ordinal i)
instance Subset(RangeFrom(q, i), q) given (q|Ix, i:q)
- def inject (j) =
+ def inject(j) =
unsafe_from_ordinal $ j.val + ordinal i
- def project (j) =
+ def project(j) =
j' = ordinal j
i' = ordinal i
if j' < i'
then Nothing
else Just $ RangeFrom.new $ unsafe_nat_diff(j', i')
- def unsafe_project (j) = RangeFrom.new unsafe_nat_diff(ordinal j, ordinal i)
+ def unsafe_project(j) = RangeFrom.new unsafe_nat_diff(ordinal j, ordinal i)
instance Subset(RangeFromExc(q, i), q) given (q|Ix, i:q)
- def inject (j) = unsafe_from_ordinal $ j.val + ordinal i + 1
- def project (j) =
+ def inject(j) = unsafe_from_ordinal $ j.val + ordinal i + 1
+ def project(j) =
j' = ordinal j
i' = ordinal i
if j' <= i'
then Nothing
else Just $ RangeFromExc.new unsafe_nat_diff(j', i' + 1)
- def unsafe_project (j) =
+ def unsafe_project(j) =
RangeFromExc.new unsafe_nat_diff(ordinal j, ordinal i + 1)
instance Subset(RangeTo(q, i), q) given (q|Ix, i:q)
- def inject (j) = unsafe_from_ordinal j.val
- def project (j) =
+ def inject(j) = unsafe_from_ordinal j.val
+ def project(j) =
j' = ordinal j
i' = ordinal i
if j' > i'
@@ -999,24 +999,24 @@ instance Subset(RangeTo(q, i), q) given (q|Ix, i:q)
def unsafe_project(j) = RangeTo.new (ordinal j)
instance Subset(RangeToExc(q, i), q) given (q|Ix, i:q)
- def inject (j) = unsafe_from_ordinal j.val
- def project (j) =
+ def inject(j) = unsafe_from_ordinal j.val
+ def project(j) =
j' = ordinal j
i' = ordinal i
if j' >= i'
then Nothing
else Just $ RangeToExc.new j'
- def unsafe_project (j) = RangeToExc.new (ordinal j)
+ def unsafe_project(j) = RangeToExc.new (ordinal j)
instance Subset(RangeToExc(q, i), RangeTo(q, i)) given (q|Ix, i:q)
- def inject (j) = unsafe_from_ordinal j.val
- def project (j) =
+ def inject(j) = unsafe_from_ordinal j.val
+ def project(j) =
j' = ordinal j
i' = ordinal i
if j' >= i'
then Nothing
else Just $ RangeToExc.new j'
- def unsafe_project (j) = RangeToExc.new (ordinal j)
+ def unsafe_project(j) = RangeToExc.new (ordinal j)
'## Elementary/Special Functions
This is more or less the standard [LibM fare](https://en.wikipedia.org/wiki/C_mathematical_functions).
@@ -1024,7 +1024,7 @@ Roughly it lines up with some definitions of the set of [Elementary](https://en.
In truth, nothing is elementary or special except that we humans have decided it is.
Many, but not all of these functions are [Transcendental](https://en.wikipedia.org/wiki/Transcendental_function).
-interface Floating (a:Type)
+interface Floating(a:Type)
exp : (a) -> a
exp2 : (a) -> a
log : (a) -> a
@@ -1046,124 +1046,124 @@ interface Floating (a:Type)
erf : (a) -> a
erfc : (a) -> a
-def lbeta (x:a, y:a) -> a given (a|Sub|Floating) = lgamma x + lgamma y - lgamma (x + y)
+def lbeta(x:a, y:a) -> a given (a|Sub|Floating) = lgamma x + lgamma y - lgamma (x + y)
-- Todo: better numerics for very large and small values.
-- Using %exp here to avoid circular definition problems.
-def float32_sinh (x:Float32) -> Float32 = %fdiv(%fsub(%exp(x), %exp(%fsub(0.0,x))), 2.0)
-def float32_cosh (x:Float32) -> Float32 = %fdiv(%fadd(%exp(x), %exp(%fsub(0.0,x))), 2.0)
-def float32_tanh (x:Float32) -> Float32 = %fdiv(%fsub(%exp(x), %exp(%fsub(0.0,x)))
+def float32_sinh(x:Float32) -> Float32 = %fdiv(%fsub(%exp(x), %exp(%fsub(0.0,x))), 2.0)
+def float32_cosh(x:Float32) -> Float32 = %fdiv(%fadd(%exp(x), %exp(%fsub(0.0,x))), 2.0)
+def float32_tanh(x:Float32) -> Float32 = %fdiv(%fsub(%exp(x), %exp(%fsub(0.0,x)))
,%fadd(%exp(x), %exp(%fsub(0.0,x))))
-- Todo: unify this with float32 functions.
-def float64_sinh (x:Float64) -> Float64 = %fdiv(%fsub(%exp(x), %exp(%fsub(f_to_f64 0.0, x))), f_to_f64 2.0)
-def float64_cosh (x:Float64) -> Float64 = %fdiv(%fadd(%exp(x), %exp(%fsub(f_to_f64 0.0, x))), f_to_f64 2.0)
-def float64_tanh (x:Float64) -> Float64 = %fdiv(%fsub(%exp(x), %exp(%fsub(f_to_f64 0.0, x)))
+def float64_sinh(x:Float64) -> Float64 = %fdiv(%fsub(%exp(x), %exp(%fsub(f_to_f64 0.0, x))), f_to_f64 2.0)
+def float64_cosh(x:Float64) -> Float64 = %fdiv(%fadd(%exp(x), %exp(%fsub(f_to_f64 0.0, x))), f_to_f64 2.0)
+def float64_tanh(x:Float64) -> Float64 = %fdiv(%fsub(%exp(x), %exp(%fsub(f_to_f64 0.0, x)))
,%fadd(%exp(x), %exp(%fsub(f_to_f64 0.0, x))))
-instance Floating Float64
- def exp (x) = %exp (x)
- def exp2 (x) = %exp2 (x)
- def log (x) = %log (x)
- def log2 (x) = %log2 (x)
- def log10 (x) = %log10 (x)
- def log1p (x) = %log1p (x)
- def sin (x) = %sin (x)
- def cos (x) = %cos (x)
- def tan (x) = %tan (x)
- def sinh (x) = float64_sinh x
- def cosh (x) = float64_cosh x
- def tanh (x) = float64_tanh x
- def floor (x) = %floor (x)
- def ceil (x) = %ceil (x)
- def round (x) = %round (x)
- def sqrt (x) = %sqrt (x)
- def pow (x, y) = %fpow(x, y)
- def lgamma (x) = %lgamma (x)
- def erf (x) = %erf (x)
- def erfc (x) = %erfc (x)
-
-instance Floating Float32
- def exp (x) = %exp (x)
- def exp2 (x) = %exp2 (x)
- def log (x) = %log (x)
- def log2 (x) = %log2 (x)
- def log10 (x) = %log10 (x)
- def log1p (x) = %log1p (x)
- def sin (x) = %sin (x)
- def cos (x) = %cos (x)
- def tan (x) = %tan (x)
- def sinh (x) = float32_sinh x
- def cosh (x) = float32_cosh x
- def tanh (x) = float32_tanh x
- def floor (x) = %floor (x)
- def ceil (x) = %ceil (x)
- def round (x) = %round (x)
- def sqrt (x) = %sqrt (x)
- def pow (x, y) = %fpow(x, y)
- def lgamma (x) = %lgamma (x)
- def erf (x) = %erf (x)
- def erfc (x) = %erfc (x)
+instance Floating(Float64)
+ def exp(x) = %exp(x)
+ def exp2(x) = %exp2(x)
+ def log(x) = %log(x)
+ def log2(x) = %log2(x)
+ def log10(x) = %log10(x)
+ def log1p(x) = %log1p(x)
+ def sin(x) = %sin(x)
+ def cos(x) = %cos(x)
+ def tan(x) = %tan( x)
+ def sinh(x) = float64_sinh(x)
+ def cosh(x) = float64_cosh(x)
+ def tanh(x) = float64_tanh(x)
+ def floor(x) = %floor(x)
+ def ceil(x) = %ceil(x)
+ def round(x) = %round(x)
+ def sqrt(x) = %sqrt(x)
+ def pow(x,y) = %fpow(x,y)
+ def lgamma(x)= %lgamma(x)
+ def erf(x) = %erf(x)
+ def erfc(x) = %erfc(x)
+
+instance Floating(Float32)
+ def exp(x) = %exp(x)
+ def exp2(x) = %exp2(x)
+ def log(x) = %log(x)
+ def log2(x) = %log2(x)
+ def log10(x) = %log10(x)
+ def log1p(x) = %log1p(x)
+ def sin(x) = %sin(x)
+ def cos(x) = %cos(x)
+ def tan(x) = %tan(x)
+ def sinh(x) = float32_sinh(x)
+ def cosh(x) = float32_cosh(x)
+ def tanh(x) = float32_tanh(x)
+ def floor(x) = %floor(x)
+ def ceil(x) = %ceil(x)
+ def round(x) = %round(x)
+ def sqrt(x) = %sqrt(x)
+ def pow(x,y) = %fpow(x, y)
+ def lgamma(x)= %lgamma(x)
+ def erf(x) = %erf(x)
+ def erfc(x) = %erfc(x)
'## Raw pointer operations
-struct Ptr (a:Type) =
+struct Ptr(a:Type) =
val : RawPtr
-def cast_ptr (ptr: Ptr a) -> Ptr b given (a, b) = Ptr.new(ptr.val)
+def cast_ptr(ptr: Ptr a) -> Ptr b given (a, b) = Ptr.new(ptr.val)
-interface Storable (a|Data)
+interface Storable(a|Data)
store : (Ptr a, a) -> {IO} ()
load : (Ptr a) -> {IO} a
storage_size : () -> Nat
-instance Storable Word8
- def store (ptr, x) = %ptrStore(ptr.val, x)
- def load (ptr) = %ptrLoad (ptr.val)
- def storage_size () = 1
+instance Storable(Word8)
+ def store(ptr, x) = %ptrStore(ptr.val, x)
+ def load(ptr) = %ptrLoad(ptr.val)
+ def storage_size() = 1
-instance Storable Int32
- def store (ptr, x) = %ptrStore(internal_cast(to=%Int32Ptr(), ptr.val), x)
- def load (ptr) = %ptrLoad (internal_cast(to=%Int32Ptr(), ptr.val))
- def storage_size () = 4
+instance Storable(Int32)
+ def store(ptr, x) = %ptrStore(internal_cast(to=%Int32Ptr(), ptr.val), x)
+ def load(ptr) = %ptrLoad(internal_cast(to=%Int32Ptr(), ptr.val))
+ def storage_size() = 4
-instance Storable Word32
- def store (ptr, x) = %ptrStore(internal_cast(to=%Word32Ptr(), ptr), x)
- def load (ptr) = %ptrLoad (internal_cast(to=%Word32Ptr(), ptr))
- def storage_size () = 4
+instance Storable(Word32)
+ def store(ptr, x) = %ptrStore(internal_cast(to=%Word32Ptr(), ptr), x)
+ def load(ptr) = %ptrLoad(internal_cast(to=%Word32Ptr(), ptr))
+ def storage_size() = 4
-instance Storable Float32
- def store (ptr, x) = %ptrStore (internal_cast(to=%Float32Ptr(), ptr.val), x)
- def load (ptr) = %ptrLoad (internal_cast(to=%Float32Ptr(), ptr.val))
- def storage_size () = 4
+instance Storable(Float32)
+ def store(ptr, x) = %ptrStore(internal_cast(to=%Float32Ptr(), ptr.val), x)
+ def load(ptr) = %ptrLoad(internal_cast(to=%Float32Ptr(), ptr.val))
+ def storage_size() = 4
-instance Storable Nat
- def store (ptr, x) = store(Ptr.new(ptr.val), nat_to_rep x)
- def load (ptr) = rep_to_nat $ load(Ptr.new(ptr.val))
- def storage_size () = storage_size(a=NatRep)
+instance Storable(Nat)
+ def store(ptr, x) = store(Ptr.new(ptr.val), nat_to_rep x)
+ def load(ptr) = rep_to_nat $ load(Ptr.new(ptr.val))
+ def storage_size() = storage_size(a=NatRep)
-instance Storable (Ptr a) given (a)
- def store (ptr, x) = %ptrStore (internal_cast(to=%PtrPtr(), ptr.val), x.val)
- def load (ptr) = Ptr.new(%ptrLoad(internal_cast(to=%PtrPtr(), ptr)))
- def storage_size () = 8 -- TODO: something more portable?
+instance Storable(Ptr a) given (a)
+ def store(ptr, x) = %ptrStore(internal_cast(to=%PtrPtr(), ptr.val), x.val)
+ def load(ptr) = Ptr.new(%ptrLoad(internal_cast(to=%PtrPtr(), ptr)))
+ def storage_size() = 8 -- TODO: something more portable?
-- TODO: Storable instances for other types
-def malloc (n:Nat) -> {IO} (Ptr a) given (a|Storable) =
+def malloc(n:Nat) -> {IO} (Ptr a) given (a|Storable) =
numBytes = storage_size(a=a) * n
- Ptr.new(%alloc (nat_to_rep numBytes))
+ Ptr.new(%alloc(nat_to_rep numBytes))
-def free (ptr:Ptr a) -> {IO} () given (a) = %free(ptr.val)
+def free(ptr:Ptr a) -> {IO} () given (a) = %free(ptr.val)
-def (+>>) (ptr:Ptr a, i:Nat) -> Ptr a given (a|Storable) =
+def (+>>)(ptr:Ptr a, i:Nat) -> Ptr a given (a|Storable) =
i' = nat_to_rep $ i * storage_size(a=a)
Ptr.new(%ptrOffset(ptr.val, i'))
-- TODO: consider making a Storable instance for tables instead
-def store_table (ptr: Ptr a, tab:n=>a) -> {IO} () given (a|Storable, n|Ix) =
+def store_table(ptr: Ptr a, tab:n=>a) -> {IO} () given (a|Storable, n|Ix) =
for_ i. store(ptr +>> ordinal i, tab[i])
-def memcpy (dest:Ptr a, src:Ptr a, n:Nat) -> {IO} () given (a|Storable) =
+def memcpy(dest:Ptr a, src:Ptr a, n:Nat) -> {IO} () given (a|Storable) =
for_ i:(Fin n).
i' = ordinal i
store(dest +>> i', load $ src +>> i')
@@ -1187,74 +1187,75 @@ def with_table_ptr(
for i. store(ptr +>> ordinal i, xs[i])
action ptr
-def table_from_ptr (ptr:Ptr a) -> {IO} n=>a given (a|Storable, n|Ix) =
+def table_from_ptr(ptr:Ptr a) -> {IO} n=>a given (a|Storable, n|Ix) =
for i. load $ ptr +>> ordinal i
'## Miscellaneous common utilities
pi : Float = 3.141592653589793
-def id (x:a) -> a given (a) = x
-def dup (x:a) -> (a, a) given (a) = (x, x)
-def map (f:(a)->{|eff} b, xs: n=>a) -> {|eff} (n=>b) given (a, b, n|Ix, eff) =
+def id(x:a) -> a given (a) = x
+def dup(x:a) -> (a, a) given (a) = (x, x)
+def map(f:(a)->{|eff} b, xs: n=>a) -> {|eff} (n=>b) given (a, b, n|Ix, eff) =
for i. f xs[i]
-- map, flipped so that the function goes last
-def each (xs: n=>a, f:(a)->{|eff} b) -> {|eff} (n=>b) given (a, b, n|Ix, eff) =
+def each(xs: n=>a, f:(a)->{|eff} b) -> {|eff} (n=>b) given (a, b, n|Ix, eff) =
for i. f xs[i]
-def zip (xs:n=>a, ys:n=>b) -> (n=>(a,b)) given (a, b, n|Ix) = for i. (xs[i], ys[i])
-def unzip (xys:n=>(a,b)) -> (n=>a , n=>b) given (a, b, n|Ix)= (each xys fst, each xys snd)
-def fanout (x:a) -> n=>a given (n|Ix, a) = for i. x
-def sq (x:a) -> a given (a|Mul) = x * x
-def abs (x:a) -> a given (a|Sub|Ord) = select(x > zero, x, zero - x)
-def mod (x:a, y:a) -> a given (a|Add|Integral) = rem(y + rem(x, y), y)
+def zip(xs:n=>a, ys:n=>b) -> (n=>(a,b)) given (a, b, n|Ix) = for i. (xs[i], ys[i])
+def unzip(xys:n=>(a,b)) -> (n=>a , n=>b) given (a, b, n|Ix)= (each xys fst, each xys snd)
+def fanout(x:a) -> n=>a given (n|Ix, a) = for i. x
+def sq(x:a) -> a given (a|Mul) = x * x
+def abs(x:a) -> a given (a|Sub|Ord) = select(x > zero, x, zero - x)
+def mod(x:a, y:a) -> a given (a|Add|Integral) = rem(y + rem(x, y), y)
'## Table Operations
-instance Floating (n=>a) given (a|Floating, n|Ix)
- def exp (x) = each x exp
- def exp2 (x) = each x exp2
- def log (x) = each x log
- def log2 (x) = each x log2
- def log10 (x) = each x log10
- def log1p (x) = each x log1p
- def sin (x) = each x sin
- def cos (x) = each x cos
- def tan (x) = each x tan
- def sinh (x) = each x sinh
- def cosh (x) = each x cosh
- def tanh (x) = each x tanh
- def floor (x) = each x floor
- def ceil (x) = each x ceil
- def round (x) = each x round
- def sqrt (x) = each x sqrt
- def pow (x, y) = for i. pow(x[i], y[i])
- def lgamma (x) = each x lgamma
- def erf (x) = each x erf
- def erfc (x) = each x erfc
+instance Floating(n=>a) given (a|Floating, n|Ix)
+ def exp(x) = each x exp
+ def exp2(x) = each x exp2
+ def log(x) = each x log
+ def log2(x) = each x log2
+ def log10(x) = each x log10
+ def log1p(x) = each x log1p
+ def sin(x) = each x sin
+ def cos(x) = each x cos
+ def tan(x) = each x tan
+ def sinh(x) = each x sinh
+ def cosh(x) = each x cosh
+ def tanh(x) = each x tanh
+ def floor(x) = each x floor
+ def ceil(x) = each x ceil
+ def round(x) = each x round
+ def sqrt(x) = each x sqrt
+ def pow(x, y) = for i. pow(x[i], y[i])
+ def lgamma(x) = each x lgamma
+ def erf(x) = each x erf
+ def erfc(x) = each x erfc
'### Reductions
-- `combine` should be a commutative and associative, and form a
-- commutative monoid with `identity`
-def reduce (identity:a, combine:(a,a)->a, xs:n=>a) -> a given (a|Data, n|Ix) =
+def reduce(identity:a, combine:(a,a)->a, xs:n=>a) -> a given (a|Data, n|Ix) =
-- TODO: implement with the accumulator effect
fold identity \i c. combine(c, xs[i])
-- TODO: call this `scan` and call the current `scan` something else
-def scan' (init:a, body:(n,a)->a) -> n=>a given (a|Data, n|Ix) =
+def scan'(init:a, body:(n,a)->a) -> n=>a given (a|Data, n|Ix) =
snd $ scan init \i x. dup(body(i, x))
-def fsum (xs:n=>Float) -> Float given (n|Ix) = yield_accum(AddMonoid Float) \ref. for i. ref += xs[i]
-def sum (xs:n=>v) -> v given (n|Ix, v|Add) = reduce(zero, (+), xs)
-def prod (xs:n=>v) -> v given (n|Ix, v|Mul) = reduce(one , (*), xs)
-def mean (xs:n=>v) -> v given (n|Ix, v|VSpace) = sum xs / n_to_f (size n)
-def std (xs:n=>v) -> v given (n|Ix, v|Mul|Sub|VSpace|Floating) = sqrt $ mean (each xs sq) - sq (mean xs)
-def any (xs:n=>Bool) -> Bool given (n|Ix) = reduce(False, (||), xs)
-def all (xs:n=>Bool) -> Bool given (n|Ix) = reduce(True , (&&), xs)
+def fsum(xs:n=>Float) -> Float given (n|Ix) =
+ yield_accum(AddMonoid Float) \ref. for i. ref += xs[i]
+def sum(xs:n=>v) -> v given (n|Ix, v|Add) = reduce(zero, (+), xs)
+def prod(xs:n=>v) -> v given (n|Ix, v|Mul) = reduce(one , (*), xs)
+def mean(xs:n=>v) -> v given (n|Ix, v|VSpace) = sum xs / n_to_f (size n)
+def std(xs:n=>v) -> v given (n|Ix, v|Mul|Sub|VSpace|Floating) = sqrt $ mean (each xs sq) - sq (mean xs)
+def any(xs:n=>Bool) -> Bool given (n|Ix) = reduce(False, (||), xs)
+def all(xs:n=>Bool) -> Bool given (n|Ix) = reduce(True , (&&), xs)
'### apply_n
-def apply_n (n:Nat, x:a, f:(a) -> a) -> a given (a|Data) =
+def apply_n(n:Nat, x:a, f:(a) -> a) -> a given (a|Data) =
yield_state x \ref. for _:(Fin n).
ref := f (get ref)
@@ -1262,14 +1263,14 @@ def apply_n (n:Nat, x:a, f:(a) -> a) -> a given (a|Data) =
TODO: Move this to be with reductions?
It's a kind of `scan`.
-def cumsum (xs: n=>a) -> n=>a given (n|Ix, a|Add) =
+def cumsum(xs: n=>a) -> n=>a given (n|Ix, a|Add) =
total <- with_state zero
for i.
newTotal = get total + xs[i]
total := newTotal
newTotal
-def cumsum_low (xs: n=>a) -> n=>a given (n|Ix, a|Add) =
+def cumsum_low(xs: n=>a) -> n=>a given (n|Ix, a|Add) =
total <- with_state zero
for i.
oldTotal = get total
@@ -1281,14 +1282,14 @@ def cumsum_low (xs: n=>a) -> n=>a given (n|Ix, a|Add) =
'### AD operations
-- TODO: add vector space constraints
-def linearize (f:(a)->b, x:a) -> (b, (a)->b) given (a, b) =
+def linearize(f:(a)->b, x:a) -> (b, (a)->b) given (a, b) =
%linearize(\x. f x, x)
-def jvp (f:(a)->b, x:a, t:a) -> b given (a, b) = (snd $ linearize(f, x))(t)
-def transpose_linear (f:(a)->b) -> (b)->a given (a, b) = \ct.
+def jvp(f:(a)->b, x:a, t:a) -> b given (a, b) = (snd $ linearize(f, x))(t)
+def transpose_linear(f:(a)->b) -> (b)->a given (a, b) = \ct.
%linearTranspose(\x. f x, ct)
-def vjp (f:(a)->b, x:a) -> (b, (b)->a) given (a, b) =
+def vjp(f:(a)->b, x:a) -> (b, (b)->a) given (a, b) =
(y, df) = linearize(f, x)
(y, transpose_linear df)
@@ -1319,61 +1320,61 @@ interface HasDefaultTolerance(a)
default_atol : a
default_rtol : a
-def (~~) (x:a, y:a) -> Bool given (a|HasAllClose|HasDefaultTolerance) =
+def (~~)(x:a, y:a) -> Bool given (a|HasAllClose|HasDefaultTolerance) =
allclose(default_atol, default_rtol, x, y)
-instance HasAllClose Float32
- def allclose (atol, rtol, x, y) = abs (x - y) <= (atol + rtol * abs y)
+instance HasAllClose(Float32)
+ def allclose(atol, rtol, x, y) = abs (x - y) <= (atol + rtol * abs y)
-instance HasAllClose Float64
- def allclose (atol, rtol, x, y) = abs (x - y) <= (atol + rtol * abs y)
+instance HasAllClose(Float64)
+ def allclose(atol, rtol, x, y) = abs (x - y) <= (atol + rtol * abs y)
-instance HasDefaultTolerance Float32
+instance HasDefaultTolerance(Float32)
default_atol = f_to_f32 0.00001
default_rtol = f_to_f32 0.0001
-instance HasDefaultTolerance Float64
+instance HasDefaultTolerance(Float64)
default_atol = f_to_f64 0.00000001
default_rtol = f_to_f64 0.00001
-instance HasAllClose (a, b) given ( a|HasDefaultTolerance|HasAllClose
- , b|HasDefaultTolerance|HasAllClose)
- def allclose (atol, rtol, pair1, pair2) =
+instance HasAllClose((a, b)) given ( a|HasDefaultTolerance|HasAllClose
+ , b|HasDefaultTolerance|HasAllClose)
+ def allclose(atol, rtol, pair1, pair2) =
(x1, x2) = pair1
(y1, y2) = pair2
(x1 ~~ y1) && (x2 ~~ y2)
-instance HasDefaultTolerance (a, b) given (a|HasDefaultTolerance,b|HasDefaultTolerance)
+instance HasDefaultTolerance((a, b)) given (a|HasDefaultTolerance,b|HasDefaultTolerance)
default_atol = (default_atol, default_atol)
default_rtol = (default_rtol, default_rtol)
-instance HasAllClose (n=>t) given (n|Ix, t|HasAllClose)
- def allclose (atol, rtol, a, b) =
+instance HasAllClose(n=>t) given (n|Ix, t|HasAllClose)
+ def allclose(atol, rtol, a, b) =
all for i:n. allclose(atol[i], rtol[i], a[i], b[i])
-instance HasDefaultTolerance (n=>t) given (n|Ix, t|HasDefaultTolerance)
+instance HasDefaultTolerance(n=>t) given (n|Ix, t|HasDefaultTolerance)
default_atol = for i. default_atol
default_rtol = for i. default_rtol
'### AD Checking tools
-def check_deriv_base (f:(Float)->Float, x:Float) -> Bool =
+def check_deriv_base(f:(Float)->Float, x:Float) -> Bool =
eps = 0.01
ansFwd = deriv( f, x)
ansRev = deriv_rev( f, x)
ansNumeric = (f (x + eps) - f (x - eps)) / (2. * eps)
ansFwd ~~ ansNumeric && ansRev ~~ ansNumeric
-def check_deriv (f:(Float)->Float, x:Float) -> Bool =
+def check_deriv(f:(Float)->Float, x:Float) -> Bool =
check_deriv_base(f, x) && check_deriv_base(\x. deriv(f, x), x)
'## Length-erased lists
data List(a)=
- AsList (n:Nat, elements:(Fin n => a))
+ AsList(n:Nat, elements:(Fin n => a))
-instance Eq (List a) given (a|Eq)
- def (==) (xsList, ysList) =
+instance Eq(List a) given (a|Eq)
+ def (==)(xsList, ysList) =
AsList(nx,xs) = xsList
AsList(ny,ys) = ysList
if nx /= ny
@@ -1381,16 +1382,16 @@ instance Eq (List a) given (a|Eq)
else all for i:(Fin nx).
xs[i] == ys[unsafe_from_ordinal (ordinal i)]
-def unsafe_cast_table (xs:from=>a) -> to=>a given (to|Ix, from|Ix, a) =
+def unsafe_cast_table(xs:from=>a) -> to=>a given (to|Ix, from|Ix, a) =
for i. xs[unsafe_from_ordinal (ordinal i)]
-def to_list (xs:n=>a) -> List a given (n|Ix, a) =
+def to_list(xs:n=>a) -> List a given (n|Ix, a) =
n' = size n
AsList(_, unsafe_cast_table(to=Fin n', xs))
-instance Monoid (List a) given (a|Data)
+instance Monoid(List a) given (a|Data)
mempty = AsList(_, [])
- def (<>) (x, y) =
+ def (<>)(x, y) =
AsList(nx,xs) = x
AsList(ny,ys) = y
nz = nx + ny
@@ -1400,30 +1401,30 @@ instance Monoid (List a) given (a|Data)
True -> xs[unsafe_from_ordinal i']
False -> ys[unsafe_from_ordinal $ unsafe_nat_diff(i', nx)]
-named-instance ListMonoid (a|Data) -> Monoid (List a)
+named-instance ListMonoid (a|Data) -> Monoid(List a)
mempty = mempty
- def (<>) (x, y) = x <> y
+ def (<>)(x, y) = x <> y
-- TODO Eliminate or reimplement this operation, since it costs O(n)
-- where n is the length of the list held in the reference.
-def append (list: Ref(h, List a), x:a) -> {Accum h} ()
+def append(list: Ref(h, List a), x:a) -> {Accum h} ()
given (a|Data, h) (AccumMonoid(h, List a)) =
list += to_list [x]
-- TODO: replace `slice` with this?
-def post_slice (xs:n=>a, start:Post n, end:Post n) -> List a given (n|Ix, a) =
+def post_slice(xs:n=>a, start:Post n, end:Post n) -> List a given (n|Ix, a) =
slice_size = unsafe_nat_diff(ordinal end, ordinal start)
to_list for i:(Fin slice_size).
xs[unsafe_from_ordinal(n=n, ordinal i + ordinal start)]
'## Dynamic buffer
-struct DynBuffer (a) =
+struct DynBuffer(a) =
size : Ptr Nat
max_size : Ptr Nat
buffer : Ptr (Ptr a)
-def with_dynamic_buffer (action: (DynBuffer a) -> {IO} b) -> {IO} b given (a|Storable, b) =
+def with_dynamic_buffer(action: (DynBuffer a) -> {IO} b) -> {IO} b given (a|Storable, b) =
initMaxSize = 256
sizePtr <- with_alloc 1
store(sizePtr, 0)
@@ -1435,7 +1436,7 @@ def with_dynamic_buffer (action: (DynBuffer a) -> {IO} b) -> {IO} b given (a|Sto
free $ load bufferPtr
result
-def maybe_increase_buffer_size (db: DynBuffer a, sizeDelta:Nat) -> {IO} () given (a|Storable) =
+def maybe_increase_buffer_size(db: DynBuffer a, sizeDelta:Nat) -> {IO} () given (a|Storable) =
size = load db.size
max_size = load db.max_size
bufPtr = load db.buffer
@@ -1449,10 +1450,10 @@ def maybe_increase_buffer_size (db: DynBuffer a, sizeDelta:Nat) -> {IO} () given
store(db.max_size, newMaxSize)
store(db.buffer , newBufPtr)
-def add_at_nat_ptr (ptr: Ptr Nat, n:Nat) -> {IO} () =
+def add_at_nat_ptr(ptr: Ptr Nat, n:Nat) -> {IO} () =
store(ptr, load ptr + n)
-def extend_dynamic_buffer (buf: DynBuffer a, new:List a) -> {IO} () given (a|Storable) =
+def extend_dynamic_buffer(buf: DynBuffer a, new:List a) -> {IO} () given (a|Storable) =
AsList(n, xs) = new
maybe_increase_buffer_size(buf, n)
bufPtr = load buf.buffer
@@ -1460,23 +1461,23 @@ def extend_dynamic_buffer (buf: DynBuffer a, new:List a) -> {IO} () given (a|St
store_table(bufPtr +>> size, xs)
add_at_nat_ptr(buf.size, n)
-def load_dynamic_buffer (buf: DynBuffer a) -> {IO} (List a) given (a|Storable) =
+def load_dynamic_buffer(buf: DynBuffer a) -> {IO} (List a) given (a|Storable) =
bufPtr = load buf.buffer
size = load buf.size
AsList(size, table_from_ptr bufPtr)
-def push_dynamic_buffer (buf: DynBuffer a, x:a) -> {IO} () given (a|Storable) =
+def push_dynamic_buffer(buf: DynBuffer a, x:a) -> {IO} () given (a|Storable) =
extend_dynamic_buffer(buf, to_list [x])
'## Strings and Characters
String : Type = List Char
-def string_from_char_ptr (n:Word32, ptr:Ptr Char) -> {IO} String =
+def string_from_char_ptr(n:Word32, ptr:Ptr Char) -> {IO} String =
AsList(rep_to_nat n, table_from_ptr ptr)
-- TODO. This is ASCII code point. It really should be Int32 for Unicode codepoint
-def codepoint (c:Char) -> Int = w8_to_i c
+def codepoint(c:Char) -> Int = w8_to_i c
struct CString =
ptr : RawPtr
@@ -1488,6 +1489,7 @@ def with_c_string(
) -> {IO} a given (a) =
AsList(n, s') = s <> "\NUL"
with_table_ptr s' \ptr. action $ CString.new(ptr.val)
+
'### Show interface
For things that can be shown.
`show` gives a string representation of its input.
@@ -1495,57 +1497,57 @@ No particular promises are made to exactly what that representation will contain
In particular it is **not** promised to be parseable.
Nor does it promise a particular level of precision for numeric values.
-interface Show (a)
+interface Show(a)
show : (a) -> String
-instance Show String
+instance Show(String)
def show(x) = x
foreign "showInt32" showInt32 : (Int32) -> {IO} (Word32, RawPtr)
-instance Show Int32
- def show (x) = unsafe_io \.
+instance Show(Int32)
+ def show(x) = unsafe_io \.
(n, ptr) = showInt32 x
string_from_char_ptr(n, Ptr.new ptr)
foreign "showInt64" showInt64 : (Int64) -> {IO} (Word32, RawPtr)
-instance Show Int64
- def show (x) = unsafe_io \.
+instance Show(Int64)
+ def show(x) = unsafe_io \.
(n, ptr) = showInt64 x
string_from_char_ptr(n, Ptr.new ptr)
-instance Show Nat
- def show (x) = show $ n_to_i64 x
+instance Show(Nat)
+ def show(x) = show $ n_to_i64 x
foreign "showFloat32" showFloat32 : (Float32) -> {IO} (Word32, RawPtr)
-instance Show Float32
- def show (x) = unsafe_io \.
+instance Show(Float32)
+ def show(x) = unsafe_io \.
(n, ptr) = showFloat32 x
string_from_char_ptr(n, Ptr.new ptr)
foreign "showFloat64" showFloat64 : (Float64) -> {IO} (Word32, RawPtr)
-instance Show Float64
- def show (x) = unsafe_io \.
+instance Show(Float64)
+ def show(x) = unsafe_io \.
(n, ptr) = showFloat64 x
string_from_char_ptr(n, Ptr.new ptr)
-instance Show ()
+instance Show(())
def show(_) = "()"
-instance Show (a, b) given (a|Show, b|Show)
+instance Show((a, b)) given (a|Show, b|Show)
def show(x) =
(a, b) = x
"(" <> show a <> ", " <> show b <> ")"
-instance Show (a, b, c) given (a|Show, b|Show, c|Show)
+instance Show((a, b, c)) given (a|Show, b|Show, c|Show)
def show(x) =
(a, b, c) = x
"(" <> show a <> ", " <> show b <> ", " <> show c <> ")"
-instance Show (a, b, c, d) given (a|Show, b|Show, c|Show, d|Show)
+instance Show((a, b, c, d)) given (a|Show, b|Show, c|Show, d|Show)
def show(x) =
(a, b, c, d) = x
"(" <> show a <> ", " <> show b <> ", " <> show c <> ", " <> show d <> ")"
@@ -1553,13 +1555,13 @@ instance Show (a, b, c, d) given (a|Show, b|Show, c|Show, d|Show)
'### Parse interface
For types that can be parsed from a `String`.
-interface Parse (a)
+interface Parse(a)
parseString : (String) -> Maybe a
foreign "strtof" strtofFFI : (RawPtr, RawPtr) -> {IO} Float
-instance Parse Float
- def parseString (str) = unsafe_io \.
+instance Parse(Float)
+ def parseString(str) = unsafe_io \.
AsList(str_len, _) = str
with_c_string str \cStr.
with_alloc 1 \end_ptr:(Ptr (Ptr Char)).
@@ -1571,14 +1573,14 @@ instance Parse Float
'## Floating-point helper functions
TODO: Move these to be with Elementary/Special functions. Or move those to be here.
-def sign (x:Float) -> Float =
+def sign(x:Float) -> Float =
case x > 0.0 of
True -> 1.0
False -> case x < 0.0 of
True -> -1.0
False -> x
-def copysign (a:Float, b:Float) -> Float =
+def copysign(a:Float, b:Float) -> Float =
case b > 0.0 of
True -> a
False -> case b < 0.0 of
@@ -1590,31 +1592,31 @@ infinity = 1.0 / 0.0
nan = 0.0 / 0.0
-- Todo: use IEEE floating-point builtins.
-def isinf (x:Float) -> Bool = (x == infinity) || (x == -infinity)
-def isnan (x:Float) -> Bool = not (x >= x && x <= x)
+def isinf(x:Float) -> Bool = (x == infinity) || (x == -infinity)
+def isnan(x:Float) -> Bool = not (x >= x && x <= x)
-- Todo: use IEEE-754R 5.11: Floating Point Comparison Relation cmpUnordered.
-def either_is_nan (x:Float, y:Float) -> Bool = (isnan x) || (isnan y)
+def either_is_nan(x:Float, y:Float) -> Bool = (isnan x) || (isnan y)
'## File system operations
FilePath : Type = String
-def is_null_raw_ptr (ptr:RawPtr) -> Bool =
+def is_null_raw_ptr(ptr:RawPtr) -> Bool =
raw_ptr_to_i64 ptr == 0
-def from_nullable_raw_ptr (ptr:RawPtr) -> Maybe (Ptr a) given (a) =
+def from_nullable_raw_ptr(ptr:RawPtr) -> Maybe (Ptr a) given (a) =
if is_null_raw_ptr ptr
then Nothing
else Just $ Ptr.new ptr
-def c_string_ptr (s:CString) -> Maybe (Ptr Char) = from_nullable_raw_ptr s.ptr
+def c_string_ptr(s:CString) -> Maybe (Ptr Char) = from_nullable_raw_ptr s.ptr
data StreamMode =
ReadMode
WriteMode
-struct Stream (mode:StreamMode) =
+struct Stream(mode:StreamMode) =
ptr : RawPtr
'### Stream IO
@@ -1625,7 +1627,7 @@ foreign "fwrite" fwriteFFI : (RawPtr, Int64, Int64, RawPtr) -> {IO} Int64
foreign "fread" freadFFI : (RawPtr, Int64, Int64, RawPtr) -> {IO} Int64
foreign "fflush" fflushFFI : (RawPtr) -> {IO} Int64
-def fopen (path:String, mode:StreamMode) -> {IO} (Stream mode) =
+def fopen(path:String, mode:StreamMode) -> {IO} (Stream mode) =
modeStr = case mode of
ReadMode -> "r"
WriteMode -> "w"
@@ -1633,11 +1635,11 @@ def fopen (path:String, mode:StreamMode) -> {IO} (Stream mode) =
with_c_string modeStr \cMode.
Stream.new $ fopenFFI(cPath.ptr, cMode.ptr)
-def fclose (stream:Stream mode) -> {IO} () given (mode) =
+def fclose(stream:Stream mode) -> {IO} () given (mode) =
fcloseFFI stream.ptr
()
-def fwrite (stream:Stream WriteMode, s:String) -> {IO} () =
+def fwrite(stream:Stream WriteMode, s:String) -> {IO} () =
AsList(n, s') = s
with_table_ptr s' \ptr.
fwriteFFI(ptr.val, i_to_i64 1, n_to_i64 n, stream.ptr)
@@ -1647,21 +1649,21 @@ def fwrite (stream:Stream WriteMode, s:String) -> {IO} () =
'### Iteration
TODO: move this out of the file-system section
-def while (body: () -> {|eff} Bool) -> {|eff} () given (eff) =
+def while(body: () -> {|eff} Bool) -> {|eff} () given (eff) =
body' : () -> {|eff} Word8 = \. b_to_w8 $ body()
- %while body'
+ %while(body')
-data IterResult (a|Data) =
+data IterResult(a|Data) =
Continue
Done(a)
-- TODO: can we improve effect inference so we don't need this?
-def lift_state (ref: Ref(h, c), f:(a) -> {|eff} b, x:a) -> {State h|eff} b
+def lift_state(ref: Ref(h, c), f:(a) -> {|eff} b, x:a) -> {State h|eff} b
given (a, b, c, h, eff) =
f x
-- A little iteration combinator
-def iter (body: (Nat) -> {|eff} IterResult a) -> {|eff} a given (a|Data, eff) =
+def iter(body: (Nat) -> {|eff} IterResult a) -> {|eff} a given (a|Data, eff) =
result = yield_state Nothing \resultRef.
i <- with_state 0
while \.
@@ -1675,18 +1677,18 @@ def iter (body: (Nat) -> {|eff} IterResult a) -> {|eff} a given (a|Data, eff) =
Just(ans) -> ans
Nothing -> unreachable()
-def bounded_iter
- (maxIters:Nat, fallback:a)
- -> ((Nat) -> {|eff} IterResult a)
- -> {|eff} a
- given (a|Data, eff) = \body. iter \i.
+def bounded_iter(
+ maxIters:Nat,
+ fallback:a,
+ body:(Nat) -> {|eff} IterResult a
+ ) -> {|eff} a given (a|Data, eff) = iter \i.
if i >= maxIters
then Done fallback
else body i
'### Environment Variables
-def from_c_string (s:CString) -> {IO} (Maybe String) =
+def from_c_string(s:CString) -> {IO} (Maybe String) =
case c_string_ptr s of
Nothing -> Nothing
Just(ptr) ->
@@ -1702,16 +1704,16 @@ def from_c_string (s:CString) -> {IO} (Maybe String) =
foreign "getenv" getenvFFI : (RawPtr) -> {IO} RawPtr
-def get_env (name:String) -> {IO} Maybe String =
+def get_env(name:String) -> {IO} Maybe String =
cStr <- with_c_string name
getenvFFI cStr.ptr | CString.new | from_c_string
-def check_env (name:String) -> {IO} Bool =
+def check_env(name:String) -> {IO} Bool =
is_just $ get_env name
'### More Stream IO
-def fread (stream:Stream ReadMode) -> {IO} String =
+def fread(stream:Stream ReadMode) -> {IO} String =
-- TODO: allow reading longer files!
n = 4096
ptr:(Ptr Char) <- with_alloc n
@@ -1726,11 +1728,11 @@ def fread (stream:Stream ReadMode) -> {IO} String =
'### Print
-def get_output_stream () -> {IO} Stream WriteMode =
+def get_output_stream() -> {IO} Stream WriteMode =
Stream.new $ %outputStream()
@noinline
-def print (s:String) -> {IO} () =
+def print(s:String) -> {IO} () =
stream = get_output_stream()
fwrite(stream, s)
fwrite(stream, "\n")
@@ -1742,7 +1744,7 @@ foreign "remove" removeFFI : (RawPtr) -> {IO} Int64
foreign "mkstemp" mkstempFFI : (RawPtr) -> {IO} Int32
foreign "close" closeFFI : (Int32) -> {IO} Int32
-def shell_out (command:String) -> {IO} String =
+def shell_out(command:String) -> {IO} String =
modeStr = "r"
with_c_string command \command'.
with_c_string modeStr \modeStr'.
@@ -1757,15 +1759,15 @@ Not to be confused with a partially applied function
'### Error throwing
@noinline
-def error (s:String) -> a given (a|Data) = unsafe_io \.
+def error(s:String) -> a given (a|Data) = unsafe_io \.
print s
- %throwError a
+ %throwError(a)
-def todo () ->> a given (a|Data) = error "TODO: implement it!"
+def todo() ->> a given (a|Data) = error "TODO: implement it!"
'### File Operations
-def delete_file (f:FilePath) -> {IO} () =
+def delete_file(f:FilePath) -> {IO} () =
s <- with_c_string(f)
removeFFI s.ptr
()
@@ -1784,13 +1786,13 @@ def with_file(
fclose stream
result
-def write_file (f:FilePath, s:String) -> {IO} () =
+def write_file(f:FilePath, s:String) -> {IO} () =
with_file(f, WriteMode) \stream. fwrite(stream, s)
-def read_file (f:FilePath) -> {IO} String =
+def read_file(f:FilePath) -> {IO} String =
with_file(f, ReadMode) \stream. fread stream
-def has_file (f:FilePath) -> {IO} Bool =
+def has_file(f:FilePath) -> {IO} Bool =
stream = fopen(f, ReadMode)
result = not (is_null_raw_ptr stream.ptr)
if result then fclose stream
@@ -1798,19 +1800,19 @@ def has_file (f:FilePath) -> {IO} Bool =
'### Temporary Files
-def new_temp_file () -> {IO} FilePath =
+def new_temp_file() -> {IO} FilePath =
s <- with_c_string "/tmp/dex-XXXXXX"
fd = mkstempFFI s.ptr
closeFFI fd
string_from_char_ptr(15, (Ptr.new s.ptr))
-def with_temp_file (action: (FilePath) -> {IO} a) -> {IO} a given (a) =
+def with_temp_file(action: (FilePath) -> {IO} a) -> {IO} a given (a) =
tmpFile = new_temp_file()
result = action tmpFile
delete_file tmpFile
result
-def with_temp_files (action: (n=>FilePath) -> {IO} a) -> {IO} a given (n|Ix, a) =
+def with_temp_files(action: (n=>FilePath) -> {IO} a) -> {IO} a given (n|Ix, a) =
tmpFiles = for i. new_temp_file()
result = action tmpFiles
for i. delete_file tmpFiles[i]
@@ -1819,37 +1821,37 @@ def with_temp_files (action: (n=>FilePath) -> {IO} a) -> {IO} a given (n|Ix, a)
'### Table operations
@noinline
-def from_ordinal_error (i:Nat, upper:Nat) -> String =
+def from_ordinal_error(i:Nat, upper:Nat) -> String =
"Ordinal index out of range:" <> show i <> " >= " <> show upper
-def from_ordinal (i:Nat) -> n given (n|Ix) =
+def from_ordinal(i:Nat) -> n given (n|Ix) =
case i < size n of
True -> unsafe_from_ordinal i
False -> error $ from_ordinal_error(i, size n)
-- TODO: should this be called `from_ordinal`?
-def to_ix (i:Nat) -> Maybe n given (n|Ix) =
+def to_ix(i:Nat) -> Maybe n given (n|Ix) =
case i < size n of
True -> Just $ unsafe_from_ordinal i
False -> Nothing
-- TODO: could make an `unsafeCastIndex` and this could avoid the runtime copy
-- TODO: safe (runtime-checked) and unsafe versions
-def cast_table (xs:to=>a) -> from=>a given (from|Ix, to|Ix, a|Data) =
+def cast_table(xs:to=>a) -> from=>a given (from|Ix, to|Ix, a|Data) =
case size from == size to of
True -> unsafe_cast_table xs
False -> error $
"Table size mismatch in cast: " <> show (size from) <> " vs " <> show (size to)
-def asidx (i:Nat) -> n given (n|Ix) = from_ordinal i
-def (@) (i:Nat, n|Ix) -> n = from_ordinal i
+def asidx(i:Nat) -> n given (n|Ix) = from_ordinal i
+def (@)(i:Nat, n|Ix) -> n = from_ordinal i
-def slice (xs:n=>a, start:Nat, m|Ix) -> m=>a given (n|Ix, a) =
+def slice(xs:n=>a, start:Nat, m|Ix) -> m=>a given (n|Ix, a) =
for i. xs[from_ordinal (ordinal i + start)]
-def head (xs:n=>a) -> a given (n|Ix, a) = xs[0@_]
+def head(xs:n=>a) -> a given (n|Ix, a) = xs[0@_]
-def tail (xs:n=>a, start:Nat) -> List a given (n|Ix, a) =
+def tail(xs:n=>a, start:Nat) -> List a given (n|Ix, a) =
numElts = size n -| start
to_list $ slice(xs, start, Fin numElts)
@@ -1864,7 +1866,7 @@ Dex's PRNG system is modelled directly after [JAX's](https://github.com/google/j
Key = Word64
@noinline
-def threefry_2x32 (k:Word64, count:Word64) -> Word64 =
+def threefry_2x32(k:Word64, count:Word64) -> Word64 =
-- Based on jax's threefry_2x32 by Matt Johnson and Peter Hawkins
rotations1 = [13, 15, 26, 6]
rotations2 = [17, 29, 16, 24]
@@ -1897,31 +1899,31 @@ def threefry_2x32 (k:Word64, count:Word64) -> Word64 =
(w32_to_w64 x .<<. 32) .|. (w32_to_w64 y)
-def hash (x:Key, y:Nat) -> Key =
+def hash(x:Key, y:Nat) -> Key =
y64 = n_to_w64 y
threefry_2x32(x, y64)
-def new_key (x:Nat) -> Key = hash(0, x)
-def many (f:(Key)->a, k:Key, i:n) -> a given (a, n|Ix) = f hash(k, ordinal i)
-def ixkey (k:Key, i:n) -> Key given (n|Ix) = hash(k, ordinal i)
-def split_key (k:Key) -> Fin n => Key given (n:Nat) = for i. ixkey(k, i)
+def new_key(x:Nat) -> Key = hash(0, x)
+def many(f:(Key)->a, k:Key, i:n) -> a given (a, n|Ix) = f hash(k, ordinal i)
+def ixkey(k:Key, i:n) -> Key given (n|Ix) = hash(k, ordinal i)
+def split_key(k:Key) -> Fin n => Key given (n:Nat) = for i. ixkey(k, i)
'### Sample Generators
These functions generate samples taken from, different distributions.
Such as `rand_mat` with samples from the distribution of floating point matrices where each element is taken from a i.i.d. uniform distribution. Note that additional standard distributions are provided by the `stats` library.
-def rand (k:Key) -> Float =
+def rand(k:Key) -> Float =
exponent_bits = 1065353216 -- 1065353216 = 127 << 23
mantissa_bits = (high_word k .&. 8388607) -- 8388607 == (1 << 23) - 1
bits = exponent_bits .|. mantissa_bits
%bitcast(Float, bits) - 1.0
-def rand_vec (n:Nat, f: (Key) -> a, k: Key) -> Fin n => a given (n|Ix, a) =
+def rand_vec(n:Nat, f: (Key) -> a, k: Key) -> Fin n => a given (n|Ix, a) =
for i:(Fin n). f ixkey(k, i)
-def rand_mat (n:Nat, m:Nat, f: (Key) -> a, k: Key) -> Fin n => Fin m => a given (a) =
+def rand_mat(n:Nat, m:Nat, f: (Key) -> a, k: Key) -> Fin n => Fin m => a given (a) =
for i j. f ixkey(k, (i, j))
-def randn (k:Key) -> Float =
+def randn(k:Key) -> Float =
[k1, k2] = split_key k
-- rand is uniform between 0 and 1, but implemented such that it rounds to 0
-- (in float32) once every few million draws, but never rounds to 1.
@@ -1930,14 +1932,14 @@ def randn (k:Key) -> Float =
sqrt ((-2.0) * log u1) * cos (2.0 * pi * u2)
-- TODO: Make this better...
-def rand_int (k:Key) -> Nat = w64_to_n k `mod` 2147483647
+def rand_int(k:Key) -> Nat = w64_to_n k `mod` 2147483647
-def bern (p:Float, k:Key) -> Bool = rand k < p
+def bern(p:Float, k:Key) -> Bool = rand k < p
-def randn_vec (k:Key) -> n=>Float given (n|Ix) =
+def randn_vec(k:Key) -> n=>Float given (n|Ix) =
for i. randn (ixkey(k, i))
-def rand_idx (k:Key) -> n given (n|Ix) =
+def rand_idx(k:Key) -> n given (n|Ix) =
rand k * n_to_f (size n) | floor | f_to_n | unsafe_from_ordinal
'## Inner product typeclass
@@ -1945,10 +1947,10 @@ def rand_idx (k:Key) -> n given (n|Ix) =
interface InnerProd(v|VSpace)
inner_prod : (v, v) -> Float
-instance InnerProd Float
+instance InnerProd(Float)
def inner_prod(x, y) = x * y
-instance InnerProd (n=>a) given (a|InnerProd, n|Ix)
+instance InnerProd(n=>a) given (a|InnerProd, n|Ix)
def inner_prod(x, y) =sum for i. inner_prod(x[i], y[i])
'## Arbitrary
@@ -1957,40 +1959,40 @@ Type class for generating example values
interface Arbitrary(a)
arb : (Key) -> a
-instance Arbitrary Bool
- def arb (key) = key .&. 1 == 0
+instance Arbitrary(Bool)
+ def arb(key) = key .&. 1 == 0
-instance Arbitrary Float32
- def arb (key) = randn key
+instance Arbitrary(Float32)
+ def arb(key) = randn key
-instance Arbitrary Int32
- def arb (key) = f_to_i $ randn key * 5.0
+instance Arbitrary(Int32)
+ def arb(key) = f_to_i $ randn key * 5.0
-instance Arbitrary Nat
- def arb (key) = f_to_n $ randn key * 5.0
+instance Arbitrary(Nat)
+ def arb(key) = f_to_n $ randn key * 5.0
-instance Arbitrary (n=>a) given (n|Ix, a|Arbitrary)
- def arb (key) = for i. arb $ ixkey(key, i)
+instance Arbitrary(n=>a) given (n|Ix, a|Arbitrary)
+ def arb(key) = for i. arb $ ixkey(key, i)
-instance Arbitrary ((i:n)=>(..<i) => a) given (n|Ix, a|Arbitrary)
- def arb (key) = for i. arb $ ixkey(key, i)
+instance Arbitrary((i:n)=>(..<i) => a) given (n|Ix, a|Arbitrary)
+ def arb(key) = for i. arb $ ixkey(key, i)
-instance Arbitrary ((i:n)=>(..i) => a) given (n|Ix, a|Arbitrary)
- def arb (key) = for i. arb $ ixkey(key, i)
+instance Arbitrary((i:n)=>(..i) => a) given (n|Ix, a|Arbitrary)
+ def arb(key) = for i. arb $ ixkey(key, i)
-instance Arbitrary ((i:n)=>(i..) => a) given (n|Ix, a|Arbitrary)
- def arb (key) = for i. arb $ ixkey(key, i)
+instance Arbitrary((i:n)=>(i..) => a) given (n|Ix, a|Arbitrary)
+ def arb(key) = for i. arb $ ixkey(key, i)
-instance Arbitrary ((i:n)=>(i<..) => a) given (n|Ix, a|Arbitrary)
- def arb (key) = for i. arb $ ixkey(key, i)
+instance Arbitrary((i:n)=>(i<..) => a) given (n|Ix, a|Arbitrary)
+ def arb(key) = for i. arb $ ixkey(key, i)
-instance Arbitrary (a, b) given (a|Arbitrary, b|Arbitrary)
- def arb (key) =
+instance Arbitrary((a, b)) given (a|Arbitrary, b|Arbitrary)
+ def arb(key) =
[k1, k2] = split_key key
(arb k1, arb k2)
-instance Arbitrary (Fin n) given (n)
- def arb (key) = rand_idx key
+instance Arbitrary(Fin n) given (n)
+ def arb(key) = rand_idx key
'## Ord on Arrays
@@ -1998,7 +2000,7 @@ instance Arbitrary (Fin n) given (n)
'returns the highest index `i` such that `xs.i <= x`
-def search_sorted (xs:n=>a, x:a) -> Maybe n given (n|Ix, a|Ord) =
+def search_sorted(xs:n=>a, x:a) -> Maybe n given (n|Ix, a|Ord) =
if size n == 0
then Nothing
else if x < xs[from_ordinal 0]
@@ -2019,24 +2021,24 @@ def search_sorted (xs:n=>a, x:a) -> Maybe n given (n|Ix, a|Ord) =
'### min / max etc
-def min_by (f:(a)->o, x:a, y:a) -> a given (o|Ord, a) = select(f x < f y, x, y)
-def max_by (f:(a)->o, x:a, y:a) -> a given (o|Ord, a) = select(f x > f y, x, y)
+def min_by(f:(a)->o, x:a, y:a) -> a given (o|Ord, a) = select(f x < f y, x, y)
+def max_by(f:(a)->o, x:a, y:a) -> a given (o|Ord, a) = select(f x > f y, x, y)
-def min (x1: o, x2: o) -> o given (o|Ord) = min_by(id, x1, x2)
-def max (x1: o, x2: o) -> o given (o|Ord) = max_by(id, x1, x2)
+def min(x1: o, x2: o) -> o given (o|Ord) = min_by(id, x1, x2)
+def max(x1: o, x2: o) -> o given (o|Ord) = max_by(id, x1, x2)
-def minimum_by (f:(a)->o, xs:n=>a) -> a given (a|Data, o|Ord, n|Ix) =
+def minimum_by(f:(a)->o, xs:n=>a) -> a given (a|Data, o|Ord, n|Ix) =
reduce(xs[0@_], \x y. min_by(f, x, y), xs)
-def maximum_by (f:(a)->o, xs:n=>a) -> a given (a|Data, o|Ord, n|Ix) =
+def maximum_by(f:(a)->o, xs:n=>a) -> a given (a|Data, o|Ord, n|Ix) =
reduce(xs[0@_], \x y. max_by(f, x, y), xs)
-def minimum (xs:n=>o) -> o given (n|Ix, o|Ord) = minimum_by(id, xs)
-def maximum (xs:n=>o) -> o given (n|Ix, o|Ord) = maximum_by(id, xs)
+def minimum(xs:n=>o) -> o given (n|Ix, o|Ord) = minimum_by(id, xs)
+def maximum(xs:n=>o) -> o given (n|Ix, o|Ord) = maximum_by(id, xs)
'### argmin/argmax
-- TODO: put in same section as `searchsorted`
-def argscan (comp:(o,o)->Bool, xs:n=>o) -> n given (o|Ord, n|Ix) =
+def argscan(comp:(o,o)->Bool, xs:n=>o) -> n given (o|Ord, n|Ix) =
zeroth = (0@_, xs[0@_])
compare = \p1 p2.
(idx1, x1) = p1
@@ -2045,8 +2047,8 @@ def argscan (comp:(o,o)->Bool, xs:n=>o) -> n given (o|Ord, n|Ix) =
zipped = for i. (i, xs[i])
fst $ reduce(zeroth, compare, zipped)
-def argmin (xs:n=>o) -> n given (n|Ix, o|Ord) = argscan((<), xs)
-def argmax (xs:n=>o) -> n given (n|Ix, o|Ord) = argscan((>), xs)
+def argmin(xs:n=>o) -> n given (n|Ix, o|Ord) = argscan((<), xs)
+def argmax(xs:n=>o) -> n given (n|Ix, o|Ord) = argscan((>), xs)
def lexical_order(
compareElements:(n,n)->Bool,
@@ -2080,13 +2082,13 @@ def lexical_order(
True -> Continue
False -> Done False
-instance Ord (List n) given (n|Ord)
- def (>) (xs, ys) = lexical_order((>), (<), xs, ys)
- def (<) (xs, ys) = lexical_order((>), (<), xs, ys)
+instance Ord(List n) given (n|Ord)
+ def (>)(xs, ys) = lexical_order((>), (>), xs, ys)
+ def (<)(xs, ys) = lexical_order((<), (<), xs, ys)
'### clip
-def clip (bounds:(a,a), x:a) -> a given (a|Ord) =
+def clip(bounds:(a,a), x:a) -> a given (a|Ord) =
(low,high) = bounds
min(high, max(low, x))
@@ -2094,7 +2096,7 @@ def clip (bounds:(a,a), x:a) -> a given (a|Ord) =
TODO: these should be with the other Elementary/Special Functions
### atan/atan2
-def atan_inner (x:Float) -> Float =
+def atan_inner(x:Float) -> Float =
-- From "Computing accurate Horner form approximations to
-- special functions in finite precision arithmetic"
-- https://arxiv.org/abs/1508.03211
@@ -2112,10 +2114,10 @@ def atan_inner (x:Float) -> Float =
r * x + x
-def min_and_max (x:a, y:a) -> (a, a) given (a|Ord) =
+def min_and_max(x:a, y:a) -> (a, a) given (a|Ord) =
select(x < y, (x, y), (y, x)) -- get both with one comparison.
-def atan2 (y:Float, x:Float) -> Float =
+def atan2(y:Float, x:Float) -> Float =
-- Based off of the Tensorflow implementation at
-- github.com/tensorflow/mlir-hlo/blob/master/lib/
-- Dialect/mhlo/transforms/legalize_trigonometric_to_approximation.cc#L147
@@ -2133,21 +2135,21 @@ def atan2 (y:Float, x:Float) -> Float =
a = copysign(a, y)
select(either_is_nan(x, y), nan, a) -- Propagate NaNs.
-def atan (x:Float) -> Float = atan2(x, 1.0)
+def atan(x:Float) -> Float = atan2(x, 1.0)
'## Miscellaneous utilities
TODO: all of these should be in some other section
-def reflect (i:n) -> n given (n|Ix) =
+def reflect(i:n) -> n given (n|Ix) =
unsafe_from_ordinal $ unsafe_nat_diff(size n, ordinal i + 1)
-def reverse (x:n=>a) -> n=>a given (n|Ix, a) =
+def reverse(x:n=>a) -> n=>a given (n|Ix, a) =
for i. x[reflect i]
-def wrap_periodic (n|Ix, i:Nat) -> n =
+def wrap_periodic(n|Ix, i:Nat) -> n =
unsafe_from_ordinal(n=n, i `mod` size n)
-def pad_to (m|Ix, x:a, xs:n=>a) -> m=>a given (n|Ix, a) =
+def pad_to(m|Ix, x:a, xs:n=>a) -> m=>a given (n|Ix, a) =
n' = size n
for i.
i' = ordinal i
@@ -2155,13 +2157,13 @@ def pad_to (m|Ix, x:a, xs:n=>a) -> m=>a given (n|Ix, a) =
True -> xs[i'@_]
False -> x
-def idiv_ceil (x:Nat, y:Nat) -> Nat = x `idiv` y + b_to_n (x `rem` y /= 0)
-def intdiv2 (x:Nat) -> Nat = rep_to_nat $ %shr(nat_to_rep x, 1 :: NatRep)
-def intpow2 (power:Nat) -> Nat = rep_to_nat $ %shl(1 :: NatRep, nat_to_rep power)
-def is_odd (x:Nat) -> Bool = rem(x, 2) == 1
-def is_even (x:Nat) -> Bool = rem(x, 2) == 0
+def idiv_ceil(x:Nat, y:Nat) -> Nat = x `idiv` y + b_to_n (x `rem` y /= 0)
+def intdiv2(x:Nat) -> Nat = rep_to_nat $ %shr(nat_to_rep x, 1 :: NatRep)
+def intpow2(power:Nat) -> Nat = rep_to_nat $ %shl(1 :: NatRep, nat_to_rep power)
+def is_odd(x:Nat) -> Bool = rem(x, 2) == 1
+def is_even(x:Nat) -> Bool = rem(x, 2) == 0
-def is_power_of_2 (x:Nat) -> Bool =
+def is_power_of_2(x:Nat) -> Bool =
-- A fast trick based on bitwise AND.
-- This works on integer types larger than 8 bits.
-- Note: The bitwise and operator (.&.)
@@ -2178,7 +2180,7 @@ def is_power_of_2 (x:Nat) -> Bool =
-- code path in ImpToLLVM, because it's the first LLVM intrinsic
-- we have with a fixed-point argument.
-- https://llvm.org/docs/LangRef.html#llvm-ctlz-intrinsic
-def natlog2 (x:Nat) -> Nat =
+def natlog2(x:Nat) -> Nat =
tmp = yield_state 0 \ans.
cmp <- run_state 1
while \.
@@ -2191,8 +2193,11 @@ def natlog2 (x:Nat) -> Nat =
False
unsafe_nat_diff(tmp, 1) -- TODO: something less horrible
-def general_integer_power
- (times:(a,a)->a, one:a, base:a, power:Nat) -> a given (a|Data) =
+def general_integer_power(
+ times:(a,a)->a,
+ one:a, base:a,
+ power:Nat
+ ) -> a given (a|Data) =
iters = if power == 0 then 0 else 1 + natlog2 power
-- Implements exponentiation by squaring.
-- This could be nicer if there were a way to explicitly
@@ -2206,33 +2211,33 @@ def general_integer_power
z := times(get z, get z)
pow := intdiv2 (get pow)
-def intpow (base:a, power:Nat) -> a given (a|Mul) =
+def intpow(base:a, power:Nat) -> a given (a|Mul) =
general_integer_power((*), one, base, power)
def from_just(x:Maybe a) -> a given (a) = case x of Just(x') -> x'
-def any_sat (f:(a)->Bool, xs:n=>a) -> Bool given (a, n|Ix) = any(each xs f)
+def any_sat(f:(a)->Bool, xs:n=>a) -> Bool given (a, n|Ix) = any(each xs f)
-def seq_maybes (xs: n=>Maybe a) -> Maybe (n => a) given (n|Ix, a) =
+def seq_maybes(xs: n=>Maybe a) -> Maybe (n => a) given (n|Ix, a) =
-- is it possible to implement this safely? (i.e. without using partial
-- functions)
case any_sat(is_nothing, xs) of
True -> Nothing
False -> Just $ each xs from_just
-def linear_search (xs:n=>a, query:a) -> Maybe n given (n|Ix, a|Eq) =
+def linear_search(xs:n=>a, query:a) -> Maybe n given (n|Ix, a|Eq) =
yield_state Nothing \ref. for i.
case xs[i] == query of
True -> ref := Just i
False -> ()
-def list_length (l:List a) -> Nat given (a) =
+def list_length(l:List a) -> Nat given (a) =
AsList(n, _) = l
n
-- This is for efficiency (rather than using `<>` repeatedly)
-- TODO: we want this for any monoid but this implementation won't work.
-def concat (lists:n=>(List a)) -> List a given (a, n|Ix) =
+def concat(lists:n=>(List a)) -> List a given (a, n|Ix) =
totalSize = sum for i. list_length lists[i]
to_list $ with_state 0 \listIdx.
eltIdx <- with_state 0
@@ -2250,7 +2255,7 @@ def concat (lists:n=>(List a)) -> List a given (a, n|Ix) =
eltIdx := eltIdxVal + 1
xs[eltIdxVal@_]
-def cat_maybes (xs:n=>Maybe a) -> List a given (n|Ix, a|Data) =
+def cat_maybes(xs:n=>Maybe a) -> List a given (n|Ix, a|Data) =
(num_res, res_inds) = yield_state (0::Nat, for i:n. Nothing) \ref.
for i. case xs[i] of
Just(_) ->
@@ -2265,19 +2270,19 @@ def cat_maybes (xs:n=>Maybe a) -> List a given (n|Ix, a|Data) =
Nothing -> todo -- Impossible
Nothing -> todo -- Impossible
-def filter (xs:n=>a, condition:(a)->Bool) -> List a given (a|Data, n|Ix) =
+def filter(xs:n=>a, condition:(a)->Bool) -> List a given (a|Data, n|Ix) =
cat_maybes $ for i. if condition xs[i] then Just xs[i] else Nothing
-def arg_filter (xs:n=>a, condition:(a)->Bool) -> List n given (a|Data, n|Ix) =
+def arg_filter(xs:n=>a, condition:(a)->Bool) -> List n given (a|Data, n|Ix) =
cat_maybes $ for i. if condition xs[i] then Just i else Nothing
-- TODO: use `ix_offset : [Ix n] -> n -> Int -> Maybe n` instead
-def prev_ix (i:n) -> Maybe n given (n|Ix) =
+def prev_ix(i:n) -> Maybe n given (n|Ix) =
case i_to_n (n_to_i (ordinal i) - 1) of
Nothing -> Nothing
Just(i_prev) -> unsafe_from_ordinal(i_prev) | Just
-def lines (source:String) -> List String =
+def lines(source:String) -> List String =
AsList(_, s) = source
AsList(num_lines, newline_ixs) = cat_maybes for i_char.
if s[i_char] == '\n'
@@ -2293,35 +2298,35 @@ def lines (source:String) -> List String =
'## Probability
-- cdf should include 0.0 but not 1.0
-def categorical_from_cdf (cdf: n=>Float, key: Key) -> n given (n|Ix) =
+def categorical_from_cdf(cdf: n=>Float, key: Key) -> n given (n|Ix) =
r = rand key
case search_sorted(cdf, r) of
Just(i) -> i
-def normalize_pdf (xs: d=>Float) -> d=>Float given (d|Ix) = xs / sum xs
+def normalize_pdf(xs: d=>Float) -> d=>Float given (d|Ix) = xs / sum xs
-def cdf_for_categorical (logprobs: n=>Float) -> n=>Float given (n|Ix) =
+def cdf_for_categorical(logprobs: n=>Float) -> n=>Float given (n|Ix) =
maxLogProb = maximum logprobs
cumsum_low $ normalize_pdf $ for i. exp(logprobs[i] - maxLogProb)
-def categorical (logprobs: n=>Float, key: Key) -> n given (n|Ix) =
+def categorical(logprobs: n=>Float, key: Key) -> n given (n|Ix) =
categorical_from_cdf(cdf_for_categorical logprobs, key)
-- batch variant to share the work of forming the cumsum
-- (alternatively we could rely on hoisting of loop constants)
-def categorical_batch (logprobs: n=>Float, key: Key) -> m=>n given (n|Ix, m|Ix) =
+def categorical_batch(logprobs: n=>Float, key: Key) -> m=>n given (n|Ix, m|Ix) =
cdf = cdf_for_categorical logprobs
for i. categorical_from_cdf(cdf, ixkey(key, i))
-def logsumexp (x: n=>Float) -> Float given (n|Ix) =
+def logsumexp(x: n=>Float) -> Float given (n|Ix) =
m = maximum x
m + (log $ sum for i. exp (x[i] - m))
-def logsoftmax (x: n=>Float) -> n=>Float given (n|Ix) =
+def logsoftmax(x: n=>Float) -> n=>Float given (n|Ix) =
lse = logsumexp x
for i. x[i] - lse
-def softmax (x: n=>Float) -> n=>Float given (n|Ix) =
+def softmax(x: n=>Float) -> n=>Float given (n|Ix) =
m = maximum x
e = for i. exp (x[i] - m)
s = sum e
@@ -2330,45 +2335,45 @@ def softmax (x: n=>Float) -> n=>Float given (n|Ix) =
'## Polynomials
TODO: Move this somewhere else
-def evalpoly (coefficients:n=>v, x:Float) -> v given (n|Ix, v|VSpace) =
+def evalpoly(coefficients:n=>v, x:Float) -> v given (n|Ix, v|VSpace) =
-- Evaluate a polynomial at x. Same as Numpy's polyval.
fold zero \i c. coefficients[i] + x .* c
'## TestMode
-- TODO: move this to be in Testing Helpers
-def dex_test_mode () -> Bool = unsafe_io \. check_env "DEX_TEST_MODE"
+def dex_test_mode() -> Bool = unsafe_io \. check_env "DEX_TEST_MODE"
'## Exception effect
-- TODO: move `error` and `todo` to here.
-def catch (f:() -> {Except|eff} a) -> {|eff} Maybe a given (a, eff) =
+def catch(f:() -> {Except|eff} a) -> {|eff} Maybe a given (a, eff) =
f' : (() -> {Except|eff} a) = \. f()
- %catchException f'
+ %catchException(f')
-def throw () -> {Except} a given (a) =
- %throwException a
+def throw() -> {Except} a given (a) =
+ %throwException(a)
-def assert (b:Bool) -> {Except} () =
+def assert(b:Bool) -> {Except} () =
if not b then throw()
'### Misc instances that require `error`
instance Subset(a, Either(a,b)) given (a|Data, b|Data)
- def inject (x) = Left x
- def project (x) = case x of
+ def inject(x) = Left x
+ def project(x) = case x of
Left( y) -> Just y
Right(x) -> Nothing
- def unsafe_project (x) = case x of
+ def unsafe_project(x) = case x of
Left( x) -> x
Right(x) -> error "Can't project Right branch to Left branch"
instance Subset(b, Either(a,b)) given (a|Data, b|Data)
- def inject (x) = Right x
- def project (x) = case x of
+ def inject(x) = Right x
+ def project(x) = case x of
Left( x) -> Nothing
Right(y) -> Just y
- def unsafe_project (x) = case x of
+ def unsafe_project(x) = case x of
Left( x) -> error "Can't project Left branch to Right branch"
Right(x) -> x
@@ -2383,14 +2388,14 @@ instance Subset(b, Either(a,b)) given (a|Data, b|Data)
'### Index set for tables
-def int_to_reversed_digits (k:Nat) -> a=>b given (a|Ix, b|Ix) =
+def int_to_reversed_digits(k:Nat) -> a=>b given (a|Ix, b|Ix) =
base = size b
snd $ scan k \_ cur_k.
next_k = cur_k `idiv` base
digit = cur_k `mod` base
(next_k, unsafe_from_ordinal(n=b, digit))
-def reversed_digits_to_int (digits: a=>b) -> Nat given (a|Ix, b|Ix) =
+def reversed_digits_to_int(digits: a=>b) -> Nat given (a|Ix, b|Ix) =
base = size b
fst $ fold (0, 1) \j pair.
(cur_k, cur_base) = pair
@@ -2398,30 +2403,30 @@ def reversed_digits_to_int (digits: a=>b) -> Nat given (a|Ix, b|Ix) =
next_base = cur_base * base
(next_k, next_base)
-instance Ix (a=>b) given (a|Ix, b|Ix)
+instance Ix(a=>b) given (a|Ix, b|Ix)
-- 0@a is the least significant digit,
-- while (size a - 1)@a is the most significant digit.
- def size' () = size b `intpow` size a
- def ordinal (i) = reversed_digits_to_int i
- def unsafe_from_ordinal (i) = int_to_reversed_digits i
+ def size'() = size b `intpow` size a
+ def ordinal(i) = reversed_digits_to_int i
+ def unsafe_from_ordinal(i) = int_to_reversed_digits i
-instance NonEmpty (a=>b) given (a|Ix, b|NonEmpty)
+instance NonEmpty(a=>b) given (a|Ix, b|NonEmpty)
first_ix = unsafe_from_ordinal 0
'### stack
-- TODO: replace `DynBuffer` with this?
-def Stack (h:Heap, a|Data) = Ref(h, (Nat, List a))
+def Stack(h:Heap, a|Data) = Ref(h, (Nat, List a))
-def stack_size (stack:Stack(h, a)) -> {State h} Nat given (h, a) = get $ fst_ref stack
+def stack_size(stack:Stack(h, a)) -> {State h} Nat given (h, a) = get $ fst_ref stack
-def unsafe_get_stack_buffer (stack:Stack(h, a)) -> {State h} (Ref(h, Fin 0 => a)) given (h, a|Data) =
+def unsafe_get_stack_buffer(stack:Stack(h, a)) -> {State h} (Ref(h, Fin 0 => a)) given (h, a|Data) =
get $ snd_ref $ unsafe_coerce(to=Ref h (Nat, Ref h (Fin 0 => a)), snd_ref stack)
-def stack_buf_size (stack:Stack(h, a)) -> {State h} Nat given (h, a|Data) =
+def stack_buf_size(stack:Stack(h, a)) -> {State h} Nat given (h, a|Data) =
get $ fst_ref $ unsafe_coerce(to=Ref h (Nat, Ref h (Fin 0 => a)), snd_ref stack)
-def ensure_size_at_least (stack:Stack(h, a), req_size:Nat) -> {State h} () given (h, a|Data) =
+def ensure_size_at_least(stack:Stack(h, a), req_size:Nat) -> {State h} () given (h, a|Data) =
if req_size > stack_buf_size stack then
-- TODO: maybe this should use integer arithmetic?
new_buf_size = f_to_n $ 2.0 `pow` (ceil $ log2 $ n_to_f req_size)
@@ -2433,13 +2438,13 @@ def ensure_size_at_least (stack:Stack(h, a), req_size:Nat) -> {State h} () given
Just(i') -> cur_data[i']
Nothing -> uninitialized_value()
-def read_stack (stack:Stack(h, a)) -> {State h} (List a) given (h, a|Data) =
+def read_stack(stack:Stack(h, a)) -> {State h} (List a) given (h, a|Data) =
n = stack_size stack
buf = unsafe_coerce(to=Ref(h, Fin n => a), unsafe_get_stack_buffer stack)
AsList(n, get buf)
@noinline
-def stack_push (stack:Stack(h, a), x:a) -> {State h} () given (a|Data, h) =
+def stack_push(stack:Stack(h, a), x:a) -> {State h} () given (a|Data, h) =
n_old = stack_size stack
n_new = n_old + 1
ensure_size_at_least(stack, n_new)
@@ -2448,7 +2453,7 @@ def stack_push (stack:Stack(h, a), x:a) -> {State h} () given (a|Data, h) =
fst_ref stack := n_new
@noinline
-def stack_extend (stack:Stack(h, a), x:n=>a) -> {State h} () given (a|Data, n|Ix, h) =
+def stack_extend(stack:Stack(h, a), x:n=>a) -> {State h} () given (a|Data, n|Ix, h) =
n_old = stack_size stack
n_new = n_old + size n
ensure_size_at_least(stack, n_new)
@@ -2457,7 +2462,7 @@ def stack_extend (stack:Stack(h, a), x:n=>a) -> {State h} () given (a|Data, n|Ix
buf_slice := x
fst_ref stack := n_new
-def stack_pop (stack:Stack(h, a)) -> {State h} Maybe a given (a|Data, h) =
+def stack_pop(stack:Stack(h, a)) -> {State h} Maybe a given (a|Data, h) =
n_old = stack_size stack
case n_old == 0 of
True -> Nothing
@@ -2475,52 +2480,52 @@ def with_stack(
init_stack = to_list for i:(Fin stack_init_size). uninitialized_value()
with_state (0, init_stack) \ref . action ref
-def stack_extend_internal (stack:Stack(h, Char), x:Fin n=>Char) -> {State h} () given (n, h) =
+def stack_extend_internal(stack:Stack(h, Char), x:Fin n=>Char) -> {State h} () given (n, h) =
stack_extend(stack, x)
-def stack_push_internal (stack:Stack(h, Char), x:Char) -> {State h} () given (h) =
+def stack_push_internal(stack:Stack(h, Char), x:Char) -> {State h} () given (h) =
stack_push(stack, x)
-def with_stack_internal (f:(given (h:Heap), Stack(h, Char)) -> {State h} ()) -> List Char =
+def with_stack_internal(f:(given (h:Heap), Stack(h, Char)) -> {State h} ()) -> List Char =
with_stack Char \stack.
f stack
read_stack stack
-def show_any (x:a) -> String given (a) = unsafe_coerce(to=String, %showAny(x))
+def show_any(x:a) -> String given (a) = unsafe_coerce(to=String, %showAny(x))
-def coerce_table (m|Ix, x:n=>a) -> m => a given (n|Ix, a|Data) =
+def coerce_table(m|Ix, x:n=>a) -> m => a given (n|Ix, a|Data) =
if size m == size n
then unsafe_coerce(to=m=>a, x)
else error "mismatched sizes in table coercion"
'### Linear Algebra
-def linspace (n|Ix, low:Float, high:Float) -> n=>Float =
+def linspace(n|Ix, low:Float, high:Float) -> n=>Float =
dx = (high - low) / n_to_f (size n)
for i:n. low + n_to_f (ordinal i) * dx
-def transpose (x:n=>m=>a) -> m=>n=>a given (n|Ix, m|Ix, a) = for i j. x[j,i]
-def vdot (x:n=>Float, y:n=>Float) -> Float given (n|Ix) = fsum for i. x[i] * y[i]
-def dot (s:n=>Float, vs:n=>v) -> v given (n|Ix, v|VSpace) = sum for j. s[j] .* vs[j]
+def transpose(x:n=>m=>a) -> m=>n=>a given (n|Ix, m|Ix, a) = for i j. x[j,i]
+def vdot(x:n=>Float, y:n=>Float) -> Float given (n|Ix) = fsum for i. x[i] * y[i]
+def dot(s:n=>Float, vs:n=>v) -> v given (n|Ix, v|VSpace) = sum for j. s[j] .* vs[j]
-def naive_matmul (x: l=>m=>Float, y: m=>n=>Float) -> (l=>n=>Float) given (l|Ix, m|Ix, n|Ix) =
+def naive_matmul(x: l=>m=>Float, y: m=>n=>Float) -> (l=>n=>Float) given (l|Ix, m|Ix, n|Ix) =
for i k. fsum for j. x[i,j] * y[j,k]
-- A `FullTileIx` type represents `tile_ix`th full tile (of size
-- `tile_size`) iterating over the index set `n`.
-- This type is only well formed when tile_ix * tile_size < size n.
-struct FullTileIx (n|Ix, tile_size:Nat, tile_ix:Nat) =
+struct FullTileIx(n|Ix, tile_size:Nat, tile_ix:Nat) =
unwrap : Fin tile_size
-instance Ix (FullTileIx(n, tile_size, tile_ix)) given (n|Ix, tile_size:Nat, tile_ix:Nat)
- def size' () = tile_size
- def ordinal (i) = ordinal i.unwrap
- def unsafe_from_ordinal (i) =FullTileIx.new $ unsafe_from_ordinal i
+instance Ix(FullTileIx(n, tile_size, tile_ix)) given (n|Ix, tile_size:Nat, tile_ix:Nat)
+ def size'() = tile_size
+ def ordinal(i) = ordinal i.unwrap
+ def unsafe_from_ordinal(i) =FullTileIx.new $ unsafe_from_ordinal i
instance Subset(FullTileIx(n, tile_size, tile_ix), n) given (n|Ix, tile_size:Nat, tile_ix:Nat)
- def inject (i) = unsafe_from_ordinal $ tile_size * tile_ix + ordinal i.unwrap
- def project (i) = todo
- def unsafe_project (i) = todo
+ def inject(i) = unsafe_from_ordinal $ tile_size * tile_ix + ordinal i.unwrap
+ def project(i) = todo
+ def unsafe_project(i) = todo
-- A `CodaIx` type represents the last few elements of the index set `n`,
-- as might be left over after iterating by tiles.
@@ -2528,13 +2533,13 @@ instance Subset(FullTileIx(n, tile_size, tile_ix), n) given (n|Ix, tile_size:Nat
struct CodaIx(n|Ix, coda_offset:Nat, coda_size:Nat) =
unwrap : Fin coda_size
-instance Ix (CodaIx(n, coda_offset, coda_size)) given (n|Ix, coda_offset:Nat, coda_size:Nat)
- def size' () = coda_size
- def ordinal (i) = ordinal i.unwrap
- def unsafe_from_ordinal (i) = CodaIx.new $ unsafe_from_ordinal i
+instance Ix(CodaIx(n, coda_offset, coda_size)) given (n|Ix, coda_offset:Nat, coda_size:Nat)
+ def size'() = coda_size
+ def ordinal(i) = ordinal i.unwrap
+ def unsafe_from_ordinal(i) = CodaIx.new $ unsafe_from_ordinal i
instance Subset(CodaIx(n, coda_offset, coda_size), n) given (n|Ix, coda_offset:Nat, coda_size:Nat)
- def inject (i) = unsafe_from_ordinal $ coda_offset + ordinal i.unwrap
+ def inject(i) = unsafe_from_ordinal $ coda_offset + ordinal i.unwrap
def project(i) = todo
def unsafe_project(i) = todo
@@ -2572,15 +2577,15 @@ def (**)(
m_ix = inject m_offset
result!l_ix!n_ix += x[l_ix,m_ix] * y[m_ix,n_ix]
-def (**.) (mat: n=>m=>Float, v: m=>Float) -> (n=>Float) given (n|Ix, m|Ix) =
+def (**.)(mat: n=>m=>Float, v: m=>Float) -> (n=>Float) given (n|Ix, m|Ix) =
for i. vdot(mat[i], v)
-def (.**) (v: m=>Float, mat: n=>m=>Float) -> (n=>Float) given (n|Ix, m|Ix) =
+def(.**)(v: m=>Float, mat: n=>m=>Float) -> (n=>Float) given (n|Ix, m|Ix) =
mat **. v
-def inner (x:n=>Float, mat:n=>m=>Float, y:m=>Float) -> Float given (n|Ix, m|Ix) =
+def inner(x:n=>Float, mat:n=>m=>Float, y:m=>Float) -> Float given (n|Ix, m|Ix) =
fsum for p.
(i,j) = p
x[i] * mat[i,j] * y[j]
-def eye () ->> n=>n=>a given (n|Ix, a|Add|Mul) =
+def eye() ->> n=>n=>a given (n|Ix, a|Add|Mul) =
for i j. select(ordinal i == ordinal j, one, zero)
diff --git a/lib/sort.dx b/lib/sort.dx
index e6a7c77b..1178dd0c 100644
--- a/lib/sort.dx
+++ b/lib/sort.dx
@@ -12,58 +12,60 @@ it's doing bubble / insertion sort with quadratic time cost.
However, if it breaks the list in half recursively, it'll be doing parallel mergesort.
Currently the Dex compiler will do the quadratic-time version.
-def concat_table {a b v} (leftin: a=>v) (rightin: b=>v) : ((a|b)=>v) =
+def concat_table(leftin: a=>v, rightin: b=>v) -> (Either a b=>v) given (a|Ix, b|Ix, v) =
for idx. case idx of
- Left i -> leftin.i
- Right i -> rightin.i
+ Left i -> leftin[i]
+ Right i -> rightin[i]
-def merge_sorted_tables {a m n} [Ord a] (xs:m=>a) (ys:n=>a) : ((m|n)=>a) =
+def merge_sorted_tables(xs:m=>a, ys:n=>a) -> (Either m n=>a) given (a|Ord, m|Ix, n|Ix) =
-- Possible improvements:
-- 1) Using a SortedTable type.
-- 2) Avoid needlessly initializing the return array.
init = concat_table xs ys -- Initialize array of correct size.
yield_state init \buf.
with_state (0, 0) \countrefs.
- for i:(m|n).
+ for i:(Either m n).
(cur_x, cur_y) = get countrefs
if cur_y >= size n -- no ys left
then
countrefs := (cur_x + 1, cur_y)
- buf!i := xs.(unsafe_from_ordinal _ cur_x)
+ buf!i := xs[unsafe_from_ordinal cur_x]
else
if cur_x < size m -- still xs left
- then
- if xs.(unsafe_from_ordinal _ cur_x) <= ys.(unsafe_from_ordinal _ cur_y)
+ then
+ if xs[unsafe_from_ordinal cur_x] <= ys[unsafe_from_ordinal cur_y]
then
countrefs := (cur_x + 1, cur_y)
- buf!i := xs.(unsafe_from_ordinal _ cur_x)
+ buf!i := xs[unsafe_from_ordinal cur_x]
else
countrefs := (cur_x, cur_y + 1)
- buf!i := ys.(unsafe_from_ordinal _ cur_y)
+ buf!i := ys[unsafe_from_ordinal cur_y]
-def merge_sorted_lists {a} [Ord a] (AsList nx xs: List a) (AsList ny ys: List a) : List a =
+def merge_sorted_lists(lx: List a, ly: List a) -> List a given (a|Ord) =
-- Need this wrapper because Dex can't automatically weaken
-- (a | b)=>c to ((Fin d)=>c)
+ AsList(nx, xs) = lx
+ AsList(ny, ys) = ly
sorted = merge_sorted_tables xs ys
newsize = nx + ny
- AsList _ $ unsafe_cast_table (Fin newsize) sorted
+ AsList _ $ unsafe_cast_table(to=Fin newsize, sorted)
-def sort {a n} [Ord a] (xs: n=>a) : n=>a =
+def sort(xs: n=>a) -> n=>a given (n|Ix, a|Ord) =
-- Warning: Has quadratic runtime cost for now.
- xlists = for i:n. (AsList 1 [xs.i])
+ xlists = for i:n. (AsList 1 [xs[i]])
-- Merge sort monoid:
mempty = AsList 0 []
mcombine = merge_sorted_lists
-- reduce might someday recursively subdivide the problem.
- (AsList _ r) = reduce mempty mcombine xlists
- unsafe_cast_table n r
+ AsList(_, r) = reduce mempty mcombine xlists
+ unsafe_cast_table(to=n, r)
-def (+|) {n} [Ix n] (i:n) (delta:Nat) : n =
+def (+|)(i:n, delta:Nat) -> n given (n|Ix) =
i' = ordinal i + delta
- from_ordinal _ $ select (i' >= size n) (size n -| 1) i'
+ from_ordinal $ select (i' >= size n) (size n -| 1) i'
-def is_sorted {a n} [Ord a] (xs:n=>a) : Bool =
- all for i. xs.i <= xs.(i +| 1)
+def is_sorted(xs:n=>a) -> Bool given (a|Ord, n|Ix) =
+ all for i. xs[i] <= xs[i +| 1]
diff --git a/makefile b/makefile
index 650b7b34..e4340598 100644
--- a/makefile
+++ b/makefile
@@ -221,8 +221,8 @@ endif
test-names = uexpr-tests print-tests adt-tests type-tests struct-tests cast-tests eval-tests show-tests \
read-tests shadow-tests monad-tests io-tests exception-tests sort-tests \
- standalone-function-tests \
- ad-tests parser-tests serialize-tests parser-combinator-tests \
+ parser-tests standalone-function-tests \
+ ad-tests serialize-tests parser-combinator-tests \
record-variant-tests typeclass-tests complex-tests trig-tests \
linalg-tests set-tests fft-tests stats-tests stack-tests
diff --git a/src/lib/ConcreteSyntax.hs b/src/lib/ConcreteSyntax.hs
index 8a7c7533..c5fecf4c 100644
--- a/src/lib/ConcreteSyntax.hs
+++ b/src/lib/ConcreteSyntax.hs
@@ -238,8 +238,7 @@ structDef :: Parser CTopDecl
structDef = withSrc do
keyWord StructKW
tyName <- anyName
- params <- explicitParams
- sym "="
+ params <- typeParams
fields <- onePerLine nameAndType
return $ CStruct tyName params fields
@@ -247,11 +246,10 @@ dataDef :: Parser CTopDecl
dataDef = withSrc do
keyWord DataKW
tyName <- anyName
- params <- explicitParams
- sym "="
+ params <- typeParams
dataCons <- onePerLine do
dataConName <- anyName
- dataConArgs <- explicitParams
+ dataConArgs <- optExplicitParams
return (dataConName, dataConArgs)
return $ CData tyName params dataCons
@@ -273,9 +271,6 @@ interfaceDef = withSrc do
return (methodName, ty)
return $ CInterface className params methodDecls
--- superclassConstraints :: Parser [Group]
--- superclassConstraints = optionalMonoid $ brackets $ cNames `sepBy` sym ","
-
effectDef :: Parser CTopDecl
effectDef = withSrc do
keyWord EffectKW
@@ -364,7 +359,7 @@ instanceDef isNamed = do
args <- (sym ":" >> return Nothing)
<|> ((Just <$> parens (commaSep cParenGroup)) <* sym "->")
return $ Just (name, args)
- className <- anyNameNoSC
+ className <- anyName
args <- argList
givens <- optional givenClause
methods <- realBlock
@@ -375,7 +370,7 @@ funDefLet = label "function definition" do
keyWord DefKW
mayBreak do
name <- anyName
- params <- parens (commaSep cParenGroup)
+ params <- explicitParams
rhs <- optional do
expl <- explicitness
effs <- optional cEffs
@@ -391,20 +386,39 @@ explicitness :: Parser AppExplicitness
explicitness = (sym "->" $> ExplicitApp)
<|> (sym "->>" $> ImplicitApp)
-anyNameNoSC :: Parser SourceName
-anyNameNoSC = withConsumption ConsumeNothing $ anyName
-
-- Intended for occurrences, like `foo(x, y, z)` (cf. defParamsList).
--- Previous lexeme shouldn't consume trailing whitespace.
argList :: Parser [Group]
-argList = nextChar >>= \case
- '(' -> parens (commaSep cGroup)
- _ -> do
- arg <- sc >> leafGroup
- return [arg]
+argList = immediateParens (commaSep cParenGroup)
+
+immediateLParen :: Parser ()
+immediateLParen = label "'(' (without preceding whitespace)" do
+ nextChar >>= \case
+ '(' -> precededByWhitespace >>= \case
+ True -> empty
+ False -> charLexeme '('
+ _ -> empty
+
+immediateParens :: Parser a -> Parser a
+immediateParens p = bracketed immediateLParen rParen p
+
+-- Putting `sym =` inside the cases gives better errors.
+typeParams :: Parser ExplicitParams
+typeParams =
+ (explicitParams <* sym "=")
+ <|> (return [] <* sym "=")
+
+optExplicitParams :: Parser ExplicitParams
+optExplicitParams = label "optional parameter list" $
+ explicitParams <|> return []
explicitParams :: Parser ExplicitParams
-explicitParams = label "parameter list" $ parens (commaSep cGroup) <|> return []
+explicitParams = label "parameter list in parentheses (without preceding whitespace)" $
+ immediateParens $ commaSep cParenGroup
+
+noGap :: Parser ()
+noGap = precededByWhitespace >>= \case
+ True -> fail "Unexpected whitespace"
+ False -> return ()
givenClause :: Parser GivenClause
givenClause = keyWord GivenKW >> do
@@ -571,22 +585,20 @@ noElse msg = (optional $ try $ sc >> keyWord ElseKW) >>= \case
leafGroup :: Parser Group
leafGroup = do
- leaf <- leafGroupNoSC
- postOps <- many postfixGroupNoSC <* sc
+ leaf <- leafGroup'
+ postOps <- many postfixGroup
return $ foldl (\accum (op, opLhs) -> joinSrc accum opLhs $ CBin (WithSrc Nothing op) accum opLhs) leaf postOps
where
- -- These "noSC" functions don't consume trailing whitespace. We want to parse
- -- things like `f(x,y).foo(z)[q]`.
- leafGroupNoSC :: Parser Group
- leafGroupNoSC = withSrc do
+ leafGroup' :: Parser Group
+ leafGroup' = withSrc do
next <- nextChar
case next of
- '_' -> noSC $ underscore $> CHole
+ '_' -> underscore $> CHole
'(' -> (CIdentifier <$> symName)
- <|> cParensNoSC
- '[' -> cBracketsNoSC
- '{' -> CBraces <$> bracketNoSC '{' '}' (commaSep cGroup)
+ <|> cParens
+ '[' -> cBrackets
+ '{' -> cBraces
'\"' -> CString <$> strLit
'\'' -> CChar <$> charLit
'%' -> do
@@ -598,33 +610,30 @@ leafGroup = do
<|> CFloat <$> doubleLit)
'\\' -> cNullaryLam <|> cLam
-- For exprs include for, rof, for_, rof_
- 'f' -> cFor <|> cIdentifierNoSC
- 'd' -> cDo <|> cIdentifierNoSC
- 'r' -> cFor <|> cIdentifierNoSC
- 'c' -> cCase <|> cIdentifierNoSC
- 'i' -> cIf <|> cIdentifierNoSC
- _ -> cIdentifierNoSC
-
- noSC :: Parser a -> Parser a
- noSC p = withConsumption ConsumeNothing p
-
- postfixGroupNoSC :: Parser (Bin', Group)
- postfixGroupNoSC =
- ((JuxtaposeNoSpace,) <$> withSrc cParensNoSC)
- <|> ((JuxtaposeNoSpace,) <$> withSrc cBracketsNoSC)
- <|> ((Dot,) <$> (try $ char '.' >> withSrc cIdentifierNoSC))
-
- bracketNoSC :: Char -> Char -> Parser a -> Parser a
- bracketNoSC l r p = charLexeme l >> mayBreak (sc >> p) <* char r
-
- cIdentifierNoSC :: Parser Group'
- cIdentifierNoSC = noSC $ CIdentifier <$> anyName
-
- cParensNoSC :: Parser Group'
- cParensNoSC = CParens <$> bracketNoSC '(' ')' (cParenGroup `sepBy` sym ",")
-
- cBracketsNoSC :: Parser Group'
- cBracketsNoSC = CBrackets <$> bracketNoSC '[' ']' (commaSep cGroup)
+ 'f' -> cFor <|> cIdentifier
+ 'd' -> cDo <|> cIdentifier
+ 'r' -> cFor <|> cIdentifier
+ 'c' -> cCase <|> cIdentifier
+ 'i' -> cIf <|> cIdentifier
+ _ -> cIdentifier
+
+ postfixGroup :: Parser (Bin', Group)
+ postfixGroup = noGap >>
+ ((JuxtaposeNoSpace,) <$> withSrc cParens)
+ <|> ((JuxtaposeNoSpace,) <$> withSrc cBrackets)
+ <|> ((Dot,) <$> (try $ char '.' >> withSrc cIdentifier))
+
+cIdentifier :: Parser Group'
+cIdentifier = CIdentifier <$> anyName
+
+cParens :: Parser Group'
+cParens = CParens <$> parens (commaSep cParenGroup)
+
+cBrackets :: Parser Group'
+cBrackets = CBrackets <$> brackets (commaSep cGroup)
+
+cBraces :: Parser Group'
+cBraces = CBrackets <$> braces (commaSep cGroup)
-- A `PrecTable` is enough information to (i) remove or replace
-- operators for special contexts, and (ii) build the input structure
diff --git a/src/lib/Lexing.hs b/src/lib/Lexing.hs
index 4320b8c0..f49d3d1d 100644
--- a/src/lib/Lexing.hs
+++ b/src/lib/Lexing.hs
@@ -1,6 +1,12 @@
+-- Copyright 2022 Google LLC
+--
+-- Use of this source code is governed by a BSD-style
+-- license that can be found in the LICENSE file or at
+-- https://developers.google.com/open-source/licenses/bsd
+
module Lexing where
-import Control.Monad.Reader
+import Control.Monad.State.Strict
import Data.Char
import Data.HashSet qualified as HS
import qualified Data.Scientific as Scientific
@@ -20,17 +26,19 @@ import Err
import LabeledItems
import Types.Primitives
-data ParseCtx = ParseCtx { curIndent :: Int
- , whitespaceConsumption :: WhitespaceConsumption }
-type Parser = ReaderT ParseCtx (Parsec Void Text)
+data ParseCtx = ParseCtx
+ { curIndent :: Int -- used Reader-style (i.e. ask/local)
+ , canBreak :: Bool -- used Reader-style (i.e. ask/local)
+ , prevWhitespace :: Bool -- tracks whether we just consumed whitespace
+ }
+
+initParseCtx :: ParseCtx
+initParseCtx = ParseCtx 0 False False
-data WhitespaceConsumption =
- ConsumeNothing
- | ConsumeAllExceptLineBreaks
- | ConsumeAll
+type Parser = StateT ParseCtx (Parsec Void Text)
parseit :: Text -> Parser a -> Except a
-parseit s p = case parse (runReaderT p (ParseCtx 0 ConsumeAllExceptLineBreaks)) "" s of
+parseit s p = case parse (fst <$> runStateT p initParseCtx) "" s of
Left e -> throw ParseErr $ errorBundlePretty e
Right x -> return x
@@ -41,8 +49,8 @@ mustParseit s p = case parseit s p of
debug :: (Show a) => String -> Parser a -> Parser a
debug lbl action = do
- ctx <- ask
- lift $ dbg lbl $ runReaderT action ctx
+ ctx <- get
+ lift $ dbg lbl $ fst <$> runStateT action ctx
-- === Lexemes ===
@@ -160,11 +168,7 @@ knownSymStrs = HS.fromList
-- string must be in `knownSymStrs`
sym :: Text -> Lexer ()
-sym s = lexeme $ try $ string s >> notFollowedBy continuation
- -- This awful hack is because "|}" should be a single lexeme for
- -- variants, but we can't make "}" be a symChar because that would
- -- allow user-defined symbols like --}, which is horrible.
- where continuation = if s == "|" then (symChar <|> char '}') else symChar
+sym s = lexeme $ try $ string s >> notFollowedBy symChar
anySym :: Lexer String
anySym = lexeme $ try $ do
@@ -209,7 +213,8 @@ symChars = HS.fromList ".,!$^&*:-~+/=<>|?\\@#"
-- === Util ===
sc :: Parser ()
-sc = skipMany $ hidden space <|> hidden lineComment
+sc = (skipSome s >> recordWhitespace) <|> return ()
+ where s = hidden space <|> hidden lineComment
lineComment :: Parser ()
lineComment = do
@@ -220,33 +225,29 @@ outputLines :: Parser ()
outputLines = void $ many (symbol ">" >> takeWhileP Nothing (/= '\n') >> ((eol >> return ()) <|> eof))
space :: Parser ()
-space = asks whitespaceConsumption >>= \case
- ConsumeNothing -> fail ""
- ConsumeAllExceptLineBreaks -> void $ takeWhile1P (Just "white space") (`elem` (" \t" :: String))
- ConsumeAll -> space1
-
-withConsumption :: WhitespaceConsumption -> Parser a -> Parser a
-withConsumption c p = localConsumption p \_ -> c
-{-# INLINE withConsumption #-}
-
-localConsumption :: Parser a -> (WhitespaceConsumption -> WhitespaceConsumption) -> Parser a
-localConsumption p update =
- local (\ctx -> ctx { whitespaceConsumption = update $ whitespaceConsumption ctx }) p
-{-# INLINE localConsumption #-}
+space = gets canBreak >>= \case
+ True -> space1
+ False -> void $ takeWhile1P (Just "white space") (`elem` (" \t" :: String))
mayBreak :: Parser a -> Parser a
-mayBreak = withConsumption ConsumeAll
+mayBreak p = pLocal (\ctx -> ctx { canBreak = True }) p
{-# INLINE mayBreak #-}
mayNotBreak :: Parser a -> Parser a
-mayNotBreak p = localConsumption p \case
- ConsumeAll -> ConsumeAllExceptLineBreaks
- c -> c
+mayNotBreak p = pLocal (\ctx -> ctx { canBreak = False }) p
{-# INLINE mayNotBreak #-}
-optionalMonoid :: Monoid a => Parser a -> Parser a
-optionalMonoid p = p <|> return mempty
-{-# INLINE optionalMonoid #-}
+precededByWhitespace :: Parser Bool
+precededByWhitespace = gets prevWhitespace
+{-# INLINE precededByWhitespace #-}
+
+recordWhitespace :: Parser ()
+recordWhitespace = modify \ctx -> ctx { prevWhitespace = True }
+{-# INLINE recordWhitespace #-}
+
+recordNonWhitespace :: Parser ()
+recordNonWhitespace = modify \ctx -> ctx { prevWhitespace = False }
+{-# INLINE recordNonWhitespace #-}
nameString :: Parser String
nameString = lexeme . try $ (:) <$> lowerChar <*> many alphaNumChar
@@ -281,7 +282,7 @@ withPos p = do
nextLine :: Parser ()
nextLine = do
eol
- n <- asks curIndent
+ n <- curIndent <$> get
void $ mayNotBreak $ many $ try (sc >> eol)
void $ replicateM n (char ' ')
@@ -297,9 +298,15 @@ withIndent p = do
nextLine
indent <- T.length <$> takeWhileP (Just "space") (==' ')
when (indent <= 0) empty
- local (\ctx -> ctx { curIndent = curIndent ctx + indent }) $ p
+ pLocal (\ctx -> ctx { curIndent = curIndent ctx + indent }) $ p
{-# INLINE withIndent #-}
+pLocal :: (ParseCtx -> ParseCtx) -> Parser a -> Parser a
+pLocal f p = do
+ s <- get
+ put (f s) >> p <* put s
+{-# INLINE pLocal #-}
+
eol :: Parser ()
eol = void MC.eol
@@ -311,7 +318,7 @@ failIf True s = fail s
failIf False _ = return ()
lexeme :: Parser a -> Parser a
-lexeme = L.lexeme sc
+lexeme p = L.lexeme sc (p <* recordNonWhitespace)
{-# INLINE lexeme #-}
symbol :: Text -> Parser ()
diff --git a/tests/adt-tests.dx b/tests/adt-tests.dx
index 4a817de7..e9a5aff9 100644
--- a/tests/adt-tests.dx
+++ b/tests/adt-tests.dx
@@ -51,7 +51,7 @@ MkMyPair(z1, MkMyPair(z2, z3)) = zz
> [(1, 1), (3, 2), (5, 3)]
data MyEither(a:Type, b:Type) =
- MyLeft (a)
+ MyLeft(a)
MyRight(b)
x : MyEither Int Float = MyLeft 1
@@ -189,7 +189,7 @@ AsList(_, xsTab) = xsList
(x, y)
> (3, 4)
-def catLists (xs:List a, ys:List a) -> List a given (a) =
+def catLists(xs:List a, ys:List a) -> List a given (a) =
AsList(nx, xs') = xs
AsList(ny, ys') = ys
nz = nx + ny
@@ -215,7 +215,7 @@ def catLists (xs:List a, ys:List a) -> List a given (a) =
-def listToTable2 (l: List a) -> (Fin (list_length l))=>a given (a) =
+def listToTable2(l: List a) -> (Fin (list_length l))=>a given (a) =
AsList(_, xs) = l
xs
@@ -231,7 +231,7 @@ l2 = AsList _ [1, 2, 3]
:p sum $ listToTable2 l2
> 6
-def zerosLikeList (l : List a) -> (Fin (list_length l))=>Float given (a) =
+def zerosLikeList(l: List a) -> (Fin (list_length l))=>Float given (a) =
for i:(Fin $ list_length l). 0.0
:p zerosLikeList l2
@@ -240,7 +240,7 @@ def zerosLikeList (l : List a) -> (Fin (list_length l))=>Float given (a) =
data Graph(n|Ix, a:Type) =
MkGraph(nodes:(n=>a), edges:(List (n, n)))
-def graphToAdjacencyMatrix (g:Graph n a) -> n=>n=>Bool given (n|Ix, a) =
+def graphToAdjacencyMatrix(g:Graph n a) -> n=>n=>Bool given (n|Ix, a) =
MkGraph(nodes, AsList(_, edges)) = g
init = for i j. False
yield_state init \mRef.
@@ -271,7 +271,7 @@ data MySum =
data MySum2 =
Foo2
- Bar2 (Fin 3 => Int)
+ Bar2(Fin 3 => Int)
-- bug #348
:p concat for i:(Fin 4). AsList _ [(Foo2, Foo2)]
diff --git a/tests/eval-tests.dx b/tests/eval-tests.dx
index 474c79bc..73d1a541 100644
--- a/tests/eval-tests.dx
+++ b/tests/eval-tests.dx
@@ -136,7 +136,7 @@ s = 1.0
for i. xs[i] + 10
> [12, 11, 10]
-def cumsumplus (xs:n=>Float) -> n=>Float given (n|Ix) =
+def cumsumplus(xs:n=>Float) -> n=>Float given (n|Ix) =
snd $ scan 0.0 \i c.
ans = c + xs[i]
(ans, 1.0 + ans)
@@ -144,7 +144,7 @@ def cumsumplus (xs:n=>Float) -> n=>Float given (n|Ix) =
:p cumsumplus [1.0, 2.0, 3.0]
> [2., 4., 7.]
-def cumsum2d (xs:n=>m=>a) -> n=>m=>a given (n|Ix, m|Ix, a|Add) =
+def cumsum2d(xs:n=>m=>a) -> n=>m=>a given (n|Ix, m|Ix, a|Add) =
cumsum for i. cumsum xs[i]
xs : (Fin 4)=>(Fin 3)=>Float = arb $ new_key 0
@@ -520,7 +520,7 @@ count = (w32_to_w64 (i_to_w32 608135816) .<<. 32) .|. (w32_to_w64 (i_to_w32 2242
c
> ([3., 2., 1., 0.], 4.)
-def eitherFloor (x:(Int `Either` Float)) -> Int = case x of
+def eitherFloor(x:(Int `Either` Float)) -> Int = case x of
Left i -> i
Right f -> f_to_i $ floor f
@@ -569,7 +569,7 @@ Just 1 == Just 0
-- Userspace-Bool breaks these
-- -- Needed to avoid ambiguous type variables if both sides use the same constructor
--- def cmpEither (x:Int|Int) (y:Int|Int) : Bool = x == y
+-- def cmpEither(x:Int|Int) (y:Int|Int) : Bool = x == y
-- :p cmpEither (Left 1) (Left 1)
-- > True
@@ -708,7 +708,7 @@ def newtonSolve(tol:Float, f: (Float)->Float, x0:Float) -> Float =
(f 5, w)
> (7, 2.)
--- def add (n : Type) ?-> (a : n=>Float) (b : n=>Float) : n=>Float =
+-- def add(n : Type) ?-> (a : n=>Float) (b : n=>Float) : n=>Float =
-- (tile (\t:(Tile n (Fin VectorWidth)). storeVector $ loadTile t a + loadTile t b)
-- (\i:n. a.i + b.i))
-- toAdd = for _:(Fin 10). 1.0
@@ -809,7 +809,7 @@ s1 = "hello world"
-- > [[0], [1, 2], [3, 4, 5], [6, 7, 8, 9]]
-- @noinline
-def fromLeftFloat (x:(Float `Either` Int)) -> Float =
+def fromLeftFloat(x:(Float `Either` Int)) -> Float =
case x of
Left x' -> x'
Right _ -> error "this is an error"
@@ -822,27 +822,27 @@ def fromLeftFloat (x:(Float `Either` Int)) -> Float =
> 1.2
@noinline
-def f1 (x:Int) -> Int = x + 1
+def f1(x:Int) -> Int = x + 1
@noinline
-def f2 (x:Int) -> Int = f1 $ f1 $ f1 $ f1 $ f1 $ f1 $ f1 $ f1 $ f1 $ f1 $ x
+def f2(x:Int) -> Int = f1 $ f1 $ f1 $ f1 $ f1 $ f1 $ f1 $ f1 $ f1 $ f1 $ x
@noinline
-def f3 (x:Int) -> Int = f2 $ f2 $ f2 $ f2 $ f2 $ f2 $ f2 $ f2 $ f2 $ f2 $ x
+def f3(x:Int) -> Int = f2 $ f2 $ f2 $ f2 $ f2 $ f2 $ f2 $ f2 $ f2 $ f2 $ x
@noinline
-def f4 (x:Int) -> Int = f3 $ f3 $ f3 $ f3 $ f3 $ f3 $ f3 $ f3 $ f3 $ f3 $ x
+def f4(x:Int) -> Int = f3 $ f3 $ f3 $ f3 $ f3 $ f3 $ f3 $ f3 $ f3 $ f3 $ x
@noinline
-def f5 (x:Int) -> Int = f4 $ f4 $ f4 $ f4 $ f4 $ f4 $ f4 $ f4 $ f4 $ f4 $ x
+def f5(x:Int) -> Int = f4 $ f4 $ f4 $ f4 $ f4 $ f4 $ f4 $ f4 $ f4 $ f4 $ x
@noinline
-def f6 (x:Int) -> Int = f5 $ f5 $ f5 $ f5 $ f5 $ f5 $ f5 $ f5 $ f5 $ f5 $ x
+def f6(x:Int) -> Int = f5 $ f5 $ f5 $ f5 $ f5 $ f5 $ f5 $ f5 $ f5 $ f5 $ x
-- regression test for #1229 - checks that constraints and data args can be interleaved
@noinline
-def interleave_args (x:Float, xs:n=>Float) -> Float given (n|Ix) =
+def interleave_args(x:Float, xs:n=>Float) -> Float given (n|Ix) =
x + sum xs
interleave_args(1.0, [2.0, 3.0])
@@ -872,7 +872,7 @@ interleave_args(1.0, [2.0, 3.0])
:p intpow 2 17
> 131072
-def f (x:Float) -> Float =
+def f(x:Float) -> Float =
intpow x 2
x2 = 3.1
@@ -908,7 +908,7 @@ interface AssociatedWithTwo(a:Type, c:Type)
value2 : () -> Int
instance AssociatedWithTwo(Int, Float)
- def value2 () = 8
+ def value2() = 8
-- TODO: the source location for this error message is a bit off since we added
-- n-ary applications.
@@ -959,7 +959,7 @@ all $ for i:(Maybe (Fin 2)).
i == (unsafe_from_ordinal (ordinal i))
> True
-def all_prefixes (s:String) -> List String =
+def all_prefixes(s:String) -> List String =
AsList(_, s') = s
to_list for i. post_slice s' first_ix i
@@ -970,7 +970,7 @@ def all_prefixes (s:String) -> List String =
-- -- Regression test for https://github.com/google-research/dex-lang/issues/694
-- data Thing = T {a: Float}
--- def scan_over (xs:n=>Float) -> n=>Float given (n|Ix) =
+-- def scan_over(xs:n=>Float) -> n=>Float given (n|Ix) =
-- snd $ scan (T {a=1.0}) \i t.
-- (T {a}) = t
-- (T {a=a + xs[i]}, a)
@@ -1005,19 +1005,19 @@ first_ix :: (Fin 0)
> Type error:Couldn't synthesize a class dictionary for: (NonEmpty (Fin 0))
>
> first_ix :: (Fin 0)
-> ^^^^^^^^
+> ^^^^^^^^^
first_ix :: (Fin 0 `Either` Fin 0)
> Type error:Couldn't synthesize a class dictionary for: (NonEmpty (Either (Fin 0) (Fin 0)))
>
> first_ix :: (Fin 0 `Either` Fin 0)
-> ^^^^^^^^
+> ^^^^^^^^^
first_ix :: (Fin 0, ())
> Type error:Couldn't synthesize a class dictionary for: (NonEmpty (Fin 0, ()))
>
> first_ix :: (Fin 0, ())
-> ^^^^^^^^
+> ^^^^^^^^^
-- Subset tests
@@ -1076,7 +1076,7 @@ Just f64_to_f
-- from a case (this tickles ACase colliding with the code generated
-- to print). The @noinline defeats case-of-known-constructor.
@noinline
-def id_bool (x:Bool) -> Bool = x
+def id_bool(x:Bool) -> Bool = x
case (id_bool True) of
True -> Just f64_to_f
diff --git a/tests/exception-tests.dx b/tests/exception-tests.dx
index cc1effcd..d35aaebb 100644
--- a/tests/exception-tests.dx
+++ b/tests/exception-tests.dx
@@ -1,57 +1,57 @@
-def checkFloatInUnitInterval (x:Float) : {Except} Float =
+def checkFloatInUnitInterval(x:Float) -> {Except} Float =
assert $ x >= 0.0
assert $ x <= 1.0
x
-:p catch do assert False
+:p catch \. assert False
> Nothing
-:p catch do assert True
+:p catch \. assert True
> (Just ())
-:p catch do checkFloatInUnitInterval 1.2
+:p catch \. checkFloatInUnitInterval 1.2
> Nothing
-:p catch do checkFloatInUnitInterval (-1.2)
+:p catch \. checkFloatInUnitInterval (-1.2)
> Nothing
-:p catch do checkFloatInUnitInterval 0.2
+:p catch \. checkFloatInUnitInterval 0.2
> (Just 0.2)
:p yield_state 0 \ref.
- catch do
+ catch \.
ref := 1
assert False
ref := 2
> 1
-:p catch do
+:p catch \.
for i:(Fin 5).
if ordinal i > 3
- then throw ()
+ then throw()
else 23
> Nothing
-:p catch do
+:p catch \.
for i:(Fin 3).
if ordinal i > 3
- then throw ()
+ then throw()
else 23
> (Just [23, 23, 23])
-- Is this the result we want?
:p yield_state zero \ref.
- catch do
+ catch \.
for i:(Fin 6).
if (ordinal i `rem` 2) == 0
- then throw ()
+ then throw()
else ()
ref!i := 1
> [0, 1, 0, 1, 0, 1]
-:p catch do
+:p catch \.
run_state 0 \ref.
ref := 1
assert False
@@ -59,9 +59,9 @@ def checkFloatInUnitInterval (x:Float) : {Except} Float =
> Nothing
-- https://github.com/google-research/dex-lang/issues/612
-def sashabug (h: Unit) : {Except} List Int =
+def sashabug(h: ()) -> {Except} List Int =
yield_state mempty \results.
results := (get results) <> AsList 1 [2]
-catch do (catch do sashabug ())
+catch \. (catch \. sashabug ())
> (Just (Just (AsList 1 [2])))
diff --git a/tests/monad-tests.dx b/tests/monad-tests.dx
index 92868294..6641f79e 100644
--- a/tests/monad-tests.dx
+++ b/tests/monad-tests.dx
@@ -1,21 +1,21 @@
:p
- def m (ref:Ref h Int) -> {State h} Int given (h:Heap) = get ref
+ def m(ref:Ref h Int) -> {State h} Int given (h:Heap) = get ref
run_state 2 m
> (2, 2)
:p
- def m (ref:Ref h Int) -> {State h} () given (h:Heap) = ref := 3
+ def m(ref:Ref h Int) -> {State h} () given (h:Heap) = ref := 3
run_state 0 m
> ((), 3)
:p
- def m (ref:Ref h Int) -> {Read h} Int given (h:Heap) = ask ref
+ def m(ref:Ref h Int) -> {Read h} Int given (h:Heap) = ask ref
with_reader 5 m
> 5
:p
- def stateAction (ref:Ref h Float) -> {State h} () given (h:Heap) =
+ def stateAction(ref:Ref h Float) -> {State h} () given (h:Heap) =
x = get ref
ref := (x + 2.0)
z = get ref
@@ -87,7 +87,7 @@ def myAction(w:Ref hw Float, r:Ref hr Float)
run_accum (AddMonoid Float) \w1. run_accum (AddMonoid Float) \w2. m w1 w2
> (((), 3.), 2.)
-def foom (s:Ref h ((Fin 3)=>Int)) -> {State h} () given (h:Heap) =
+def foom(s:Ref h ((Fin 3)=>Int)) -> {State h} () given (h:Heap) =
s!(from_ordinal 0) := 1
s!(from_ordinal 2) := 2
@@ -96,7 +96,7 @@ def foom (s:Ref h ((Fin 3)=>Int)) -> {State h} () given (h:Heap) =
-- TODO: handle effects returning functions
-- :p
--- def foo (x:Float) : Float =
+-- def foo(x:Float) : Float =
-- f = withReader x \r.
-- y = ask r
-- \z. 100.0 * x + 10.0 * y + z
@@ -198,7 +198,7 @@ def effectsAtZero(f: (Int)->{|eff} ()) -> {|eff} () given (eff:Effects) =
> False
-- Test custom list monoid with accum
-def adjacencyMatrixToEdgeList (mat: n=>n=>Bool) -> List (n, n) given (n|Ix) =
+def adjacencyMatrixToEdgeList(mat: n=>n=>Bool) -> List (n, n) given (n|Ix) =
yield_accum(ListMonoid((n,n))) \list.
for idxs.
(i, j) = idxs
diff --git a/tests/parser-tests.dx b/tests/parser-tests.dx
index 446f666c..cb5b312e 100644
--- a/tests/parser-tests.dx
+++ b/tests/parser-tests.dx
@@ -30,7 +30,7 @@ f = \x. x + 10.
:p f -1.0 -- parses as (-) f (-1.0)
> Type error:
-> Expected: (Float32 -> Float32)
+> Expected: ((x:Float32) -> Float32)
> Actual: Float32
>
> :p f -1.0 -- parses as (-) f (-1.0)
@@ -39,16 +39,6 @@ f = \x. x + 10.
:p f (-1.0)
> 9.
-'Lambdas can have specific arrow annotations.
-
-lam1 = \{n}. \x:(Fin n). ordinal x
-:t lam1
-> ((n:Nat) ?-> (Fin n) -> Nat)
-
-lam4 = \{n m}. (Fin n, Fin m)
-:t lam4
-> (Nat ?-> Nat ?-> (Type & Type))
-
:p (
1
+
@@ -59,7 +49,7 @@ lam4 = \{n m}. (Fin n, Fin m)
:p
xs = [1,2,3]
for i.
- if xs.i > 1
+ if xs[i] > 1
then 0
else 1
> [1, 0, 0]
@@ -71,24 +61,14 @@ lam4 = \{n m}. (Fin n, Fin m)
ref := get ref + 1
> ((), 10)
-def myInt : Int = 1
-:p myInt
-> 1
-
-def myInt2 : {State Int} Int = 1
-> Syntax error: Nullary def can't have effects
->
-> def myInt2 : {State Int} Int = 1
-> ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
-
-- Check that a syntax error in a funDefLet doesn't try to reparse the
-- whole definition as something it's not.
-def frob (x:Int) (y:Int) : + =
+def frob(x:Int, y:Int) -> + =
-> Parse error:86:30:
+> Parse error:66:29:
> |
-> 86 | def frob (x:Int) (y:Int) : + =
-> | ^
+> 66 | def frob(x:Int, y:Int) -> + =
+> | ^
> unexpected '='
> expecting name or symbol name
@@ -96,35 +76,31 @@ def frob (x:Int) (y:Int) : + =
for i.
i
-> Parse error:97:1:
+> Parse error:77:1:
> |
-> 97 | i
+> 77 | i
> | ^
> expecting end of line or space
def (foo + bar) : Int = 6
-> Parse error:105:6:
-> |
-> 105 | def (foo + bar) : Int = 6
-> | ^
+> Parse error:85:6:
+> |
+> 85 | def (foo + bar) : Int = 6
+> | ^
> unexpected 'f'
> expecting symbol name
'Data definitions allow but do not require type / kind annotations
-data MyPair1 a b = MkPair1 a b
-
-data MyPair2 a:Type b:Type = MkPair2 x:a y:b
-
-'which may be grouped with parentheses
+data MyPair1(a, b) = MkPair1(a, b)
-data MyPair3 (a:Type) (b:Type) = MkPair3 (x:a) (y:b)
+data MyPair2(a:Type, b:Type) = MkPair2(x:a, y:b)
'Data definitions allow interleaving arguments and class constraints
(regression test for Issue 1015)
-data TableInType n a [Ix n] table:(n=>a) =
+data TableInType(n|Ix, a, table:(n=>a)) =
MkTableInType -- Doesn't store any data except in the type!
'Left arrow <- desugars to a continuation lambda
@@ -135,3 +111,83 @@ data TableInType n a [Ix n] table:(n=>a) =
x := 4
get x
> 4
+
+'Check that we get reasonably helpful error messages if we try to write param lists
+without parens or with whitespace before the parens.
+
+
+data MyMaybe a =
+ MyNothing
+
+> Parse error:119:14:
+> |
+> 119 | data MyMaybe a =
+> | ^
+> unexpected 'a'
+> expecting '=' or parameter list in parentheses (without preceding whitespace)
+data MyMaybe (a) =
+ MyNothing
+
+> Parse error:122:14:
+> |
+> 122 | data MyMaybe (a) =
+> | ^
+> unexpected '('
+> expecting '=' or parameter list in parentheses (without preceding whitespace)
+data MyMaybe(a) =
+ MyNothing
+ MyJust a
+
+> Parse error:127:10:
+> |
+> 127 | MyJust a
+> | ^^
+> unexpected "a<newline>"
+> expecting end of input, end of line, or optional parameter list
+data MyMaybe(a) =
+ MyNothing
+ MyJust (a)
+
+> Parse error:131:10:
+> |
+> 131 | MyJust (a)
+> | ^^
+> unexpected "(a"
+> expecting end of input, end of line, or optional parameter list
+interface MyClass a
+ pass
+
+> Parse error:133:19:
+> |
+> 133 | interface MyClass a
+> | ^
+> expecting parameter list in parentheses (without preceding whitespace)
+instance MyClass Int
+ pass
+
+> Parse error:136:18:
+> |
+> 136 | instance MyClass Int
+> | ^
+> expecting '(' (without preceding whitespace)
+instance MyClass (Int)
+ pass
+
+> Parse error:139:18:
+> |
+> 139 | instance MyClass (Int)
+> | ^
+> expecting '(' (without preceding whitespace)
+def myFunction (x) = ()
+
+> Parse error:142:16:
+> |
+> 142 | def myFunction (x) = ()
+> | ^
+> expecting parameter list in parentheses (without preceding whitespace)
+def myFunction x = ()
+> Parse error:144:16:
+> |
+> 144 | def myFunction x = ()
+> | ^
+> expecting parameter list in parentheses (without preceding whitespace)
diff --git a/tests/print-tests.dx b/tests/print-tests.dx
index d3cdaec3..e1cdbe8d 100644
--- a/tests/print-tests.dx
+++ b/tests/print-tests.dx
@@ -14,7 +14,7 @@
:pcodegen [Just (Just 1.0), Just Nothing, Nothing]
> [(Just (Just 1.)), (Just Nothing), Nothing]
-data MyType = MyValue (Nat)
+data MyType = MyValue(Nat)
:pcodegen MyValue 1
> (MyValue 1)
diff --git a/tests/shadow-tests.dx b/tests/shadow-tests.dx
index 9592aba6..4a059804 100644
--- a/tests/shadow-tests.dx
+++ b/tests/shadow-tests.dx
@@ -36,7 +36,7 @@ arr = 20
> Error: variable already defined: arr
>
> arr = 20
-> ^^^
+> ^^^^
:p arr
> Error: ambiguous variable: arr is defined:
@@ -89,4 +89,4 @@ ShadowCon' = 1
> Error: variable already defined: ShadowCon'
>
> ShadowCon' = 1
-> ^^^^^^^^^^
+> ^^^^^^^^^^^
diff --git a/tests/sort-tests.dx b/tests/sort-tests.dx
index e5940dfb..e4b6876b 100644
--- a/tests/sort-tests.dx
+++ b/tests/sort-tests.dx
@@ -5,7 +5,6 @@ import sort
:p is_sorted $ sort [9, 3, 7, 4, 6, 1, 9, 1, 9, -1, 10, 10, 100, 0]
> True
-
'### Lexical Sorting Tests
:p "aaa" < "bbb"
diff --git a/tests/type-tests.dx b/tests/type-tests.dx
index e6d16946..2e6aef63 100644
--- a/tests/type-tests.dx
+++ b/tests/type-tests.dx
@@ -153,13 +153,13 @@ MyPair : (Type) -> Type =
-- TODO: put source annotation on effect for a better message here
-def fEff () -> {|a} a given (a) = todo
+def fEff() -> {|a} a given (a) = todo
> Type error:
> Expected: Type
> Actual: EffKind
>
-> def fEff () -> {|a} a given (a) = todo
-> ^
+> def fEff() -> {|a} a given (a) = todo
+> ^^
:p
for i:(Fin 7). sum for j:(Fin unboundName). 1.0
@@ -328,25 +328,25 @@ newPair : NewPair Int Float = MkNewPair 1 2.0
> :p (\x:Int. x) == (\x:Int. x)
> ^^^^^^^^^^^^^^^^^^^^^^^^^^
-def getFst1 (xs:n=>b) -> b given (n|Ix, b) =
+def getFst1(xs:n=>b) -> b given (n|Ix, b) =
xs[from_ordinal 0]
:p getFst1 [1,2,3]
> 1
-def getFst2 (xs:n=>b) -> b given (n|Ix, b) =
+def getFst2(xs:n=>b) -> b given (n|Ix, b) =
xs[from_ordinal 0]
:p getFst2 [1,2,3]
> 1
-def getFst3 (xs:n=>b) -> b given (b, n|Ix) =
+def getFst3(xs:n=>b) -> b given (b, n|Ix) =
xs[from_ordinal 0]
:p getFst3 [1,2,3]
> 1
-def triRefIndex (ref:Ref h ((i':n)=>(..i')=>Float), i:n) -> Ref h ((..i)=>Float)
+def triRefIndex(ref:Ref h ((i':n)=>(..i')=>Float), i:n) -> Ref h ((..i)=>Float)
given (h, n|Ix) (Data ((i':n)=>(..i')=>Float)) =
%indexRef(ref, i)
@@ -369,7 +369,7 @@ def triRefIndex (ref:Ref h ((i':n)=>(..i')=>Float), i:n) -> Ref h ((..i)=>Float)
id (for i:(Fin 2). for j:(..i). 1.0)
> [[1.], [1., 1.]]
-def weakerInferenceReduction (l: (i:n)=>(..i)=>Float, j:n) -> () given (n|Ix) =
+def weakerInferenceReduction(l: (i:n)=>(..i)=>Float, j:n) -> () given (n|Ix) =
for i:(..j).
i' = inject(superset=n, i)
for k:(..i').
@@ -408,7 +408,7 @@ c = [1, 2]
-- Tests for type inference of table literals
-def mkEmpty (a|Data) -> (Fin 0)=>a = []
+def mkEmpty(a|Data) -> (Fin 0)=>a = []
:t [0.0, 1.0]
> ((Fin 2) => Float32)
@@ -419,7 +419,7 @@ def mkEmpty (a|Data) -> (Fin 0)=>a = []
:t (coerce_table _ [0.0, 1.0]) :: (Fin 1, Fin 2)=>Float
> ((Fin 1, Fin 2) => Float32)
-def uncurryTable (x: (Fin 2, Fin 2)=>a) -> (Fin 2)=>(Fin 2)=>a given (a) =
+def uncurryTable(x: (Fin 2, Fin 2)=>a) -> (Fin 2)=>(Fin 2)=>a given (a) =
for i j. x[(i, j)]
-- We should be able to infer the tuple type here
@@ -431,8 +431,8 @@ def uncurryTable (x: (Fin 2, Fin 2)=>a) -> (Fin 2)=>(Fin 2)=>a given (a) =
> ((Fin 2) => (Fin 2) => Nat)
-- Make sure that the local type alias is unifiable with Int
-def GetInt (n: Int) -> Type = Int
-def ff (n : Int) -> Int =
+def GetInt(n: Int) -> Type = Int
+def ff(n : Int) -> Int =
i = GetInt n
the i 2
@@ -440,7 +440,7 @@ ff 0
> 2
-- The two local aliases for Fin n should be unifiable with each other and Fin n
-def q (n: Nat) -> (Fin n)=>Nat =
+def q(n: Nat) -> (Fin n)=>Nat =
ix1 = Fin n
x1 = for i:ix1. ordinal i
ix2 = Fin n
@@ -478,7 +478,7 @@ q 5
-- Make sure we fail gracefully when the annotated index set doesn't
-- have a static size.
-def frob (_:()) -> () given (n) =
+def frob(_:()) -> () given (n) =
[0.0, 1.0]::((Fin n)=>Float)
()
> Type error:
@@ -494,7 +494,7 @@ def frob (_:()) -> () given (n) =
-- foo is a function with all-implicit arguments (whether that's a
-- good idea or not).
-def foo () ->> a=>Float given (a|Ix) =
+def foo() ->> a=>Float given (a|Ix) =
for z:a. 1.0
:t foo
@@ -517,7 +517,7 @@ foo : Fin 3 => Float = [1.0, 2.0, 3.0]
> Error: variable already defined: foo
>
> foo : Fin 3 => Float = [1.0, 2.0, 3.0]
-> ^^^
+> ^^^^
'### Tests for dependent pair syntax
@@ -539,10 +539,10 @@ if False
'### Tests for dependent pair pattern match
-def LowerTriMat (n|Ix, v:Type) -> Type = (i:n)=>(..i)=>v
-def UpperTriMat (n|Ix, v:Type) -> Type = (i:n)=>(i..)=>v
+def LowerTriMat(n|Ix, v:Type) -> Type = (i:n)=>(..i)=>v
+def UpperTriMat(n|Ix, v:Type) -> Type = (i:n)=>(i..)=>v
-def transpose_ix (i:n, j:(i..)) -> (i:n &> ..i) given (n|Ix) =
+def transpose_ix(i:n, j:(i..)) -> (i:n &> ..i) given (n|Ix) =
j' = inject(superset=n, j)
i' = unsafe_project i
(j' ,> i')
diff --git a/tests/uexpr-tests.dx b/tests/uexpr-tests.dx
index 1fa396e2..5524ef79 100644
--- a/tests/uexpr-tests.dx
+++ b/tests/uexpr-tests.dx
@@ -2,12 +2,12 @@
:p 3 + (4 + 5)
> 12
-def depId (a:Type, x:a) -> a = x
+def depId(a:Type, x:a) -> a = x
:p depId Int 1
> 1
-def returnFirstArg (a:Type, b:Type, x:a, y:b) -> a = x
+def returnFirstArg(a:Type, b:Type, x:a, y:b) -> a = x
:p returnFirstArg Int Float 1 2.0
> 1
@@ -15,17 +15,17 @@ def returnFirstArg (a:Type, b:Type, x:a, y:b) -> a = x
:p 1.0 + 2.0
> 3.
-def triple (x:Float) -> Float = x + x + x
+def triple(x:Float) -> Float = x + x + x
:p triple 1.0
> 3.
-def idExplicit (a:Type, x:a) -> a = x
+def idExplicit(a:Type, x:a) -> a = x
:p idExplicit Int 1
> 1
-def idImplicit (x:a) -> a given (a:Type) = x
+def idImplicit(x:a) -> a given (a:Type) = x
:p idImplicit 1
> 1
@@ -54,27 +54,27 @@ idImplicit2 : (given (a:Type), a) -> a = \x. x
> Error: variable not in scope: x
>
> :p x + x
-> ^
+> ^^
idiv = 1
> Error: variable already defined: idiv
>
> idiv = 1
-> ^^^^
+> ^^^^^
-def TyId (a:Type) -> Type = a
+def TyId(a:Type) -> Type = a
:p
x:(TyId Int) = 1
x
> 1
:p
- def TyId2 (a:Type) -> Type = a
+ def TyId2(a:Type) -> Type = a
x:(TyId2 Int) = 1
x
> 1
-def tabId (x:n=>Int) -> n=>Int given (n|Ix) = for i. x[i]
+def tabId(x:n=>Int) -> n=>Int given (n|Ix) = for i. x[i]
-- bug: this doesn't work if we split it across top-level decls
:p
@@ -176,7 +176,7 @@ myPair = (1, 2.3)
xsRef!i := ordinal i
> [0, 1, 2]
-def passthrough (f:(a)->{|eff} b, x:a) -> {|eff} b
+def passthrough(f:(a)->{|eff} b, x:a) -> {|eff} b
given (a, b, eff:Effects) =
f x
@@ -208,13 +208,13 @@ def passthrough (f:(a)->{|eff} b, x:a) -> {|eff} b
> f : (aa:Type, aa) -> aa = \bb x. myId x
> ^^^^^^^^^^^^^
-def myFst (p:(a,b)) -> a given (a, b) =
+def myFst(p:(a,b)) -> a given (a, b) =
(x, _) = p
x
:p myFst (1,2)
> 1
-def myOtherFst (pair:(a,b)) -> a given (a, b) =
+def myOtherFst(pair:(a,b)) -> a given (a, b) =
(x, _) = pair
x
:p myOtherFst (1,2)
@@ -253,7 +253,7 @@ def myOtherFst (pair:(a,b)) -> a given (a, b) =
id'' : (given (b:Type), b) -> b = id
-def eitherFloor (x:(Either Int Float)) -> Int = case x of
+def eitherFloor(x:(Either Int Float)) -> Int = case x of
Left i -> i
Right f -> f_to_i f
@@ -285,7 +285,7 @@ def eitherFloor (x:(Either Int Float)) -> Int = case x of
> (Solving for: [a, b])
>
> [a, b] = [1, 2, 3]
-> ^^^^^^
+> ^^^^^^^
:p
[[a, _], [_, d]] = [[1, 2], [3, 4]]
@@ -302,7 +302,7 @@ def eitherFloor (x:(Either Int Float)) -> Int = case x of
> (Solving for: [a, b])
>
> [a, b, c, d] = coerce_table (Fin 2, Fin 2) [1, 2, 3, 4]
-> ^^^^^^^^^^^^
+> ^^^^^^^^^^^^^
-- Needs delayed inference (can't verify Ix and reduce the type before we infer the hole).
-- :p
@@ -310,7 +310,7 @@ def eitherFloor (x:(Either Int Float)) -> Int = case x of
-- (a, b, c, d)
-- > (0, (0, (0, 0)))
-def bug (n|Data) -> () =
+def bug(n|Data) -> () =
for w':n.
w : n = todo
for i:(w..). ()
@@ -336,7 +336,7 @@ badDefinition = 4
> Error: variable already defined: badDefinition
>
> badDefinition = 4
-> ^^^^^^^^^^^^^
+> ^^^^^^^^^^^^^^
:p badDefinition
> Error: ambiguous variable: badDefinition is defined: