Skip to content

Commit 496b8ac

Browse files
Nil GoyetteLukeMathWalker
authored andcommitted
Weighted mean (#51)
* Move MultiInputError in lib * Add weighted_mean * Add weighted_mean tests * Add weighted_mean_axis * Add weighted_mean_axis tests * Use quickcheck in mean tests * Divide into mean_axis and sum_axis * Add tests * Precalculate weights_sum * Update doc and rules * Add tests
1 parent 680209f commit 496b8ac

File tree

4 files changed

+334
-30
lines changed

4 files changed

+334
-30
lines changed

src/deviation.rs

Lines changed: 1 addition & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@ use num_traits::{Signed, ToPrimitive};
33
use std::convert::Into;
44
use std::ops::AddAssign;
55

6-
use crate::errors::{MultiInputError, ShapeMismatch};
6+
use crate::errors::MultiInputError;
77

88
/// An extension trait for `ArrayBase` providing functions
99
/// to compute different deviation measures.
@@ -217,25 +217,6 @@ where
217217
private_decl! {}
218218
}
219219

220-
macro_rules! return_err_if_empty {
221-
($arr:expr) => {
222-
if $arr.len() == 0 {
223-
return Err(MultiInputError::EmptyInput);
224-
}
225-
};
226-
}
227-
macro_rules! return_err_unless_same_shape {
228-
($arr_a:expr, $arr_b:expr) => {
229-
if $arr_a.shape() != $arr_b.shape() {
230-
return Err(ShapeMismatch {
231-
first_shape: $arr_a.shape().to_vec(),
232-
second_shape: $arr_b.shape().to_vec(),
233-
}
234-
.into());
235-
}
236-
};
237-
}
238-
239220
impl<A, S, D> DeviationExt<A, S, D> for ArrayBase<S, D>
240221
where
241222
S: Data<Elem = A>,

src/lib.rs

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,29 @@ pub use crate::summary_statistics::SummaryStatisticsExt;
4242
#[macro_use]
4343
extern crate approx;
4444

45+
#[macro_use]
46+
mod multi_input_error_macros {
47+
macro_rules! return_err_if_empty {
48+
($arr:expr) => {
49+
if $arr.len() == 0 {
50+
return Err(MultiInputError::EmptyInput);
51+
}
52+
};
53+
}
54+
macro_rules! return_err_unless_same_shape {
55+
($arr_a:expr, $arr_b:expr) => {
56+
use crate::errors::{MultiInputError, ShapeMismatch};
57+
if $arr_a.shape() != $arr_b.shape() {
58+
return Err(MultiInputError::ShapeMismatch(ShapeMismatch {
59+
first_shape: $arr_a.shape().to_vec(),
60+
second_shape: $arr_b.shape().to_vec(),
61+
})
62+
.into());
63+
}
64+
};
65+
}
66+
}
67+
4568
#[macro_use]
4669
mod private {
4770
/// This is a public type in a private module, so it can be included in

src/summary_statistics/means.rs

Lines changed: 213 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,9 @@
11
use super::SummaryStatisticsExt;
2-
use crate::errors::EmptyInput;
3-
use ndarray::{ArrayBase, Data, Dimension};
2+
use crate::errors::{EmptyInput, MultiInputError, ShapeMismatch};
3+
use ndarray::{Array, ArrayBase, Axis, Data, Dimension, Ix1, RemoveAxis};
44
use num_integer::IterBinomial;
55
use num_traits::{Float, FromPrimitive, Zero};
6-
use std::ops::{Add, Div};
6+
use std::ops::{Add, Div, Mul};
77

88
impl<A, S, D> SummaryStatisticsExt<A, S, D> for ArrayBase<S, D>
99
where
@@ -24,6 +24,67 @@ where
2424
}
2525
}
2626

