diff --git a/libm/src/math/generic/fmod.rs b/libm/src/math/generic/fmod.rs index 29acc8a4..3c3fd44b 100644 --- a/libm/src/math/generic/fmod.rs +++ b/libm/src/math/generic/fmod.rs @@ -1,8 +1,12 @@ /* SPDX-License-Identifier: MIT OR Apache-2.0 */ -use crate::support::{CastFrom, Float, Int, MinInt}; +use crate::support::{CastFrom, CastInto, Float, HInt, Int, MinInt, NarrowingDiv}; #[inline] -pub fn fmod(x: F, y: F) -> F { +pub fn fmod(x: F, y: F) -> F +where + F::Int: HInt, + ::D: NarrowingDiv, +{ let _1 = F::Int::ONE; let sx = x.to_bits() & F::SIGN_MASK; let ux = x.to_bits() & !F::SIGN_MASK; @@ -29,7 +33,7 @@ pub fn fmod(x: F, y: F) -> F { // To compute `(num << ex) % (div << ey)`, first // evaluate `rem = (num << (ex - ey)) % div` ... - let rem = reduction(num, ex - ey, div); + let rem = reduction::(num, ex - ey, div); // ... so the result will be `rem << ey` if rem.is_zero() { @@ -58,11 +62,55 @@ fn into_sig_exp(mut bits: F::Int) -> (F::Int, u32) { } /// Compute the remainder `(x * 2.pow(e)) % y` without overflow. -fn reduction(mut x: I, e: u32, y: I) -> I { - x %= y; - for _ in 0..e { - x <<= 1; - x = x.checked_sub(y).unwrap_or(x); +fn reduction(mut x: F::Int, e: u32, y: F::Int) -> F::Int +where + F: Float, + F::Int: HInt, + <::Int as HInt>::D: NarrowingDiv, +{ + // `f16` only has 5 exponent bits, so even `f16::MAX = 65504.0` is only + // a 40-bit integer multiple of the smallest subnormal. + if F::BITS == 16 { + debug_assert!(F::EXP_MAX - F::EXP_MIN == 29); + debug_assert!(e <= 29); + let u: u16 = x.cast(); + let v: u16 = y.cast(); + let u = (u as u64) << e; + let v = v as u64; + return F::Int::cast_from((u % v) as u16); } - x + + // Ensure `x < 2y` for later steps + if x >= (y << 1) { + // This case is only reached with subnormal divisors, + // but it might be better to just normalize all significands + // to make this unnecessary. The further calls could potentially + // benefit from assuming a specific fixed leading bit position. + x %= y; + } + + // The simple implementation seems to be fastest for a short reduction + // at this size. The limit here was chosen empirically on an Intel Nehalem. + // Less old CPUs that have faster `u64 * u64 -> u128` might not benefit, + // and 32-bit systems or architectures without hardware multipliers might + // want to do this in more cases. + if F::BITS == 64 && e < 32 { + // Assumes `x < 2y` + for _ in 0..e { + x = x.checked_sub(y).unwrap_or(x); + x <<= 1; + } + return x.checked_sub(y).unwrap_or(x); + } + + // Fast path for short reductions + if e < F::BITS { + let w = x.widen() << e; + if let Some((_, r)) = w.checked_narrowing_div_rem(y) { + return r; + } + } + + // Assumes `x < 2y` + crate::support::linear_mul_reduction(x, e, y) } diff --git a/libm/src/math/support/big/tests.rs b/libm/src/math/support/big/tests.rs index d54706c7..0c32f445 100644 --- a/libm/src/math/support/big/tests.rs +++ b/libm/src/math/support/big/tests.rs @@ -3,6 +3,7 @@ use std::string::String; use std::{eprintln, format}; use super::{HInt, MinInt, i256, u256}; +use crate::support::{Int as _, NarrowingDiv}; const LOHI_SPLIT: u128 = 0xaaaaaaaaaaaaaaaaffffffffffffffff; @@ -336,3 +337,28 @@ fn i256_shifts() { x = y; } } +#[test] +fn div_u256_by_u128() { + for j in i8::MIN..=i8::MAX { + let y: u128 = (j as i128).rotate_right(4).unsigned(); + if y == 0 { + continue; + } + for i in i8::MIN..=i8::MAX { + let x: u128 = (i as i128).rotate_right(4).unsigned(); + let xy = x.widen_mul(y); + assert_eq!(xy.checked_narrowing_div_rem(y), Some((x, 0))); + if y != 1 { + assert_eq!((xy + u256::ONE).checked_narrowing_div_rem(y), Some((x, 1))); + } + if x != 0 { + assert_eq!( + (xy - u256::ONE).checked_narrowing_div_rem(y), + Some((x - 1, y - 1)) + ); + } + let r = ((y as f64) * 0.12345) as u128; + assert_eq!((xy + r.widen()).checked_narrowing_div_rem(y), Some((x, r))); + } + } +} diff --git a/libm/src/math/support/int_traits.rs b/libm/src/math/support/int_traits.rs index 9d8826df..1559231a 100644 --- a/libm/src/math/support/int_traits.rs +++ b/libm/src/math/support/int_traits.rs @@ -1,5 +1,8 @@ use core::{cmp, fmt, ops}; +mod narrowing_div; +pub use narrowing_div::NarrowingDiv; + /// Minimal integer implementations needed on all integer types, including wide integers. #[allow(dead_code)] // Some constants are only used with tests pub trait MinInt: @@ -293,7 +296,14 @@ int_impl!(i128, u128); /// Trait for integers twice the bit width of another integer. This is implemented for all /// primitives except for `u8`, because there is not a smaller primitive. -pub trait DInt: MinInt { +pub trait DInt: + MinInt + + core::ops::Add + + core::ops::Sub + + core::ops::Shl + + core::ops::Shr + + Ord +{ /// Integer that is half the bit width of the integer this trait is implemented for type H: HInt; diff --git a/libm/src/math/support/int_traits/narrowing_div.rs b/libm/src/math/support/int_traits/narrowing_div.rs new file mode 100644 index 00000000..8d015842 --- /dev/null +++ b/libm/src/math/support/int_traits/narrowing_div.rs @@ -0,0 +1,162 @@ +use crate::support::{DInt, HInt, Int, MinInt, u256}; + +/// Trait for unsigned division of a double-wide integer +/// when the quotient doesn't overflow. +/// +/// This is the inverse of widening multiplication: +/// - for any `x` and nonzero `y`: `x.widen_mul(y).checked_narrowing_div_rem(y) == Some((x, 0))`, +/// - and for any `r in 0..y`: `x.carrying_mul(y, r).checked_narrowing_div_rem(y) == Some((x, r))`, +pub trait NarrowingDiv: DInt + MinInt { + /// Computes `(self / n, self % n))` + /// + /// # Safety + /// The caller must ensure that `self.hi() < n`, or equivalently, + /// that the quotient does not overflow. + unsafe fn unchecked_narrowing_div_rem(self, n: Self::H) -> (Self::H, Self::H); + + /// Returns `Some((self / n, self % n))` when `self.hi() < n`. + fn checked_narrowing_div_rem(self, n: Self::H) -> Option<(Self::H, Self::H)> { + if self.hi() < n { + Some(unsafe { self.unchecked_narrowing_div_rem(n) }) + } else { + None + } + } +} + +macro_rules! impl_narrowing_div_primitive { + ($D:ident) => { + impl NarrowingDiv for $D { + unsafe fn unchecked_narrowing_div_rem(self, n: Self::H) -> (Self::H, Self::H) { + if self.hi() >= n { + unsafe { core::hint::unreachable_unchecked() } + } + ((self / n as $D) as Self::H, (self % n as $D) as Self::H) + } + } + }; +} + +// Extend division from `u2N / uN` to `u4N / u2N` +// This is not the most efficient algorithm, but it is +// relatively simple. +macro_rules! impl_narrowing_div_recurse { + ($D:ident) => { + impl NarrowingDiv for $D { + unsafe fn unchecked_narrowing_div_rem(self, n: Self::H) -> (Self::H, Self::H) { + if self.hi() >= n { + unsafe { core::hint::unreachable_unchecked() } + } + + // Normalize the divisor by shifting the most significant one + // to the leading position. `n != 0` is implied by `self.hi() < n` + let lz = n.leading_zeros(); + let a = self << lz; + let b = n << lz; + + let ah = a.hi(); + let (a0, a1) = a.lo().lo_hi(); + // SAFETY: For both calls, `b.leading_zeros() == 0` by the above shift. + // SAFETY: `ah < b` follows from `self.hi() < n` + let (q1, r) = unsafe { div_three_digits_by_two(a1, ah, b) }; + // SAFETY: `r < b` is given as the postcondition of the previous call + let (q0, r) = unsafe { div_three_digits_by_two(a0, r, b) }; + + // Undo the earlier normalization for the remainder + (Self::H::from_lo_hi(q0, q1), r >> lz) + } + } + }; +} + +impl_narrowing_div_primitive!(u16); +impl_narrowing_div_primitive!(u32); +impl_narrowing_div_primitive!(u64); +impl_narrowing_div_primitive!(u128); +impl_narrowing_div_recurse!(u256); + +/// Implement `u3N / u2N`-division on top of `u2N / uN`-division. +/// +/// Returns the quotient and remainder of `(a * R + a0) / n`, +/// where `R = (1 << U::BITS)` is the digit size. +/// +/// # Safety +/// Requires that `n.leading_zeros() == 0` and `a < n`. +unsafe fn div_three_digits_by_two(a0: U, a: U::D, n: U::D) -> (U, U::D) +where + U: HInt, + U::D: Int + NarrowingDiv, +{ + if n.leading_zeros() > 0 || a >= n { + debug_assert!(false, "unsafe preconditions not met"); + unsafe { core::hint::unreachable_unchecked() } + } + + // n = n1R + n0 + let (n0, n1) = n.lo_hi(); + // a = a2R + a1 + let (a1, a2) = a.lo_hi(); + + let mut q; + let mut r; + let mut wrap; + // `a < n` is guaranteed by the caller, but `a2 == n1 && a1 < n0` is possible + if let Some((q0, r1)) = a.checked_narrowing_div_rem(n1) { + q = q0; + // a = qn1 + r1, where 0 <= r1 < n1 + + // Include the remainder with the low bits: + // r = a0 + r1R + r = U::D::from_lo_hi(a0, r1); + + // Subtract the contribution of the divisor low bits with the estimated quotient + let d = q.widen_mul(n0); + (r, wrap) = r.overflowing_sub(d); + + // Since `q` is the quotient of dividing with a slightly smaller divisor, + // it may be an overapproximation, but is never too small, and similarly, + // `r` is now either the correct remainder ... + if !wrap { + return (q, r); + } + // ... or the remainder went "negative" (by as much as `d = qn0 < RR`) + // and we have to adjust. + q -= U::ONE; + } else { + debug_assert!(a2 == n1 && a1 < n0); + // Otherwise, `a2 == n1`, and the estimated quotient would be + // `R + (a1 % n1)`, but the correct quotient can't overflow. + // We'll start from `q = R = (1 << U::BITS)`, + // so `r = aR + a0 - qn = (a - n)R + a0` + r = U::D::from_lo_hi(a0, a1.wrapping_sub(n0)); + // Since `a < n`, the first decrement is always needed: + q = U::MAX; /* R - 1 */ + } + + (r, wrap) = r.overflowing_add(n); + if wrap { + return (q, r); + } + + // If the remainder still didn't wrap, we need another step. + q -= U::ONE; + (r, wrap) = r.overflowing_add(n); + // Since `n >= RR/2`, at least one of the two `r += n` must have wrapped. + debug_assert!(wrap, "estimated quotient should be off by at most two"); + (q, r) +} + +#[cfg(test)] +mod test { + use super::{HInt, NarrowingDiv}; + + #[test] + fn inverse_mul() { + for x in 0..=u8::MAX { + for y in 1..=u8::MAX { + let xy = x.widen_mul(y); + assert_eq!(xy.checked_narrowing_div_rem(y), Some((x, 0))); + } + } + } +} diff --git a/libm/src/math/support/mod.rs b/libm/src/math/support/mod.rs index b2d7bd8d..15ab010d 100644 --- a/libm/src/math/support/mod.rs +++ b/libm/src/math/support/mod.rs @@ -8,6 +8,7 @@ pub(crate) mod feature_detect; mod float_traits; pub mod hex_float; mod int_traits; +mod modular; #[allow(unused_imports)] pub use big::{i256, u256}; @@ -28,7 +29,8 @@ pub use hex_float::hf16; pub use hex_float::hf128; #[allow(unused_imports)] pub use hex_float::{hf32, hf64}; -pub use int_traits::{CastFrom, CastInto, DInt, HInt, Int, MinInt}; +pub use int_traits::{CastFrom, CastInto, DInt, HInt, Int, MinInt, NarrowingDiv}; +pub use modular::linear_mul_reduction; /// Hint to the compiler that the current path is cold. pub fn cold_path() { diff --git a/libm/src/math/support/modular.rs b/libm/src/math/support/modular.rs new file mode 100644 index 00000000..5b14b9ec --- /dev/null +++ b/libm/src/math/support/modular.rs @@ -0,0 +1,194 @@ +use crate::support::int_traits::NarrowingDiv; +use crate::support::{DInt, HInt, Int}; + +/// Contains: +/// n in (R/8, R/4) +/// x in [0, 2n) +#[derive(Debug, Clone, PartialEq, Eq)] +struct Reducer { + // let m = 2n + m: U, + // RR/2 = qm + r + r: U, + xq2: U::D, +} + +impl Reducer +where + U: HInt, + U: Int, +{ + /// Construct a reducer for `(x << _) mod n`. + /// + /// Requires `R/8 < n < R/4` and `x < 2n`. + fn new(x: U, n: U) -> Self + where + U::D: NarrowingDiv, + { + let _1 = U::ONE; + assert!(n > (_1 << (U::BITS - 3))); + assert!(n < (_1 << (U::BITS - 2))); + let m = n << 1; + assert!(x < m); + + // We need q and r s.t. RR/2 = qm + r + // As R/4 < m < R/2, + // we have R <= q < 2R + // so let q = R + f + // RR/2 = (R + f)m + r + // R(R/2 - m) = fm + r + + // v = R/2 - m < R/4 < m + let v = (_1 << (U::BITS - 1)) - m; + let (f, r) = v.widen_hi().checked_narrowing_div_rem(m).unwrap(); + + // xq < qm <= RR/2 + // 2xq < RR + // 2xq = 2xR + 2xf; + let x2: U = x << 1; + let xq2 = x2.widen_hi() + x2.widen_mul(f); + Self { m, r, xq2 } + } + + /// Extract the current remainder in the range `[0, 2n)` + fn partial_remainder(&self) -> U { + // RR/2 = qm + r, 0 <= r < m + // 2xq = uR + v, 0 <= v < R + // muR = 2mxq - mv + // = xRR - 2xr - mv + // mu + (2xr + mv)/R == xR + + // 0 <= 2xq < RR + // R <= q < 2R + // 0 <= x < R/2 + // R/4 < m < R/2 + // 0 <= r < m + // 0 <= mv < mR + // 0 <= 2xr < rR < mR + + // 0 <= (2xr + mv)/R < 2m + // Add `mu` to each term to obtain: + // mu <= xR < mu + 2m + + // Since `0 <= 2m < R`, `xR` is the only multiple of `R` between + // `mu` and `m(u+2)`, so we can truncate the latter to find `x`. + let _1 = U::ONE; + self.m.widen_mul(self.xq2.hi() + (_1 + _1)).hi() + } + + /// Maps the remainder `x` to `(x << k) - un`, + /// for a suitable quotient `u`, which is returned. + fn shift_reduce(&mut self, k: u32) -> U { + assert!(k < U::BITS); + // 2xq << k = aRR/2 + b; + let a = self.xq2.hi() >> (U::BITS - 1 - k); + let (lo, hi) = (self.xq2 << k).lo_hi(); + let b = U::D::from_lo_hi(lo, hi & (U::MAX >> 1)); + + // (2xq << k) - aqm + // = aRR/2 + b - aqm + // = a(RR/2 - qm) + b + // = ar + b + self.xq2 = a.widen_mul(self.r) + b; + a + } + + /// Maps the remainder `x` to `x(R/2) - un`, + /// for a suitable quotient `u`, which is returned. + fn word_reduce(&mut self) -> U { + // 2xq = uR + v + let (v, u) = self.xq2.lo_hi(); + // xqR - uqm + // = uRR/2 + vR/2 - uRR/2 + ur + // = ur + (v/2)R + self.xq2 = u.widen_mul(self.r) + U::widen_hi(v >> 1); + u + } +} + +/// Compute the remainder `(x << e) % y` with unbounded integers. +/// Requires `x < 2y` and `y.leading_zeros() >= 2` +pub fn linear_mul_reduction(x: U, mut e: u32, y: U) -> U +where + U: HInt + Int, + U::D: NarrowingDiv, +{ + assert!(y <= U::MAX >> 2); + assert!(x < (y << 1)); + let _0 = U::ZERO; + let _1 = U::ONE; + + // power of two divisor + if (y & (y - _1)).is_zero() { + if e < U::BITS { + return (x << e) & (y - _1); + } else { + return _0; + } + } + + // shift the divisor so it has exactly two leading zeros + let y_shift = y.leading_zeros() - 2; + let mut m = Reducer::new(x, y << y_shift); + e += y_shift; + + while e >= U::BITS - 1 { + m.word_reduce(); + e -= U::BITS - 1; + } + m.shift_reduce(e); + + let rem = m.partial_remainder() >> y_shift; + rem.checked_sub(y).unwrap_or(rem) +} + +#[cfg(test)] +mod test { + use crate::support::linear_mul_reduction; + use crate::support::modular::Reducer; + + #[test] + fn reducer_ops() { + for n in 33..=63_u8 { + for x in 0..2 * n { + let temp = Reducer::new(x, n); + let n = n as u32; + let x0 = temp.partial_remainder() as u32; + assert_eq!(x as u32, x0); + for k in 0..=7 { + let mut red = temp.clone(); + let u = red.shift_reduce(k) as u32; + let x1 = red.partial_remainder() as u32; + assert_eq!(x1, (x0 << k) - u * n); + assert!(x1 < 2 * n); + assert!((red.xq2 as u32).is_multiple_of(2 * x1)); + + // `word_reduce` is equivalent to + // `shift_reduce(U::BITS - 1)` + if k == 7 { + let mut alt = temp.clone(); + let w = alt.word_reduce(); + assert_eq!(u, w as u32); + assert_eq!(alt, red); + } + } + } + } + } + #[test] + fn reduction() { + for y in 1..64u8 { + for x in 0..2 * y { + let mut r = x % y; + for e in 0..100 { + assert_eq!(r, linear_mul_reduction(x, e, y)); + // maintain the correct expected remainder + r <<= 1; + if r >= y { + r -= y; + } + } + } + } + } +}