summaryrefslogtreecommitdiff
path: root/lib/sort.dx
blob: 5f712c3f21e361610a96de97b4409bcf1c7e4d7a (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
'# Monoidal Merge Sort

'Warning: Very slow for now!!

'Because merging sorted lists is associative, we can expose the
parallelism of merge sort to the Dex compiler by making a monoid
and using `reduce` or `yield_accum`.
However, this approach puts a lot of pressure on the compiler.
As noted on [StackOverflow](https://stackoverflow.com/questions/21877572/can-you-formulate-the-bubble-sort-as-a-monoid-or-semigroup),
if the compiler does the reduction one element at a time,
it's doing bubble / insertion sort with quadratic time cost.
However, if it breaks the list in half recursively, it'll be doing parallel mergesort.
Currently the Dex compiler will do the quadratic-time version.

def concat_table(leftin: a=>v, rightin: b=>v) -> (Either a b=>v) given (a|Ix, b|Ix, v:Type) =
  for idx. case idx of
    Left i  -> leftin[i]
    Right i -> rightin[i]

def merge_sorted_tables(xs:m=>a, ys:n=>a) -> (Either m n=>a) given (a|Ord, m|Ix, n|Ix) =
  # Possible improvements:
  #  1) Using a SortedTable type.
  #  2) Avoid needlessly initializing the return array.
  init = concat_table xs ys  # Initialize array of correct size.
  yield_state init \buf.
    with_state (0, 0) \countrefs.
      for i:(Either m n).
        (cur_x, cur_y) = get countrefs
        if cur_y >= size n  # no ys left
          then
            countrefs := (cur_x + 1, cur_y)
            buf!i := xs[unsafe_from_ordinal cur_x]
          else
            if cur_x < size m # still xs left
              then
                if xs[unsafe_from_ordinal cur_x] <= ys[unsafe_from_ordinal cur_y]
                  then
                    countrefs := (cur_x + 1, cur_y)
                    buf!i := xs[unsafe_from_ordinal cur_x]
                  else
                    countrefs := (cur_x, cur_y + 1)
                    buf!i := ys[unsafe_from_ordinal cur_y]

def merge_sorted_lists(lx: List a, ly: List a) -> List a given (a|Ord) =
  # Need this wrapper because Dex can't automatically weaken
  # (a | b)=>c to ((Fin d)=>c)
  AsList(nx, xs) = lx
  AsList(ny, ys) = ly
  sorted = merge_sorted_tables xs ys
  newsize = nx + ny
  AsList _ $ unsafe_cast_table(to=Fin newsize, sorted)

# Warning: Has quadratic runtime cost for now.
def sort(xs: n=>a) -> n=>a given (n|Ix, a|Ord) =
  xlists = each xs \x. to_list([x])
  # reduce might someday recursively subdivide the problem.
  AsList(_, r) = reduce xlists mempty merge_sorted_lists
  unsafe_cast_table(to=n, r)

def (+|)(i:n, delta:Nat) -> n  given (n|Ix) =
  i' = ordinal i + delta
  from_ordinal $ select (i' >= size n) (size n -| 1) i'

def is_sorted(xs:n=>a) -> Bool given (a|Ord, n|Ix) =
  all for i:n. xs[i] <= xs[i +| 1]