Skip to content

Commit a7055c6

Browse files
committed
Avoid returning NaN from Gamma::sample
This changes the order of multiplications used to compute the result to avoid multiplying zero with an expression that can overflow to +inf. Note that the parameter combinations which could lead to this (shape very close to zero, scale very close to the max float value) continue to be inaccurately handled; the Gamma distribution sampler will now tend to return zero instead of NaN for them. The limit (shape 0, scale inf) is not well defined.
1 parent be28239 commit a7055c6

File tree

2 files changed

+51
-22
lines changed

2 files changed

+51
-22
lines changed

CHANGELOG.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
2424
- Fix hang and debug assertion in `Zipf::new` on invalid parameters (#41)
2525
- Fix panic in `Binomial::sample` with `n ≥ 2^63`; this is a Value-breaking change (#43)
2626
- Error instead of producing `-inf` output for `Exp` when `lambda` is `-0.0` (#44)
27+
- Avoid returning NaN from `Gamma::sample`; this is a Value-breaking change and also affects `ChiSquared` and `Dirichlet` (#46)
2728

2829
## [0.5.2]
2930

src/gamma.rs

Lines changed: 50 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,14 @@ use serde::{Deserialize, Serialize};
5353
///
5454
/// # Notes
5555
///
56+
/// When the shape (`k`) or scale (`θ`) parameters are close to the upper limits
57+
/// of the floating point type `F`, the implementation may overflow and produce
58+
/// `inf`. On the other hand, when `k` is relatively close to zero (like 0.005)
59+
/// and `θ` is huge (like 1e200), the implementation is likely be affected by
60+
/// underflow and may fail to produce tiny floating point values (like 1e-200),
61+
/// returning 0.0 for them instead. The exact thresholds for this to occur
62+
/// depend on `F`.
63+
///
5664
/// The algorithm used is that described by Marsaglia & Tsang 2000[^1],
5765
/// falling back to directly sampling from an Exponential for `shape
5866
/// == 1`, and using the boosting technique described in that paper for
@@ -173,8 +181,10 @@ where
173181
return Err(Error::ScaleTooSmall);
174182
}
175183

176-
let repr = if shape == F::one() {
177-
One(Exp::new(F::one() / scale).map_err(|_| Error::ScaleTooLarge)?)
184+
let repr = if shape == F::infinity() || scale == F::infinity() {
185+
One(Exp::new(F::zero()).unwrap())
186+
} else if shape == F::one() {
187+
One(Exp::new(F::one() / scale).unwrap())
178188
} else if shape < F::one() {
179189
Small(GammaSmallShape::new_raw(shape, scale))
180190
} else {
@@ -212,6 +222,28 @@ where
212222
d,
213223
}
214224
}
225+
226+
fn sample_unscaled<R: Rng + ?Sized>(&self, rng: &mut R) -> F {
227+
// Marsaglia & Tsang method, 2000
228+
loop {
229+
let x: F = rng.sample(StandardNormal);
230+
let v_cbrt = F::one() + self.c * x;
231+
if v_cbrt <= F::zero() {
232+
continue;
233+
}
234+
235+
let v = v_cbrt * v_cbrt * v_cbrt;
236+
let u: F = rng.sample(Open01);
237+
238+
let x_sqr = x * x;
239+
if u < F::one() - F::from(0.0331).unwrap() * x_sqr * x_sqr
240+
|| u.ln() < F::from(0.5).unwrap() * x_sqr + self.d * (F::one() - v + v.ln())
241+
{
242+
// `x` is concentrated enough that `v` should always be finite
243+
return v;
244+
}
245+
}
246+
}
215247
}
216248

217249
impl<F> Distribution<F> for Gamma<F>
@@ -238,35 +270,22 @@ where
238270
fn sample<R: Rng + ?Sized>(&self, rng: &mut R) -> F {
239271
let u: F = rng.sample(Open01);
240272

241-
self.large_shape.sample(rng) * u.powf(self.inv_shape)
273+
let a = self.large_shape.sample_unscaled(rng);
274+
let b = u.powf(self.inv_shape);
275+
// Multiplying numbers with `scale` can overflow, so do it last to avoid
276+
// producing NaN = inf * 0.0. All the other terms are finite and small.
277+
(a * b * self.large_shape.d) * self.large_shape.scale
242278
}
243279
}
280+
244281
impl<F> Distribution<F> for GammaLargeShape<F>
245282
where
246283
F: Float,
247284
StandardNormal: Distribution<F>,
248285
Open01: Distribution<F>,
249286
{
250287
fn sample<R: Rng + ?Sized>(&self, rng: &mut R) -> F {
251-
// Marsaglia & Tsang method, 2000
252-
loop {
253-
let x: F = rng.sample(StandardNormal);
254-
let v_cbrt = F::one() + self.c * x;
255-
if v_cbrt <= F::zero() {
256-
// a^3 <= 0 iff a <= 0
257-
continue;
258-
}
259-
260-
let v = v_cbrt * v_cbrt * v_cbrt;
261-
let u: F = rng.sample(Open01);
262-
263-
let x_sqr = x * x;
264-
if u < F::one() - F::from(0.0331).unwrap() * x_sqr * x_sqr
265-
|| u.ln() < F::from(0.5).unwrap() * x_sqr + self.d * (F::one() - v + v.ln())
266-
{
267-
return self.d * v * self.scale;
268-
}
269-
}
288+
self.sample_unscaled(rng) * (self.d * self.scale)
270289
}
271290
}
272291

@@ -278,4 +297,13 @@ mod test {
278297
fn gamma_distributions_can_be_compared() {
279298
assert_eq!(Gamma::new(1.0, 2.0), Gamma::new(1.0, 2.0));
280299
}
300+
301+
#[test]
302+
fn gamma_extreme_values() {
303+
let d = Gamma::new(f64::infinity(), 2.0).unwrap();
304+
assert_eq!(d.sample(&mut crate::test::rng(21)), f64::infinity());
305+
306+
let d = Gamma::new(2.0, f64::infinity()).unwrap();
307+
assert_eq!(d.sample(&mut crate::test::rng(21)), f64::infinity());
308+
}
281309
}

0 commit comments

Comments
 (0)