1
1
use super :: SummaryStatisticsExt ;
2
- use crate :: errors:: EmptyInput ;
3
- use ndarray:: { ArrayBase , Data , Dimension } ;
2
+ use crate :: errors:: { EmptyInput , MultiInputError , ShapeMismatch } ;
3
+ use ndarray:: { Array , ArrayBase , Axis , Data , Dimension , Ix1 , RemoveAxis } ;
4
4
use num_integer:: IterBinomial ;
5
5
use num_traits:: { Float , FromPrimitive , Zero } ;
6
- use std:: ops:: { Add , Div } ;
6
+ use std:: ops:: { Add , Div , Mul } ;
7
7
8
8
impl < A , S , D > SummaryStatisticsExt < A , S , D > for ArrayBase < S , D >
9
9
where
24
24
}
25
25
}
26
26
27
+ fn weighted_mean ( & self , weights : & Self ) -> Result < A , MultiInputError >
28
+ where
29
+ A : Copy + Div < Output = A > + Mul < Output = A > + Zero ,
30
+ {
31
+ return_err_if_empty ! ( self ) ;
32
+ let weighted_sum = self . weighted_sum ( weights) ?;
33
+ Ok ( weighted_sum / weights. sum ( ) )
34
+ }
35
+
36
+ fn weighted_sum ( & self , weights : & ArrayBase < S , D > ) -> Result < A , MultiInputError >
37
+ where
38
+ A : Copy + Mul < Output = A > + Zero ,
39
+ {
40
+ return_err_unless_same_shape ! ( self , weights) ;
41
+ Ok ( self
42
+ . iter ( )
43
+ . zip ( weights)
44
+ . fold ( A :: zero ( ) , |acc, ( & d, & w) | acc + d * w) )
45
+ }
46
+
47
+ fn weighted_mean_axis (
48
+ & self ,
49
+ axis : Axis ,
50
+ weights : & ArrayBase < S , Ix1 > ,
51
+ ) -> Result < Array < A , D :: Smaller > , MultiInputError >
52
+ where
53
+ A : Copy + Div < Output = A > + Mul < Output = A > + Zero ,
54
+ D : RemoveAxis ,
55
+ {
56
+ return_err_if_empty ! ( self ) ;
57
+ let mut weighted_sum = self . weighted_sum_axis ( axis, weights) ?;
58
+ let weights_sum = weights. sum ( ) ;
59
+ weighted_sum. mapv_inplace ( |v| v / weights_sum) ;
60
+ Ok ( weighted_sum)
61
+ }
62
+
63
+ fn weighted_sum_axis (
64
+ & self ,
65
+ axis : Axis ,
66
+ weights : & ArrayBase < S , Ix1 > ,
67
+ ) -> Result < Array < A , D :: Smaller > , MultiInputError >
68
+ where
69
+ A : Copy + Mul < Output = A > + Zero ,
70
+ D : RemoveAxis ,
71
+ {
72
+ if self . shape ( ) [ axis. index ( ) ] != weights. len ( ) {
73
+ return Err ( MultiInputError :: ShapeMismatch ( ShapeMismatch {
74
+ first_shape : self . shape ( ) . to_vec ( ) ,
75
+ second_shape : weights. shape ( ) . to_vec ( ) ,
76
+ } ) ) ;
77
+ }
78
+
79
+ // We could use `lane.weighted_sum` here, but we're avoiding 2
80
+ // conditions and an unwrap per lane.
81
+ Ok ( self . map_axis ( axis, |lane| {
82
+ lane. iter ( )
83
+ . zip ( weights)
84
+ . fold ( A :: zero ( ) , |acc, ( & d, & w) | acc + d * w)
85
+ } ) )
86
+ }
87
+
27
88
fn harmonic_mean ( & self ) -> Result < A , EmptyInput >
28
89
where
29
90
A : Float + FromPrimitive ,
@@ -194,18 +255,31 @@ where
194
255
#[ cfg( test) ]
195
256
mod tests {
196
257
use super :: SummaryStatisticsExt ;
197
- use crate :: errors:: EmptyInput ;
198
- use approx:: assert_abs_diff_eq;
199
- use ndarray:: { array, Array , Array1 } ;
258
+ use crate :: errors:: { EmptyInput , MultiInputError , ShapeMismatch } ;
259
+ use approx:: { abs_diff_eq , assert_abs_diff_eq} ;
260
+ use ndarray:: { arr0 , array, Array , Array1 , Array2 , Axis } ;
200
261
use ndarray_rand:: RandomExt ;
201
262
use noisy_float:: types:: N64 ;
263
+ use quickcheck:: { quickcheck, TestResult } ;
202
264
use rand:: distributions:: Uniform ;
203
265
use std:: f64;
204
266
205
267
#[ test]
206
268
fn test_means_with_nan_values ( ) {
207
269
let a = array ! [ f64 :: NAN , 1. ] ;
208
270
assert ! ( a. mean( ) . unwrap( ) . is_nan( ) ) ;
271
+ assert ! ( a. weighted_mean( & array![ 1.0 , f64 :: NAN ] ) . unwrap( ) . is_nan( ) ) ;
272
+ assert ! ( a. weighted_sum( & array![ 1.0 , f64 :: NAN ] ) . unwrap( ) . is_nan( ) ) ;
273
+ assert ! ( a
274
+ . weighted_mean_axis( Axis ( 0 ) , & array![ 1.0 , f64 :: NAN ] )
275
+ . unwrap( )
276
+ . into_scalar( )
277
+ . is_nan( ) ) ;
278
+ assert ! ( a
279
+ . weighted_sum_axis( Axis ( 0 ) , & array![ 1.0 , f64 :: NAN ] )
280
+ . unwrap( )
281
+ . into_scalar( )
282
+ . is_nan( ) ) ;
209
283
assert ! ( a. harmonic_mean( ) . unwrap( ) . is_nan( ) ) ;
210
284
assert ! ( a. geometric_mean( ) . unwrap( ) . is_nan( ) ) ;
211
285
}
@@ -214,16 +288,40 @@ mod tests {
214
288
fn test_means_with_empty_array_of_floats ( ) {
215
289
let a: Array1 < f64 > = array ! [ ] ;
216
290
assert_eq ! ( a. mean( ) , None ) ;
291
+ assert_eq ! (
292
+ a. weighted_mean( & array![ 1.0 ] ) ,
293
+ Err ( MultiInputError :: EmptyInput )
294
+ ) ;
295
+ assert_eq ! (
296
+ a. weighted_mean_axis( Axis ( 0 ) , & array![ 1.0 ] ) ,
297
+ Err ( MultiInputError :: EmptyInput )
298
+ ) ;
217
299
assert_eq ! ( a. harmonic_mean( ) , Err ( EmptyInput ) ) ;
218
300
assert_eq ! ( a. geometric_mean( ) , Err ( EmptyInput ) ) ;
301
+
302
+ // The sum methods accept empty arrays
303
+ assert_eq ! ( a. weighted_sum( & array![ ] ) , Ok ( 0.0 ) ) ;
304
+ assert_eq ! ( a. weighted_sum_axis( Axis ( 0 ) , & array![ ] ) , Ok ( arr0( 0.0 ) ) ) ;
219
305
}
220
306
221
307
#[ test]
222
308
fn test_means_with_empty_array_of_noisy_floats ( ) {
223
309
let a: Array1 < N64 > = array ! [ ] ;
224
310
assert_eq ! ( a. mean( ) , None ) ;
311
+ assert_eq ! ( a. weighted_mean( & array![ ] ) , Err ( MultiInputError :: EmptyInput ) ) ;
312
+ assert_eq ! (
313
+ a. weighted_mean_axis( Axis ( 0 ) , & array![ ] ) ,
314
+ Err ( MultiInputError :: EmptyInput )
315
+ ) ;
225
316
assert_eq ! ( a. harmonic_mean( ) , Err ( EmptyInput ) ) ;
226
317
assert_eq ! ( a. geometric_mean( ) , Err ( EmptyInput ) ) ;
318
+
319
+ // The sum methods accept empty arrays
320
+ assert_eq ! ( a. weighted_sum( & array![ ] ) , Ok ( N64 :: new( 0.0 ) ) ) ;
321
+ assert_eq ! (
322
+ a. weighted_sum_axis( Axis ( 0 ) , & array![ ] ) ,
323
+ Ok ( arr0( N64 :: new( 0.0 ) ) )
324
+ ) ;
227
325
}
228
326
229
327
#[ test]
@@ -240,9 +338,9 @@ mod tests {
240
338
] ;
241
339
// Computed using NumPy
242
340
let expected_mean = 0.5475494059146699 ;
341
+ let expected_weighted_mean = 0.6782420496397121 ;
243
342
// Computed using SciPy
244
343
let expected_harmonic_mean = 0.21790094950226022 ;
245
- // Computed using SciPy
246
344
let expected_geometric_mean = 0.4345897639796527 ;
247
345
248
346
assert_abs_diff_eq ! ( a. mean( ) . unwrap( ) , expected_mean, epsilon = 1e-9 ) ;
@@ -256,6 +354,114 @@ mod tests {
256
354
expected_geometric_mean,
257
355
epsilon = 1e-12
258
356
) ;
357
+
358
+ // weighted_mean with itself, normalized
359
+ let weights = & a / a. sum ( ) ;
360
+ assert_abs_diff_eq ! (
361
+ a. weighted_sum( & weights) . unwrap( ) ,
362
+ expected_weighted_mean,
363
+ epsilon = 1e-12
364
+ ) ;
365
+
366
+ let data = a. into_shape ( ( 2 , 5 , 5 ) ) . unwrap ( ) ;
367
+ let weights = array ! [ 0.1 , 0.5 , 0.25 , 0.15 , 0.2 ] ;
368
+ assert_abs_diff_eq ! (
369
+ data. weighted_mean_axis( Axis ( 1 ) , & weights) . unwrap( ) ,
370
+ array![
371
+ [ 0.50202721 , 0.53347361 , 0.29086033 , 0.56995637 , 0.37087139 ] ,
372
+ [ 0.58028328 , 0.50485216 , 0.59349973 , 0.70308937 , 0.72280630 ]
373
+ ] ,
374
+ epsilon = 1e-8
375
+ ) ;
376
+ assert_abs_diff_eq ! (
377
+ data. weighted_mean_axis( Axis ( 2 ) , & weights) . unwrap( ) ,
378
+ array![
379
+ [ 0.33434378 , 0.38365259 , 0.56405781 , 0.48676574 , 0.55016179 ] ,
380
+ [ 0.71112376 , 0.55134174 , 0.45566513 , 0.74228516 , 0.68405851 ]
381
+ ] ,
382
+ epsilon = 1e-8
383
+ ) ;
384
+ assert_abs_diff_eq ! (
385
+ data. weighted_sum_axis( Axis ( 1 ) , & weights) . unwrap( ) ,
386
+ array![
387
+ [ 0.60243266 , 0.64016833 , 0.34903240 , 0.68394765 , 0.44504567 ] ,
388
+ [ 0.69633993 , 0.60582259 , 0.71219968 , 0.84370724 , 0.86736757 ]
389
+ ] ,
390
+ epsilon = 1e-8
391
+ ) ;
392
+ assert_abs_diff_eq ! (
393
+ data. weighted_sum_axis( Axis ( 2 ) , & weights) . unwrap( ) ,
394
+ array![
395
+ [ 0.40121254 , 0.46038311 , 0.67686937 , 0.58411889 , 0.66019415 ] ,
396
+ [ 0.85334851 , 0.66161009 , 0.54679815 , 0.89074219 , 0.82087021 ]
397
+ ] ,
398
+ epsilon = 1e-8
399
+ ) ;
400
+ }
401
+
402
+ #[ test]
403
+ fn weighted_sum_dimension_zero ( ) {
404
+ let a = Array2 :: < usize > :: zeros ( ( 0 , 20 ) ) ;
405
+ assert_eq ! (
406
+ a. weighted_sum_axis( Axis ( 0 ) , & Array1 :: zeros( 0 ) ) . unwrap( ) ,
407
+ Array1 :: from_elem( 20 , 0 )
408
+ ) ;
409
+ assert_eq ! (
410
+ a. weighted_sum_axis( Axis ( 1 ) , & Array1 :: zeros( 20 ) ) . unwrap( ) ,
411
+ Array1 :: from_elem( 0 , 0 )
412
+ ) ;
413
+ assert_eq ! (
414
+ a. weighted_sum_axis( Axis ( 0 ) , & Array1 :: zeros( 1 ) ) ,
415
+ Err ( MultiInputError :: ShapeMismatch ( ShapeMismatch {
416
+ first_shape: vec![ 0 , 20 ] ,
417
+ second_shape: vec![ 1 ]
418
+ } ) )
419
+ ) ;
420
+ assert_eq ! (
421
+ a. weighted_sum( & Array2 :: zeros( ( 10 , 20 ) ) ) ,
422
+ Err ( MultiInputError :: ShapeMismatch ( ShapeMismatch {
423
+ first_shape: vec![ 0 , 20 ] ,
424
+ second_shape: vec![ 10 , 20 ]
425
+ } ) )
426
+ ) ;
427
+ }
428
+
429
+ #[ test]
430
+ fn mean_eq_if_uniform_weights ( ) {
431
+ fn prop ( a : Vec < f64 > ) -> TestResult {
432
+ if a. len ( ) < 1 {
433
+ return TestResult :: discard ( ) ;
434
+ }
435
+ let a = Array1 :: from ( a) ;
436
+ let weights = Array1 :: from_elem ( a. len ( ) , 1.0 / a. len ( ) as f64 ) ;
437
+ let m = a. mean ( ) . unwrap ( ) ;
438
+ let wm = a. weighted_mean ( & weights) . unwrap ( ) ;
439
+ let ws = a. weighted_sum ( & weights) . unwrap ( ) ;
440
+ TestResult :: from_bool (
441
+ abs_diff_eq ! ( m, wm, epsilon = 1e-9 ) && abs_diff_eq ! ( wm, ws, epsilon = 1e-9 ) ,
442
+ )
443
+ }
444
+ quickcheck ( prop as fn ( Vec < f64 > ) -> TestResult ) ;
445
+ }
446
+
447
+ #[ test]
448
+ fn mean_axis_eq_if_uniform_weights ( ) {
449
+ fn prop ( mut a : Vec < f64 > ) -> TestResult {
450
+ if a. len ( ) < 24 {
451
+ return TestResult :: discard ( ) ;
452
+ }
453
+ let depth = a. len ( ) / 12 ;
454
+ a. truncate ( depth * 3 * 4 ) ;
455
+ let weights = Array1 :: from_elem ( depth, 1.0 / depth as f64 ) ;
456
+ let a = Array1 :: from ( a) . into_shape ( ( depth, 3 , 4 ) ) . unwrap ( ) ;
457
+ let ma = a. mean_axis ( Axis ( 0 ) ) . unwrap ( ) ;
458
+ let wm = a. weighted_mean_axis ( Axis ( 0 ) , & weights) . unwrap ( ) ;
459
+ let ws = a. weighted_sum_axis ( Axis ( 0 ) , & weights) . unwrap ( ) ;
460
+ TestResult :: from_bool (
461
+ abs_diff_eq ! ( ma, wm, epsilon = 1e-12 ) && abs_diff_eq ! ( wm, ws, epsilon = 1e12 ) ,
462
+ )
463
+ }
464
+ quickcheck ( prop as fn ( Vec < f64 > ) -> TestResult ) ;
259
465
}
260
466
261
467
#[ test]
0 commit comments