Skip to content

Commit bd58139

Browse files
authored
Merge pull request #28 from manta1130/feature/math
Add math
2 parents 0312d8d + c2267fd commit bd58139

File tree

2 files changed

+200
-0
lines changed

2 files changed

+200
-0
lines changed

src/lib.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@ pub(crate) mod internal_type_traits;
2020

2121
pub use dsu::Dsu;
2222
pub use fenwicktree::FenwickTree;
23+
pub use math::{crt, floor_sum, inv_mod, pow_mod};
2324
pub use mincostflow::MinCostFlowGraph;
2425
pub use string::{
2526
lcp_array, lcp_array_arbitrary, suffix_array, suffix_array_arbitrary, suffix_array_manual,

src/math.rs

Lines changed: 199 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1 +1,200 @@
1+
use crate::internal_math;
12

3+
use std::mem::swap;
4+
5+
#[allow(clippy::many_single_char_names)]
6+
pub fn pow_mod(x: i64, mut n: i64, m: u32) -> u32 {
7+
assert!(0 <= n && 1 <= m && m <= 2u32.pow(31));
8+
if m == 1 {
9+
return 0;
10+
}
11+
let bt = internal_math::Barrett::new(m);
12+
let mut r = 1;
13+
let mut y = internal_math::safe_mod(x, m as i64) as u32;
14+
while n != 0 {
15+
if n & 1 != 0 {
16+
r = bt.mul(r, y);
17+
}
18+
y = bt.mul(y, y);
19+
n >>= 1;
20+
}
21+
r
22+
}
23+
24+
pub fn inv_mod(x: i64, m: i64) -> i64 {
25+
assert!(1 <= m);
26+
let z = internal_math::inv_gcd(x, m);
27+
assert!(z.0 == 1);
28+
z.1
29+
}
30+
31+
pub fn crt(r: &[i64], m: &[i64]) -> (i64, i64) {
32+
assert_eq!(r.len(), m.len());
33+
// Contracts: 0 <= r0 < m0
34+
let (mut r0, mut m0) = (0, 1);
35+
for (&(mut ri), &(mut mi)) in r.iter().zip(m.iter()) {
36+
assert!(1 < mi);
37+
ri = internal_math::safe_mod(ri, mi);
38+
if m0 < mi {
39+
swap(&mut r0, &mut ri);
40+
swap(&mut m0, &mut mi);
41+
}
42+
if m0 % mi == 0 {
43+
if r0 % mi != ri {
44+
return (0, 0);
45+
}
46+
continue;
47+
}
48+
// assume: m0 > mi, lcm(m0, mi) >= 2 * max(m0, mi)
49+
50+
// (r0, m0), (ri, mi) -> (r2, m2 = lcm(m0, m1));
51+
// r2 % m0 = r0
52+
// r2 % mi = ri
53+
// -> (r0 + x*m0) % mi = ri
54+
// -> x*u0*g % (u1*g) = (ri - r0) (u0*g = m0, u1*g = mi)
55+
// -> x = (ri - r0) / g * inv(u0) (mod u1)
56+
57+
// im = inv(u0) (mod u1) (0 <= im < u1)
58+
let (g, im) = internal_math::inv_gcd(m0, mi);
59+
let u1 = mi / g;
60+
// |ri - r0| < (m0 + mi) <= lcm(m0, mi)
61+
if (ri - r0) % g != 0 {
62+
return (0, 0);
63+
}
64+
// u1 * u1 <= mi * mi / g / g <= m0 * mi / g = lcm(m0, mi)
65+
let x = (ri - r0) / g % u1 * im % u1;
66+
67+
// |r0| + |m0 * x|
68+
// < m0 + m0 * (u1 - 1)
69+
// = m0 + m0 * mi / g - m0
70+
// = lcm(m0, mi)
71+
r0 += x * m0;
72+
m0 *= u1; // -> lcm(m0, mi)
73+
if r0 < 0 {
74+
r0 += m0
75+
};
76+
}
77+
78+
(r0, m0)
79+
}
80+
81+
pub fn floor_sum(n: i64, m: i64, mut a: i64, mut b: i64) -> i64 {
82+
let mut ans = 0;
83+
if a >= m {
84+
ans += (n - 1) * n * (a / m) / 2;
85+
a %= m;
86+
}
87+
if b >= m {
88+
ans += n * (b / m);
89+
b %= m;
90+
}
91+
92+
let y_max = (a * n + b) / m;
93+
let x_max = y_max * m - b;
94+
if y_max == 0 {
95+
return ans;
96+
}
97+
ans += (n - (x_max + a - 1) / a) * y_max;
98+
ans += floor_sum(y_max, a, m, (a - x_max % a) % a);
99+
ans
100+
}
101+
102+
#[cfg(test)]
103+
mod tests {
104+
#![allow(clippy::unreadable_literal)]
105+
#![allow(clippy::cognitive_complexity)]
106+
use super::*;
107+
#[test]
108+
fn test_pow_mod() {
109+
assert_eq!(pow_mod(0, 0, 1), 0);
110+
assert_eq!(pow_mod(0, 0, 3), 1);
111+
assert_eq!(pow_mod(0, 0, 723), 1);
112+
assert_eq!(pow_mod(0, 0, 998244353), 1);
113+
assert_eq!(pow_mod(0, 0, 2u32.pow(31)), 1);
114+
115+
assert_eq!(pow_mod(0, 1, 1), 0);
116+
assert_eq!(pow_mod(0, 1, 3), 0);
117+
assert_eq!(pow_mod(0, 1, 723), 0);
118+
assert_eq!(pow_mod(0, 1, 998244353), 0);
119+
assert_eq!(pow_mod(0, 1, 2u32.pow(31)), 0);
120+
121+
assert_eq!(pow_mod(0, i64::max_value(), 1), 0);
122+
assert_eq!(pow_mod(0, i64::max_value(), 3), 0);
123+
assert_eq!(pow_mod(0, i64::max_value(), 723), 0);
124+
assert_eq!(pow_mod(0, i64::max_value(), 998244353), 0);
125+
assert_eq!(pow_mod(0, i64::max_value(), 2u32.pow(31)), 0);
126+
127+
assert_eq!(pow_mod(1, 0, 1), 0);
128+
assert_eq!(pow_mod(1, 0, 3), 1);
129+
assert_eq!(pow_mod(1, 0, 723), 1);
130+
assert_eq!(pow_mod(1, 0, 998244353), 1);
131+
assert_eq!(pow_mod(1, 0, 2u32.pow(31)), 1);
132+
133+
assert_eq!(pow_mod(1, 1, 1), 0);
134+
assert_eq!(pow_mod(1, 1, 3), 1);
135+
assert_eq!(pow_mod(1, 1, 723), 1);
136+
assert_eq!(pow_mod(1, 1, 998244353), 1);
137+
assert_eq!(pow_mod(1, 1, 2u32.pow(31)), 1);
138+
139+
assert_eq!(pow_mod(1, i64::max_value(), 1), 0);
140+
assert_eq!(pow_mod(1, i64::max_value(), 3), 1);
141+
assert_eq!(pow_mod(1, i64::max_value(), 723), 1);
142+
assert_eq!(pow_mod(1, i64::max_value(), 998244353), 1);
143+
assert_eq!(pow_mod(1, i64::max_value(), 2u32.pow(31)), 1);
144+
145+
assert_eq!(pow_mod(i64::max_value(), 0, 1), 0);
146+
assert_eq!(pow_mod(i64::max_value(), 0, 3), 1);
147+
assert_eq!(pow_mod(i64::max_value(), 0, 723), 1);
148+
assert_eq!(pow_mod(i64::max_value(), 0, 998244353), 1);
149+
assert_eq!(pow_mod(i64::max_value(), 0, 2u32.pow(31)), 1);
150+
151+
assert_eq!(pow_mod(i64::max_value(), i64::max_value(), 1), 0);
152+
assert_eq!(pow_mod(i64::max_value(), i64::max_value(), 3), 1);
153+
assert_eq!(pow_mod(i64::max_value(), i64::max_value(), 723), 640);
154+
assert_eq!(
155+
pow_mod(i64::max_value(), i64::max_value(), 998244353),
156+
683296792
157+
);
158+
assert_eq!(
159+
pow_mod(i64::max_value(), i64::max_value(), 2u32.pow(31)),
160+
2147483647
161+
);
162+
163+
assert_eq!(pow_mod(2, 3, 1_000_000_007), 8);
164+
assert_eq!(pow_mod(5, 7, 1_000_000_007), 78125);
165+
assert_eq!(pow_mod(123, 456, 1_000_000_007), 565291922);
166+
}
167+
168+
#[test]
169+
#[should_panic]
170+
fn test_inv_mod_1() {
171+
inv_mod(271828, 0);
172+
}
173+
174+
#[test]
175+
#[should_panic]
176+
fn test_inv_mod_2() {
177+
inv_mod(3141592, 1000000008);
178+
}
179+
180+
#[test]
181+
fn test_crt() {
182+
let a = [44, 23, 13];
183+
let b = [13, 50, 22];
184+
assert_eq!(crt(&a, &b), (1773, 7150));
185+
let a = [12345, 67890, 99999];
186+
let b = [13, 444321, 95318];
187+
assert_eq!(crt(&a, &b), (103333581255, 550573258014));
188+
}
189+
190+
#[test]
191+
fn test_floor_sum() {
192+
assert_eq!(floor_sum(0, 1, 0, 0), 0);
193+
assert_eq!(floor_sum(1_000_000_000, 1, 1, 1), 500_000_000_500_000_000);
194+
assert_eq!(
195+
floor_sum(1_000_000_000, 1_000_000_000, 999_999_999, 999_999_999),
196+
499_999_999_500_000_000
197+
);
198+
assert_eq!(floor_sum(332955, 5590132, 2231, 999423), 22014575);
199+
}
200+
}

0 commit comments

Comments
 (0)