Skip to content

Commit 21987d8

Browse files
committed
Rewrite explanatory comments in Reducer
1 parent c69e44a commit 21987d8

File tree

1 file changed

+92
-42
lines changed

1 file changed

+92
-42
lines changed

libm/src/math/support/modular.rs

Lines changed: 92 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -97,78 +97,128 @@ where
9797
let m = n << 1;
9898
assert!(x < m);
9999

100-
// We need q and r s.t. RR/2 = qm + r, and `0 <= r < m`
101-
// As R/4 < m < R/2,
102-
// we have R <= q < 2R
103-
// so let q = R + f
104-
// RR/2 = (R + f)m + r
105-
// R(R/2 - m) = fm + r
106-
107-
// v = R/2 - m < R/4 < m
108-
let v = (_1 << (U::BITS - 1)) - m;
109-
let (f, r) = v.widen_hi().checked_narrowing_div_rem(m).unwrap();
110-
111-
// xq < qm <= RR/2
112-
// 2xq < RR
113-
// 2xq = 2xR + 2xf;
114-
let _2x: U = x << 1;
100+
// We need to compute the parameters
101+
// `q = (RR/2) / m`
102+
// `r = (RR/2) % m`
103+
104+
// Since `m` is in `(R/4, R/2)`, the quotient `q` is in `[R, 2R)`, and
105+
// it would overflow in `U` if computed directly. Instead, we compute
106+
// `f = q - R`, which is in `[0, R)`. To do so, we simply subtract `Rm`
107+
// from the dividend, which doesn't change the remainder:
108+
// `f = R(R/2 - m) / m`
109+
// `r = R(R/2 - m) % m`
110+
let dividend = ((_1 << (U::BITS - 1)) - m).widen_hi();
111+
let (f, r) = dividend.checked_narrowing_div_rem(m).unwrap();
112+
113+
// As `x < m`, `xq < qm <= RR/2`
114+
// Thus `2xq = 2xR + 2xf` does not overflow in `U::D`.
115+
let _2x = x + x;
115116
let _2xq = _2x.widen_hi() + _2x.widen_mul(f);
116117
Self { m, r, _2xq }
117118
}
118119

119-
/// Extract the current remainder in the range `[0, 2n)`
120+
/// Extract the current remainder `x` in the range `[0, 2n)`
120121
fn partial_remainder(&self) -> U {
121-
// RR/2 = qm + r, 0 <= r < m
122-
// 2xq = uR + v, 0 <= v < R
123-
// muR = 2mxq - mv
124-
// = xRR - 2xr - mv
125-
// mu + (2xr + mv)/R == xR
126-
127-
// 0 <= 2xq < RR
128-
// R <= q < 2R
129-
// 0 <= x < R/2
130-
// R/4 < m < R/2
131-
// 0 <= r < m
132-
// 0 <= mv < mR
133-
// 0 <= 2xr < rR < mR
134-
135-
// 0 <= (2xr + mv)/R < 2m
136-
// Add `mu` to each term to obtain:
137-
// mu <= xR < mu + 2m
138-
139-
// Since `0 <= 2m < R`, `xR` is the only multiple of `R` between
140-
// `mu` and `m(u+2)`, so the high half of `m(u+2)` must equal `x`.
141-
let _1 = U::ONE;
142-
self.m.widen_mul(self._2xq.hi() + (_1 + _1)).hi()
122+
// `RR/2 = qm + r`, where `0 <= r < m`
123+
// `2xq = uR + v`, where `0 <= v < R`
124+
125+
// The goal is to extract the current value of `x` from the value `2xq`
126+
// that we actually have. A bit simplified, we could multiply it by `m`
127+
// to obtain `2xqm == 2x(RR/2 - r) == xRR - 2xr`, where `2xr < RR`.
128+
// We could just round that up to the next multiple of `RR` to get `x`,
129+
// but we can avoid having to multiply the full double-wide `2xq` by
130+
// making a couple of adjustments:
131+
132+
// First, let's only use the high half `u` for the product, and
133+
// include an additional error term due to the truncation:
134+
// `mu = xR - (2xr + mv)/R`
135+
136+
// Next, show bounds for the error term
137+
// `0 <= mv < mR` follows from `0 <= v < R`
138+
// `0 <= 2xr < mR` follows from `0 <= x < m < R/2` and `0 <= r < m`
139+
// Adding those together, we have:
140+
// `0 <= (mv + 2xr)/R < 2m`
141+
// Which also implies:
142+
// `0 < 2m - (mv + 2xr)/R <= 2m < R`
143+
144+
// For that reason, we can use `u + 2` as the factor to obtain
145+
// `m(u + 2) = xR + (2m - (mv + 2xr)/R)`
146+
// By the previous inequality, the second term fits neatly in the lower
147+
// half, so we get exactly `x` as the high half.
148+
let u = self._2xq.hi();
149+
let _2 = U::ONE + U::ONE;
150+
self.m.widen_mul(u + _2).hi()
151+
152+
// Additionally, we should ensure that `u + 2` cannot overflow:
153+
// Since `x < m` and `2qm <= RR`,
154+
// `2xq <= 2q(m-1) <= RR - 2q`
155+
// As we also have `q > R`,
156+
// `2xq < RR - 2R`
157+
// which is sufficient.
143158
}
144159

