Skip to content

Commit 31f712f

Browse files
committed
define and implement trait NarrowingDiv for unsigned integer division
1 parent 9c176c2 commit 31f712f

File tree

4 files changed

+192
-1
lines changed

4 files changed

+192
-1
lines changed

libm/src/math/support/big/tests.rs

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@ use std::string::String;
33
use std::{eprintln, format};
44

55
use super::{HInt, MinInt, i256, u256};
6+
use crate::support::{Int as _, NarrowingDiv};
67

78
const LOHI_SPLIT: u128 = 0xaaaaaaaaaaaaaaaaffffffffffffffff;
89

@@ -336,3 +337,28 @@ fn i256_shifts() {
336337
x = y;
337338
}
338339
}
340+
#[test]
341+
fn div_u256_by_u128() {
342+
for j in i8::MIN..=i8::MAX {
343+
let y: u128 = (j as i128).rotate_right(4).unsigned();
344+
if y == 0 {
345+
continue;
346+
}
347+
for i in i8::MIN..=i8::MAX {
348+
let x: u128 = (i as i128).rotate_right(4).unsigned();
349+
let xy = x.widen_mul(y);
350+
assert_eq!(xy.checked_narrowing_div_rem(y), Some((x, 0)));
351+
if y != 1 {
352+
assert_eq!((xy + u256::ONE).checked_narrowing_div_rem(y), Some((x, 1)));
353+
}
354+
if x != 0 {
355+
assert_eq!(
356+
(xy - u256::ONE).checked_narrowing_div_rem(y),
357+
Some((x - 1, y - 1))
358+
);
359+
}
360+
let r = ((y as f64) * 0.12345) as u128;
361+
assert_eq!((xy + r.widen()).checked_narrowing_div_rem(y), Some((x, r)));
362+
}
363+
}
364+
}

libm/src/math/support/int_traits.rs

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,8 @@
11
use core::{cmp, fmt, ops};
22