27+
fn weighted_mean(&self, weights: &Self) -> Result<A, MultiInputError>
28+
where
29+
A: Copy + Div<Output = A> + Mul<Output = A> + Zero,
30+
{
31+
return_err_if_empty!(self);
32+
let weighted_sum = self.weighted_sum(weights)?;
33+
Ok(weighted_sum / weights.sum())
34+
}
35+
36+
fn weighted_sum(&self, weights: &ArrayBase<S, D>) -> Result<A, MultiInputError>
37+
where
38+
A: Copy + Mul<Output = A> + Zero,
39+
{
40+
return_err_unless_same_shape!(self, weights);
41+
Ok(self
42+
.iter()
43+
.zip(weights)
44+
.fold(A::zero(), |acc, (&d, &w)| acc + d * w))
45+
}
46+
47+
fn weighted_mean_axis(
48+
&self,
49+
axis: Axis,
50+
weights: &ArrayBase<S, Ix1>,
51+
) -> Result<Array<A, D::Smaller>, MultiInputError>
52+
where
53+
A: Copy + Div<Output = A> + Mul<Output = A> + Zero,
54+
D: RemoveAxis,
55+
{
56+
return_err_if_empty!(self);
57+
let mut weighted_sum = self.weighted_sum_axis(axis, weights)?;
58+
let weights_sum = weights.sum();
59+
weighted_sum.mapv_inplace(|v| v / weights_sum);
60+
Ok(weighted_sum)
61+
}
62+
63+
fn weighted_sum_axis(
64+
&self,
65+
axis: Axis,
66+
weights: &ArrayBase<S, Ix1>,
67+
) -> Result<Array<A, D::Smaller>, MultiInputError>
68+
where
69+
A: Copy + Mul<Output = A> + Zero,
70+
D: RemoveAxis,
71+
{
72+
if self.shape()[axis.index()] != weights.len() {
73+
return Err(MultiInputError::ShapeMismatch(ShapeMismatch {
74+
first_shape: self.shape().to_vec(),
75+
second_shape: weights.shape().to_vec(),
76+
}));
77+
}
78+
79+
// We could use `lane.weighted_sum` here, but we're avoiding 2
80+
// conditions and an unwrap per lane.
81+
Ok(self.map_axis(axis, |lane| {
82+
lane.iter()
83+
.zip(weights)
84+
.fold(A::zero(), |acc, (&d, &w)| acc + d * w)
85+
}))
86+
}
87+
2788
fn harmonic_mean(&self) -> Result<A, EmptyInput>
2889
where
2990
A: Float + FromPrimitive,
@@ -194,18 +255,31 @@ where
194255
#[cfg(test)]
195256
mod tests {
196257
use super::SummaryStatisticsExt;
197-
use crate::errors::EmptyInput;
198-
use approx::assert_abs_diff_eq;
199-
use ndarray::{array, Array, Array1};
258+
use crate::errors::{EmptyInput, MultiInputError, ShapeMismatch};
259+
use approx::{abs_diff_eq, assert_abs_diff_eq};
260+
use ndarray::{arr0, array, Array, Array1, Array2, Axis};
200261
use ndarray_rand::RandomExt;
201262
use noisy_float::types::N64;
263+
use quickcheck::{quickcheck, TestResult};
202264
use rand::distributions::Uniform;
203265
use std::f64;
204266

205267
#[test]
206268
fn test_means_with_nan_values() {
207269
let a = array![f64::NAN, 1.];
208270
assert!(a.mean().unwrap().is_nan());
271+
assert!(a.weighted_mean(&array![1.0, f64::NAN]).unwrap().is_nan());
272+
assert!(a.weighted_sum(&array![1.0, f64::NAN]).unwrap().is_nan());
273+
assert!(a
274+
.weighted_mean_axis(Axis(0), &array![1.0, f64::NAN])
275+
.unwrap()
276+
.into_scalar()
277+
.is_nan());
278+
assert!(a
279+
.weighted_sum_axis(Axis(0), &array![1.0, f64::NAN])
280+
.unwrap()
281+
.into_scalar()
282+
.is_nan());
209283
assert!(a.harmonic_mean().unwrap().is_nan());
210284
assert!(a.geometric_mean().unwrap().is_nan());
211285
}
@@ -214,16 +288,40 @@ mod tests {
214288
fn test_means_with_empty_array_of_floats() {
215289
let a: Array1<f64> = array![];
216290
assert_eq!(a.mean(), None);
291+
assert_eq!(
292+
a.weighted_mean(&array![1.0]),
293+
Err(MultiInputError::EmptyInput)
294+
);
295+
assert_eq!(
296+
a.weighted_mean_axis(Axis(0), &array![1.0]),
297+
Err(MultiInputError::EmptyInput)
298+
);
217299
assert_eq!(a.harmonic_mean(), Err(EmptyInput));
218300
assert_eq!(a.geometric_mean(), Err(EmptyInput));
301+
302+
// The sum methods accept empty arrays
303+
assert_eq!(a.weighted_sum(&array![]), Ok(0.0));
304+
assert_eq!(a.weighted_sum_axis(Axis(0), &array![]), Ok(arr0(0.0)));
219305
}
220306

221307
#[test]
222308
fn test_means_with_empty_array_of_noisy_floats() {
223309
let a: Array1<N64> = array![];
224310
assert_eq!(a.mean(), None);
311+
assert_eq!(a.weighted_mean(&array![]), Err(MultiInputError::EmptyInput));
312+
assert_eq!(
313+
a.weighted_mean_axis(Axis(0), &array![]),
314+
Err(MultiInputError::EmptyInput)
315+
);
225316
assert_eq!(a.harmonic_mean(), Err(EmptyInput));
226317
assert_eq!(a.geometric_mean(), Err(EmptyInput));
318+
319+
// The sum methods accept empty arrays
320+
assert_eq!(a.weighted_sum(&array![]), Ok(N64::new(0.0)));
321+
assert_eq!(
322+
a.weighted_sum_axis(Axis(0), &array![]),
323+
Ok(arr0(N64::new(0.0)))
324+
);
227325
}
228326

229327
#[test]
@@ -240,9 +338,9 @@ mod tests {
240338
];
241339
// Computed using NumPy
242340
let expected_mean = 0.5475494059146699;
341+
let expected_weighted_mean = 0.6782420496397121;
243342
// Computed using SciPy
244343
let expected_harmonic_mean = 0.21790094950226022;
245-
// Computed using SciPy
246344
let expected_geometric_mean = 0.4345897639796527;
247345

248346
assert_abs_diff_eq!(a.mean().unwrap(), expected_mean, epsilon = 1e-9);
@@ -256,6 +354,114 @@ mod tests {
256354
expected_geometric_mean,
257355
epsilon = 1e-12
258356
);
357+
358+
// weighted_mean with itself, normalized
359+
let weights = &a / a.sum();
360+
assert_abs_diff_eq!(
361+
a.weighted_sum(&weights).unwrap(),
362+
expected_weighted_mean,
363+
epsilon = 1e-12
364+
);
365+
366+
let data = a.into_shape((2, 5, 5)).unwrap();
367+
let weights = array![0.1, 0.5, 0.25, 0.15, 0.2];
368+
assert_abs_diff_eq!(
369+
data.weighted_mean_axis(Axis(1), &weights).unwrap(),
370+
array![
371+
[0.50202721, 0.53347361, 0.29086033, 0.56995637, 0.37087139],
372+
[0.58028328, 0.50485216, 0.59349973, 0.70308937, 0.72280630]
373+
],
374+
epsilon = 1e-8
375+
);
376+
assert_abs_diff_eq!(
377+
data.weighted_mean_axis(Axis(2), &weights).unwrap(),
378+
array![
379+
[0.33434378, 0.38365259, 0.56405781, 0.48676574, 0.55016179],
380+
[0.71112376, 0.55134174, 0.45566513, 0.74228516, 0.68405851]
381+
],
382+
epsilon = 1e-8
383+
);
384+
assert_abs_diff_eq!(
385+
data.weighted_sum_axis(Axis(1), &weights).unwrap(),
386+
array![
387+
[0.60243266, 0.64016833, 0.34903240, 0.68394765, 0.44504567],
388+
[0.69633993, 0.60582259, 0.71219968, 0.84370724, 0.86736757]
389+
],
390+
epsilon = 1e-8
391+
);
392+
assert_abs_diff_eq!(
393+
data.weighted_sum_axis(Axis(2), &weights).unwrap(),
394+
array![
395+
[0.40121254, 0.46038311, 0.67686937, 0.58411889, 0.66019415],
396+
[0.85334851, 0.66161009, 0.54679815, 0.89074219, 0.82087021]
397+
],
398+
epsilon = 1e-8
399+
);
400+
}
401+
402+
#[test]
403+
fn weighted_sum_dimension_zero() {
404+
let a = Array2::<usize>::zeros((0, 20));
405+
assert_eq!(
406+
a.weighted_sum_axis(Axis(0), &Array1::zeros(0)).unwrap(),
407+
Array1::from_elem(20, 0)
408+
);
409+
assert_eq!(
410+
a.weighted_sum_axis(Axis(1), &Array1::zeros(20)).unwrap(),
411+
Array1::from_elem(0, 0)
412+
);
413+
assert_eq!(
414+
a.weighted_sum_axis(Axis(0), &Array1::zeros(1)),
415+
Err(MultiInputError::ShapeMismatch(ShapeMismatch {
416+
first_shape: vec![0, 20],
417+
second_shape: vec![1]
418+
}))
419+
);
420+
assert_eq!(
421+
a.weighted_sum(&Array2::zeros((10, 20))),
422+
Err(MultiInputError::ShapeMismatch(ShapeMismatch {
423+
first_shape: vec![0, 20],
424+
second_shape: vec![10, 20]
425+
}))
426+
);
427+
}
428+
429+
#[test]
430+
fn mean_eq_if_uniform_weights() {
431+
fn prop(a: Vec<f64>) -> TestResult {
432+
if a.len() < 1 {
433+
return TestResult::discard();
434+
}
435+
let a = Array1::from(a);
436+
let weights = Array1::from_elem(a.len(), 1.0 / a.len() as f64);
437+
let m = a.mean().unwrap();
438+
let wm = a.weighted_mean(&weights).unwrap();
439+
let ws = a.weighted_sum(&weights).unwrap();
440+
TestResult::from_bool(
441+
abs_diff_eq!(m, wm, epsilon = 1e-9) && abs_diff_eq!(wm, ws, epsilon = 1e-9),
442+
)
443+
}
444+
quickcheck(prop as fn(Vec<f64>) -> TestResult);
445+
}
446+
447+
#[test]
448+
fn mean_axis_eq_if_uniform_weights() {
449+
fn prop(mut a: Vec<f64>) -> TestResult {
450+
if a.len() < 24 {
451+
return TestResult::discard();
452+
}
453+
let depth = a.len() / 12;
454+
a.truncate(depth * 3 * 4);
455+
let weights = Array1::from_elem(depth, 1.0 / depth as f64);
456+
let a = Array1::from(a).into_shape((depth, 3, 4)).unwrap();
457+
let ma = a.mean_axis(Axis(0)).unwrap();
458+
let wm = a.weighted_mean_axis(Axis(0), &weights).unwrap();
459+
let ws = a.weighted_sum_axis(Axis(0), &weights).unwrap();
460+
TestResult::from_bool(
461+
abs_diff_eq!(ma, wm, epsilon = 1e-12) && abs_diff_eq!(wm, ws, epsilon = 1e12),
462+
)
463+
}
464+
quickcheck(prop as fn(Vec<f64>) -> TestResult);
259465
}
260466

261467
#[test]

0 commit comments

Comments
 (0)