@@ -97,78 +97,128 @@ where
97
97
let m = n << 1 ;
98
98
assert ! ( x < m) ;
99
99
100
- // We need q and r s.t. RR/2 = qm + r, and `0 <= r < m`
101
- // As R/4 < m < R/2,
102
- // we have R <= q < 2R
103
- // so let q = R + f
104
- // RR/2 = (R + f)m + r
105
- // R(R/2 - m) = fm + r
106
-
107
- // v = R/2 - m < R/4 < m
108
- let v = ( _1 << ( U :: BITS - 1 ) ) - m;
109
- let ( f, r) = v. widen_hi ( ) . checked_narrowing_div_rem ( m) . unwrap ( ) ;
110
-
111
- // xq < qm <= RR/2
112
- // 2xq < RR
113
- // 2xq = 2xR + 2xf;
114
- let _2x: U = x << 1 ;
100
+ // We need to compute the parameters
101
+ // `q = (RR/2) / m`
102
+ // `r = (RR/2) % m`
103
+
104
+ // Since `m` is in `(R/4, R/2)`, the quotient `q` is in `[R, 2R)`, and
105
+ // it would overflow in `U` if computed directly. Instead, we compute
106
+ // `f = q - R`, which is in `[0, R)`. To do so, we simply subtract `Rm`
107
+ // from the dividend, which doesn't change the remainder:
108
+ // `f = R(R/2 - m) / m`
109
+ // `r = R(R/2 - m) % m`
110
+ let dividend = ( ( _1 << ( U :: BITS - 1 ) ) - m) . widen_hi ( ) ;
111
+ let ( f, r) = dividend. checked_narrowing_div_rem ( m) . unwrap ( ) ;
112
+
113
+ // As `x < m`, `xq < qm <= RR/2`
114
+ // Thus `2xq = 2xR + 2xf` does not overflow in `U::D`.
115
+ let _2x = x + x;
115
116
let _2xq = _2x. widen_hi ( ) + _2x. widen_mul ( f) ;
116
117
Self { m, r, _2xq }
117
118
}
118
119
119
- /// Extract the current remainder in the range `[0, 2n)`
120
+ /// Extract the current remainder `x` in the range `[0, 2n)`
120
121
fn partial_remainder ( & self ) -> U {
121
- // RR/2 = qm + r, 0 <= r < m
122
- // 2xq = uR + v, 0 <= v < R
123
- // muR = 2mxq - mv
124
- // = xRR - 2xr - mv
125
- // mu + (2xr + mv)/R == xR
126
-
127
- // 0 <= 2xq < RR
128
- // R <= q < 2R
129
- // 0 <= x < R/2
130
- // R/4 < m < R/2
131
- // 0 <= r < m
132
- // 0 <= mv < mR
133
- // 0 <= 2xr < rR < mR
134
-
135
- // 0 <= (2xr + mv)/R < 2m
136
- // Add `mu` to each term to obtain:
137
- // mu <= xR < mu + 2m
138
-
139
- // Since `0 <= 2m < R`, `xR` is the only multiple of `R` between
140
- // `mu` and `m(u+2)`, so the high half of `m(u+2)` must equal `x`.
141
- let _1 = U :: ONE ;
142
- self . m . widen_mul ( self . _2xq . hi ( ) + ( _1 + _1) ) . hi ( )
122
+ // `RR/2 = qm + r`, where `0 <= r < m`
123
+ // `2xq = uR + v`, where `0 <= v < R`
124
+
125
+ // The goal is to extract the current value of `x` from the value `2xq`
126
+ // that we actually have. A bit simplified, we could multiply it by `m`
127
+ // to obtain `2xqm == 2x(RR/2 - r) == xRR - 2xr`, where `2xr < RR`.
128
+ // We could just round that up to the next multiple of `RR` to get `x`,
129
+ // but we can avoid having to multiply the full double-wide `2xq` by
130
+ // making a couple of adjustments:
131
+
132
+ // First, let's only use the high half `u` for the product, and
133
+ // include an additional error term due to the truncation:
134
+ // `mu = xR - (2xr + mv)/R`
135
+
136
+ // Next, show bounds for the error term
137
+ // `0 <= mv < mR` follows from `0 <= v < R`
138
+ // `0 <= 2xr < mR` follows from `0 <= x < m < R/2` and `0 <= r < m`
139
+ // Adding those together, we have:
140
+ // `0 <= (mv + 2xr)/R < 2m`
141
+ // Which also implies:
142
+ // `0 < 2m - (mv + 2xr)/R <= 2m < R`
143
+
144
+ // For that reason, we can use `u + 2` as the factor to obtain
145
+ // `m(u + 2) = xR + (2m - (mv + 2xr)/R)`
146
+ // By the previous inequality, the second term fits neatly in the lower
147
+ // half, so we get exactly `x` as the high half.
148
+ let u = self . _2xq . hi ( ) ;
149
+ let _2 = U :: ONE + U :: ONE ;
150
+ self . m . widen_mul ( u + _2) . hi ( )
151
+
152
+ // Additionally, we should ensure that `u + 2` cannot overflow:
153
+ // Since `x < m` and `2qm <= RR`,
154
+ // `2xq <= 2q(m-1) <= RR - 2q`
155
+ // As we also have `q > R`,
156
+ // `2xq < RR - 2R`
157
+ // which is sufficient.
143
158
}
144
159
145
160
/// Replace the remainder `x` with `(x << k) - un`,
146
161
/// for a suitable quotient `u`, which is returned.
162
+ ///
163
+ /// Requires that `k < U::BITS`.
147
164
fn shift_reduce ( & mut self , k : u32 ) -> U {
148
165
assert ! ( k < U :: BITS ) ;
149
- // 2xq << k = aRR/2 + b;
166
+
167
+ // First, split the shifted value:
168
+ // `2xq << k = aRR/2 + b`, where `0 <= b < RR/2`
150
169
let a = self . _2xq . hi ( ) >> ( U :: BITS - 1 - k) ;
151
170
let ( low, high) = ( self . _2xq << k) . lo_hi ( ) ;
152
171
let b = U :: D :: from_lo_hi ( low, high & ( U :: MAX >> 1 ) ) ;
153
172
173
+ // Then, subtract `2anq = aqm`:
174
+ // ```
154
175
// (2xq << k) - aqm
155
176
// = aRR/2 + b - aqm
156
177
// = a(RR/2 - qm) + b
157
178
// = ar + b
179
+ // ```
158
180
self . _2xq = a. widen_mul ( self . r ) + b;
159
181
a
182
+
183
+ // Since `a` is at most the high half of `2xq`, we have
184
+ // `a + 2 < R` (shown above, in `partial_remainder`)
185
+ // Using that together with `b < RR/2` and `r < m < R/2`,
186
+ // we get `(a + 2)r + b < RR`, so
187
+ // `ar + b < RR - 2r = 2mq`
188
+ // which shows that the new remainder still satisfies `x < m`.
160
189
}
161
190
191
+ // NB: `word_reduce()` is just the special case `shift_reduce(U::BITS - 1)`
192
+ // that optimizes especially well. The correspondence is that `a == u` and
193
+ // `b == (v >> 1).widen_hi()`
194
+ //
162
195
/// Replace the remainder `x` with `x(R/2) - un`,
163
196
/// for a suitable quotient `u`, which is returned.
164
197
fn word_reduce ( & mut self ) -> U {
165
- // 2xq = uR + v
166
- let ( v, u) = self . _2xq . lo_hi ( ) ;
167
- // xqR - uqm
198
+ // To do so, we replace `2xq = uR + v` with
199
+ // ```
200
+ // 2 * (x(R/2) - un) * q
201
+ // = xqR - 2unq
202
+ // = xqR - uqm
168
203
// = uRR/2 + vR/2 - uRR/2 + ur
169
204
// = ur + (v/2)R
205
+ // ```
206
+ let ( v, u) = self . _2xq . lo_hi ( ) ;
170
207
self . _2xq = u. widen_mul ( self . r ) + U :: widen_hi ( v >> 1 ) ;
171
208
u
209
+
210
+ // Additional notes:
211
+ // 1. As `v` is the low bits of `2xq`, it is even and can be halved.
212
+ // 2. The new remainder is `(xr + mv/2) / R` (see below)
213
+ // and since `v < R`, `r < m`, `x < m < R/2`,
214
+ // that is also strictly less than `m`.
215
+ // ```
216
+ // (x(R/2) - un)R
217
+ // = xRR/2 - (m/2)uR
218
+ // = x(qm + r) - (m/2)(2xq - v)
219
+ // = xqm + xr - xqm + mv/2
220
+ // = xr + mv/2
221
+ // ```
172
222
}
173
223
}
174
224
0 commit comments