|
| 1 | +#ifndef WEILYCODER_POLY_KARATSUBA_HPP |
| 2 | +#define WEILYCODER_POLY_KARATSUBA_HPP |
| 3 | + |
| 4 | +#include <algorithm> |
| 5 | +#include <iterator> |
| 6 | +#include <type_traits> |
| 7 | +#include <vector> |
| 8 | + |
| 9 | +namespace weilycoder { |
| 10 | +template <typename InputIt, typename OutputIt, size_t Threshold = 32> |
| 11 | +void karatsuba_multiply(InputIt a_begin, InputIt a_end, InputIt b_begin, InputIt b_end, |
| 12 | + OutputIt result_begin) { |
| 13 | + using T = typename std::iterator_traits<InputIt>::value_type; |
| 14 | + |
| 15 | + static_assert( |
| 16 | + std::is_base_of<std::random_access_iterator_tag, |
| 17 | + typename std::iterator_traits<InputIt>::iterator_category>::value, |
| 18 | + "karatsuba_multiply requires InputIt to be a random access iterator"); |
| 19 | + static_assert(std::is_base_of< |
| 20 | + std::random_access_iterator_tag, |
| 21 | + typename std::iterator_traits<OutputIt>::iterator_category>::value, |
| 22 | + "karatsuba_multiply requires OutputIt to be a random access iterator"); |
| 23 | + |
| 24 | + size_t a_size = std::distance(a_begin, a_end); |
| 25 | + size_t b_size = std::distance(b_begin, b_end); |
| 26 | + |
| 27 | + if (a_size <= Threshold || b_size <= Threshold) { |
| 28 | + // Base case: use standard multiplication |
| 29 | + for (size_t i = 0; i < a_size; ++i) |
| 30 | + for (size_t j = 0; j < b_size; ++j) |
| 31 | + result_begin[i + j] += a_begin[i] * b_begin[j]; |
| 32 | + return; |
| 33 | + } |
| 34 | + |
| 35 | + size_t res_size = a_size + b_size - 1; |
| 36 | + size_t half_size = std::max(a_size, b_size) / 2; |
| 37 | + |
| 38 | + // Split the polynomials |
| 39 | + auto a_low_begin = a_begin; |
| 40 | + auto a_low_end = (a_size > half_size) ? a_begin + half_size : a_end; |
| 41 | + auto a_high_begin = (a_size > half_size) ? a_begin + half_size : a_end; |
| 42 | + auto a_high_end = a_end; |
| 43 | + auto b_low_begin = b_begin; |
| 44 | + auto b_low_end = (b_size > half_size) ? b_begin + half_size : b_end; |
| 45 | + auto b_high_begin = (b_size > half_size) ? b_begin + half_size : b_end; |
| 46 | + auto b_high_end = b_end; |
| 47 | + |
| 48 | + size_t a_low_size = std::distance(a_low_begin, a_low_end); |
| 49 | + size_t b_low_size = std::distance(b_low_begin, b_low_end); |
| 50 | + size_t a_high_size = std::distance(a_high_begin, a_high_end); |
| 51 | + size_t b_high_size = std::distance(b_high_begin, b_high_end); |
| 52 | + size_t a_max_size = std::max(a_low_size, a_high_size); |
| 53 | + size_t b_max_size = std::max(b_low_size, b_high_size); |
| 54 | + size_t part_size = a_max_size + b_max_size - 1; |
| 55 | + |
| 56 | + std::vector<T> z0(part_size); |
| 57 | + std::vector<T> z1(part_size); |
| 58 | + std::vector<T> z2(part_size); |
| 59 | + |
| 60 | + // z0 = a_low * b_low |
| 61 | + karatsuba_multiply(a_low_begin, a_low_end, b_low_begin, b_low_end, z0.begin()); |
| 62 | + // z2 = a_high * b_high |
| 63 | + karatsuba_multiply(a_high_begin, a_high_end, b_high_begin, b_high_end, z2.begin()); |
| 64 | + |
| 65 | + // z1 = (a_low + a_high) * (b_low + b_high) - z0 - z2 |
| 66 | + std::vector<T> a_sum(std::max(a_low_size, a_high_size)); |
| 67 | + for (size_t i = 0; i < a_low_size; ++i) |
| 68 | + a_sum[i] += a_low_begin[i]; |
| 69 | + for (size_t i = 0; i < a_high_size; ++i) |
| 70 | + a_sum[i] += a_high_begin[i]; |
| 71 | + std::vector<T> b_sum(std::max(b_low_size, b_high_size)); |
| 72 | + for (size_t i = 0; i < b_low_size; ++i) |
| 73 | + b_sum[i] += b_low_begin[i]; |
| 74 | + for (size_t i = 0; i < b_high_size; ++i) |
| 75 | + b_sum[i] += b_high_begin[i]; |
| 76 | + karatsuba_multiply(a_sum.begin(), a_sum.end(), b_sum.begin(), b_sum.end(), |
| 77 | + z1.begin()); |
| 78 | + for (size_t i = 0; i < part_size; ++i) |
| 79 | + z1[i] -= z0[i] + z2[i]; |
| 80 | + |
| 81 | + // Combine results |
| 82 | + for (size_t i = 0; i < part_size; ++i) { |
| 83 | + if (i >= res_size) |
| 84 | + break; |
| 85 | + result_begin[i] += z0[i]; |
| 86 | + } |
| 87 | + for (size_t i = 0; i < part_size; ++i) { |
| 88 | + if (i + half_size >= res_size) |
| 89 | + break; |
| 90 | + result_begin[i + half_size] += z1[i]; |
| 91 | + } |
| 92 | + for (size_t i = 0; i < part_size; ++i) { |
| 93 | + if (i + 2 * half_size >= res_size) |
| 94 | + break; |
| 95 | + result_begin[i + 2 * half_size] += z2[i]; |
| 96 | + } |
| 97 | +} |
| 98 | + |
| 99 | +template <typename T> |
| 100 | +std::vector<T> karatsuba_multiply(const std::vector<T> &a, const std::vector<T> &b) { |
| 101 | + std::vector<T> result(a.size() + b.size() - 1); |
| 102 | + karatsuba_multiply(a.begin(), a.end(), b.begin(), b.end(), result.begin()); |
| 103 | + return result; |
| 104 | +} |
| 105 | +} // namespace weilycoder |
| 106 | + |
| 107 | +#endif |
0 commit comments