Skip to content

Commit 7c360c6

Browse files
authored
dont cache FFT of size 3, update SmallFft class (#75)
1 parent fa61aa0 commit 7c360c6

File tree

5 files changed

+104
-53
lines changed

5 files changed

+104
-53
lines changed

CMakeLists.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
cmake_minimum_required(VERSION 3.10)
2-
project(dsplib LANGUAGES CXX VERSION 0.54.6)
2+
project(dsplib LANGUAGES CXX VERSION 0.54.7)
33

44
set(CMAKE_CXX_STANDARD 17)
55
set(CMAKE_CXX_STANDARD_REQUIRED ON)

lib/fft/factory.cpp

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -46,8 +46,8 @@ std::shared_ptr<BaseFftPlanR> _get_rfft_plan(int n) {
4646
//-------------------------------------------------------------------------------------------------
4747
std::shared_ptr<BaseFftPlanC> create_fft_plan(int n) {
4848
//dont cache small fft plans
49-
if ((n == 1) || (n == 2) || (n == 4) || (n == 8)) {
50-
return std::make_shared<SmallFftPow2C>(n);
49+
if (SmallFftC::is_supported(n)) {
50+
return std::make_shared<SmallFftC>(n);
5151
}
5252

5353
//TODO: use weak_ptr cache to prevent duplication
@@ -61,8 +61,8 @@ std::shared_ptr<BaseFftPlanC> create_fft_plan(int n) {
6161
}
6262

6363
std::shared_ptr<BaseFftPlanR> create_rfft_plan(int n) {
64-
if ((n == 1) || (n == 2) || (n == 4) || (n == 8)) {
65-
return std::make_shared<SmallFftPow2R>(n);
64+
if (SmallFftR::is_supported(n)) {
65+
return std::make_shared<SmallFftR>(n);
6666
}
6767

6868
thread_local LRUCache<int, std::shared_ptr<BaseFftPlanR>> cache{FFT_CACHE_SIZE};

lib/fft/primes-fft.h

Lines changed: 3 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@ class PrimesFftC : public BaseFftPlanC
2020
explicit PrimesFftC(int n)
2121
: n_{n} {
2222
DSPLIB_ASSERT(isprime(n_), "`n` must be a prime number");
23-
DSPLIB_ASSERT(n_ >= 3, "`n` must be greater than or equal to 3");
23+
DSPLIB_ASSERT(n_ >= 5, "`n` must be greater than or equal to 5");
2424
if (n > MAX_DFT_SIZE) {
2525
const cmplx_t w = expj(-2 * pi / n);
2626
czt_ = std::make_shared<CztPlan>(n, n, w);
@@ -45,31 +45,10 @@ class PrimesFftC : public BaseFftPlanC
4545
}
4646

4747
private:
48-
static void _dft_n3(const cmplx_t* restrict x, cmplx_t* restrict y) noexcept {
49-
constexpr real_t c = -0.5;
50-
constexpr real_t d = 0.866025403784439;
51-
52-
y[0].re = x[0].re + x[1].re + x[2].re;
53-
y[0].im = x[0].im + x[1].im + x[2].im;
54-
55-
const real_t re1_c = x[1].re * c;
56-
const real_t im1_d = x[1].im * d;
57-
const real_t re2_c = x[2].re * c;
58-
const real_t im2_d = x[2].im * d;
59-
y[1].re = x[0].re + (re1_c + im1_d) + (re2_c - im2_d);
60-
y[2].re = x[0].re + (re1_c - im1_d) + (re2_c + im2_d);
61-
62-
const real_t re1_d = x[1].re * d;
63-
const real_t im1_c = x[1].im * c;
64-
const real_t re2_d = x[2].re * d;
65-
const real_t im2_c = x[2].im * c;
66-
y[1].im = x[0].im + (-re1_d + im1_c) + (re2_d + im2_c);
67-
y[2].im = x[0].im + (re1_d + im1_c) + (-re2_d + im2_c);
68-
}
69-
7048
//TODO: add dft5, dft7
7149

72-
static void _dft_slow(const cmplx_t* restrict x, cmplx_t* restrict y, uint32_t n, const cmplx_t* restrict tw) noexcept {
50+
static void _dft_slow(const cmplx_t* restrict x, cmplx_t* restrict y, uint32_t n,
51+
const cmplx_t* restrict tw) noexcept {
7352
DSPLIB_ASSUME(n <= MAX_DFT_SIZE);
7453
std::memset(reinterpret_cast<real_t*>(y), 0, n * sizeof(cmplx_t));
7554

@@ -90,11 +69,6 @@ class PrimesFftC : public BaseFftPlanC
9069
void _dft(const cmplx_t* restrict x, cmplx_t* restrict y, int n) const {
9170
assert(n == n_);
9271

93-
if (n == 3) {
94-
_dft_n3(x, y);
95-
return;
96-
}
97-
9872
if (n <= MAX_DFT_SIZE) {
9973
assert(!w_.empty());
10074
_dft_slow(x, y, n, w_.data());

lib/fft/small-fft.h

Lines changed: 65 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -9,17 +9,17 @@ namespace dsplib {
99

1010
//-------------------------------------------------------------------------------------------------------------
1111
//FFT implementation for small sizes
12-
class SmallFftPow2C : public BaseFftPlanC
12+
class SmallFftC : public BaseFftPlanC
1313
{
1414
public:
15-
friend class SmallFftPow2R;
15+
friend class SmallFftR;
1616

17-
explicit SmallFftPow2C(int n)
17+
explicit SmallFftC(int n)
1818
: n_{n} {
19-
DSPLIB_ASSERT((n >= 1) && (n <= 8), "only small power-of-two sizes are supported: 1, 2, 4, 8");
19+
DSPLIB_ASSERT(is_supported(n), "only small sizes are supported: 1, 2, 3, 4, 8");
2020
}
2121

22-
~SmallFftPow2C() override {
22+
~SmallFftC() override {
2323
}
2424

2525
void solve(const cmplx_t* x, cmplx_t* y, int n) const final {
@@ -31,6 +31,9 @@ class SmallFftPow2C : public BaseFftPlanC
3131
case 2:
3232
_fft_n2(x, y);
3333
break;
34+
case 3:
35+
_fft_n3(x, y);
36+
break;
3437
case 4:
3538
_fft_n4(x, y);
3639
break;
@@ -53,6 +56,10 @@ class SmallFftPow2C : public BaseFftPlanC
5356
return n_;
5457
}
5558

59+
static bool is_supported(int n) noexcept {
60+
return (n == 1 || n == 2 || n == 3 || n == 4 || n == 8);
61+
}
62+
5663
private:
5764
static void _fft_n2(const cmplx_t* restrict x, cmplx_t* restrict y) noexcept {
5865
y[0].re = x[0].re + x[1].re;
@@ -61,6 +68,28 @@ class SmallFftPow2C : public BaseFftPlanC
6168
y[1].im = x[0].im - x[1].im;
6269
}
6370

71+
static void _fft_n3(const cmplx_t* restrict x, cmplx_t* restrict y) noexcept {
72+
constexpr real_t c = -0.5;
73+
constexpr real_t d = 0.866025403784439;
74+
75+
y[0].re = x[0].re + x[1].re + x[2].re;
76+
y[0].im = x[0].im + x[1].im + x[2].im;
77+
78+
const real_t re1_c = x[1].re * c;
79+
const real_t im1_d = x[1].im * d;
80+
const real_t re2_c = x[2].re * c;
81+
const real_t im2_d = x[2].im * d;
82+
y[1].re = x[0].re + (re1_c + im1_d) + (re2_c - im2_d);
83+
y[2].re = x[0].re + (re1_c - im1_d) + (re2_c + im2_d);
84+
85+
const real_t re1_d = x[1].re * d;
86+
const real_t im1_c = x[1].im * c;
87+
const real_t re2_d = x[2].re * d;
88+
const real_t im2_c = x[2].im * c;
89+
y[1].im = x[0].im + (-re1_d + im1_c) + (re2_d + im2_c);
90+
y[2].im = x[0].im + (re1_d + im1_c) + (-re2_d + im2_c);
91+
}
92+
6493
static void _fft_n4(const cmplx_t* restrict x, cmplx_t* restrict y) noexcept {
6594
y[0].re = x[0].re + x[1].re + x[2].re + x[3].re;
6695
y[0].im = x[0].im + x[1].im + x[2].im + x[3].im;
@@ -101,15 +130,15 @@ class SmallFftPow2C : public BaseFftPlanC
101130
};
102131

103132
//-------------------------------------------------------------------------------------------------------------
104-
class SmallFftPow2R : public BaseFftPlanR
133+
class SmallFftR : public BaseFftPlanR
105134
{
106135
public:
107-
explicit SmallFftPow2R(int n)
136+
explicit SmallFftR(int n)
108137
: n_{n} {
109-
DSPLIB_ASSERT((n >= 1) && (n <= 8), "only small power-of-two sizes are supported: 1, 2, 4, 8");
138+
DSPLIB_ASSERT(is_supported(n), "only small sizes are supported: 1, 2, 3, 4, 8");
110139
}
111140

112-
~SmallFftPow2R() override {
141+
~SmallFftR() override {
113142
}
114143

115144
void solve(const real_t* x, cmplx_t* y, int n) const final {
@@ -121,6 +150,9 @@ class SmallFftPow2R : public BaseFftPlanR
121150
case 2:
122151
_fft_n2(x, y);
123152
break;
153+
case 3:
154+
_fft_n3(x, y);
155+
break;
124156
case 4:
125157
_fft_n4(x, y);
126158
break;
@@ -143,6 +175,10 @@ class SmallFftPow2R : public BaseFftPlanR
143175
return n_;
144176
}
145177

178+
static bool is_supported(int n) noexcept {
179+
return (n == 1 || n == 2 || n == 3 || n == 4 || n == 8);
180+
}
181+
146182
private:
147183
static void _fft_n2(const real_t* restrict x, cmplx_t* restrict y) noexcept {
148184
y[0].re = x[0] + x[1];
@@ -151,6 +187,24 @@ class SmallFftPow2R : public BaseFftPlanR
151187
y[1].im = 0;
152188
}
153189

190+
static void _fft_n3(const real_t* restrict x, cmplx_t* restrict y) noexcept {
191+
constexpr real_t c = -0.5;
192+
constexpr real_t d = 0.866025403784439;
193+
194+
y[0].re = x[0] + x[1] + x[2];
195+
y[0].im = 0;
196+
197+
const real_t re1_c = x[1] * c;
198+
const real_t re2_c = x[2] * c;
199+
y[1].re = x[0] + re1_c + re2_c;
200+
y[2].re = y[1].re;
201+
202+
const real_t re1_d = x[1] * d;
203+
const real_t re2_d = x[2] * d;
204+
y[1].im = (-re1_d) + (re2_d);
205+
y[2].im = -y[1].im;
206+
}
207+
154208
static void _fft_n4(const real_t* restrict x, cmplx_t* restrict y) noexcept {
155209
y[0].re = x[0] + x[1] + x[2] + x[3];
156210
y[0].im = 0;
@@ -169,7 +223,7 @@ class SmallFftPow2R : public BaseFftPlanR
169223
p1[1] = x[1] + x[5];
170224
p1[2] = x[2] + x[6];
171225
p1[3] = x[3] + x[7];
172-
SmallFftPow2R::_fft_n4(p1, r1);
226+
SmallFftR::_fft_n4(p1, r1);
173227

174228
cmplx_t p2[4];
175229
cmplx_t r2[4];
@@ -178,7 +232,7 @@ class SmallFftPow2R : public BaseFftPlanR
178232
p2[2].re = 0;
179233
p2[2].im = x[6] - x[2];
180234
p2[3] = (x[3] - x[7]) * cmplx_t{-0.707106781186548, -0.707106781186548};
181-
SmallFftPow2C::_fft_n4(p2, r2);
235+
SmallFftC::_fft_n4(p2, r2);
182236

183237
for (int i = 0; i < 4; ++i) {
184238
*y++ = r1[i];

tests/fft_test.cpp

Lines changed: 31 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -250,8 +250,8 @@ TEST(FFT, SmallFft) {
250250
using namespace std::complex_literals;
251251

252252
{
253-
SmallFftPow2R plan_r(1);
254-
SmallFftPow2C plan_c(1);
253+
SmallFftR plan_r(1);
254+
SmallFftC plan_c(1);
255255
arr_real x = {10};
256256
arr_cmplx ref = {10};
257257
auto y1 = plan_r.solve(x);
@@ -262,8 +262,8 @@ TEST(FFT, SmallFft) {
262262
ASSERT_EQ_ARR_CMPLX(y3, ref);
263263
}
264264
{
265-
SmallFftPow2R plan_r(2);
266-
SmallFftPow2C plan_c(2);
265+
SmallFftR plan_r(2);
266+
SmallFftC plan_c(2);
267267
arr_real x = {1, 2};
268268
arr_cmplx ref = {3, -1};
269269
auto y1 = plan_r.solve(x);
@@ -274,8 +274,31 @@ TEST(FFT, SmallFft) {
274274
ASSERT_EQ_ARR_CMPLX(y3, ref);
275275
}
276276
{
277-
SmallFftPow2R plan_r(4);
278-
SmallFftPow2C plan_c(4);
277+
SmallFftR plan_r(3);
278+
SmallFftC plan_c(3);
279+
arr_real x = {1, 2, 3};
280+
arr_cmplx ref = {6.00000000000000 + 0.00000000000000i, -1.50000000000000 + 0.866025403784439i,
281+
-1.50000000000000 - 0.866025403784439i};
282+
auto y1 = plan_r.solve(x);
283+
auto y2 = plan_c.solve(complex(x));
284+
auto y3 = fft(x);
285+
ASSERT_EQ_ARR_CMPLX(y1, ref);
286+
ASSERT_EQ_ARR_CMPLX(y2, ref);
287+
ASSERT_EQ_ARR_CMPLX(y3, ref);
288+
}
289+
{
290+
SmallFftC plan(3);
291+
arr_cmplx x = {1 + 1i, 2 + 2i, 3 - 3i};
292+
arr_cmplx ref = {6.00000000000000 + 0.00000000000000i, 2.83012701892219 + 2.36602540378444i,
293+
-5.83012701892219 + 0.633974596215561i};
294+
auto y1 = plan.solve(x);
295+
auto y2 = fft(x);
296+
ASSERT_EQ_ARR_CMPLX(y1, ref);
297+
ASSERT_EQ_ARR_CMPLX(y2, ref);
298+
}
299+
{
300+
SmallFftR plan_r(4);
301+
SmallFftC plan_c(4);
279302
arr_real x = {1, 2, 3, 4};
280303
arr_cmplx ref = {10.0000000000000 + 0.00000000000000i, -2.00000000000000 + 2.00000000000000i,
281304
-2.00000000000000 + 0.00000000000000i, -2.00000000000000 - 2.00000000000000i};
@@ -287,8 +310,8 @@ TEST(FFT, SmallFft) {
287310
ASSERT_EQ_ARR_CMPLX(y3, ref);
288311
}
289312
{
290-
SmallFftPow2R plan_r(8);
291-
SmallFftPow2C plan_c(8);
313+
SmallFftR plan_r(8);
314+
SmallFftC plan_c(8);
292315
arr_real x = {1, 2, 3, 4, 5, 6, 7, 8};
293316
arr_cmplx ref = {36.0000000000000 + 0.00000000000000i, -4.00000000000000 + 9.65685424949238i,
294317
-4.00000000000000 + 4.00000000000000i, -4.00000000000000 + 1.65685424949238i,

0 commit comments

Comments
 (0)