summaryrefslogtreecommitdiff
path: root/lib/fft.dx
blob: 87b28c0fd89237d815de70c912ab5888a0fde464 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
'# Fast Fourier Transform
For arrays whose size is a power of 2, we use a radix-2 algorithm based
on the [Futhark demo](https://github.com/diku-dk/fft/blob/master/lib/github.com/diku-dk/fft/stockham-radix-2.fut#L30).
That demo also uses types to enforce internally that the array sizes are powers of 2.

'For non-power-of-2 sized arrays, it uses
[Bluestein's Algorithm](https://en.wikipedia.org/wiki/Chirp_Z-transform),
which calls the power-of-2 FFT as a subroutine.

import complex

'## Helper functions

def odd_sized_palindrome(mid:a, seq:n=>a) -> ((n `Either` () `Either` n)=>a) given (a, n|Ix) =
  # Turns sequence 12345 into 543212345.
  for i.
    case i of
      Left i -> case i of
        Left i -> seq[reflect i]
        Right () -> mid
      Right i -> seq[i]

'## Inner FFT functions

data FTDirection =
  ForwardFT
  InverseFT

def butterfly_ixs(j':halfn, pow2:Nat) -> (n, n, n, n) given (halfn|Ix, n|Ix) =
  # Re-index at a finer frequency.
  # halfn must have half the size of n.
  # For explanation, see https://en.wikipedia.org/wiki/Butterfly_diagram
  # Note: with fancier index sets, this might be replacable by reshapes.
  j = ordinal j'
  k = ((idiv j pow2) * pow2 * 2) + mod j pow2
  left_write_ix  = unsafe_from_ordinal k
  right_write_ix = unsafe_from_ordinal (k + pow2)

  left_read_ix  = unsafe_from_ordinal j
  right_read_ix = unsafe_from_ordinal (j + size halfn)
  (left_read_ix, right_read_ix, left_write_ix, right_write_ix)

def power_of_2_fft(
  direction: FTDirection,
  x: ((Fin log2_n)=>(Fin 2))=>Complex
  ) -> ((Fin log2_n)=>(Fin 2))=>Complex given (log2_n:Nat) =
  # (Fin n)=>(Fin 2) has 2^n elements, so (Fin log2_n)=>(Fin 2) has exactly n.

  dir_const = case direction of
    ForwardFT -> -pi
    InverseFT -> pi

  (n, ans) = yield_state (1, x) \combRef.
    for i:(Fin log2_n).
      ipow2Ref = fst_ref combRef
      xRef = snd_ref combRef
      ipow2 = get ipow2Ref

      log2_half_n = unsafe_nat_diff log2_n 1  # TODO: use `i` as a proof that log2_n > 0
      xRef := yield_accum (AddMonoid Complex) \bufRef.
        for j:((Fin log2_half_n)=>(Fin 2)).  # Executes in parallel.
          (left_read_ix, right_read_ix,
           left_write_ix, right_write_ix) = butterfly_ixs j ipow2

          # Read one element from the last buffer, scaled.
          angle = dir_const * (n_to_f $ mod (ordinal j) ipow2) / n_to_f ipow2
          v = (get xRef!right_read_ix) * (Complex (cos angle) (sin angle))

          # Add and subtract it to the relevant places in the new buffer.
          bufRef!left_write_ix  += (get (xRef!left_read_ix)) + v
          bufRef!right_write_ix += (get (xRef!left_read_ix)) - v
      ipow2Ref := ipow2 * 2

  case direction of
    ForwardFT -> ans
    InverseFT -> ans / (n_to_f n)

def pad_to_power_of_2(
    log2_m:Nat,
    pad_val:a, xs:n=>a
    ) -> ((Fin log2_m)=>(Fin 2))=>a given (a, n|Ix) =
  flatsize = intpow2 log2_m
  padded_flat = pad_to (Fin flatsize) pad_val xs
  unsafe_cast_table(to=(Fin log2_m)=>(Fin 2), padded_flat)

def convolve_complex(
    u:n=>Complex,
    v:m=>Complex
    ) -> (Either n m=>Complex) given (n|Ix, m|Ix) =
  # Convolve by pointwise multiplication in the Fourier domain.
  # Pad and convert to Fourier domain.
  min_convolve_size = (size n + size m) -| 1
  log_working_size = nextpow2 min_convolve_size
  u_padded = pad_to_power_of_2 log_working_size zero u
  v_padded = pad_to_power_of_2 log_working_size zero v
  spectral_u = power_of_2_fft ForwardFT u_padded
  spectral_v = power_of_2_fft ForwardFT v_padded

  # Pointwise multiply.
  spectral_conv = for i. spectral_u[i] * spectral_v[i]

  # Convert back to primal domain and undo padding.
  padded_conv = power_of_2_fft InverseFT spectral_conv
  slice padded_conv 0 (Either n m)

def convolve(u:n=>Float, v:m=>Float) -> (Either n m =>Float) given (n|Ix, m|Ix) =
  u' = for i. Complex u[i] 0.0
  v' = for i. Complex v[i] 0.0
  ans = convolve_complex u' v'
  for i. ans[i].re

def bluestein(x: n=>Complex) -> n=>Complex given (n|Ix) =
  # Bluestein's algorithm.
  # Converts the general FFT into a convolution,
  # which is then solved with calls to a power-of-2 FFT.
  im = Complex 0.0 1.0
  wks = for i.
    i_squared = n_to_f $ sq $ ordinal i
    exp $ (-im) * (Complex (pi * i_squared / (n_to_f (size n))) 0.0)

  AsList(_, tailTable) = tail wks 1
  back_and_forth = odd_sized_palindrome (head wks) tailTable
  xq = for i. x[i] * wks[i]
  back_and_forth_conj = for i. complex_conj back_and_forth[i]
  convolution = convolve_complex xq back_and_forth_conj
  convslice = slice convolution (unsafe_nat_diff (size n) 1) n
  for i. wks[i] * convslice[i]


'## FFT Interface

def fft(x: n=>Complex) -> n=>Complex given (n|Ix) =
  if is_power_of_2 (size n)
    then
      newsize = natlog2 (size n)
      castx = unsafe_cast_table(to=(Fin newsize)=>(Fin 2), x)
      ret = power_of_2_fft ForwardFT castx
      unsafe_cast_table(to=n, ret)
    else
      bluestein x

def ifft(xs: n=>Complex) -> n=>Complex given (n|Ix) =
  if is_power_of_2 (size n)
    then
      newsize = natlog2 (size n)
      castx = unsafe_cast_table(to=(Fin newsize)=>(Fin 2), xs)
      ret = power_of_2_fft InverseFT castx
      unsafe_cast_table(to=n, ret)
    else
      unscaled_fft = fft (for i. complex_conj xs[i])
      for i. (complex_conj unscaled_fft[i]) / (n_to_f (size n))

def  fft_real(x: n=>Float) -> n=>Complex given (n|Ix) =  fft for i. Complex x[i] 0.0
def ifft_real(x: n=>Float) -> n=>Complex given (n|Ix) = ifft for i. Complex x[i] 0.0

def fft2(x: n=>m=>Complex) -> n=>m=>Complex given (n|Ix, m|Ix) =
  x'      = for i. fft x[i]
  transpose for i. fft (transpose x')[i]

def ifft2(x: n=>m=>Complex) -> n=>m=>Complex given (n|Ix, m|Ix) =
  x'      = for i. ifft x[i]
  transpose for i. ifft (transpose x')[i]

def  fft2_real(x: n=>m=>Float) -> n=>m=>Complex given (n|Ix, m|Ix) =  fft2 for i j. Complex x[i,j] 0.0
def ifft2_real(x: n=>m=>Float) -> n=>m=>Complex given (n|Ix, m|Ix) = ifft2 for i j. Complex x[i,j] 0.0