@@ -53,6 +53,14 @@ use serde::{Deserialize, Serialize};
5353///
5454/// # Notes
5555///
56+ /// When the shape (`k`) or scale (`θ`) parameters are close to the upper limits
57+ /// of the floating point type `F`, the implementation may overflow and produce
58+ /// `inf`. On the other hand, when `k` is relatively close to zero (like 0.005)
59+ /// and `θ` is huge (like 1e200), the implementation is likely be affected by
60+ /// underflow and may fail to produce tiny floating point values (like 1e-200),
61+ /// returning 0.0 for them instead. The exact thresholds for this to occur
62+ /// depend on `F`.
63+ ///
5664/// The algorithm used is that described by Marsaglia & Tsang 2000[^1],
5765/// falling back to directly sampling from an Exponential for `shape
5866/// == 1`, and using the boosting technique described in that paper for
@@ -173,8 +181,10 @@ where
173181 return Err ( Error :: ScaleTooSmall ) ;
174182 }
175183
176- let repr = if shape == F :: one ( ) {
177- One ( Exp :: new ( F :: one ( ) / scale) . map_err ( |_| Error :: ScaleTooLarge ) ?)
184+ let repr = if shape == F :: infinity ( ) || scale == F :: infinity ( ) {
185+ One ( Exp :: new ( F :: zero ( ) ) . unwrap ( ) )
186+ } else if shape == F :: one ( ) {
187+ One ( Exp :: new ( F :: one ( ) / scale) . unwrap ( ) )
178188 } else if shape < F :: one ( ) {
179189 Small ( GammaSmallShape :: new_raw ( shape, scale) )
180190 } else {
@@ -212,6 +222,28 @@ where
212222 d,
213223 }
214224 }
225+
226+ fn sample_unscaled < R : Rng + ?Sized > ( & self , rng : & mut R ) -> F {
227+ // Marsaglia & Tsang method, 2000
228+ loop {
229+ let x: F = rng. sample ( StandardNormal ) ;
230+ let v_cbrt = F :: one ( ) + self . c * x;
231+ if v_cbrt <= F :: zero ( ) {
232+ continue ;
233+ }
234+
235+ let v = v_cbrt * v_cbrt * v_cbrt;
236+ let u: F = rng. sample ( Open01 ) ;
237+
238+ let x_sqr = x * x;
239+ if u < F :: one ( ) - F :: from ( 0.0331 ) . unwrap ( ) * x_sqr * x_sqr
240+ || u. ln ( ) < F :: from ( 0.5 ) . unwrap ( ) * x_sqr + self . d * ( F :: one ( ) - v + v. ln ( ) )
241+ {
242+ // `x` is concentrated enough that `v` should always be finite
243+ return v;
244+ }
245+ }
246+ }
215247}
216248
217249impl < F > Distribution < F > for Gamma < F >
@@ -238,35 +270,22 @@ where
238270 fn sample < R : Rng + ?Sized > ( & self , rng : & mut R ) -> F {
239271 let u: F = rng. sample ( Open01 ) ;
240272
241- self . large_shape . sample ( rng) * u. powf ( self . inv_shape )
273+ let a = self . large_shape . sample_unscaled ( rng) ;
274+ let b = u. powf ( self . inv_shape ) ;
275+ // Multiplying numbers with `scale` can overflow, so do it last to avoid
276+ // producing NaN = inf * 0.0. All the other terms are finite and small.
277+ ( a * b * self . large_shape . d ) * self . large_shape . scale
242278 }
243279}
280+
244281impl < F > Distribution < F > for GammaLargeShape < F >
245282where
246283 F : Float ,
247284 StandardNormal : Distribution < F > ,
248285 Open01 : Distribution < F > ,
249286{
250287 fn sample < R : Rng + ?Sized > ( & self , rng : & mut R ) -> F {
251- // Marsaglia & Tsang method, 2000
252- loop {
253- let x: F = rng. sample ( StandardNormal ) ;
254- let v_cbrt = F :: one ( ) + self . c * x;
255- if v_cbrt <= F :: zero ( ) {
256- // a^3 <= 0 iff a <= 0
257- continue ;
258- }
259-
260- let v = v_cbrt * v_cbrt * v_cbrt;
261- let u: F = rng. sample ( Open01 ) ;
262-
263- let x_sqr = x * x;
264- if u < F :: one ( ) - F :: from ( 0.0331 ) . unwrap ( ) * x_sqr * x_sqr
265- || u. ln ( ) < F :: from ( 0.5 ) . unwrap ( ) * x_sqr + self . d * ( F :: one ( ) - v + v. ln ( ) )
266- {
267- return self . d * v * self . scale ;
268- }
269- }
288+ self . sample_unscaled ( rng) * ( self . d * self . scale )
270289 }
271290}
272291
@@ -278,4 +297,13 @@ mod test {
278297 fn gamma_distributions_can_be_compared ( ) {
279298 assert_eq ! ( Gamma :: new( 1.0 , 2.0 ) , Gamma :: new( 1.0 , 2.0 ) ) ;
280299 }
300+
301+ #[ test]
302+ fn gamma_extreme_values ( ) {
303+ let d = Gamma :: new ( f64:: infinity ( ) , 2.0 ) . unwrap ( ) ;
304+ assert_eq ! ( d. sample( & mut crate :: test:: rng( 21 ) ) , f64 :: infinity( ) ) ;
305+
306+ let d = Gamma :: new ( 2.0 , f64:: infinity ( ) ) . unwrap ( ) ;
307+ assert_eq ! ( d. sample( & mut crate :: test:: rng( 21 ) ) , f64 :: infinity( ) ) ;
308+ }
281309}
0 commit comments