Skip to content

Commit 95415b5

Browse files
committed
feat: Implement Karatsuba multiplication algorithm and add test cases
1 parent 9155c79 commit 95415b5

File tree

2 files changed

+133
-0
lines changed

2 files changed

+133
-0
lines changed

test/convolution_mod_2_64.test.cpp

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,26 @@
1+
#define PROBLEM "https://judge.yosupo.jp/problem/convolution_mod_2_64"
2+
3+
#include "../weilycoder/poly/karatsuba.hpp"
4+
#include <cstdint>
5+
#include <iostream>
6+
#include <vector>
7+
using namespace std;
8+
using namespace weilycoder;
9+
10+
int main() {
11+
cin.tie(nullptr)->sync_with_stdio(false);
12+
cin.exceptions(cin.badbit | cin.failbit);
13+
size_t n, m;
14+
cin >> n >> m;
15+
16+
vector<uint64_t> a(n), b(m);
17+
for (size_t i = 0; i < n; ++i)
18+
cin >> a[i];
19+
for (size_t j = 0; j < m; ++j)
20+
cin >> b[j];
21+
22+
auto c = karatsuba_multiply(a, b);
23+
for (size_t k = 0; k < n + m - 1; ++k)
24+
cout << c[k] << (k + 1 == n + m - 1 ? '\n' : ' ');
25+
return 0;
26+
}

weilycoder/poly/karatsuba.hpp

Lines changed: 107 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,107 @@
1+
#ifndef WEILYCODER_POLY_KARATSUBA_HPP
2+
#define WEILYCODER_POLY_KARATSUBA_HPP
3+
4+
#include <algorithm>
5+
#include <iterator>
6+
#include <type_traits>
7+
#include <vector>
8+
9+
namespace weilycoder {
10+
template <typename InputIt, typename OutputIt, size_t Threshold = 32>
11+
void karatsuba_multiply(InputIt a_begin, InputIt a_end, InputIt b_begin, InputIt b_end,
12+
OutputIt result_begin) {
13+
using T = typename std::iterator_traits<InputIt>::value_type;
14+
15+
static_assert(
16+
std::is_base_of<std::random_access_iterator_tag,
17+
typename std::iterator_traits<InputIt>::iterator_category>::value,
18+
"karatsuba_multiply requires InputIt to be a random access iterator");
19+
static_assert(std::is_base_of<
20+
std::random_access_iterator_tag,
21+
typename std::iterator_traits<OutputIt>::iterator_category>::value,
22+
"karatsuba_multiply requires OutputIt to be a random access iterator");
23+
24+
size_t a_size = std::distance(a_begin, a_end);
25+
size_t b_size = std::distance(b_begin, b_end);
26+
27+
if (a_size <= Threshold || b_size <= Threshold) {
28+
// Base case: use standard multiplication
29+
for (size_t i = 0; i < a_size; ++i)
30+
for (size_t j = 0; j < b_size; ++j)
31+
result_begin[i + j] += a_begin[i] * b_begin[j];
32+
return;
33+
}
34+
35+
size_t res_size = a_size + b_size - 1;
36+
size_t half_size = std::max(a_size, b_size) / 2;
37+
38+
// Split the polynomials
39+
auto a_low_begin = a_begin;
40+
auto a_low_end = (a_size > half_size) ? a_begin + half_size : a_end;
41+
auto a_high_begin = (a_size > half_size) ? a_begin + half_size : a_end;
42+
auto a_high_end = a_end;
43+
auto b_low_begin = b_begin;
44+
auto b_low_end = (b_size > half_size) ? b_begin + half_size : b_end;
45+
auto b_high_begin = (b_size > half_size) ? b_begin + half_size : b_end;
46+
auto b_high_end = b_end;
47+
48+
size_t a_low_size = std::distance(a_low_begin, a_low_end);
49+
size_t b_low_size = std::distance(b_low_begin, b_low_end);
50+
size_t a_high_size = std::distance(a_high_begin, a_high_end);
51+
size_t b_high_size = std::distance(b_high_begin, b_high_end);
52+
size_t a_max_size = std::max(a_low_size, a_high_size);
53+
size_t b_max_size = std::max(b_low_size, b_high_size);
54+
size_t part_size = a_max_size + b_max_size - 1;
55+
56+
std::vector<T> z0(part_size);
57+
std::vector<T> z1(part_size);
58+
std::vector<T> z2(part_size);
59+
60+
// z0 = a_low * b_low
61+
karatsuba_multiply(a_low_begin, a_low_end, b_low_begin, b_low_end, z0.begin());
62+
// z2 = a_high * b_high
63+
karatsuba_multiply(a_high_begin, a_high_end, b_high_begin, b_high_end, z2.begin());
64+
65+
// z1 = (a_low + a_high) * (b_low + b_high) - z0 - z2
66+
std::vector<T> a_sum(std::max(a_low_size, a_high_size));
67+
for (size_t i = 0; i < a_low_size; ++i)
68+
a_sum[i] += a_low_begin[i];
69+
for (size_t i = 0; i < a_high_size; ++i)
70+
a_sum[i] += a_high_begin[i];
71+
std::vector<T> b_sum(std::max(b_low_size, b_high_size));
72+
for (size_t i = 0; i < b_low_size; ++i)
73+
b_sum[i] += b_low_begin[i];
74+
for (size_t i = 0; i < b_high_size; ++i)
75+
b_sum[i] += b_high_begin[i];
76+
karatsuba_multiply(a_sum.begin(), a_sum.end(), b_sum.begin(), b_sum.end(),
77+
z1.begin());
78+
for (size_t i = 0; i < part_size; ++i)
79+
z1[i] -= z0[i] + z2[i];
80+
81+
// Combine results
82+
for (size_t i = 0; i < part_size; ++i) {
83+
if (i >= res_size)
84+
break;
85+
result_begin[i] += z0[i];
86+
}
87+
for (size_t i = 0; i < part_size; ++i) {
88+
if (i + half_size >= res_size)
89+
break;
90+
result_begin[i + half_size] += z1[i];
91+
}
92+
for (size_t i = 0; i < part_size; ++i) {
93+
if (i + 2 * half_size >= res_size)
94+
break;
95+
result_begin[i + 2 * half_size] += z2[i];
96+
}
97+
}
98+
99+
template <typename T>
100+
std::vector<T> karatsuba_multiply(const std::vector<T> &a, const std::vector<T> &b) {
101+
std::vector<T> result(a.size() + b.size() - 1);
102+
karatsuba_multiply(a.begin(), a.end(), b.begin(), b.end(), result.begin());
103+
return result;
104+
}
105+
} // namespace weilycoder
106+
107+
#endif

0 commit comments

Comments
 (0)