145160
/// Replace the remainder `x` with `(x << k) - un`,
146161
/// for a suitable quotient `u`, which is returned.
162+
///
163+
/// Requires that `k < U::BITS`.
147164
fn shift_reduce(&mut self, k: u32) -> U {
148165
assert!(k < U::BITS);
149-
// 2xq << k = aRR/2 + b;
166+
167+
// First, split the shifted value:
168+
// `2xq << k = aRR/2 + b`, where `0 <= b < RR/2`
150169
let a = self._2xq.hi() >> (U::BITS - 1 - k);
151170
let (low, high) = (self._2xq << k).lo_hi();
152171
let b = U::D::from_lo_hi(low, high & (U::MAX >> 1));
153172

173+
// Then, subtract `2anq = aqm`:
174+
// ```
154175
// (2xq << k) - aqm
155176
// = aRR/2 + b - aqm
156177
// = a(RR/2 - qm) + b
157178
// = ar + b
179+
// ```
158180
self._2xq = a.widen_mul(self.r) + b;
159181
a
182+
183+
// Since `a` is at most the high half of `2xq`, we have
184+
// `a + 2 < R` (shown above, in `partial_remainder`)
185+
// Using that together with `b < RR/2` and `r < m < R/2`,
186+
// we get `(a + 2)r + b < RR`, so
187+
// `ar + b < RR - 2r = 2mq`
188+
// which shows that the new remainder still satisfies `x < m`.
160189
}
161190

191+
// NB: `word_reduce()` is just the special case `shift_reduce(U::BITS - 1)`
192+
// that optimizes especially well. The correspondence is that `a == u` and
193+
// `b == (v >> 1).widen_hi()`
194+
//
162195
/// Replace the remainder `x` with `x(R/2) - un`,
163196
/// for a suitable quotient `u`, which is returned.
164197
fn word_reduce(&mut self) -> U {
165-
// 2xq = uR + v
166-
let (v, u) = self._2xq.lo_hi();
167-
// xqR - uqm
198+
// To do so, we replace `2xq = uR + v` with
199+
// ```
200+
// 2 * (x(R/2) - un) * q
201+
// = xqR - 2unq
202+
// = xqR - uqm
168203
// = uRR/2 + vR/2 - uRR/2 + ur
169204
// = ur + (v/2)R
205+
// ```
206+
let (v, u) = self._2xq.lo_hi();
170207
self._2xq = u.widen_mul(self.r) + U::widen_hi(v >> 1);
171208
u
209+
210+
// Additional notes:
211+
// 1. As `v` is the low bits of `2xq`, it is even and can be halved.
212+
// 2. The new remainder is `(xr + mv/2) / R` (see below)
213+
// and since `v < R`, `r < m`, `x < m < R/2`,
214+
// that is also strictly less than `m`.
215+
// ```
216+
// (x(R/2) - un)R
217+
// = xRR/2 - (m/2)uR
218+
// = x(qm + r) - (m/2)(2xq - v)
219+
// = xqm + xr - xqm + mv/2
220+
// = xr + mv/2
221+
// ```
172222
}
173223
}
174224

0 commit comments

Comments
 (0)