Skip to content

libm: implement accelerated computation of (x << e) % y #1012

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 4 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions libm/src/math/support/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ pub(crate) mod feature_detect;
mod float_traits;
pub mod hex_float;
mod int_traits;
mod modular;

#[allow(unused_imports)]
pub use big::{i256, u256};
Expand All @@ -30,6 +31,8 @@ pub use hex_float::hf128;
pub use hex_float::{hf32, hf64};
#[allow(unused_imports)]
pub use int_traits::{CastFrom, CastInto, DInt, HInt, Int, MinInt, NarrowingDiv};
#[allow(unused_imports)]
pub use modular::linear_mul_reduction;

/// Hint to the compiler that the current path is cold.
pub fn cold_path() {
Expand Down
195 changes: 195 additions & 0 deletions libm/src/math/support/modular.rs
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this commit should be able to allow the #[allow(dead_code)] on NarrowingDiv

Original file line number Diff line number Diff line change
@@ -0,0 +1,195 @@
use crate::support::int_traits::NarrowingDiv;
use crate::support::{DInt, HInt, Int};
Comment on lines +1 to +2
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you add a module-level doc comment with some of the common names used here? E.g. R, RR (if that's not R*R) r, m, q, xq. I'm unfortunately a bit lost :) (but I don't need to understand it in detail)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If you've authored this, add /* SPDX-License-Identifier: MIT OR Apache-2.0 */ as well (or only MIT if it's derived, as appropriate)


/// Contains:
/// n in (R/8, R/4)
/// x in [0, 2n)
#[derive(Debug, Clone, PartialEq, Eq)]
struct Reducer<U: HInt> {
// let m = 2n
m: U,
// RR/2 = qm + r
r: U,
xq2: U::D,
}

impl<U> Reducer<U>
where
U: HInt,
U: Int<Unsigned = U>,
{
/// Construct a reducer for `(x << _) mod n`.
///
/// Requires `R/8 < n < R/4` and `x < 2n`.
fn new(x: U, n: U) -> Self
where
U::D: NarrowingDiv,
{
let _1 = U::ONE;
assert!(n > (_1 << (U::BITS - 3)));
assert!(n < (_1 << (U::BITS - 2)));
let m = n << 1;
assert!(x < m);

// We need q and r s.t. RR/2 = qm + r
// As R/4 < m < R/2,
// we have R <= q < 2R
// so let q = R + f
// RR/2 = (R + f)m + r
// R(R/2 - m) = fm + r

// v = R/2 - m < R/4 < m
let v = (_1 << (U::BITS - 1)) - m;
let (f, r) = v.widen_hi().checked_narrowing_div_rem(m).unwrap();

// xq < qm <= RR/2
// 2xq < RR
// 2xq = 2xR + 2xf;
let x2: U = x << 1;
let xq2 = x2.widen_hi() + x2.widen_mul(f);
Self { m, r, xq2 }
}

/// Extract the current remainder in the range `[0, 2n)`
fn partial_remainder(&self) -> U {
// RR/2 = qm + r, 0 <= r < m
// 2xq = uR + v, 0 <= v < R
// muR = 2mxq - mv
// = xRR - 2xr - mv
// mu + (2xr + mv)/R == xR

// 0 <= 2xq < RR
// R <= q < 2R
// 0 <= x < R/2
// R/4 < m < R/2
// 0 <= r < m
// 0 <= mv < mR
// 0 <= 2xr < rR < mR

// 0 <= (2xr + mv)/R < 2m
// Add `mu` to each term to obtain:
// mu <= xR < mu + 2m

// Since `0 <= 2m < R`, `xR` is the only multiple of `R` between
// `mu` and `m(u+2)`, so we can truncate the latter to find `x`.
let _1 = U::ONE;
self.m.widen_mul(self.xq2.hi() + (_1 + _1)).hi()
}

/// Maps the remainder `x` to `(x << k) - un`,
/// for a suitable quotient `u`, which is returned.
fn shift_reduce(&mut self, k: u32) -> U {
assert!(k < U::BITS);
// 2xq << k = aRR/2 + b;
let a = self.xq2.hi() >> (U::BITS - 1 - k);
let (lo, hi) = (self.xq2 << k).lo_hi();
let b = U::D::from_lo_hi(lo, hi & (U::MAX >> 1));

// (2xq << k) - aqm
// = aRR/2 + b - aqm
// = a(RR/2 - qm) + b
// = ar + b
self.xq2 = a.widen_mul(self.r) + b;
a
}

/// Maps the remainder `x` to `x(R/2) - un`,
/// for a suitable quotient `u`, which is returned.
fn word_reduce(&mut self) -> U {
// 2xq = uR + v
let (v, u) = self.xq2.lo_hi();
// xqR - uqm
// = uRR/2 + vR/2 - uRR/2 + ur
// = ur + (v/2)R
self.xq2 = u.widen_mul(self.r) + U::widen_hi(v >> 1);
u
}
}

/// Compute the remainder `(x << e) % y` with unbounded integers.
/// Requires `x < 2y` and `y.leading_zeros() >= 2`
#[allow(dead_code)]
pub fn linear_mul_reduction<U>(x: U, mut e: u32, y: U) -> U
where
U: HInt + Int<Unsigned = U>,
U::D: NarrowingDiv,
{
assert!(y <= U::MAX >> 2);
assert!(x < (y << 1));
let _0 = U::ZERO;
let _1 = U::ONE;

// power of two divisor
if (y & (y - _1)).is_zero() {
if e < U::BITS {
return (x << e) & (y - _1);
} else {
return _0;
}
}

// shift the divisor so it has exactly two leading zeros
let y_shift = y.leading_zeros() - 2;
let mut m = Reducer::new(x, y << y_shift);
e += y_shift;

while e >= U::BITS - 1 {
m.word_reduce();
e -= U::BITS - 1;
}
m.shift_reduce(e);

let rem = m.partial_remainder() >> y_shift;
rem.checked_sub(y).unwrap_or(rem)
}

#[cfg(test)]
mod test {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could this spot check another integer size as well? Just using constants

use crate::support::linear_mul_reduction;
use crate::support::modular::Reducer;

#[test]
fn reducer_ops() {
for n in 33..=63_u8 {
for x in 0..2 * n {
let temp = Reducer::new(x, n);
let n = n as u32;
let x0 = temp.partial_remainder() as u32;
assert_eq!(x as u32, x0);
for k in 0..=7 {
let mut red = temp.clone();
let u = red.shift_reduce(k) as u32;
let x1 = red.partial_remainder() as u32;
assert_eq!(x1, (x0 << k) - u * n);
assert!(x1 < 2 * n);
assert!((red.xq2 as u32).is_multiple_of(2 * x1));

// `word_reduce` is equivalent to
// `shift_reduce(U::BITS - 1)`
if k == 7 {
let mut alt = temp.clone();
let w = alt.word_reduce();
assert_eq!(u, w as u32);
assert_eq!(alt, red);
}
}
}
}
}
#[test]
fn reduction() {
for y in 1..64u8 {
for x in 0..2 * y {
let mut r = x % y;
for e in 0..100 {
assert_eq!(r, linear_mul_reduction(x, e, y));
// maintain the correct expected remainder
r <<= 1;
if r >= y {
r -= y;
}
}
}
}
}
}
Loading