|
| 1 | +use core::{ |
| 2 | + cmp::PartialOrd, |
| 3 | + ops::{Add, Mul, Neg}, |
| 4 | +}; |
| 5 | + |
| 6 | +use crate::BaseRng; |
| 7 | + |
| 8 | +#[cfg(any(feature = "std", feature = "libm"))] |
| 9 | +pub(super) fn f32(rng: &mut impl BaseRng, mu: f32, sigma: f32) -> f32 { |
| 10 | + float_normal_impl(rng, mu, sigma) |
| 11 | +} |
| 12 | + |
| 13 | +#[cfg(any(feature = "std", feature = "libm"))] |
| 14 | +pub(super) fn f64(rng: &mut impl BaseRng, mu: f64, sigma: f64) -> f64 { |
| 15 | + float_normal_impl(rng, mu, sigma) |
| 16 | +} |
| 17 | + |
| 18 | +pub(super) fn f32_approx(rng: &mut impl BaseRng, mu: f32, sigma: f32) -> f32 { |
| 19 | + float_normal_approx_impl(rng, mu, sigma) |
| 20 | +} |
| 21 | + |
| 22 | +pub(super) fn f64_approx(rng: &mut impl BaseRng, mu: f64, sigma: f64) -> f64 { |
| 23 | + float_normal_approx_impl(rng, mu, sigma) |
| 24 | +} |
| 25 | + |
| 26 | +trait FloatExt: |
| 27 | + Add<Self, Output = Self> + Mul<Self, Output = Self> + Neg<Output = Self> + PartialOrd<Self> + Sized |
| 28 | +{ |
| 29 | + const EPSILON: Self; |
| 30 | + |
| 31 | + fn from_f64(x: f64) -> Self; |
| 32 | + fn gen(rng: &mut impl BaseRng) -> Self; |
| 33 | +} |
| 34 | + |
| 35 | +#[cfg(any(feature = "std", feature = "libm"))] |
| 36 | +trait FloatMathExt: FloatExt { |
| 37 | + const TAU: Self; |
| 38 | + |
| 39 | + fn ln(self) -> Self; |
| 40 | + fn sqrt(self) -> Self; |
| 41 | + fn cos(self) -> Self; |
| 42 | +} |
| 43 | + |
| 44 | +macro_rules! impl_float_ext { |
| 45 | + ($float:ident) => { |
| 46 | + impl FloatExt for $float { |
| 47 | + const EPSILON: Self = $float::EPSILON; |
| 48 | + |
| 49 | + #[inline] |
| 50 | + fn from_f64(x: f64) -> Self { |
| 51 | + x as $float |
| 52 | + } |
| 53 | + #[inline] |
| 54 | + fn gen(rng: &mut impl BaseRng) -> Self { |
| 55 | + rng.$float() |
| 56 | + } |
| 57 | + } |
| 58 | + }; |
| 59 | +} |
| 60 | + |
| 61 | +macro_rules! impl_float_math_ext { |
| 62 | + ($float:ident, $tau:ident) => { |
| 63 | + #[cfg(all(feature = "std", not(feature = "libm")))] |
| 64 | + impl FloatMathExt for $float { |
| 65 | + const TAU: Self = $tau; |
| 66 | + |
| 67 | + #[inline] |
| 68 | + fn ln(self) -> Self { |
| 69 | + $float::ln(self) |
| 70 | + } |
| 71 | + #[inline] |
| 72 | + fn sqrt(self) -> Self { |
| 73 | + $float::sqrt(self) |
| 74 | + } |
| 75 | + #[inline] |
| 76 | + fn cos(self) -> Self { |
| 77 | + $float::cos(self) |
| 78 | + } |
| 79 | + } |
| 80 | + |
| 81 | + #[cfg(feature = "libm")] |
| 82 | + impl FloatMathExt for $float { |
| 83 | + const TAU: Self = $tau; |
| 84 | + |
| 85 | + #[inline] |
| 86 | + fn ln(self) -> Self { |
| 87 | + libm_dep::Libm::<$float>::log(self) |
| 88 | + } |
| 89 | + #[inline] |
| 90 | + fn sqrt(self) -> Self { |
| 91 | + libm_dep::Libm::<$float>::sqrt(self) |
| 92 | + } |
| 93 | + #[inline] |
| 94 | + fn cos(self) -> Self { |
| 95 | + libm_dep::Libm::<$float>::cos(self) |
| 96 | + } |
| 97 | + } |
| 98 | + }; |
| 99 | +} |
| 100 | + |
| 101 | +// TAU constant was stabilized in Rust 1.47. Our current MSRV is 1.43. |
| 102 | +#[cfg(any(feature = "std", feature = "libm"))] |
| 103 | +#[allow(clippy::excessive_precision)] |
| 104 | +const F32_TAU: f32 = 6.28318530717958647692528676655900577_f32; |
| 105 | +#[cfg(any(feature = "std", feature = "libm"))] |
| 106 | +#[allow(clippy::excessive_precision)] |
| 107 | +const F64_TAU: f64 = 6.28318530717958647692528676655900577_f64; |
| 108 | + |
| 109 | +impl_float_ext!(f32); |
| 110 | +impl_float_ext!(f64); |
| 111 | +impl_float_math_ext!(f32, F32_TAU); |
| 112 | +impl_float_math_ext!(f64, F64_TAU); |
| 113 | + |
| 114 | +#[cfg(any(feature = "std", feature = "libm"))] |
| 115 | +fn float_normal_impl<T: FloatMathExt>(rng: &mut impl BaseRng, mu: T, sigma: T) -> T { |
| 116 | + // https://en.wikipedia.org/wiki/Box%E2%80%93Muller_transform |
| 117 | + let u1 = loop { |
| 118 | + let u1 = T::gen(rng); |
| 119 | + |
| 120 | + if u1 > T::EPSILON { |
| 121 | + break u1; |
| 122 | + } |
| 123 | + }; |
| 124 | + |
| 125 | + let u2 = T::gen(rng); |
| 126 | + let mag = sigma * (-T::from_f64(2.0) * u1.ln()).sqrt(); |
| 127 | + mag * (T::TAU * u2).cos() + mu |
| 128 | +} |
| 129 | + |
| 130 | +fn float_normal_approx_impl<T: FloatExt>(rng: &mut impl BaseRng, mu: T, sigma: T) -> T { |
| 131 | + // http://marc-b-reynolds.github.io/distribution/2021/03/18/CheapGaussianApprox.html |
| 132 | + let u = rng.u128(); |
| 133 | + |
| 134 | + // Counting ones in a u64 half of the generated number gives us binomial |
| 135 | + // distribution with p = 1/2 and n = 64. Subtracting 32 centers the |
| 136 | + // distribution on [-32, 32]. Shifting the lower u64 by 64 bits discard the |
| 137 | + // other half and `count_ones` can be used without any masking. |
| 138 | + let bd = (u << 64).count_ones() as i64 - 32; |
| 139 | + |
| 140 | + // Sample two u32 integers from uniform distribution. |
| 141 | + let a = ((u >> 64) & 0xffffffff) as i64; |
| 142 | + let b = (u >> 96) as i64; |
| 143 | + |
| 144 | + // First iteration of Central limit theorem (summing two uniform random |
| 145 | + // variables) _often_ gives triangular distribution. By using subtraction |
| 146 | + // instead of addition, the triangular distribution is centered around zero. |
| 147 | + let td = a - b; |
| 148 | + |
| 149 | + // Sum the binomial and triangular distributions. |
| 150 | + let r = ((bd << 32) + td) as f64; |
| 151 | + |
| 152 | + // Magic constant for scaling which is a result of minimizing the maximum |
| 153 | + // error with respect to the reference normal distribution. |
| 154 | + let k = 5.76917e-11; |
| 155 | + |
| 156 | + T::from_f64(k * r) * sigma + mu |
| 157 | +} |
| 158 | + |
| 159 | +#[cfg(test)] |
| 160 | +mod tests { |
| 161 | + use fastrand::Rng; |
| 162 | + |
| 163 | + use super::*; |
| 164 | + |
| 165 | + fn normal_distribution_test<F>(sample: F) |
| 166 | + where |
| 167 | + F: Fn(&mut Rng, f32, f32) -> f32, |
| 168 | + { |
| 169 | + // The test based on the following picture from Wikipedia: |
| 170 | + // https://upload.wikimedia.org/wikipedia/commons/thumb/3/3a/Standard_deviation_diagram_micro.svg/1920px-Standard_deviation_diagram_micro.svg.png |
| 171 | + let mut rng = Rng::with_seed(42); |
| 172 | + |
| 173 | + let mu = 10.0; |
| 174 | + let sigma = 3.0; |
| 175 | + |
| 176 | + let total = 10000; |
| 177 | + let mut in_one_sigma_range = 0; |
| 178 | + let mut in_two_sigma_range = 0; |
| 179 | + let mut in_three_sigma_range = 0; |
| 180 | + for _ in 0..total { |
| 181 | + let value = sample(&mut rng, mu, sigma); |
| 182 | + |
| 183 | + if (mu - sigma..=mu + sigma).contains(&value) { |
| 184 | + in_one_sigma_range += 1; |
| 185 | + } else if (mu - sigma * 2.0..=mu + sigma * 2.0).contains(&value) { |
| 186 | + in_two_sigma_range += 1; |
| 187 | + } else if (mu - sigma * 3.0..=mu + sigma * 3.0).contains(&value) { |
| 188 | + in_three_sigma_range += 1; |
| 189 | + } |
| 190 | + } |
| 191 | + |
| 192 | + let in_one_sigma_range = in_one_sigma_range as f32 / total as f32 * 100.0; |
| 193 | + let in_two_sigma_range = in_two_sigma_range as f32 / total as f32 * 100.0; |
| 194 | + let in_three_sigma_range = in_three_sigma_range as f32 / total as f32 * 100.0; |
| 195 | + assert!( |
| 196 | + (64.0..=72.0).contains(&in_one_sigma_range), |
| 197 | + "value in \"one sigma range\" should be sampled ~68.2%, but is {}%", |
| 198 | + in_one_sigma_range |
| 199 | + ); |
| 200 | + assert!( |
| 201 | + (23.0..=31.0).contains(&in_two_sigma_range), |
| 202 | + "value in \"two sigma range\" should be sampled ~27.2%, but is {}%", |
| 203 | + in_two_sigma_range |
| 204 | + ); |
| 205 | + assert!( |
| 206 | + (1.0..=7.0).contains(&in_three_sigma_range), |
| 207 | + "value in \"three sigma range\" should be sampled ~4.2%, but is {}%", |
| 208 | + in_three_sigma_range |
| 209 | + ); |
| 210 | + } |
| 211 | + |
| 212 | + #[test] |
| 213 | + #[cfg(any(feature = "std", feature = "libm"))] |
| 214 | + fn normal_is_actually_normal() { |
| 215 | + normal_distribution_test(float_normal_impl); |
| 216 | + } |
| 217 | + |
| 218 | + #[test] |
| 219 | + fn normal_approx_is_actually_normal() { |
| 220 | + normal_distribution_test(float_normal_approx_impl); |
| 221 | + } |
| 222 | +} |
0 commit comments