Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Got FFT working again with new syntax. #1347

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
50 changes: 26 additions & 24 deletions lib/fft.dx
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ import complex

'## Helper functions

def odd_sized_palindrome(mid:a, seq:n=>a) -> ((n `Either` () `Either` n)=>a) given (a, n|Ix) =
def odd_sized_palindrome(mid:a, seq:n=>a) -> ((n `Either` () `Either` n)=>a) given (a:Type, n|Ix) =
# Turns sequence 12345 into 543212345.
for i.
case i of
Expand All @@ -33,11 +33,11 @@ def butterfly_ixs(j':halfn, pow2:Nat) -> (n, n, n, n) given (halfn|Ix, n|Ix) =
# Note: with fancier index sets, this might be replacable by reshapes.
j = ordinal j'
k = ((idiv j pow2) * pow2 * 2) + mod j pow2
left_write_ix = unsafe_from_ordinal k
right_write_ix = unsafe_from_ordinal (k + pow2)
left_write_ix : n = unsafe_from_ordinal k
right_write_ix : n = unsafe_from_ordinal (k + pow2)

left_read_ix = unsafe_from_ordinal j
right_read_ix = unsafe_from_ordinal (j + size halfn)
left_read_ix : n = unsafe_from_ordinal j
right_read_ix : n = unsafe_from_ordinal (j + size halfn)
(left_read_ix, right_read_ix, left_write_ix, right_write_ix)

def power_of_2_fft(
Expand All @@ -59,8 +59,9 @@ def power_of_2_fft(
log2_half_n = unsafe_nat_diff log2_n 1 # TODO: use `i` as a proof that log2_n > 0
xRef := yield_accum (AddMonoid Complex) \bufRef.
for j:((Fin log2_half_n)=>(Fin 2)). # Executes in parallel.
t = (Fin log2_n) => Fin 2
(left_read_ix, right_read_ix,
left_write_ix, right_write_ix) = butterfly_ixs j ipow2
left_write_ix, right_write_ix) : (t, t, t, t) = butterfly_ixs j ipow2

# Read one element from the last buffer, scaled.
angle = dir_const * (n_to_f $ mod (ordinal j) ipow2) / n_to_f ipow2
Expand All @@ -78,7 +79,7 @@ def power_of_2_fft(
def pad_to_power_of_2(
log2_m:Nat,
pad_val:a, xs:n=>a
) -> ((Fin log2_m)=>(Fin 2))=>a given (a, n|Ix) =
) -> ((Fin log2_m)=>(Fin 2))=>a given (a:Type, n|Ix) =
flatsize = intpow2 log2_m
padded_flat = pad_to (Fin flatsize) pad_val xs
unsafe_cast_table(to=(Fin log2_m)=>(Fin 2), padded_flat)
Expand All @@ -91,21 +92,22 @@ def convolve_complex(
# Pad and convert to Fourier domain.
min_convolve_size = (size n + size m) -| 1
log_working_size = nextpow2 min_convolve_size
u_padded = pad_to_power_of_2 log_working_size zero u
v_padded = pad_to_power_of_2 log_working_size zero v
sn = size n
u_padded = pad_to_power_of_2 log_working_size (zero::Complex) u
v_padded = pad_to_power_of_2 log_working_size (zero::Complex) v
spectral_u = power_of_2_fft ForwardFT u_padded
spectral_v = power_of_2_fft ForwardFT v_padded

# Pointwise multiply.
spectral_conv = for i. spectral_u[i] * spectral_v[i]
spectral_conv = for i:(Fin log_working_size)=>(Fin 2). spectral_u[i] * spectral_v[i]

# Convert back to primal domain and undo padding.
padded_conv = power_of_2_fft InverseFT spectral_conv
slice padded_conv 0 (Either n m)

def convolve(u:n=>Float, v:m=>Float) -> (Either n m =>Float) given (n|Ix, m|Ix) =
u' = for i. Complex u[i] 0.0
v' = for i. Complex v[i] 0.0
u' = for i:n. Complex u[i] 0.0
v' = for i:m. Complex v[i] 0.0
ans = convolve_complex u' v'
for i. ans[i].re

Expand All @@ -114,14 +116,14 @@ def bluestein(x: n=>Complex) -> n=>Complex given (n|Ix) =
# Converts the general FFT into a convolution,
# which is then solved with calls to a power-of-2 FFT.
im = Complex 0.0 1.0
wks = for i.
wks = for i:n.
i_squared = n_to_f $ sq $ ordinal i
exp $ (-im) * (Complex (pi * i_squared / (n_to_f (size n))) 0.0)

AsList(_, tailTable) = tail wks 1
back_and_forth = odd_sized_palindrome (head wks) tailTable
xq = for i. x[i] * wks[i]
back_and_forth_conj = for i. complex_conj back_and_forth[i]
xq = for i:n. x[i] * wks[i]
back_and_forth_conj = each back_and_forth complex_conj
convolution = convolve_complex xq back_and_forth_conj
convslice = slice convolution (unsafe_nat_diff (size n) 1) n
for i. wks[i] * convslice[i]
Expand All @@ -147,19 +149,19 @@ def ifft(xs: n=>Complex) -> n=>Complex given (n|Ix) =
ret = power_of_2_fft InverseFT castx
unsafe_cast_table(to=n, ret)
else
unscaled_fft = fft (for i. complex_conj xs[i])
unscaled_fft = fft (each xs complex_conj)
for i. (complex_conj unscaled_fft[i]) / (n_to_f (size n))

def fft_real(x: n=>Float) -> n=>Complex given (n|Ix) = fft for i. Complex x[i] 0.0
def ifft_real(x: n=>Float) -> n=>Complex given (n|Ix) = ifft for i. Complex x[i] 0.0
def fft_real(x: n=>Float) -> n=>Complex given (n|Ix) = fft for i:n. Complex x[i] 0.0
def ifft_real(x: n=>Float) -> n=>Complex given (n|Ix) = ifft for i:n. Complex x[i] 0.0

def fft2(x: n=>m=>Complex) -> n=>m=>Complex given (n|Ix, m|Ix) =
x' = for i. fft x[i]
transpose for i. fft (transpose x')[i]
x' = for i:n. fft x[i]
transpose for i:m. fft (transpose x')[i]

def ifft2(x: n=>m=>Complex) -> n=>m=>Complex given (n|Ix, m|Ix) =
x' = for i. ifft x[i]
transpose for i. ifft (transpose x')[i]
x' = for i:n. ifft x[i]
transpose for i:m. ifft (transpose x')[i]

def fft2_real(x: n=>m=>Float) -> n=>m=>Complex given (n|Ix, m|Ix) = fft2 for i j. Complex x[i,j] 0.0
def ifft2_real(x: n=>m=>Float) -> n=>m=>Complex given (n|Ix, m|Ix) = ifft2 for i j. Complex x[i,j] 0.0
def fft2_real(x: n=>m=>Float) -> n=>m=>Complex given (n|Ix, m|Ix) = fft2 for i:n j:m. Complex x[i,j] 0.0
def ifft2_real(x: n=>m=>Float) -> n=>m=>Complex given (n|Ix, m|Ix) = ifft2 for i:n j:m. Complex x[i,j] 0.0
2 changes: 1 addition & 1 deletion tests/fft-tests.dx
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import complex
import fft

:p map nextpow2 [0, 1, 2, 3, 4, 7, 8, 9, 1023, 1024, 1025]
:p each [0, 1, 2, 3, 4, 7, 8, 9, 1023, 1024, 1025] nextpow2
> [0, 0, 1, 2, 2, 3, 3, 4, 10, 10, 11]

a : (Fin 4)=>Complex = arb $ new_key 0
Expand Down
Loading