Skip to content

Commit a2d7c81

Browse files
Added tests
1 parent bd447c0 commit a2d7c81

File tree

1 file changed

+223
-2
lines changed

1 file changed

+223
-2
lines changed

src/internal_math.rs

Lines changed: 223 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -196,7 +196,7 @@ fn primitive_root(m: i32) -> i32 {
196196
x /= 2;
197197
}
198198
for i in (3..std::i32::MAX).step_by(2) {
199-
if (i as i64) * (i as i64) <= (x as i64) {
199+
if i as i64 * i as i64 > x as i64 {
200200
break;
201201
}
202202
if x % i == 0 {
@@ -213,11 +213,232 @@ fn primitive_root(m: i32) -> i32 {
213213
}
214214
let mut g = 2;
215215
loop {
216-
if (0..cnt).any(|i| pow_mod(g, ((m - 1) / divs[i]) as i64, m) == 1) {
216+
if (0..cnt).all(|i| pow_mod(g, ((m - 1) / divs[i]) as i64, m) != 1) {
217217
break g as i32;
218218
}
219219
g += 1;
220220
}
221221
}
222222
// omitted
223223
// template <int m> constexpr int primitive_root = primitive_root_constexpr(m);
224+
225+
#[cfg(test)]
226+
mod tests {
227+
use crate::internal_math::{inv_gcd, is_prime, pow_mod, primitive_root, safe_mod, Barrett};
228+
use std::collections::HashSet;
229+
230+
#[test]
231+
fn test_safe_mod() {
232+
assert_eq!(safe_mod(0, 3), 0);
233+
assert_eq!(safe_mod(1, 3), 1);
234+
assert_eq!(safe_mod(2, 3), 2);
235+
assert_eq!(safe_mod(3, 3), 0);
236+
assert_eq!(safe_mod(4, 3), 1);
237+
assert_eq!(safe_mod(5, 3), 2);
238+
assert_eq!(safe_mod(73, 11), 7);
239+
assert_eq!(safe_mod(2306249155046129918, 6620319213327), 1374210749525);
240+
241+
assert_eq!(safe_mod(-1, 3), 2);
242+
assert_eq!(safe_mod(-2, 3), 1);
243+
assert_eq!(safe_mod(-3, 3), 0);
244+
assert_eq!(safe_mod(-4, 3), 2);
245+
assert_eq!(safe_mod(-5, 3), 1);
246+
assert_eq!(safe_mod(-7170500492396019511, 777567337), 333221848);
247+
}
248+
249+
#[test]
250+
fn test_barrett() {
251+
let b = Barrett::new(7);
252+
assert_eq!(b.umod(), 7);
253+
assert_eq!(b.mul(2, 3), 6);
254+
assert_eq!(b.mul(4, 6), 3);
255+
assert_eq!(b.mul(5, 0), 0);
256+
257+
let b = Barrett::new(998244353);
258+
assert_eq!(b.umod(), 998244353);
259+
assert_eq!(b.mul(2, 3), 6);
260+
assert_eq!(b.mul(3141592, 653589), 919583920);
261+
assert_eq!(b.mul(323846264, 338327950), 568012980);
262+
263+
// make `z - x * self._m as u64` overflow.
264+
// Thanks @koba-e964 (at https://github.com/rust-lang-ja/ac-library-rs/pull/3#discussion_r484932161)
265+
let b = Barrett::new(2147483647);
266+
assert_eq!(b.umod(), 2147483647);
267+
assert_eq!(b.mul(1073741824, 2147483645), 2147483646);
268+
}
269+
270+
#[test]
271+
fn test_pow_mod() {
272+
assert_eq!(pow_mod(0, 0, 1), 0);
273+
assert_eq!(pow_mod(0, 0, 3), 1);
274+
assert_eq!(pow_mod(0, 0, 723), 1);
275+
assert_eq!(pow_mod(0, 0, 998244353), 1);
276+
assert_eq!(pow_mod(0, 0, i32::max_value()), 1);
277+
278+
assert_eq!(pow_mod(0, 1, 1), 0);
279+
assert_eq!(pow_mod(0, 1, 3), 0);
280+
assert_eq!(pow_mod(0, 1, 723), 0);
281+
assert_eq!(pow_mod(0, 1, 998244353), 0);
282+
assert_eq!(pow_mod(0, 1, i32::max_value()), 0);
283+
284+
assert_eq!(pow_mod(0, i64::max_value(), 1), 0);
285+
assert_eq!(pow_mod(0, i64::max_value(), 3), 0);
286+
assert_eq!(pow_mod(0, i64::max_value(), 723), 0);
287+
assert_eq!(pow_mod(0, i64::max_value(), 998244353), 0);
288+
assert_eq!(pow_mod(0, i64::max_value(), i32::max_value()), 0);
289+
290+
assert_eq!(pow_mod(1, 0, 1), 0);
291+
assert_eq!(pow_mod(1, 0, 3), 1);
292+
assert_eq!(pow_mod(1, 0, 723), 1);
293+
assert_eq!(pow_mod(1, 0, 998244353), 1);
294+
assert_eq!(pow_mod(1, 0, i32::max_value()), 1);
295+
296+
assert_eq!(pow_mod(1, 1, 1), 0);
297+
assert_eq!(pow_mod(1, 1, 3), 1);
298+
assert_eq!(pow_mod(1, 1, 723), 1);
299+
assert_eq!(pow_mod(1, 1, 998244353), 1);
300+
assert_eq!(pow_mod(1, 1, i32::max_value()), 1);
301+
302+
assert_eq!(pow_mod(1, i64::max_value(), 1), 0);
303+
assert_eq!(pow_mod(1, i64::max_value(), 3), 1);
304+
assert_eq!(pow_mod(1, i64::max_value(), 723), 1);
305+
assert_eq!(pow_mod(1, i64::max_value(), 998244353), 1);
306+
assert_eq!(pow_mod(1, i64::max_value(), i32::max_value()), 1);
307+
308+
assert_eq!(pow_mod(i64::max_value(), 0, 1), 0);
309+
assert_eq!(pow_mod(i64::max_value(), 0, 3), 1);
310+
assert_eq!(pow_mod(i64::max_value(), 0, 723), 1);
311+
assert_eq!(pow_mod(i64::max_value(), 0, 998244353), 1);
312+
assert_eq!(pow_mod(i64::max_value(), 0, i32::max_value()), 1);
313+
314+
assert_eq!(pow_mod(i64::max_value(), i64::max_value(), 1), 0);
315+
assert_eq!(pow_mod(i64::max_value(), i64::max_value(), 3), 1);
316+
assert_eq!(pow_mod(i64::max_value(), i64::max_value(), 723), 640);
317+
assert_eq!(
318+
pow_mod(i64::max_value(), i64::max_value(), 998244353),
319+
683296792
320+
);
321+
assert_eq!(
322+
pow_mod(i64::max_value(), i64::max_value(), i32::max_value()),
323+
1
324+
);
325+
326+
assert_eq!(pow_mod(2, 3, 1_000_000_007), 8);
327+
assert_eq!(pow_mod(5, 7, 1_000_000_007), 78125);
328+
assert_eq!(pow_mod(123, 456, 1_000_000_007), 565291922);
329+
}
330+
331+
#[test]
332+
fn test_is_prime() {
333+
assert!(!is_prime(0));
334+
assert!(!is_prime(1));
335+
assert!(is_prime(2));
336+
assert!(is_prime(3));
337+
assert!(!is_prime(4));
338+
assert!(is_prime(5));
339+
assert!(!is_prime(6));
340+
assert!(is_prime(7));
341+
assert!(!is_prime(8));
342+
assert!(!is_prime(9));
343+
344+
// assert!(is_prime(57));
345+
assert!(!is_prime(57));
346+
assert!(!is_prime(58));
347+
assert!(is_prime(59));
348+
assert!(!is_prime(60));
349+
assert!(is_prime(61));
350+
assert!(!is_prime(62));
351+
352+
assert!(!is_prime(701928443));
353+
assert!(is_prime(998244353));
354+
assert!(!is_prime(1_000_000_000));
355+
assert!(is_prime(1_000_000_007));
356+
357+
assert!(is_prime(i32::max_value()));
358+
}
359+
360+
#[test]
361+
fn is_prime_sieve() {
362+
let n = 1_000_000;
363+
let mut prime = vec![true; n];
364+
prime[0] = false;
365+
prime[1] = false;
366+
for i in 0..n {
367+
assert_eq!(prime[i], is_prime(i as i32));
368+
if prime[i] {
369+
for j in (2 * i..n).step_by(i) {
370+
prime[j] = false;
371+
}
372+
}
373+
}
374+
}
375+
376+
#[test]
377+
fn test_inv_gcd() {
378+
for &(a, b, g) in &[
379+
(0, 1, 1),
380+
(0, 4, 4),
381+
(0, 7, 7),
382+
(2, 3, 1),
383+
(-2, 3, 1),
384+
(4, 6, 2),
385+
(-4, 6, 2),
386+
(13, 23, 1),
387+
(57, 81, 3),
388+
(12345, 67890, 15),
389+
(-3141592 * 6535, 3141592 * 8979, 3141592),
390+
(i64::max_value(), i64::max_value(), i64::max_value()),
391+
(i64::min_value(), i64::max_value(), 1),
392+
] {
393+
let (g_, x) = inv_gcd(a, b);
394+
assert_eq!(g, g_);
395+
let b_ = b as i128;
396+
assert_eq!(((x as i128 * a as i128) % b_ + b_) % b_, g as i128 % b_);
397+
}
398+
}
399+
400+
#[test]
401+
fn test_primitive_root() {
402+
for &p in &[
403+
2,
404+
3,
405+
5,
406+
7,
407+
233,
408+
200003,
409+
998244353,
410+
1_000_000_007,
411+
i32::max_value(),
412+
] {
413+
assert!(is_prime(p));
414+
let g = primitive_root(p);
415+
if p != 2 {
416+
assert_ne!(g, 1);
417+
}
418+
419+
let q = p - 1;
420+
for i in (2..i32::max_value()).take_while(|i| i * i <= q) {
421+
if q % i != 0 {
422+
break;
423+
}
424+
for &r in &[i, q / i] {
425+
assert_ne!(pow_mod(g as i64, r as i64, p), 1);
426+
}
427+
}
428+
assert_eq!(pow_mod(g as i64, q as i64, p), 1);
429+
430+
if p < 1_000_000 {
431+
assert_eq!(
432+
(0..p - 1)
433+
.scan(1, |i, _| {
434+
*i = *i * g % p;
435+
Some(*i)
436+
})
437+
.collect::<HashSet<_>>()
438+
.len() as i32,
439+
p - 1
440+
);
441+
}
442+
}
443+
}
444+
}

0 commit comments

Comments
 (0)