Skip to content

Commit 241ea3f

Browse files
committed
implement math
1 parent 90bc429 commit 241ea3f

File tree

2 files changed

+99
-1
lines changed

2 files changed

+99
-1
lines changed

src/lib.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,3 +18,4 @@ pub(crate) mod internal_scc;
1818
pub(crate) mod internal_type_traits;
1919

2020
pub use fenwicktree::FenwickTree;
21+
pub use math::{crt, floor_sum, inv_mod, pow_mod};

src/math.rs

Lines changed: 98 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1,98 @@
1-
use crate::internal_math::*;
1+
use crate::internal_math;
2+
3+
use std::mem::swap;
4+
5+
pub fn pow_mod(x: i64, mut n: i64, m: u32) -> u32 {
6+
assert!(0 <= n && 1 <= m);
7+
8+
let bt = internal_math::Barrett::new(m);
9+
let mut r = 1;
10+
let mut y = internal_math::safe_mod(x, m as i64) as u32;
11+
while n != 0 {
12+
if n & 1 != 0 {
13+
r = bt.mul(r, y);
14+
}
15+
y = bt.mul(y, y);
16+
n >>= 1;
17+
}
18+
r
19+
}
20+
21+
pub fn inv_mod(x: i64, m: i64) -> i64 {
22+
assert!(1 <= m);
23+
let z = internal_math::inv_gcd(x, m);
24+
assert!(z.0 == 1);
25+
z.1
26+
}
27+
28+
pub fn crt(r: &Vec<i64>, m: &Vec<i64>) -> (i64, i64) {
29+
assert!(r.len() == m.len());
30+
// Contracts: 0 <= r0 < m0
31+
let (mut r0, mut m0) = (0, 1);
32+
for (ri, mi) in r.iter().zip(m.iter()) {
33+
assert!(1 < *mi);
34+
let mut r1 = internal_math::safe_mod(*ri, *mi);
35+
let mut m1 = *mi;
36+
if m0 < m1 {
37+
swap(&mut r0, &mut r1);
38+
swap(&mut m0, &mut m1);
39+
}
40+
if m0 % m1 == 0 {
41+
if r0 % m1 != r1 {
42+
return (0, 0);
43+
}
44+
continue;
45+
}
46+
// assume: m0 > m1, lcm(m0, m1) >= 2 * max(m0, m1)
47+
48+
// (r0, m0), (r1, m1) -> (r2, m2 = lcm(m0, m1));
49+
// r2 % m0 = r0
50+
// r2 % m1 = r1
51+
// -> (r0 + x*m0) % m1 = r1
52+
// -> x*u0*g % (u1*g) = (r1 - r0) (u0*g = m0, u1*g = m1)
53+
// -> x = (r1 - r0) / g * inv(u0) (mod u1)
54+
55+
// im = inv(u0) (mod u1) (0 <= im < u1)
56+
let (g, im) = internal_math::inv_gcd(m0, m1);
57+
let u1 = m1 / g;
58+
// |r1 - r0| < (m0 + m1) <= lcm(m0, m1)
59+
if (r1 - r0) % g != 0 {
60+
return (0, 0);
61+
}
62+
// u1 * u1 <= m1 * m1 / g / g <= m0 * m1 / g = lcm(m0, m1)
63+
let x = (r1 - r0) / g % u1 * im % u1;
64+
65+
// |r0| + |m0 * x|
66+
// < m0 + m0 * (u1 - 1)
67+
// = m0 + m0 * m1 / g - m0
68+
// = lcm(m0, m1)
69+
r0 += x * m0;
70+
m0 *= u1; // -> lcm(m0, m1)
71+
if r0 < 0 {
72+
r0 += m0
73+
};
74+
}
75+
76+
(r0, m0)
77+
}
78+
79+
pub fn floor_sum(n: i64, m: i64, mut a: i64, mut b: i64) -> i64 {
80+
let mut ans = 0;
81+
if a >= m {
82+
ans += (n - 1) * n * (a / m) / 2;
83+
a %= m;
84+
}
85+
if b >= m {
86+
ans += n * (b / m);
87+
b %= m;
88+
}
89+
90+
let y_max = (a * n + b) / m;
91+
let x_max = y_max * m - b;
92+
if y_max == 0 {
93+
return ans;
94+
}
95+
ans += (n - (x_max + a - 1) / a) * y_max;
96+
ans += floor_sum(y_max, a, m, (a - x_max % a) % a);
97+
return ans;
98+
}

0 commit comments

Comments
 (0)