Skip to content

Commit c6725e0

Browse files
authored
Merge pull request #7 from pnevyk/float_normal
Normal sampling
2 parents e9f99fb + e605385 commit c6725e0

File tree

5 files changed

+379
-4
lines changed

5 files changed

+379
-4
lines changed

Cargo.toml

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,9 +6,11 @@ rust-version = "1.43.1"
66

77
[dependencies]
88
fastrand = { version = "2.0.0", default-features = false }
9+
libm_dep = { package = "libm", version = "0.2.7", optional = true }
910

1011
[features]
1112
default = ["std"]
1213
std = ["alloc", "fastrand/std"]
1314
alloc = ["fastrand/alloc"]
14-
15+
# The `dep:` syntax was added in Rust 1.60. Our current MSRV is 1.43.
16+
libm = ["libm_dep"]

benches/normal.rs

Lines changed: 119 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,119 @@
1+
#![feature(test)]
2+
3+
extern crate test;
4+
5+
use fastrand::Rng;
6+
use fastrand_contrib::RngExt;
7+
use test::Bencher;
8+
9+
const SEED: u64 = 42;
10+
const MU: f64 = 10.0;
11+
const SIGMA: f64 = 3.0;
12+
13+
#[bench]
14+
fn box_muller(b: &mut Bencher) {
15+
let mut rng = Rng::with_seed(SEED);
16+
17+
b.iter(|| {
18+
let mu = core::hint::black_box(MU);
19+
let sigma = core::hint::black_box(SIGMA);
20+
21+
// https://en.wikipedia.org/wiki/Box%E2%80%93Muller_transform
22+
let u1 = loop {
23+
let u1 = rng.f64();
24+
25+
if u1 > f64::EPSILON {
26+
break u1;
27+
}
28+
};
29+
30+
let u2 = rng.f64();
31+
let mag = sigma * (-2.0 * u1.ln()).sqrt();
32+
let output = mag * (core::f64::consts::TAU * u2).cos() + mu;
33+
34+
core::hint::black_box(output);
35+
});
36+
}
37+
38+
#[bench]
39+
fn standard_approximation(b: &mut Bencher) {
40+
let mut rng = Rng::with_seed(SEED);
41+
42+
b.iter(|| {
43+
let mu = core::hint::black_box(MU);
44+
let sigma = core::hint::black_box(SIGMA);
45+
46+
// http://marc-b-reynolds.github.io/distribution/2021/03/18/CheapGaussianApprox.html
47+
let u = rng.u128(..);
48+
49+
let mask = 0xffffffff;
50+
let a = (u & mask) as i64;
51+
let b = ((u >> 32) & mask) as i64;
52+
let c = ((u >> 64) & mask) as i64;
53+
let d = (u >> 96) as i64;
54+
55+
// Magic constant.
56+
let k = 3.97815e-10;
57+
58+
let output = k * ((a + b) - (c + d)) as f64 * sigma + mu;
59+
60+
core::hint::black_box(output);
61+
});
62+
}
63+
64+
#[bench]
65+
fn popcount_approximation(b: &mut Bencher) {
66+
let mut rng = Rng::with_seed(SEED);
67+
68+
b.iter(|| {
69+
let mu = core::hint::black_box(MU);
70+
let sigma = core::hint::black_box(SIGMA);
71+
72+
// http://marc-b-reynolds.github.io/distribution/2021/03/18/CheapGaussianApprox.html
73+
let u = rng.u128(..);
74+
75+
let bd = (u << 64).count_ones() as i64 - 32;
76+
77+
let a = ((u >> 64) & 0xffffffff) as i64;
78+
let b = (u >> 96) as i64;
79+
80+
let td = a - b;
81+
82+
let r = ((bd << 32) + td) as f64;
83+
84+
// Magic constant.
85+
let k = 5.76917e-11;
86+
87+
let output = k * r * sigma + mu;
88+
89+
core::hint::black_box(output);
90+
})
91+
}
92+
93+
#[bench]
94+
fn f64_normal(b: &mut Bencher) {
95+
let mut rng = Rng::with_seed(SEED);
96+
97+
b.iter(|| {
98+
let mu = core::hint::black_box(MU);
99+
let sigma = core::hint::black_box(SIGMA);
100+
101+
let output = rng.f64_normal(mu, sigma);
102+
103+
core::hint::black_box(output);
104+
});
105+
}
106+
107+
#[bench]
108+
fn f64_normal_approx(b: &mut Bencher) {
109+
let mut rng = Rng::with_seed(SEED);
110+
111+
b.iter(|| {
112+
let mu = core::hint::black_box(MU);
113+
let sigma = core::hint::black_box(SIGMA);
114+
115+
let output = rng.f64_normal_approx(mu, sigma);
116+
117+
core::hint::black_box(output);
118+
});
119+
}

src/float_normal.rs

