Skip to content

libm: optimize fmod #1002

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 5 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
66 changes: 57 additions & 9 deletions libm/src/math/generic/fmod.rs
Original file line number Diff line number Diff line change
@@ -1,8 +1,12 @@
/* SPDX-License-Identifier: MIT OR Apache-2.0 */
use crate::support::{CastFrom, Float, Int, MinInt};
use crate::support::{CastFrom, CastInto, Float, HInt, Int, MinInt, NarrowingDiv};

#[inline]
pub fn fmod<F: Float>(x: F, y: F) -> F {
pub fn fmod<F: Float>(x: F, y: F) -> F
where
F::Int: HInt,
<F::Int as HInt>::D: NarrowingDiv,
{
let _1 = F::Int::ONE;
let sx = x.to_bits() & F::SIGN_MASK;
let ux = x.to_bits() & !F::SIGN_MASK;
Expand All @@ -29,7 +33,7 @@ pub fn fmod<F: Float>(x: F, y: F) -> F {

// To compute `(num << ex) % (div << ey)`, first
// evaluate `rem = (num << (ex - ey)) % div` ...
let rem = reduction(num, ex - ey, div);
let rem = reduction::<F>(num, ex - ey, div);
// ... so the result will be `rem << ey`

if rem.is_zero() {
Expand Down Expand Up @@ -58,11 +62,55 @@ fn into_sig_exp<F: Float>(mut bits: F::Int) -> (F::Int, u32) {
}

/// Compute the remainder `(x * 2.pow(e)) % y` without overflow.
fn reduction<I: Int>(mut x: I, e: u32, y: I) -> I {
x %= y;
for _ in 0..e {
x <<= 1;
x = x.checked_sub(y).unwrap_or(x);
fn reduction<F>(mut x: F::Int, e: u32, y: F::Int) -> F::Int
where
F: Float,
F::Int: HInt,
<<F as Float>::Int as HInt>::D: NarrowingDiv,
{
// `f16` only has 5 exponent bits, so even `f16::MAX = 65504.0` is only
// a 40-bit integer multiple of the smallest subnormal.
if F::BITS == 16 {
debug_assert!(F::EXP_MAX - F::EXP_MIN == 29);
debug_assert!(e <= 29);
let u: u16 = x.cast();
let v: u16 = y.cast();
let u = (u as u64) << e;
let v = v as u64;
return F::Int::cast_from((u % v) as u16);
}
x

// Ensure `x < 2y` for later steps
if x >= (y << 1) {
// This case is only reached with subnormal divisors,
// but it might be better to just normalize all significands
// to make this unnecessary. The further calls could potentially
// benefit from assuming a specific fixed leading bit position.
x %= y;
}

// The simple implementation seems to be fastest for a short reduction
// at this size. The limit here was chosen empirically on an Intel Nehalem.
// Less old CPUs that have faster `u64 * u64 -> u128` might not benefit,
// and 32-bit systems or architectures without hardware multipliers might
// want to do this in more cases.
if F::BITS == 64 && e < 32 {
// Assumes `x < 2y`
for _ in 0..e {
x = x.checked_sub(y).unwrap_or(x);
x <<= 1;
}
return x.checked_sub(y).unwrap_or(x);
}

// Fast path for short reductions
if e < F::BITS {
let w = x.widen() << e;
if let Some((_, r)) = w.checked_narrowing_div_rem(y) {
return r;
}
}

// Assumes `x < 2y`
crate::support::linear_mul_reduction(x, e, y)
}
26 changes: 26 additions & 0 deletions libm/src/math/support/big/tests.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ use std::string::String;
use std::{eprintln, format};

use super::{HInt, MinInt, i256, u256};
use crate::support::{Int as _, NarrowingDiv};

const LOHI_SPLIT: u128 = 0xaaaaaaaaaaaaaaaaffffffffffffffff;

Expand Down Expand Up @@ -336,3 +337,28 @@ fn i256_shifts() {
x = y;
}
}
#[test]
fn div_u256_by_u128() {
for j in i8::MIN..=i8::MAX {
let y: u128 = (j as i128).rotate_right(4).unsigned();
if y == 0 {
continue;
}
for i in i8::MIN..=i8::MAX {
let x: u128 = (i as i128).rotate_right(4).unsigned();
let xy = x.widen_mul(y);
assert_eq!(xy.checked_narrowing_div_rem(y), Some((x, 0)));
if y != 1 {
assert_eq!((xy + u256::ONE).checked_narrowing_div_rem(y), Some((x, 1)));
}
if x != 0 {
assert_eq!(
(xy - u256::ONE).checked_narrowing_div_rem(y),
Some((x - 1, y - 1))
);
}
let r = ((y as f64) * 0.12345) as u128;
assert_eq!((xy + r.widen()).checked_narrowing_div_rem(y), Some((x, r)));
}
}
}
12 changes: 11 additions & 1 deletion libm/src/math/support/int_traits.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,8 @@
use core::{cmp, fmt, ops};

mod narrowing_div;
pub use narrowing_div::NarrowingDiv;

/// Minimal integer implementations needed on all integer types, including wide integers.
#[allow(dead_code)] // Some constants are only used with tests
pub trait MinInt:
Expand Down Expand Up @@ -293,7 +296,14 @@ int_impl!(i128, u128);

/// Trait for integers twice the bit width of another integer. This is implemented for all
/// primitives except for `u8`, because there is not a smaller primitive.
pub trait DInt: MinInt {
pub trait DInt:
MinInt
+ core::ops::Add<Output = Self>
+ core::ops::Sub<Output = Self>
+ core::ops::Shl<u32, Output = Self>
+ core::ops::Shr<u32, Output = Self>
+ Ord
{
/// Integer that is half the bit width of the integer this trait is implemented for
type H: HInt<D = Self>;

Expand Down
162 changes: 162 additions & 0 deletions libm/src/math/support/int_traits/narrowing_div.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,162 @@
use crate::support::{DInt, HInt, Int, MinInt, u256};

/// Trait for unsigned division of a double-wide integer
/// when the quotient doesn't overflow.
///
/// This is the inverse of widening multiplication:
/// - for any `x` and nonzero `y`: `x.widen_mul(y).checked_narrowing_div_rem(y) == Some((x, 0))`,
/// - and for any `r in 0..y`: `x.carrying_mul(y, r).checked_narrowing_div_rem(y) == Some((x, r))`,
pub trait NarrowingDiv: DInt + MinInt<Unsigned = Self> {
/// Computes `(self / n, self % n))`
///
/// # Safety
/// The caller must ensure that `self.hi() < n`, or equivalently,
/// that the quotient does not overflow.
unsafe fn unchecked_narrowing_div_rem(self, n: Self::H) -> (Self::H, Self::H);

/// Returns `Some((self / n, self % n))` when `self.hi() < n`.
fn checked_narrowing_div_rem(self, n: Self::H) -> Option<(Self::H, Self::H)> {
if self.hi() < n {
Some(unsafe { self.unchecked_narrowing_div_rem(n) })
} else {
None
}
}
}

macro_rules! impl_narrowing_div_primitive {
($D:ident) => {
impl NarrowingDiv for $D {
unsafe fn unchecked_narrowing_div_rem(self, n: Self::H) -> (Self::H, Self::H) {
if self.hi() >= n {
unsafe { core::hint::unreachable_unchecked() }
}
((self / n as $D) as Self::H, (self % n as $D) as Self::H)
}
}
};
}

// Extend division from `u2N / uN` to `u4N / u2N`
// This is not the most efficient algorithm, but it is
// relatively simple.
macro_rules! impl_narrowing_div_recurse {
($D:ident) => {
impl NarrowingDiv for $D {
unsafe fn unchecked_narrowing_div_rem(self, n: Self::H) -> (Self::H, Self::H) {
if self.hi() >= n {
unsafe { core::hint::unreachable_unchecked() }
}

// Normalize the divisor by shifting the most significant one
// to the leading position. `n != 0` is implied by `self.hi() < n`
let lz = n.leading_zeros();
let a = self << lz;
let b = n << lz;

let ah = a.hi();
let (a0, a1) = a.lo().lo_hi();
// SAFETY: For both calls, `b.leading_zeros() == 0` by the above shift.
// SAFETY: `ah < b` follows from `self.hi() < n`
let (q1, r) = unsafe { div_three_digits_by_two(a1, ah, b) };
// SAFETY: `r < b` is given as the postcondition of the previous call
let (q0, r) = unsafe { div_three_digits_by_two(a0, r, b) };

// Undo the earlier normalization for the remainder
(Self::H::from_lo_hi(q0, q1), r >> lz)
}
}
};
}

impl_narrowing_div_primitive!(u16);
impl_narrowing_div_primitive!(u32);
impl_narrowing_div_primitive!(u64);
impl_narrowing_div_primitive!(u128);
impl_narrowing_div_recurse!(u256);

/// Implement `u3N / u2N`-division on top of `u2N / uN`-division.
///
/// Returns the quotient and remainder of `(a * R + a0) / n`,
/// where `R = (1 << U::BITS)` is the digit size.
///
/// # Safety
/// Requires that `n.leading_zeros() == 0` and `a < n`.
unsafe fn div_three_digits_by_two<U>(a0: U, a: U::D, n: U::D) -> (U, U::D)
where
U: HInt,
U::D: Int + NarrowingDiv,
{
if n.leading_zeros() > 0 || a >= n {
debug_assert!(false, "unsafe preconditions not met");
unsafe { core::hint::unreachable_unchecked() }
}

// n = n1R + n0
let (n0, n1) = n.lo_hi();
// a = a2R + a1
let (a1, a2) = a.lo_hi();

let mut q;
let mut r;
let mut wrap;
// `a < n` is guaranteed by the caller, but `a2 == n1 && a1 < n0` is possible
if let Some((q0, r1)) = a.checked_narrowing_div_rem(n1) {
q = q0;
// a = qn1 + r1, where 0 <= r1 < n1

// Include the remainder with the low bits:
// r = a0 + r1R
r = U::D::from_lo_hi(a0, r1);

// Subtract the contribution of the divisor low bits with the estimated quotient
let d = q.widen_mul(n0);
(r, wrap) = r.overflowing_sub(d);

// Since `q` is the quotient of dividing with a slightly smaller divisor,
// it may be an overapproximation, but is never too small, and similarly,
// `r` is now either the correct remainder ...
if !wrap {
return (q, r);
}
// ... or the remainder went "negative" (by as much as `d = qn0 < RR`)
// and we have to adjust.
q -= U::ONE;
} else {
debug_assert!(a2 == n1 && a1 < n0);
// Otherwise, `a2 == n1`, and the estimated quotient would be
// `R + (a1 % n1)`, but the correct quotient can't overflow.
// We'll start from `q = R = (1 << U::BITS)`,
// so `r = aR + a0 - qn = (a - n)R + a0`
r = U::D::from_lo_hi(a0, a1.wrapping_sub(n0));
// Since `a < n`, the first decrement is always needed:
q = U::MAX; /* R - 1 */
}

(r, wrap) = r.overflowing_add(n);
if wrap {
return (q, r);
}

// If the remainder still didn't wrap, we need another step.
q -= U::ONE;
(r, wrap) = r.overflowing_add(n);
// Since `n >= RR/2`, at least one of the two `r += n` must have wrapped.
debug_assert!(wrap, "estimated quotient should be off by at most two");
(q, r)
}

#[cfg(test)]
mod test {
use super::{HInt, NarrowingDiv};

#[test]
fn inverse_mul() {
for x in 0..=u8::MAX {
for y in 1..=u8::MAX {
let xy = x.widen_mul(y);
assert_eq!(xy.checked_narrowing_div_rem(y), Some((x, 0)));
}
}
}
}
4 changes: 3 additions & 1 deletion libm/src/math/support/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ pub(crate) mod feature_detect;
mod float_traits;
pub mod hex_float;
mod int_traits;
mod modular;

#[allow(unused_imports)]
pub use big::{i256, u256};
Expand All @@ -28,7 +29,8 @@ pub use hex_float::hf16;
pub use hex_float::hf128;
#[allow(unused_imports)]
pub use hex_float::{hf32, hf64};
pub use int_traits::{CastFrom, CastInto, DInt, HInt, Int, MinInt};
pub use int_traits::{CastFrom, CastInto, DInt, HInt, Int, MinInt, NarrowingDiv};
pub use modular::linear_mul_reduction;

/// Hint to the compiler that the current path is cold.
pub fn cold_path() {
Expand Down
Loading
Loading