Skip to content

Commit ac1a6e0

Browse files
committed
feat: Implement NTT and convolution functions for polynomial multiplication
1 parent 84657b2 commit ac1a6e0

File tree

3 files changed

+118
-0
lines changed

3 files changed

+118
-0
lines changed

test/convolution_mod.test.cpp

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,23 @@
1+
#define PROBLEM "https://judge.yosupo.jp/problem/convolution_mod"
2+
3+
#include "../weilycoder/poly/ntt_convolve.hpp"
4+
#include <iostream>
5+
#include <vector>
6+
using namespace std;
7+
using namespace weilycoder;
8+
9+
int main() {
10+
cin.tie(nullptr)->sync_with_stdio(false);
11+
cin.exceptions(cin.failbit | cin.badbit);
12+
size_t n, m;
13+
cin >> n >> m;
14+
vector<uint64_t> a(n), b(m);
15+
for (size_t i = 0; i < n; ++i)
16+
cin >> a[i];
17+
for (size_t i = 0; i < m; ++i)
18+
cin >> b[i];
19+
auto c = ntt_convolve_32<998244353>(a, b);
20+
for (size_t i = 0; i < n + m - 1; ++i)
21+
cout << c[i] << " \n"[i + 1 == n + m - 1];
22+
return 0;
23+
}

weilycoder/poly/ntt.hpp

Lines changed: 64 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,64 @@
1+
#ifndef WEILYCODER_POLY_NTT_HPP
2+
#define WEILYCODER_POLY_NTT_HPP
3+
4+
#include "../number-theory/mod_utility.hpp"
5+
#include "../number-theory/primitive_root.hpp"
6+
#include "fft_utility.hpp"
7+
#include <cstdint>
8+
#include <vector>
9+
10+
namespace weilycoder {
11+
/**
12+
* @brief Number Theoretic Transform (NTT)
13+
* @tparam mod The prime modulus
14+
* @tparam inverse Whether to perform the inverse NTT
15+
* @tparam bit32 Whether to use 32-bit modular multiplication
16+
* @tparam root A primitive root modulo mod
17+
* @param y The input/output vector to be transformed
18+
*/
19+
template <uint64_t mod, bool inverse = false, bool bit32 = false,
20+
uint64_t root = prime_primitive_root<mod>()>
21+
void ntt(std::vector<uint64_t> &y) {
22+
static_assert(is_prime(mod), "mod must be a prime");
23+
fft_change(y);
24+
size_t len = y.size();
25+
if (len == 0 || (len & (len - 1)) != 0)
26+
throw std::invalid_argument("Length of input vector must be a power of two");
27+
if ((mod - 1) % len != 0)
28+
throw std::invalid_argument(
29+
"mod - 1 must be divisible by the length of input vector");
30+
constexpr uint64_t g = inverse ? mod_pow<bit32>(root, mod - 2, mod) : root;
31+
for (size_t h = 2; h <= len; h <<= 1) {
32+
uint64_t wn = mod_pow<bit32>(g, (mod - 1) / h, mod);
33+
for (size_t j = 0; j < len; j += h) {
34+
uint64_t w = 1;
35+
for (size_t k = j; k < j + (h >> 1); ++k) {
36+
uint64_t u = y[k];
37+
uint64_t t = mod_mul<bit32>(w, y[k + (h >> 1)], mod);
38+
y[k] = mod_add<bit32>(u, t, mod);
39+
y[k + (h >> 1)] = mod_sub<bit32>(u, t, mod);
40+
w = mod_mul<bit32>(w, wn, mod);
41+
}
42+
}
43+
}
44+
if constexpr (inverse) {
45+
uint64_t inv_len = mod_pow<bit32>(len, mod - 2, mod);
46+
for (size_t i = 0; i < len; ++i)
47+
y[i] = mod_mul<bit32>(y[i], inv_len, mod);
48+
}
49+
}
50+
51+
/**
52+
* @brief Number Theoretic Transform (NTT) using 32-bit modular multiplication
53+
* @tparam mod The prime modulus
54+
* @tparam inverse Whether to perform the inverse NTT
55+
* @tparam root A primitive root modulo mod
56+
* @param y The input/output vector to be transformed
57+
*/
58+
template <uint64_t mod, bool inverse = false, uint64_t root = prime_primitive_root(mod)>
59+
void ntt_32(std::vector<uint64_t> &y) {
60+
ntt<mod, inverse, true, root>(y);
61+
}
62+
} // namespace weilycoder
63+
64+
#endif

weilycoder/poly/ntt_convolve.hpp

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,31 @@
1+
#ifndef WEILYCODER_POLY_NTT_CONVOLVE_HPP
2+
#define WEILYCODER_POLY_NTT_CONVOLVE_HPP
3+
4+
#include "ntt.hpp"
5+
#include <cstdint>
6+
7+
namespace weilycoder {
8+
template <uint64_t mod, bool bit32 = false, uint64_t root = prime_primitive_root<mod>()>
9+
std::vector<uint64_t> ntt_convolve(std::vector<uint64_t> a, std::vector<uint64_t> b) {
10+
size_t n = 1;
11+
while (n < a.size() + b.size() - 1)
12+
n <<= 1;
13+
a.resize(n, 0);
14+
b.resize(n, 0);
15+
ntt<mod, false, bit32, root>(a);
16+
ntt<mod, false, bit32, root>(b);
17+
for (size_t i = 0; i < n; ++i)
18+
a[i] = mod_mul<bit32>(a[i], b[i], mod);
19+
ntt<mod, true, bit32, root>(a);
20+
a.resize(a.size() + b.size() - 1);
21+
return a;
22+
}
23+
24+
template <uint64_t mod, uint64_t root = prime_primitive_root<mod>()>
25+
std::vector<uint64_t> ntt_convolve_32(std::vector<uint64_t> a,
26+
std::vector<uint64_t> b) {
27+
return ntt_convolve<mod, true, root>(a, b);
28+
}
29+
} // namespace weilycoder
30+
31+
#endif

0 commit comments

Comments
 (0)