Lines changed: 222 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,222 @@
1+
use core::{
2+
cmp::PartialOrd,
3+
ops::{Add, Mul, Neg},
4+
};
5+
6+
use crate::BaseRng;
7+
8+
#[cfg(any(feature = "std", feature = "libm"))]
9+
pub(super) fn f32(rng: &mut impl BaseRng, mu: f32, sigma: f32) -> f32 {
10+
float_normal_impl(rng, mu, sigma)
11+
}
12+
13+
#[cfg(any(feature = "std", feature = "libm"))]
14+
pub(super) fn f64(rng: &mut impl BaseRng, mu: f64, sigma: f64) -> f64 {
15+
float_normal_impl(rng, mu, sigma)
16+
}
17+
18+
pub(super) fn f32_approx(rng: &mut impl BaseRng, mu: f32, sigma: f32) -> f32 {
19+
float_normal_approx_impl(rng, mu, sigma)
20+
}
21+
22+
pub(super) fn f64_approx(rng: &mut impl BaseRng, mu: f64, sigma: f64) -> f64 {
23+
float_normal_approx_impl(rng, mu, sigma)
24+
}
25+
26+
trait FloatExt:
27+
Add<Self, Output = Self> + Mul<Self, Output = Self> + Neg<Output = Self> + PartialOrd<Self> + Sized
28+
{
29+
const EPSILON: Self;
30+
31+
fn from_f64(x: f64) -> Self;
32+
fn gen(rng: &mut impl BaseRng) -> Self;
33+
}
34+
35+
#[cfg(any(feature = "std", feature = "libm"))]
36+
trait FloatMathExt: FloatExt {
37+
const TAU: Self;
38+
39+
fn ln(self) -> Self;
40+
fn sqrt(self) -> Self;
41+
fn cos(self) -> Self;
42+
}
43+
44+
macro_rules! impl_float_ext {
45+
($float:ident) => {
46+
impl FloatExt for $float {
47+
const EPSILON: Self = $float::EPSILON;
48+
49+
#[inline]
50+
fn from_f64(x: f64) -> Self {
51+
x as $float
52+
}
53+
#[inline]
54+
fn gen(rng: &mut impl BaseRng) -> Self {
55+
rng.$float()
56+
}
57+
}
58+
};
59+
}
60+
61+
macro_rules! impl_float_math_ext {
62+
($float:ident, $tau:ident) => {
63+
#[cfg(all(feature = "std", not(feature = "libm")))]
64+
impl FloatMathExt for $float {
65+
const TAU: Self = $tau;
66+
67+
#[inline]
68+
fn ln(self) -> Self {
69+
$float::ln(self)
70+
}
71+
#[inline]
72+
fn sqrt(self) -> Self {
73+
$float::sqrt(self)
74+
}
75+
#[inline]
76+
fn cos(self) -> Self {
77+
$float::cos(self)
78+
}
79+
}
80+
81+
#[cfg(feature = "libm")]
82+
impl FloatMathExt for $float {
83+
const TAU: Self = $tau;
84+
85+
#[inline]
86+
fn ln(self) -> Self {
87+
libm_dep::Libm::<$float>::log(self)
88+
}
89+
#[inline]
90+
fn sqrt(self) -> Self {
91+
libm_dep::Libm::<$float>::sqrt(self)
92+
}
93+
#[inline]
94+
fn cos(self) -> Self {
95+
libm_dep::Libm::<$float>::cos(self)
96+
}
97+
}
98+
};
99+
}
100+
101+
// TAU constant was stabilized in Rust 1.47. Our current MSRV is 1.43.
102+
#[cfg(any(feature = "std", feature = "libm"))]
103+
#[allow(clippy::excessive_precision)]
104+
const F32_TAU: f32 = 6.28318530717958647692528676655900577_f32;
105+
#[cfg(any(feature = "std", feature = "libm"))]
106+
#[allow(clippy::excessive_precision)]
107+
const F64_TAU: f64 = 6.28318530717958647692528676655900577_f64;
108+
109+
impl_float_ext!(f32);
110+
impl_float_ext!(f64);
111+
impl_float_math_ext!(f32, F32_TAU);
112+
impl_float_math_ext!(f64, F64_TAU);
113+
114+
#[cfg(any(feature = "std", feature = "libm"))]
115+
fn float_normal_impl<T: FloatMathExt>(rng: &mut impl BaseRng, mu: T, sigma: T) -> T {
116+
// https://en.wikipedia.org/wiki/Box%E2%80%93Muller_transform
117+
let u1 = loop {
118+
let u1 = T::gen(rng);
119+
120+
if u1 > T::EPSILON {
121+
break u1;
122+
}
123+
};
124+
125+
let u2 = T::gen(rng);
126+
let mag = sigma * (-T::from_f64(2.0) * u1.ln()).sqrt();
127+
mag * (T::TAU * u2).cos() + mu
128+
}
129+
130+
fn float_normal_approx_impl<T: FloatExt>(rng: &mut impl BaseRng, mu: T, sigma: T) -> T {
131+
// http://marc-b-reynolds.github.io/distribution/2021/03/18/CheapGaussianApprox.html
132+
let u = rng.u128();
133+
134+
// Counting ones in a u64 half of the generated number gives us binomial
135+
// distribution with p = 1/2 and n = 64. Subtracting 32 centers the
136+
// distribution on [-32, 32]. Shifting the lower u64 by 64 bits discard the
137+
// other half and `count_ones` can be used without any masking.
138+
let bd = (u << 64).count_ones() as i64 - 32;
139+
140+
// Sample two u32 integers from uniform distribution.
141+
let a = ((u >> 64) & 0xffffffff) as i64;
142+
let b = (u >> 96) as i64;
143+
144+
// First iteration of Central limit theorem (summing two uniform random
145+
// variables) _often_ gives triangular distribution. By using subtraction
146+
// instead of addition, the triangular distribution is centered around zero.
147+
let td = a - b;
148+
149+
// Sum the binomial and triangular distributions.
150+
let r = ((bd << 32) + td) as f64;
151+
152+
// Magic constant for scaling which is a result of minimizing the maximum
153+
// error with respect to the reference normal distribution.
154+
let k = 5.76917e-11;
155+
156+
T::from_f64(k * r) * sigma + mu
157+
}
158+
159+
#[cfg(test)]
160+
mod tests {
161+
use fastrand::Rng;
162+
163+
use super::*;
164+
165+
fn normal_distribution_test<F>(sample: F)
166+
where
167+
F: Fn(&mut Rng, f32, f32) -> f32,
168+
{
169+
// The test based on the following picture from Wikipedia:
170+
// https://upload.wikimedia.org/wikipedia/commons/thumb/3/3a/Standard_deviation_diagram_micro.svg/1920px-Standard_deviation_diagram_micro.svg.png
171+
let mut rng = Rng::with_seed(42);
172+
173+
let mu = 10.0;
174+
let sigma = 3.0;
175+
176+
let total = 10000;
177+
let mut in_one_sigma_range = 0;
178+
let mut in_two_sigma_range = 0;
179+
let mut in_three_sigma_range = 0;
180+
for _ in 0..total {
181+
let value = sample(&mut rng, mu, sigma);
182+
183+
if (mu - sigma..=mu + sigma).contains(&value) {
184+
in_one_sigma_range += 1;
185+
} else if (mu - sigma * 2.0..=mu + sigma * 2.0).contains(&value) {
186+
in_two_sigma_range += 1;
187+
} else if (mu - sigma * 3.0..=mu + sigma * 3.0).contains(&value) {
188+
in_three_sigma_range += 1;
189+
}
190+
}
191+
192+
let in_one_sigma_range = in_one_sigma_range as f32 / total as f32 * 100.0;
193+
let in_two_sigma_range = in_two_sigma_range as f32 / total as f32 * 100.0;
194+
let in_three_sigma_range = in_three_sigma_range as f32 / total as f32 * 100.0;
195+
assert!(
196+
(64.0..=72.0).contains(&in_one_sigma_range),
197+
"value in \"one sigma range\" should be sampled ~68.2%, but is {}%",
198+
in_one_sigma_range
199+
);
200+
assert!(
201+
(23.0..=31.0).contains(&in_two_sigma_range),
202+
"value in \"two sigma range\" should be sampled ~27.2%, but is {}%",
203+
in_two_sigma_range
204+
);
205+
assert!(
206+
(1.0..=7.0).contains(&in_three_sigma_range),
207+
"value in \"three sigma range\" should be sampled ~4.2%, but is {}%",
208+
in_three_sigma_range
209+
);
210+
}
211+
212+
#[test]
213+
#[cfg(any(feature = "std", feature = "libm"))]
214+
fn normal_is_actually_normal() {
215+
normal_distribution_test(float_normal_impl);
216+
}
217+
218+
#[test]
219+
fn normal_approx_is_actually_normal() {
220+
normal_distribution_test(float_normal_approx_impl);
221+
}
222+
}

src/float_range.rs

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -213,7 +213,7 @@ mod tests {
213213

214214
#[test]
215215
fn f32_range_in_bounds() {
216-
let mut rng = Rng::new();
216+
let mut rng = Rng::with_seed(42);
217217

218218
let range = -2.0..2.0;
219219
for _ in 0..10000 {
@@ -223,7 +223,7 @@ mod tests {
223223

224224
#[test]
225225
fn f32_range_wide_range_in_bounds() {
226-
let mut rng = Rng::new();
226+
let mut rng = Rng::with_seed(42);
227227

228228
let range = f32::MIN..f32::MAX;
229229
for _ in 0..10000 {
@@ -233,7 +233,7 @@ mod tests {
233233

234234
#[test]
235235
fn f32_range_unbounded_finite() {
236-
let mut rng = Rng::new();
236+
let mut rng = Rng::with_seed(42);
237237

238238
let range = ..;
239239
for _ in 0..10000 {

0 commit comments

Comments
 (0)