summaryrefslogtreecommitdiff
path: root/lib/prelude.dx
diff options
context:
space:
mode:
authorDougal Maclaurin <dougalm@users.noreply.github.com>2023-04-04 18:33:22 -0400
committerGitHub <noreply@github.com>2023-04-04 18:33:22 -0400
commit5edfffe7a7a771030efab2287cb53e74eac24d1c (patch)
treecbf2d9b0ba0c197f2cb96d714aa627f3fa8ad5bc /lib/prelude.dx
parentb3f9aa7840c612025c3af81c7a92c8fad34a9a88 (diff)
parent02e7232b5d8c254f86147108d212020767a1d89c (diff)
Merge pull request #1261 from google-research/more-field-access-stuff
More struct features
Diffstat (limited to 'lib/prelude.dx')
-rw-r--r--lib/prelude.dx213
1 files changed, 106 insertions, 107 deletions
diff --git a/lib/prelude.dx b/lib/prelude.dx
index d5810069..b6f23d22 100644
--- a/lib/prelude.dx
+++ b/lib/prelude.dx
@@ -334,7 +334,7 @@ def unsafe_nat_diff(x:Nat, y:Nat) -> Nat =
rep_to_nat %isub(x', y')
-- `(i..)` parses as `RangeFrom _ i`
--- TODO: need to a way to indicate `.new` as private
+-- TODO: need to a way to indicate constructor as private
struct RangeFrom(q:Type, i:q) = val : Nat
-- `(i<..)` parses as `RangeFromExc _ i`
@@ -349,22 +349,22 @@ struct RangeToExc(q:Type, i:q) = val : Nat
instance Ix(RangeFrom q i) given (q|Ix, i:q)
def size'() = unsafe_nat_diff(size q, ordinal i)
def ordinal(j) = j.val
- def unsafe_from_ordinal(j) = RangeFrom.new(j)
+ def unsafe_from_ordinal(j) = RangeFrom(j)
instance Ix(RangeFromExc q i) given (q|Ix, i:q)
def size'() = unsafe_nat_diff(size q, ordinal i + 1)
def ordinal(j) = j.val
- def unsafe_from_ordinal(j) = RangeFromExc.new(j)
+ def unsafe_from_ordinal(j) = RangeFromExc(j)
instance Ix(RangeTo q i) given (q|Ix, i:q)
def size'() = ordinal i + 1
def ordinal(j) = j.val
- def unsafe_from_ordinal(j) = RangeTo.new(j)
+ def unsafe_from_ordinal(j) = RangeTo(j)
instance Ix(RangeToExc q i) given (q|Ix, i:q)
def size'() = ordinal i
def ordinal(j) = j.val
- def unsafe_from_ordinal(j) = RangeToExc.new(j)
+ def unsafe_from_ordinal(j) = RangeToExc(j)
instance Ix(())
def size'() = 1
@@ -411,12 +411,9 @@ instance Mul(n=>a) given (a|Mul, n|Ix)
'## Basic polymorphic functions and types
-def fst(pair:(a, b)) -> a given (a, b) =
- (x, _) = pair
- x
-def snd(pair:(a, b)) -> b given (a, b) =
- (_, y) = pair
- y
+def fst(pair:(a, b)) -> a given (a, b) = pair.0
+def snd(pair:(a, b)) -> b given (a, b) = pair.1
+
def swap(pair:(a, b)) -> (b, a) given (a, b) =
(x, y) = pair
(y, x)
@@ -596,7 +593,7 @@ struct Post(segment:Type) =
instance Ix(Post segment) given (segment|Ix)
def size'() = size segment + 1
def ordinal(i) = i.val
- def unsafe_from_ordinal(i) = Post.new(i)
+ def unsafe_from_ordinal(i) = Post(i)
def left_post(i:n) -> Post n given (n|Ix) =
unsafe_from_ordinal(n=Post n, ordinal i)
@@ -674,8 +671,8 @@ def (+=)(ref:Ref h w, x:w) -> {Accum h} ()
%mextend(ref, empty, \x:b y:b. %applyMethod1(bm, x, y), x)
def (!)(ref: Ref h (n=>a), i:n) -> Ref h a given (n|Ix, a|Data, h) = %indexRef(ref, i)
-def fst_ref(ref: Ref h (a,b)) -> Ref h a given (a|Data, b, h) = %fstRef(ref)
-def snd_ref(ref: Ref h (a,b)) -> Ref h b given (a, b|Data, h) = %sndRef(ref)
+def fst_ref(ref: Ref h (a,b)) -> Ref h a given (b, a|Data, h) = ref.0
+def snd_ref(ref: Ref h (a,b)) -> Ref h b given (a, b|Data, h) = ref.1
def run_reader(
init:r,
@@ -979,7 +976,7 @@ instance Subset(a, c) given (a, b, c) (Subset(a, b), Subset(b, c))
def unsafe_project'(x) = unsafe_project $ unsafe_project(to=b, x)
def unsafe_project_rangefrom(j:q) -> RangeFrom(q, i) given (q|Ix, i:q) =
- RangeFrom.new unsafe_nat_diff(ordinal j, ordinal i)
+ RangeFrom unsafe_nat_diff(ordinal j, ordinal i)
instance Subset(RangeFrom(q, i), q) given (q|Ix, i:q)
def inject'(j) =
@@ -989,8 +986,8 @@ instance Subset(RangeFrom(q, i), q) given (q|Ix, i:q)
i' = ordinal i
if j' < i'
then Nothing
- else Just $ RangeFrom.new $ unsafe_nat_diff(j', i')
- def unsafe_project'(j) = RangeFrom.new unsafe_nat_diff(ordinal j, ordinal i)
+ else Just $ RangeFrom $ unsafe_nat_diff(j', i')
+ def unsafe_project'(j) = RangeFrom unsafe_nat_diff(ordinal j, ordinal i)
instance Subset(RangeFromExc(q, i), q) given (q|Ix, i:q)
def inject'(j) = unsafe_from_ordinal $ j.val + ordinal i + 1
@@ -999,9 +996,9 @@ instance Subset(RangeFromExc(q, i), q) given (q|Ix, i:q)
i' = ordinal i
if j' <= i'
then Nothing
- else Just $ RangeFromExc.new unsafe_nat_diff(j', i' + 1)
+ else Just $ RangeFromExc unsafe_nat_diff(j', i' + 1)
def unsafe_project'(j) =
- RangeFromExc.new unsafe_nat_diff(ordinal j, ordinal i + 1)
+ RangeFromExc unsafe_nat_diff(ordinal j, ordinal i + 1)
instance Subset(RangeTo(q, i), q) given (q|Ix, i:q)
def inject'(j) = unsafe_from_ordinal j.val
@@ -1010,8 +1007,8 @@ instance Subset(RangeTo(q, i), q) given (q|Ix, i:q)
i' = ordinal i
if j' > i'
then Nothing
- else Just $ RangeTo.new j'
- def unsafe_project'(j) = RangeTo.new (ordinal j)
+ else Just $ RangeTo j'
+ def unsafe_project'(j) = RangeTo (ordinal j)
instance Subset(RangeToExc(q, i), q) given (q|Ix, i:q)
def inject'(j) = unsafe_from_ordinal j.val
@@ -1020,8 +1017,8 @@ instance Subset(RangeToExc(q, i), q) given (q|Ix, i:q)
i' = ordinal i
if j' >= i'
then Nothing
- else Just $ RangeToExc.new j'
- def unsafe_project'(j) = RangeToExc.new (ordinal j)
+ else Just $ RangeToExc j'
+ def unsafe_project'(j) = RangeToExc (ordinal j)
instance Subset(RangeToExc(q, i), RangeTo(q, i)) given (q|Ix, i:q)
def inject'(j) = unsafe_from_ordinal j.val
@@ -1030,8 +1027,8 @@ instance Subset(RangeToExc(q, i), RangeTo(q, i)) given (q|Ix, i:q)
i' = ordinal i
if j' >= i'
then Nothing
- else Just $ RangeToExc.new j'
- def unsafe_project'(j) = RangeToExc.new (ordinal j)
+ else Just $ RangeToExc j'
+ def unsafe_project'(j) = RangeToExc (ordinal j)
'## Elementary/Special Functions
This is more or less the standard [LibM fare](https://en.wikipedia.org/wiki/C_mathematical_functions).
@@ -1125,7 +1122,7 @@ instance Floating(Float32)
struct Ptr(a:Type) =
val : RawPtr
-def cast_ptr(ptr: Ptr a) -> Ptr b given (a, b) = Ptr.new(ptr.val)
+def cast_ptr(ptr: Ptr a) -> Ptr b given (a, b) = Ptr(ptr.val)
interface Storable(a|Data)
store : (Ptr a, a) -> {IO} ()
@@ -1153,26 +1150,26 @@ instance Storable(Float32)
def storage_size() = 4
instance Storable(Nat)
- def store(ptr, x) = store(Ptr.new(ptr.val), nat_to_rep x)
- def load(ptr) = rep_to_nat $ load(Ptr.new(ptr.val))
+ def store(ptr, x) = store(Ptr(ptr.val), nat_to_rep x)
+ def load(ptr) = rep_to_nat $ load(Ptr(ptr.val))
def storage_size() = storage_size(a=NatRep)
instance Storable(Ptr a) given (a)
def store(ptr, x) = %ptrStore(internal_cast(to=%PtrPtr(), ptr.val), x.val)
- def load(ptr) = Ptr.new(%ptrLoad(internal_cast(to=%PtrPtr(), ptr)))
+ def load(ptr) = Ptr(%ptrLoad(internal_cast(to=%PtrPtr(), ptr)))
def storage_size() = 8 -- TODO: something more portable?
-- TODO: Storable instances for other types
def malloc(n:Nat) -> {IO} (Ptr a) given (a|Storable) =
numBytes = storage_size(a=a) * n
- Ptr.new(%alloc(nat_to_rep numBytes))
+ Ptr(%alloc(nat_to_rep numBytes))
def free(ptr:Ptr a) -> {IO} () given (a) = %free(ptr.val)
def (+>>)(ptr:Ptr a, i:Nat) -> Ptr a given (a|Storable) =
i' = nat_to_rep $ i * storage_size(a=a)
- Ptr.new(%ptrOffset(ptr.val, i'))
+ Ptr(%ptrOffset(ptr.val, i'))
-- TODO: consider making a Storable instance for tables instead
def store_table(ptr: Ptr a, tab:n=>a) -> {IO} () given (a|Storable, n|Ix) =
@@ -1449,7 +1446,7 @@ def with_dynamic_buffer(action: (DynBuffer a) -> {IO} b) -> {IO} b given (a|Stor
store(maxSizePtr, initMaxSize)
bufferPtr <- with_alloc 1
store(bufferPtr, malloc initMaxSize)
- result = action $ DynBuffer.new(sizePtr, maxSizePtr, bufferPtr)
+ result = action $ DynBuffer(sizePtr, maxSizePtr, bufferPtr)
free $ load bufferPtr
result
@@ -1505,7 +1502,7 @@ def with_c_string(
action: (CString) -> {IO} a
) -> {IO} a given (a) =
AsList(n, s') = s <> "\NUL"
- with_table_ptr s' \ptr. action $ CString.new(ptr.val)
+ with_table_ptr s' \ptr. action CString(ptr.val)
'### Show interface
For things that can be shown.
@@ -1525,14 +1522,14 @@ foreign "showInt32" showInt32 : (Int32) -> {IO} (Word32, RawPtr)
instance Show(Int32)
def show(x) = unsafe_io \.
(n, ptr) = showInt32 x
- string_from_char_ptr(n, Ptr.new ptr)
+ string_from_char_ptr(n, Ptr ptr)
foreign "showInt64" showInt64 : (Int64) -> {IO} (Word32, RawPtr)
instance Show(Int64)
def show(x) = unsafe_io \.
(n, ptr) = showInt64 x
- string_from_char_ptr(n, Ptr.new ptr)
+ string_from_char_ptr(n, Ptr ptr)
instance Show(Nat)
def show(x) = show $ n_to_i64 x
@@ -1542,14 +1539,14 @@ foreign "showFloat32" showFloat32 : (Float32) -> {IO} (Word32, RawPtr)
instance Show(Float32)
def show(x) = unsafe_io \.
(n, ptr) = showFloat32 x
- string_from_char_ptr(n, Ptr.new ptr)
+ string_from_char_ptr(n, Ptr ptr)
foreign "showFloat64" showFloat64 : (Float64) -> {IO} (Word32, RawPtr)
instance Show(Float64)
def show(x) = unsafe_io \.
(n, ptr) = showFloat64 x
- string_from_char_ptr(n, Ptr.new ptr)
+ string_from_char_ptr(n, Ptr ptr)
instance Show(())
def show(_) = "()"
@@ -1625,7 +1622,7 @@ def is_null_raw_ptr(ptr:RawPtr) -> Bool =
def from_nullable_raw_ptr(ptr:RawPtr) -> Maybe (Ptr a) given (a) =
if is_null_raw_ptr ptr
then Nothing
- else Just $ Ptr.new ptr
+ else Just $ Ptr ptr
def c_string_ptr(s:CString) -> Maybe (Ptr Char) = from_nullable_raw_ptr s.ptr
@@ -1650,7 +1647,7 @@ def fopen(path:String, mode:StreamMode) -> {IO} (Stream mode) =
WriteMode -> "w"
with_c_string path \cPath.
with_c_string modeStr \cMode.
- Stream.new $ fopenFFI(cPath.ptr, cMode.ptr)
+ Stream $ fopenFFI(cPath.ptr, cMode.ptr)
def fclose(stream:Stream mode) -> {IO} () given (mode) =
fcloseFFI stream.ptr
@@ -1723,7 +1720,7 @@ foreign "getenv" getenvFFI : (RawPtr) -> {IO} RawPtr
def get_env(name:String) -> {IO} Maybe String =
cStr <- with_c_string name
- getenvFFI cStr.ptr | CString.new | from_c_string
+ getenvFFI cStr.ptr | CString | from_c_string
def check_env(name:String) -> {IO} Bool =
is_just $ get_env name
@@ -1746,7 +1743,7 @@ def fread(stream:Stream ReadMode) -> {IO} String =
'### Print
def get_output_stream() -> {IO} Stream WriteMode =
- Stream.new $ %outputStream()
+ Stream $ %outputStream()
@noinline
def print(s:String) -> {IO} () =
@@ -1765,7 +1762,7 @@ def shell_out(command:String) -> {IO} String =
modeStr = "r"
with_c_string command \command'.
with_c_string modeStr \modeStr'.
- pipe = Stream.new $ popenFFI(command'.ptr, modeStr'.ptr)
+ pipe = Stream $ popenFFI(command'.ptr, modeStr'.ptr)
fread pipe
'## Partial functions
@@ -1821,7 +1818,7 @@ def new_temp_file() -> {IO} FilePath =
s <- with_c_string "/tmp/dex-XXXXXX"
fd = mkstempFFI s.ptr
closeFFI fd
- string_from_char_ptr(15, (Ptr.new s.ptr))
+ string_from_char_ptr(15, (Ptr s.ptr))
def with_temp_file(action: (FilePath) -> {IO} a) -> {IO} a given (a) =
tmpFile = new_temp_file()
@@ -2274,9 +2271,9 @@ def cat_maybes(xs:n=>Maybe a) -> List a given (n|Ix, a|Data) =
(num_res, res_inds) = yield_state (0::Nat, for i:n. Nothing) \ref.
for i. case xs[i] of
Just(_) ->
- ix = get $ fst_ref ref
- (snd_ref ref) ! (unsafe_from_ordinal ix) := Just i
- fst_ref ref := ix + 1
+ ix = get ref.0
+ ref.1 ! (unsafe_from_ordinal ix) := Just i
+ ref.0 := ix + 1
Nothing -> ()
to_list $ for i:(Fin num_res).
case res_inds[unsafe_from_ordinal $ ordinal i] of
@@ -2431,61 +2428,63 @@ instance NonEmpty(a=>b) given (a|Ix, b|NonEmpty)
'### stack
-- TODO: replace `DynBuffer` with this?
-def Stack(h:Heap, a|Data) = Ref(h, (Nat, List a))
-
-def stack_size(stack:Stack(h, a)) -> {State h} Nat given (h, a) = get $ fst_ref stack
-
-def unsafe_get_stack_buffer(stack:Stack(h, a)) -> {State h} (Ref(h, Fin 0 => a)) given (h, a|Data) =
- get $ snd_ref $ unsafe_coerce(to=Ref h (Nat, Ref h (Fin 0 => a)), snd_ref stack)
-
-def stack_buf_size(stack:Stack(h, a)) -> {State h} Nat given (h, a|Data) =
- get $ fst_ref $ unsafe_coerce(to=Ref h (Nat, Ref h (Fin 0 => a)), snd_ref stack)
-
-def ensure_size_at_least(stack:Stack(h, a), req_size:Nat) -> {State h} () given (h, a|Data) =
- if req_size > stack_buf_size stack then
- -- TODO: maybe this should use integer arithmetic?
- new_buf_size = f_to_n $ 2.0 `pow` (ceil $ log2 $ n_to_f req_size)
- buf = unsafe_get_stack_buffer stack
- logical_size = stack_size stack
- cur_data = get $ unsafe_coerce(to=Ref(h, Fin logical_size => a), buf)
- snd_ref stack := to_list for i:(Fin new_buf_size).
- case to_ix(n=Fin logical_size, ordinal i) of
- Just(i') -> cur_data[i']
- Nothing -> uninitialized_value()
-
-def read_stack(stack:Stack(h, a)) -> {State h} (List a) given (h, a|Data) =
- n = stack_size stack
- buf = unsafe_coerce(to=Ref(h, Fin n => a), unsafe_get_stack_buffer stack)
- AsList(n, get buf)
-
-@noinline
-def stack_push(stack:Stack(h, a), x:a) -> {State h} () given (a|Data, h) =
- n_old = stack_size stack
- n_new = n_old + 1
- ensure_size_at_least(stack, n_new)
- buf = unsafe_get_stack_buffer stack
- buf ! (unsafe_from_ordinal n_old) := x
- fst_ref stack := n_new
-
-@noinline
-def stack_extend(stack:Stack(h, a), x:n=>a) -> {State h} () given (a|Data, n|Ix, h) =
- n_old = stack_size stack
- n_new = n_old + size n
- ensure_size_at_least(stack, n_new)
- buf = unsafe_get_stack_buffer stack
- buf_slice = unsafe_coerce(to=Ref(h,n=>a), buf ! (unsafe_from_ordinal n_old))
- buf_slice := x
- fst_ref stack := n_new
-
-def stack_pop(stack:Stack(h, a)) -> {State h} Maybe a given (a|Data, h) =
- n_old = stack_size stack
- case n_old == 0 of
- True -> Nothing
- False ->
- n_new = unsafe_nat_diff(n_old, 1)
- buf = unsafe_get_stack_buffer stack
- fst_ref stack := n_new
- Just $ get buf!(unsafe_from_ordinal n_new)
+struct Stack(h:Heap, a|Data) =
+ size_ref : Ref h Nat
+ buf_ref : Ref h (List a)
+
+ def size() -> {State h} Nat = get self.size_ref
+
+ def unsafe_get_buffer() -> {State h} (Ref(h, Fin 0 => a)) =
+ get $ snd_ref $ unsafe_coerce(to=Ref h (Nat, Ref h (Fin 0 => a)), self.buf_ref)
+
+ def buf_size() -> {State h} Nat =
+ get $ fst_ref $ unsafe_coerce(to=Ref h (Nat, Ref h (Fin 0 => a)), self.buf_ref)
+
+ def ensure_size_at_least(req_size:Nat) -> {State h} () =
+ if req_size > self.buf_size() then
+ -- TODO: maybe this should use integer arithmetic?
+ new_buf_size = f_to_n $ 2.0 `pow` (ceil $ log2 $ n_to_f req_size)
+ buf = self.unsafe_get_buffer()
+ logical_size = self.size()
+ cur_data = get $ unsafe_coerce(to=Ref(h, Fin logical_size => a), buf)
+ self.buf_ref := to_list for i:(Fin new_buf_size).
+ case to_ix(n=Fin logical_size, ordinal i) of
+ Just(i') -> cur_data[i']
+ Nothing -> uninitialized_value()
+
+ def read() -> {State h} (List a) =
+ n = self.size()
+ buf = unsafe_coerce(to=Ref(h, Fin n => a), self.unsafe_get_buffer())
+ AsList(n, get buf)
+
+ @noinline
+ def push(x:a) -> {State h} () =
+ n_old = self.size()
+ n_new = n_old + 1
+ self.ensure_size_at_least(n_new)
+ buf = self.unsafe_get_buffer()
+ buf ! (unsafe_from_ordinal n_old) := x
+ self.size_ref := n_new
+
+ @noinline
+ def extend(x:n=>a) -> {State h} () given (n|Ix) =
+ n_old = self.size()
+ n_new = n_old + size n
+ self.ensure_size_at_least(n_new)
+ buf = self.unsafe_get_buffer()
+ buf_slice = unsafe_coerce(to=Ref(h,n=>a), buf ! (unsafe_from_ordinal n_old))
+ buf_slice := x
+ self.size_ref := n_new
+
+ def pop() -> {State h} Maybe a =
+ n_old = self.size()
+ case n_old == 0 of
+ True -> Nothing
+ False ->
+ n_new = unsafe_nat_diff(n_old, 1)
+ buf = self.unsafe_get_buffer()
+ self.size_ref := n_new
+ Just $ get buf!(unsafe_from_ordinal n_new)
stack_init_size = 16
def with_stack(
@@ -2493,18 +2492,18 @@ def with_stack(
action:(given (h:Heap), Stack(h, a)) -> {State h|eff} r
) -> {|eff} r given (eff, r) =
init_stack = to_list for i:(Fin stack_init_size). uninitialized_value()
- with_state (0, init_stack) \ref . action ref
+ with_state (0, init_stack) \ref . action(Stack(ref.0, ref.1))
def stack_extend_internal(stack:Stack(h, Char), x:Fin n=>Char) -> {State h} () given (n, h) =
- stack_extend(stack, x)
+ stack.extend(x)
def stack_push_internal(stack:Stack(h, Char), x:Char) -> {State h} () given (h) =
- stack_push(stack, x)
+ stack.push(x)
def with_stack_internal(f:(given (h:Heap), Stack(h, Char)) -> {State h} ()) -> List Char =
with_stack Char \stack.
f stack
- read_stack stack
+ stack.read()
def show_any(x:a) -> String given (a) = unsafe_coerce(to=String, %showAny(x))
@@ -2535,7 +2534,7 @@ struct FullTileIx(n|Ix, tile_size:Nat, tile_ix:Nat) =
instance Ix(FullTileIx(n, tile_size, tile_ix)) given (n|Ix, tile_size:Nat, tile_ix:Nat)
def size'() = tile_size
def ordinal(i) = ordinal i.unwrap
- def unsafe_from_ordinal(i) =FullTileIx.new $ unsafe_from_ordinal i
+ def unsafe_from_ordinal(i) = FullTileIx $ unsafe_from_ordinal i
instance Subset(FullTileIx(n, tile_size, tile_ix), n) given (n|Ix, tile_size:Nat, tile_ix:Nat)
def inject'(i) = unsafe_from_ordinal $ tile_size * tile_ix + ordinal i.unwrap
@@ -2551,7 +2550,7 @@ struct CodaIx(n|Ix, coda_offset:Nat, coda_size:Nat) =
instance Ix(CodaIx(n, coda_offset, coda_size)) given (n|Ix, coda_offset:Nat, coda_size:Nat)
def size'() = coda_size
def ordinal(i) = ordinal i.unwrap
- def unsafe_from_ordinal(i) = CodaIx.new $ unsafe_from_ordinal i
+ def unsafe_from_ordinal(i) = CodaIx $ unsafe_from_ordinal i
instance Subset(CodaIx(n, coda_offset, coda_size), n) given (n|Ix, coda_offset:Nat, coda_size:Nat)
def inject'(i) = unsafe_from_ordinal $ coda_offset + ordinal i.unwrap