Skip to content

Commit db3fb80

Browse files
committed
feat: Add modular arithmetic utilities
1 parent f49a549 commit db3fb80

File tree

2 files changed

+55
-41
lines changed

2 files changed

+55
-41
lines changed
Lines changed: 52 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,52 @@
1+
#ifndef WEILYCODER_MODINT_HPP
2+
#define WEILYCODER_MODINT_HPP
3+
4+
#include <cstdint>
5+
6+
/**
7+
* @file modint.hpp
8+
* @brief Modular Integer Arithmetic Utilities
9+
*/
10+
11+
namespace weilycoder {
12+
/**
13+
* @brief Perform modular multiplication for 64-bit integers.
14+
* @tparam bit32 If true, won't use 128-bit arithmetic. You should ensure that
15+
* all inputs are small enough to avoid overflow (i.e. bit-32).
16+
* @param a The first multiplicand.
17+
* @param b The second multiplicand.
18+
* @param modulus The modulus.
19+
* @return (a * b) % modulus
20+
*/
21+
template <bool bit32 = false>
22+
uint64_t modular_multiply_64(uint64_t a, uint64_t b, uint64_t modulus) {
23+
if constexpr (bit32)
24+
return a * b % modulus;
25+
else
26+
return static_cast<unsigned __int128>(a) * b % modulus;
27+
}
28+
29+
/**
30+
* @brief Perform modular exponentiation for 64-bit integers.
31+
* @tparam bit32 If true, won't use 128-bit arithmetic. You should ensure that
32+
* all inputs are small enough to avoid overflow (i.e. bit-32).
33+
* @param base The base number.
34+
* @param exponent The exponent.
35+
* @param modulus The modulus.
36+
* @return (base^exponent) % modulus
37+
*/
38+
template <bool bit32 = false>
39+
constexpr uint64_t fast_power_64(uint64_t base, uint64_t exponent, uint64_t modulus) {
40+
uint64_t result = 1 % modulus;
41+
base %= modulus;
42+
while (exponent > 0) {
43+
if (exponent & 1)
44+
result = modular_multiply_64<bit32>(result, base, modulus);
45+
base = modular_multiply_64<bit32>(base, base, modulus);
46+
exponent >>= 1;
47+
}
48+
return result;
49+
}
50+
} // namespace weilycoder
51+
52+
#endif

weilycoder/number-theory/prime.hpp

Lines changed: 3 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -6,49 +6,11 @@
66
* @brief Prime Number Utilities
77
*/
88

9+
#include "modint.hpp"
910
#include <cstdint>
1011
#include <type_traits>
1112

1213
namespace weilycoder {
13-
/**
14-
* @brief Perform modular multiplication for 64-bit integers.
15-
* @tparam bit32 If true, won't use 128-bit arithmetic. You should ensure that
16-
* all inputs are small enough to avoid overflow (i.e. bit-32).
17-
* @param a The first multiplicand.
18-
* @param b The second multiplicand.
19-
* @param modulus The modulus.
20-
* @return (a * b) % modulus
21-
*/
22-
template <bool bit32 = false>
23-
uint64_t modular_multiply(uint64_t a, uint64_t b, uint64_t modulus) {
24-
if constexpr (bit32)
25-
return a * b % modulus;
26-
else
27-
return static_cast<unsigned __int128>(a) * b % modulus;
28-
}
29-
30-
/**
31-
* @brief Perform modular exponentiation for 64-bit integers.
32-
* @tparam bit32 If true, won't use 128-bit arithmetic. You should ensure that
33-
* all inputs are small enough to avoid overflow (i.e. bit-32).
34-
* @param base The base number.
35-
* @param exponent The exponent.
36-
* @param modulus The modulus.
37-
* @return (base^exponent) % modulus
38-
*/
39-
template <bool bit32 = false>
40-
constexpr uint64_t fast_power(uint64_t base, uint64_t exponent, uint64_t modulus) {
41-
uint64_t result = 1 % modulus;
42-
base %= modulus;
43-
while (exponent > 0) {
44-
if (exponent & 1)
45-
result = modular_multiply<bit32>(result, base, modulus);
46-
base = modular_multiply<bit32>(base, base, modulus);
47-
exponent >>= 1;
48-
}
49-
return result;
50-
}
51-
5214
/**
5315
* @brief Miller-Rabin primality test for a given base.
5416
* @tparam bit32 If true, won't use 128-bit arithmetic. You should ensure that
@@ -61,11 +23,11 @@ constexpr uint64_t fast_power(uint64_t base, uint64_t exponent, uint64_t modulus
6123
*/
6224
template <bool bit32, uint64_t base>
6325
constexpr bool miller_rabin_test(uint64_t n, uint64_t d, uint32_t s) {
64-
uint64_t x = fast_power<bit32>(base, d, n);
26+
uint64_t x = fast_power_64<bit32>(base, d, n);
6527
if (x == 0 || x == 1 || x == n - 1)
6628
return true;
6729
for (uint32_t r = 1; r < s; ++r) {
68-
x = modular_multiply<bit32>(x, x, n);
30+
x = modular_multiply_64<bit32>(x, x, n);
6931
if (x == n - 1)
7032
return true;
7133
}

0 commit comments

Comments
 (0)