1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
|
'# Fast Fourier Transform
For arrays whose size is a power of 2, we use a radix-2 algorithm based
on the [Futhark demo](https://github.com/diku-dk/fft/blob/master/lib/github.com/diku-dk/fft/stockham-radix-2.fut#L30).
That demo also uses types to enforce internally that the array sizes are powers of 2.
'For non-power-of-2 sized arrays, it uses
[Bluestein's Algorithm](https://en.wikipedia.org/wiki/Chirp_Z-transform),
which calls the power-of-2 FFT as a subroutine.
import complex
'## Helper functions
def odd_sized_palindrome(mid:a, seq:n=>a) -> ((n `Either` () `Either` n)=>a) given (a, n|Ix) =
# Turns sequence 12345 into 543212345.
for i.
case i of
Left i -> case i of
Left i -> seq[reflect i]
Right () -> mid
Right i -> seq[i]
'## Inner FFT functions
data FTDirection =
ForwardFT
InverseFT
def butterfly_ixs(j':halfn, pow2:Nat) -> (n, n, n, n) given (halfn|Ix, n|Ix) =
# Re-index at a finer frequency.
# halfn must have half the size of n.
# For explanation, see https://en.wikipedia.org/wiki/Butterfly_diagram
# Note: with fancier index sets, this might be replacable by reshapes.
j = ordinal j'
k = ((idiv j pow2) * pow2 * 2) + mod j pow2
left_write_ix = unsafe_from_ordinal k
right_write_ix = unsafe_from_ordinal (k + pow2)
left_read_ix = unsafe_from_ordinal j
right_read_ix = unsafe_from_ordinal (j + size halfn)
(left_read_ix, right_read_ix, left_write_ix, right_write_ix)
def power_of_2_fft(
direction: FTDirection,
x: ((Fin log2_n)=>(Fin 2))=>Complex
) -> ((Fin log2_n)=>(Fin 2))=>Complex given (log2_n:Nat) =
# (Fin n)=>(Fin 2) has 2^n elements, so (Fin log2_n)=>(Fin 2) has exactly n.
dir_const = case direction of
ForwardFT -> -pi
InverseFT -> pi
(n, ans) = yield_state (1, x) \combRef.
for i:(Fin log2_n).
ipow2Ref = fst_ref combRef
xRef = snd_ref combRef
ipow2 = get ipow2Ref
log2_half_n = unsafe_nat_diff log2_n 1 # TODO: use `i` as a proof that log2_n > 0
xRef := yield_accum (AddMonoid Complex) \bufRef.
for j:((Fin log2_half_n)=>(Fin 2)). # Executes in parallel.
(left_read_ix, right_read_ix,
left_write_ix, right_write_ix) = butterfly_ixs j ipow2
# Read one element from the last buffer, scaled.
angle = dir_const * (n_to_f $ mod (ordinal j) ipow2) / n_to_f ipow2
v = (get xRef!right_read_ix) * (Complex (cos angle) (sin angle))
# Add and subtract it to the relevant places in the new buffer.
bufRef!left_write_ix += (get (xRef!left_read_ix)) + v
bufRef!right_write_ix += (get (xRef!left_read_ix)) - v
ipow2Ref := ipow2 * 2
case direction of
ForwardFT -> ans
InverseFT -> ans / (n_to_f n)
def pad_to_power_of_2(
log2_m:Nat,
pad_val:a, xs:n=>a
) -> ((Fin log2_m)=>(Fin 2))=>a given (a, n|Ix) =
flatsize = intpow2 log2_m
padded_flat = pad_to (Fin flatsize) pad_val xs
unsafe_cast_table(to=(Fin log2_m)=>(Fin 2), padded_flat)
def convolve_complex(
u:n=>Complex,
v:m=>Complex
) -> (Either n m=>Complex) given (n|Ix, m|Ix) =
# Convolve by pointwise multiplication in the Fourier domain.
# Pad and convert to Fourier domain.
min_convolve_size = (size n + size m) -| 1
log_working_size = nextpow2 min_convolve_size
u_padded = pad_to_power_of_2 log_working_size zero u
v_padded = pad_to_power_of_2 log_working_size zero v
spectral_u = power_of_2_fft ForwardFT u_padded
spectral_v = power_of_2_fft ForwardFT v_padded
# Pointwise multiply.
spectral_conv = for i. spectral_u[i] * spectral_v[i]
# Convert back to primal domain and undo padding.
padded_conv = power_of_2_fft InverseFT spectral_conv
slice padded_conv 0 (Either n m)
def convolve(u:n=>Float, v:m=>Float) -> (Either n m =>Float) given (n|Ix, m|Ix) =
u' = for i. Complex u[i] 0.0
v' = for i. Complex v[i] 0.0
ans = convolve_complex u' v'
for i. ans[i].re
def bluestein(x: n=>Complex) -> n=>Complex given (n|Ix) =
# Bluestein's algorithm.
# Converts the general FFT into a convolution,
# which is then solved with calls to a power-of-2 FFT.
im = Complex 0.0 1.0
wks = for i.
i_squared = n_to_f $ sq $ ordinal i
exp $ (-im) * (Complex (pi * i_squared / (n_to_f (size n))) 0.0)
AsList(_, tailTable) = tail wks 1
back_and_forth = odd_sized_palindrome (head wks) tailTable
xq = for i. x[i] * wks[i]
back_and_forth_conj = for i. complex_conj back_and_forth[i]
convolution = convolve_complex xq back_and_forth_conj
convslice = slice convolution (unsafe_nat_diff (size n) 1) n
for i. wks[i] * convslice[i]
'## FFT Interface
def fft(x: n=>Complex) -> n=>Complex given (n|Ix) =
if is_power_of_2 (size n)
then
newsize = natlog2 (size n)
castx = unsafe_cast_table(to=(Fin newsize)=>(Fin 2), x)
ret = power_of_2_fft ForwardFT castx
unsafe_cast_table(to=n, ret)
else
bluestein x
def ifft(xs: n=>Complex) -> n=>Complex given (n|Ix) =
if is_power_of_2 (size n)
then
newsize = natlog2 (size n)
castx = unsafe_cast_table(to=(Fin newsize)=>(Fin 2), xs)
ret = power_of_2_fft InverseFT castx
unsafe_cast_table(to=n, ret)
else
unscaled_fft = fft (for i. complex_conj xs[i])
for i. (complex_conj unscaled_fft[i]) / (n_to_f (size n))
def fft_real(x: n=>Float) -> n=>Complex given (n|Ix) = fft for i. Complex x[i] 0.0
def ifft_real(x: n=>Float) -> n=>Complex given (n|Ix) = ifft for i. Complex x[i] 0.0
def fft2(x: n=>m=>Complex) -> n=>m=>Complex given (n|Ix, m|Ix) =
x' = for i. fft x[i]
transpose for i. fft (transpose x')[i]
def ifft2(x: n=>m=>Complex) -> n=>m=>Complex given (n|Ix, m|Ix) =
x' = for i. ifft x[i]
transpose for i. ifft (transpose x')[i]
def fft2_real(x: n=>m=>Float) -> n=>m=>Complex given (n|Ix, m|Ix) = fft2 for i j. Complex x[i,j] 0.0
def ifft2_real(x: n=>m=>Float) -> n=>m=>Complex given (n|Ix, m|Ix) = ifft2 for i j. Complex x[i,j] 0.0
|