Skip to content

Commit a6093e7

Browse files
committed
fix: Validate input vector length in FFT function to ensure it is a power of two
1 parent add9355 commit a6093e7

File tree

1 file changed

+6
-3
lines changed

1 file changed

+6
-3
lines changed

weilycoder/poly/fft.hpp

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
#include "fft_utility.hpp"
55
#include <complex>
66
#include <cstddef>
7+
#include <stdexcept>
78
#include <vector>
89

910
/**
@@ -23,17 +24,19 @@ template <int32_t on = 1, typename float_t = double>
2324
void fft(std::vector<std::complex<float_t>> &y) {
2425
static_assert(on == 1 || on == -1, "on must be 1 or -1");
2526
fft_change(y);
26-
for (size_t h = 2; h <= y.size(); h <<= 1) {
27+
size_t len = y.size();
28+
if (len == 0 || (len & (len - 1)) != 0)
29+
throw std::invalid_argument("Length of input vector must be a power of two");
30+
for (size_t h = 2; h <= len; h <<= 1) {
2731
std::complex<float_t> wn(cos(2 * PI<float_t> / h), sin(on * 2 * PI<float_t> / h));
28-
for (size_t j = 0; j < y.size(); j += h) {
32+
for (size_t j = 0; j < len; j += h) {
2933
std::complex<float_t> w(1, 0);
3034
for (size_t k = j; k < j + (h >> 1); ++k, w *= wn) {
3135
std::complex<float_t> u = y[k], t = w * y[k + (h >> 1)];
3236
y[k] = u + t, y[k + (h >> 1)] = u - t;
3337
}
3438
}
3539
}
36-
size_t len = y.size();
3740
if constexpr (on == -1)
3841
for (size_t i = 0; i < len; ++i)
3942
y[i] /= len;

0 commit comments

Comments
 (0)