3+
mod narrowing_div;
4+
pub use narrowing_div::NarrowingDiv;
5+
36
/// Minimal integer implementations needed on all integer types, including wide integers.
47
#[allow(dead_code)] // Some constants are only used with tests
58
pub trait MinInt:
Lines changed: 162 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,162 @@
1+
use crate::support::{DInt, HInt, Int, MinInt, u256};
2+
3+
/// Trait for unsigned division of a double-wide integer
4+
/// when the quotient doesn't overflow.
5+
///
6+
/// This is the inverse of widening multiplication:
7+
/// - for any `x` and nonzero `y`: `x.widen_mul(y).checked_narrowing_div_rem(y) == Some((x, 0))`,
8+
/// - and for any `r in 0..y`: `x.carrying_mul(y, r).checked_narrowing_div_rem(y) == Some((x, r))`,
9+
pub trait NarrowingDiv: DInt + MinInt<Unsigned = Self> {
10+
/// Computes `(self / n, self % n))`
11+
///
12+
/// # Safety
13+
/// The caller must ensure that `self.hi() < n`, or equivalently,
14+
/// that the quotient does not overflow.
15+
unsafe fn unchecked_narrowing_div_rem(self, n: Self::H) -> (Self::H, Self::H);
16+
17+
/// Returns `Some((self / n, self % n))` when `self.hi() < n`.
18+
fn checked_narrowing_div_rem(self, n: Self::H) -> Option<(Self::H, Self::H)> {
19+
if self.hi() < n {
20+
Some(unsafe { self.unchecked_narrowing_div_rem(n) })
21+
} else {
22+
None
23+
}
24+
}
25+
}
26+
27+
macro_rules! impl_narrowing_div_primitive {
28+
($D:ident) => {
29+
impl NarrowingDiv for $D {
30+
unsafe fn unchecked_narrowing_div_rem(self, n: Self::H) -> (Self::H, Self::H) {
31+
if self.hi() >= n {
32+
unsafe { core::hint::unreachable_unchecked() }
33+
}
34+
((self / n as $D) as Self::H, (self % n as $D) as Self::H)
35+
}
36+
}
37+
};
38+
}
39+
40+
// Extend division from `u2N / uN` to `u4N / u2N`
41+
// This is not the most efficient algorithm, but it is
42+
// relatively simple.
43+
macro_rules! impl_narrowing_div_recurse {
44+
($D:ident) => {
45+
impl NarrowingDiv for $D {
46+
unsafe fn unchecked_narrowing_div_rem(self, n: Self::H) -> (Self::H, Self::H) {
47+
if self.hi() >= n {
48+
unsafe { core::hint::unreachable_unchecked() }
49+
}
50+
51+
// Normalize the divisor by shifting the most significant one
52+
// to the leading position. `n != 0` is implied by `self.hi() < n`
53+
let lz = n.leading_zeros();
54+
let a = self << lz;
55+
let b = n << lz;
56+
57+
let ah = a.hi();
58+
let (a0, a1) = a.lo().lo_hi();
59+
// SAFETY: For both calls, `b.leading_zeros() == 0` by the above shift.
60+
// SAFETY: `ah < b` follows from `self.hi() < n`
61+
let (q1, r) = unsafe { div_three_digits_by_two(a1, ah, b) };
62+
// SAFETY: `r < b` is given as the postcondition of the previous call
63+
let (q0, r) = unsafe { div_three_digits_by_two(a0, r, b) };
64+
65+
// Undo the earlier normalization for the remainder
66+
(Self::H::from_lo_hi(q0, q1), r >> lz)
67+
}
68+
}
69+
};
70+
}
71+
72+
impl_narrowing_div_primitive!(u16);
73+
impl_narrowing_div_primitive!(u32);
74+
impl_narrowing_div_primitive!(u64);
75+
impl_narrowing_div_primitive!(u128);
76+
impl_narrowing_div_recurse!(u256);
77+
78+
/// Implement `u3N / u2N`-division on top of `u2N / uN`-division.
79+
///
80+
/// Returns the quotient and remainder of `(a * R + a0) / n`,
81+
/// where `R = (1 << U::BITS)` is the digit size.
82+
///
83+
/// # Safety
84+
/// Requires that `n.leading_zeros() == 0` and `a < n`.
85+
unsafe fn div_three_digits_by_two<U>(a0: U, a: U::D, n: U::D) -> (U, U::D)
86+
where
87+
U: HInt,
88+
U::D: Int + NarrowingDiv,
89+
{
90+
if n.leading_zeros() > 0 || a >= n {
91+
debug_assert!(false, "unsafe preconditions not met");
92+
unsafe { core::hint::unreachable_unchecked() }
93+
}
94+
95+
// n = n1R + n0
96+
let (n0, n1) = n.lo_hi();
97+
// a = a2R + a1
98+
let (a1, a2) = a.lo_hi();
99+
100+
let mut q;
101+
let mut r;
102+
let mut wrap;
103+
// `a < n` is guaranteed by the caller, but `a2 == n1 && a1 < n0` is possible
104+
if let Some((q0, r1)) = a.checked_narrowing_div_rem(n1) {
105+
q = q0;
106+
// a = qn1 + r1, where 0 <= r1 < n1
107+
108+
// Include the remainder with the low bits:
109+
// r = a0 + r1R
110+
r = U::D::from_lo_hi(a0, r1);
111+
112+
// Subtract the contribution of the divisor low bits with the estimated quotient
113+
let d = q.widen_mul(n0);
114+
(r, wrap) = r.overflowing_sub(d);
115+
116+
// Since `q` is the quotient of dividing with a slightly smaller divisor,
117+
// it may be an overapproximation, but is never too small, and similarly,
118+
// `r` is now either the correct remainder ...
119+
if !wrap {
120+
return (q, r);
121+
}
122+
// ... or the remainder went "negative" (by as much as `d = qn0 < RR`)
123+
// and we have to adjust.
124+
q -= U::ONE;
125+
} else {
126+
debug_assert!(a2 == n1 && a1 < n0);
127+
// Otherwise, `a2 == n1`, and the estimated quotient would be
128+
// `R + (a1 % n1)`, but the correct quotient can't overflow.
129+
// We'll start from `q = R = (1 << U::BITS)`,
130+
// so `r = aR + a0 - qn = (a - n)R + a0`
131+
r = U::D::from_lo_hi(a0, a1.wrapping_sub(n0));
132+
// Since `a < n`, the first decrement is always needed:
133+
q = U::MAX; /* R - 1 */
134+
}
135+
136+
(r, wrap) = r.overflowing_add(n);
137+
if wrap {
138+
return (q, r);
139+
}
140+
141+
// If the remainder still didn't wrap, we need another step.
142+
q -= U::ONE;
143+
(r, wrap) = r.overflowing_add(n);
144+
// Since `n >= RR/2`, at least one of the two `r += n` must have wrapped.
145+
debug_assert!(wrap, "estimated quotient should be off by at most two");
146+
(q, r)
147+
}
148+
149+
#[cfg(test)]
150+
mod test {
151+
use super::{HInt, NarrowingDiv};
152+
153+
#[test]
154+
fn inverse_mul() {
155+
for x in 0..=u8::MAX {
156+
for y in 1..=u8::MAX {
157+
let xy = x.widen_mul(y);
158+
assert_eq!(xy.checked_narrowing_div_rem(y), Some((x, 0)));
159+
}
160+
}
161+
}
162+
}

libm/src/math/support/mod.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@ pub use hex_float::hf16;
2828
pub use hex_float::hf128;
2929
#[allow(unused_imports)]
3030
pub use hex_float::{hf32, hf64};
31-
pub use int_traits::{CastFrom, CastInto, DInt, HInt, Int, MinInt};
31+
pub use int_traits::{CastFrom, CastInto, DInt, HInt, Int, MinInt, NarrowingDiv};
3232

3333
/// Hint to the compiler that the current path is cold.
3434
pub fn cold_path() {

0 commit comments

Comments
 (0)