forked from JuliaMath/FFTA.jl
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathplan.jl
More file actions
469 lines (413 loc) · 15.9 KB
/
plan.jl
File metadata and controls
469 lines (413 loc) · 15.9 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
# Plans
abstract type FFTAPlan{T,N} <: AbstractFFTs.Plan{T} end
struct FFTAInvPlan{T,N} <: FFTAPlan{T,N} end
struct FFTAPlan_cx{T,N,R<:Union{Int,AbstractVector{Int}}} <: FFTAPlan{T,N}
callgraph::NTuple{N,CallGraph{T}}
region::R
dir::Direction
pinv::FFTAInvPlan{T,N}
end
function FFTAPlan_cx{T,N}(
cg::NTuple{N,CallGraph{T}}, r::R,
dir::Direction, pinv::FFTAInvPlan{T,N}
) where {T,N,R<:Union{Int,AbstractVector{Int}}}
FFTAPlan_cx{T,N,R}(cg, r, dir, pinv)
end
struct FFTAPlan_re{T,N,R<:Union{Int,AbstractVector{Int}}} <: FFTAPlan{T,N}
callgraph::NTuple{N,CallGraph{T}}
region::R
dir::Direction
pinv::FFTAInvPlan{T,N}
flen::Int
end
function FFTAPlan_re{T,N}(
cg::NTuple{N,CallGraph{T}}, r::R,
dir::Direction, pinv::FFTAInvPlan{T,N}, flen::Int
) where {T,N,R<:Union{Int,AbstractVector{Int}}}
FFTAPlan_re{T,N,R}(cg, r, dir, pinv, flen)
end
function Base.size(p::FFTAPlan{<:Any,N}, i::Int) where N
if i < 1
throw(DomainError(i, "No non-positive dimensions"))
elseif i > N
1
elseif p isa FFTAPlan_re && i == 1
p.flen
else
first(p.callgraph[i].nodes).sz
end
end
Base.size(p::FFTAPlan{<:Any,N}) where N = ntuple(Base.Fix1(size, p), Val{N}())
Base.complex(p::FFTAPlan_re{T,N,R}) where {T,N,R} = FFTAPlan_cx{T,N,R}(p.callgraph, p.region, p.dir, p.pinv)
AbstractFFTs.plan_fft(x::AbstractArray{T,N}, region::R; kwargs...) where {T<:Complex,N,R} =
_plan_fft(x, region, FFT_FORWARD; kwargs...)
AbstractFFTs.plan_bfft(x::AbstractArray{T,N}, region::R; kwargs...) where {T<:Complex,N,R} =
_plan_fft(x, region, FFT_BACKWARD; kwargs...)
function _plan_fft(x::AbstractArray{T,N}, region::R, dir::Direction; BLUESTEIN_CUTOFF=DEFAULT_BLUESTEIN_CUTOFF, _kwargs...) where {T<:Complex,N,R}
FFTN = length(region)
if FFTN == 1
R1 = Int(region[])
g = CallGraph{T}(size(x, R1), BLUESTEIN_CUTOFF)
pinv = FFTAInvPlan{T,1}()
return FFTAPlan_cx{T,1,Int}((g,), R1, dir, pinv)
elseif FFTN == 2
sort!(region)
g1 = CallGraph{T}(size(x, region[1]), BLUESTEIN_CUTOFF)
g2 = CallGraph{T}(size(x, region[2]), BLUESTEIN_CUTOFF)
pinv = FFTAInvPlan{T,2}()
return FFTAPlan_cx{T,2,R}((g1, g2), region, dir, pinv)
else
sort!(region)
return FFTAPlan_cx{T,FFTN,R}(
ntuple(i -> CallGraph{T}(size(x, region[i]), BLUESTEIN_CUTOFF), Val(FFTN)),
region, dir, FFTAInvPlan{T,FFTN}()
)
end
end
function AbstractFFTs.plan_rfft(x::AbstractArray{T,N}, region::R; BLUESTEIN_CUTOFF=DEFAULT_BLUESTEIN_CUTOFF, _kwargs...) where {T<:Real,N,R}
FFTN = length(region)
if FFTN == 1
R1 = Int(region[])
n = size(x, R1)
# For even length problems, we solve the real problem with
# two n/2 complex FFTs followed by a butterfly. For odd size
# problems, we just solve the problem as a single complex
nn = iseven(n) ? n >> 1 : n
g = CallGraph{Complex{T}}(nn, BLUESTEIN_CUTOFF)
pinv = FFTAInvPlan{Complex{T},1}()
return FFTAPlan_re{Complex{T},1,Int}((g,), R1, FFT_FORWARD, pinv, n)
elseif FFTN == 2
sort!(region)
g1 = CallGraph{Complex{T}}(size(x, region[1]), BLUESTEIN_CUTOFF)
g2 = CallGraph{Complex{T}}(size(x, region[2]), BLUESTEIN_CUTOFF)
pinv = FFTAInvPlan{Complex{T},2}()
return FFTAPlan_re{Complex{T},2,R}((g1, g2), region, FFT_FORWARD, pinv, size(x, region[1]))
else
throw(ArgumentError("only supports 1D and 2D FFTs"))
end
end
function AbstractFFTs.plan_brfft(x::AbstractArray{T,N}, len, region::R; BLUESTEIN_CUTOFF=DEFAULT_BLUESTEIN_CUTOFF, _kwargs...) where {T,N,R}
FFTN = length(region)
if FFTN == 1
# For even length problems, we solve the real problem with
# two n/2 complex FFTs followed by a butterfly. For odd size
# problems, we just solve the problem as a single complex
R1 = Int(region[])
nn = iseven(len) ? len >> 1 : len
g = CallGraph{T}(nn, BLUESTEIN_CUTOFF)
pinv = FFTAInvPlan{T,1}()
return FFTAPlan_re{T,1,Int}((g,), R1, FFT_BACKWARD, pinv, len)
elseif FFTN == 2
sort!(region)
g1 = CallGraph{T}(len, BLUESTEIN_CUTOFF)
g2 = CallGraph{T}(size(x, region[2]), BLUESTEIN_CUTOFF)
pinv = FFTAInvPlan{T,2}()
return FFTAPlan_re{T,2,R}((g1, g2), region, FFT_BACKWARD, pinv, len)
else
throw(ArgumentError("only supports 1D and 2D FFTs"))
end
end
# Multiplication
## mul!
### Complex
#### 1D plan 1D array
function LinearAlgebra.mul!(y::AbstractVector{U}, p::FFTAPlan_cx{T,1}, x::AbstractVector{T}) where {T,U}
Base.require_one_based_indexing(x, y)
if axes(x) != axes(y)
throw(DimensionMismatch("input array has axes $(axes(x)), but output array has axes $(axes(y))"))
end
if size(p) != size(x)
throw(DimensionMismatch("plan has axes $(size(p)), but input array has axes $(size(x))"))
end
fft!(y, x, 1, 1, p.dir, p.callgraph[1][1].type, p.callgraph[1], 1)
return y
end
#### 1D plan ND array
function LinearAlgebra.mul!(y::AbstractArray{U,N}, p::FFTAPlan_cx{T,1}, x::AbstractArray{T,N}) where {T,U,N}
Base.require_one_based_indexing(x, y)
if axes(x) != axes(y)
throw(DimensionMismatch("input array has axes $(axes(x)), but output array has axes $(axes(y))"))
end
if size(p, 1) != size(x, p.region[])
throw(DimensionMismatch("plan has size $(size(p, 1)), but input array has size $(size(x, p.region[])) along region $(p.region[])"))
end
Rpre = CartesianIndices(size(x)[1:p.region[]-1])
Rpost = CartesianIndices(size(x)[p.region[]+1:end])
_mul_loop!(y, x, Rpre, Rpost, p)
return y
end
function _mul_loop!(
y::AbstractArray{U,N},
x::AbstractArray{T,N},
Rpre::CartesianIndices,
Rpost::CartesianIndices,
p::FFTAPlan_cx{T,1}) where {T,U,N}
for Ipost in Rpost, Ipre in Rpre
@views fft!(y[Ipre,:,Ipost], x[Ipre,:,Ipost], 1, 1, p.dir, p.callgraph[1][1].type, p.callgraph[1], 1)
end
end
#### ND plan ND array
@generated function LinearAlgebra.mul!(
out::AbstractArray{U,N},
p::FFTAPlan_cx{T,N},
X::AbstractArray{T,N}
) where {T,U,N}
quote
Base.require_one_based_indexing(out, X)
if size(out) != size(X)
throw(DimensionMismatch("input array has axes $(axes(X)), but output array has axes $(axes(out))"))
elseif size(p) != size(X)
throw(DimensionMismatch("plan has size $(size(p)), but input array has size $(size(X))"))
elseif !(p.region == N || p.region == 1:N)
throw(DimensionMismatch("Plan region is outside array dimensions."))
end
sz = size(X)
max_sz = maximum(sz)
obuf = Vector{T}(undef, max_sz)
ibuf = Vector{T}(undef, max_sz)
sizehint!(obuf, max_sz) # not guaranteed but hopefully prevents allocations
sizehint!(ibuf, max_sz)
dir = p.dir
copyto!(out, X) # operate in-place on output array
Base.Cartesian.@nexprs $N dim -> begin
n = size(out, dim)
resize!(obuf, n)
resize!(ibuf, n)
cg = p.callgraph[dim]
Rpre_{dim} = CartesianIndices(sz[1:dim-1])
Rpost_{dim} = CartesianIndices(sz[dim+1:N])
fft_along_dim!(out, ibuf, obuf, cg, dir, Rpre_{dim}, Rpost_{dim})
end
return out
end
end
#### MD plan ND array (M<N)
function LinearAlgebra.mul!(
out::AbstractArray{U,N},
p::FFTAPlan_cx{T,M},
X::AbstractArray{T,N}
) where {T,U,N,M}
Base.require_one_based_indexing(out, X)
if size(out) != size(X)
throw(DimensionMismatch("input array has axes $(axes(X)), but output array has axes $(axes(out))"))
elseif M > N || first(p.region) < 1 || last(p.region) > N
throw(DimensionMismatch("Plan region is outside array dimensions."))
end
sz = size(X)
max_sz = maximum(Base.Fix1(size, out), p.region)
obuf = Vector{T}(undef, max_sz)
ibuf = Vector{T}(undef, max_sz)
sizehint!(obuf, max_sz) # not guaranteed but hopefully prevents allocations
sizehint!(ibuf, max_sz)
dir = p.dir
copyto!(out, X) # operate in-place on output array
# don't use generated functions because this cannot be type-stable anyway
for dim in 1:M
pdim = p.region[dim]
n = size(out, pdim)
resize!(obuf, n)
resize!(ibuf, n)
cg = p.callgraph[dim]
Rpre = CartesianIndices(sz[1:pdim-1])
Rpost = CartesianIndices(sz[pdim+1:N])
fft_along_dim!(out, ibuf, obuf, cg, dir, Rpre, Rpost)
end
return out
end
function fft_along_dim!(
A::AbstractArray,
ibuf::Vector{T}, obuf::Vector{T},
cg::CallGraph{T}, d::Direction,
Rpre::CartesianIndices{M}, Rpost::CartesianIndices
) where {T <: Complex{<:AbstractFloat}, M}
t = cg[1].type
dim = M + 1
cols = eachindex(axes(A, dim), ibuf, obuf)
for Ipost in Rpost, Ipre in Rpre
for j in cols
ibuf[j] = A[Ipre, j, Ipost]
end
fft!(obuf, ibuf, 1, 1, d, t, cg, 1)
for j in cols
A[Ipre, j, Ipost] = obuf[j]
end
end
end
## *
### Complex
function Base.:*(p::FFTAPlan_cx{T,1}, x::AbstractVector{T}) where {T<:Complex}
y = similar(x)
LinearAlgebra.mul!(y, p, x)
y
end
function Base.:*(p::FFTAPlan_cx{T,N1}, x::AbstractArray{T,N2}) where {T<:Complex,N1,N2}
y = similar(x)
LinearAlgebra.mul!(y, p, x)
y
end
### Real
# By converting the problem to complex and back to real
#### 1D plan 1D array
##### Forward
function Base.:*(p::FFTAPlan_re{Complex{T},1}, x::AbstractVector{T}) where {T<:Real}
if p.dir !== FFT_FORWARD
throw(ArgumentError("only FFT_FORWARD supported for real vectors"))
end
Base.require_one_based_indexing(x)
n = p.flen
if iseven(n)
# For problems of even size, we solve the rfft problem by splitting the
# problem into the even and odd part and solving the simultanously as
# a single (complex) fft of half the size, see equations (6)-(8) of
# Sorensen, H. V., D. Jones, Michael Heideman, and C. Burrus.
# "Real-valued fast Fourier transform algorithms."
# IEEE Transactions on acoustics, speech, and signal processing 35, no. 6 (2003): 849-863.
if x isa Vector && isbitstype(T)
# For a vector of bits, we can just reintepret the bits to get the
# approciate representation of even (zero based) elements as the real
# part and the odd as the complex part
x_c = reinterpret(Complex{T}, x)
else
# for non-bits, we'd have to copy to a new array
x_c = complex.(view(x, 1:2:n), view(x, 2:2:n))
end
m = n >> 1
# Allocate complex result vector of half the input size plus one
y = similar(x_c, m + 1)
# Solve the complex fft of half the size
LinearAlgebra.mul!(view(y, 1:m), complex(p), x_c)
# The w stored in the plan is for m, not n, so probably cheapest to
# just recompute it instead of taking a square root
wj = w = cispi(-T(2) / n)
# Construct the result by first constructing the elements of the
# real and imaginary part, followed by the usual radix-2 assembly,
# see eq (9)
y1 = y[1]
y[1] = real(y1) + imag(y1)
y[end] = real(y1) - imag(y1)
@inbounds for j in 2:((m >> 1) + 1)
yj = y[j]
ymj = y[m-j+2]
XX = T(0.5) * ( yj + conj(ymj))
XY = T(0.5) * (-yj + conj(ymj)) * im
y[j] = XX + wj * XY
y[m-j+2] = conj(XX - wj * XY)
wj *= w
end
return y
else
# when the problem cannot be split in two equal size chunks we
# convert the problem to a complex fft and truncate the redundant
# part of the result vector
x_c = similar(x, Complex{T})
y = similar(x_c)
copyto!(x_c, x)
LinearAlgebra.mul!(y, complex(p), x_c)
return y[1:end÷2+1]
end
end
##### Backward
function Base.:*(p::FFTAPlan_re{T,1}, x::AbstractVector{T}) where {T<:Complex}
if p.dir !== FFT_BACKWARD
throw(ArgumentError("only FFT_BACKWARD supported for complex vectors"))
end
Base.require_one_based_indexing(x)
n = p.flen
# See explanation of this approach in the method for the FORWARD transform
if iseven(n)
m = n >> 1
wj = w = cispi(T(2) / n)
x_tmp = similar(x, length(x) - 1)
x_tmp[1] = complex(
(real(x[1]) + real(x[end])),
(real(x[1]) - real(x[end]))
)
for j in 2:((m >> 1) + 1)
XX = x[j] + conj(x[m-j+2])
XY = wj * (x[j] - conj(x[m-j+2]))
x_tmp[j] = XX + im * XY
x_tmp[m-j+2] = conj(XX - im * XY)
wj *= w
end
y_c = complex(p) * x_tmp
if isbitstype(T)
return copy(reinterpret(real(T), y_c))
else
return mapreduce(t -> [real(t); imag(t)], vcat, y_c)
end
else
x_tmp = similar(x, n)
x_tmp[1:end÷2+1] .= x
x_tmp[end÷2+2:end] .= iseven(n) ? conj.(x[end-1:-1:2]) : conj.(x[end:-1:2])
y = similar(x_tmp)
LinearAlgebra.mul!(y, complex(p), x_tmp)
return real(y)
end
end
#### 1D plan ND array
##### Forward
function Base.:*(p::FFTAPlan_re{Complex{T},1}, x::AbstractArray{T,N}) where {T<:Real,N}
if p.dir !== FFT_FORWARD
throw(ArgumentError("only FFT_FORWARD supported for real arrays"))
end
Base.require_one_based_indexing(x)
return mapslices(Base.Fix1(*, p), x; dims=only(p.region))
end
##### Backward
function Base.:*(p::FFTAPlan_re{T,1}, x::AbstractArray{T,N}) where {T<:Complex,N}
if p.dir !== FFT_BACKWARD
throw(ArgumentError("only FFT_BACKWARD supported for complex arrays"))
end
Base.require_one_based_indexing(x)
dim1 = only(p.region)
rlen = p.flen ÷ 2 + 1
if rlen != size(x, dim1)
throw(DimensionMismatch("real 1D plan has size $(p.flen). Dimension of input array along region $dim1 should have size $rlen, but has size $(size(x, dim1))"))
end
return mapslices(Base.Fix1(*, p), x; dims=dim1)
end
#### 2D plan ND array
##### Forward
function Base.:*(p::FFTAPlan_re{Complex{T},2}, x::AbstractArray{T,N}) where {T<:Real,N}
if p.dir !== FFT_FORWARD
throw(ArgumentError("only FFT_FORWARD supported for real arrays"))
end
Base.require_one_based_indexing(x)
half_1 = 1:(p.flen÷2+1)
x_c = complex(x)
y = similar(x_c)
LinearAlgebra.mul!(y, complex(p), x_c)
return copy(selectdim(y, first(p.region), half_1))
end
##### Backward
function Base.:*(p::FFTAPlan_re{T,2}, x::AbstractArray{T,N}) where {T<:Complex,N}
if p.dir !== FFT_BACKWARD
throw(ArgumentError("only FFT_BACKWARD supported for complex arrays"))
end
Base.require_one_based_indexing(x)
dim1 = first(p.region)
dim2 = last(p.region)
x_sz = (xrows, xcols) = (size(x, dim1), size(x, dim2))
flen = p.flen
tlen = flen ÷ 2 + 1
t_sz = (tlen, size(p, 2))
if t_sz != x_sz
throw(DimensionMismatch("real 2D plan has size $(size(p)). Transform dimensions of input array are $x_sz but should be $t_sz"))
end
res_size = ntuple(i -> ifelse(i == dim1, flen, size(x, i)), Val(N))
# for the inverse transformation we have to reconstruct the full array
half_1 = 1:tlen
half_2 = tlen+1:flen
x_full = similar(x, res_size)
# use first half as is
copy!(selectdim(x_full, dim1, half_1), x)
# the second half in the first transform dimension is reversed and conjugated
x_half_2 = selectdim(x_full, dim1, half_2) # view to the second half of x
start_reverse = xrows - iseven(flen)
map!(conj, x_half_2, selectdim(x, dim1, start_reverse:-1:2))
# for the 2D transform we have to reverse index 2:end of the same block in the second transform dimension as well
reverse!(selectdim(x_half_2, dim2, 2:xcols), dims=dim2)
y = similar(x_full)
LinearAlgebra.mul!(y, complex(p), x_full)
return real(y)
end