@@ -41,9 +41,9 @@ fn missing_filter<'a, T: Element>(missing: &'a Missing<T>) -> Box<dyn Fn(&T) ->
4141fn count_non_missing < T : Element > (
4242 array : & ArrayView < T , ndarray:: Dim < ndarray:: IxDynImpl > > ,
4343 missing : & Missing < T > ,
44- ) -> Result < usize , ActiveStorageError > {
44+ ) -> usize {
4545 let filter = missing_filter ( missing) ;
46- Ok ( array. iter ( ) . copied ( ) . filter ( filter) . count ( ) )
46+ array. iter ( ) . copied ( ) . filter ( filter) . count ( )
4747}
4848
4949/// Counts the number of non-missing elements along
@@ -53,33 +53,52 @@ fn count_array_multi_axis<T: Element>(
5353 axes : & [ usize ] ,
5454 missing : Option < Missing < T > > ,
5555) -> ( Vec < i64 > , Vec < usize > ) {
56- // Count non-missing over first axis
57- let mut result = array
58- . fold_axis ( Axis ( axes[ 0 ] ) , 0 , |running_count, val| {
56+ let result = if axes. is_empty ( ) {
57+ // Emulate numpy semantics of axis = () being
58+ // equivalent to a 'reduction over no axes'
59+ array. map ( |val| {
5960 if let Some ( missing) = & missing {
6061 if !missing. is_missing ( val) {
61- running_count + 1
62+ 1
6263 } else {
63- * running_count
64+ 0
6465 }
6566 } else {
66- running_count + 1
67+ 1
6768 }
6869 } )
69- . into_dyn ( ) ;
70- // Sum counts over remaining axes
71- if let Some ( remaining_axes) = axes. get ( 1 ..) {
72- for ( n, axis) in remaining_axes. iter ( ) . enumerate ( ) {
73- result = result
74- . fold_axis ( Axis ( axis - n - 1 ) , 0 , |total_count, count| {
75- total_count + count
76- } )
77- . into_dyn ( ) ;
70+ } else {
71+ // Should never panic here due to axis.is_empty() branch above
72+ let first_axis = axes. first ( ) . expect ( "axes list to be non-empty" ) ;
73+ // Count non-missing over first axis
74+ let mut result = array
75+ . fold_axis ( Axis ( * first_axis) , 0 , |running_count, val| {
76+ if let Some ( missing) = & missing {
77+ if !missing. is_missing ( val) {
78+ running_count + 1
79+ } else {
80+ * running_count
81+ }
82+ } else {
83+ running_count + 1
84+ }
85+ } )
86+ . into_dyn ( ) ;
87+ // Sum counts over remaining axes
88+ if let Some ( remaining_axes) = axes. get ( 1 ..) {
89+ for ( n, axis) in remaining_axes. iter ( ) . enumerate ( ) {
90+ result = result
91+ . fold_axis ( Axis ( axis - n - 1 ) , 0 , |total_count, count| {
92+ total_count + count
93+ } )
94+ . into_dyn ( ) ;
95+ }
7896 }
79- }
97+ result
98+ } ;
8099
81100 // Convert result to owned vec
82- let counts = result. iter ( ) . copied ( ) . collect :: < Vec < i64 > > ( ) ;
101+ let counts = result. iter ( ) . copied ( ) . collect ( ) ;
83102 ( counts, result. shape ( ) . into ( ) )
84103}
85104
@@ -104,9 +123,8 @@ impl NumOperation for Count {
104123
105124 match & request_data. axis {
106125 ReductionAxes :: All => {
107- let count = if let Some ( missing) = & request_data. missing {
108- let missing = Missing :: < T > :: try_from ( missing) ?;
109- count_non_missing ( & sliced, & missing) ?
126+ let count = if let Some ( missing) = typed_missing {
127+ count_non_missing ( & sliced, & missing)
110128 } else {
111129 sliced. len ( )
112130 } ;
@@ -476,7 +494,7 @@ impl NumOperation for Select {
476494 let sliced = array. slice ( slice_info) ;
477495 let count = if let Some ( missing) = & request_data. missing {
478496 let missing = Missing :: < T > :: try_from ( missing) ?;
479- count_non_missing ( & sliced, & missing) ?
497+ count_non_missing ( & sliced, & missing)
480498 } else {
481499 sliced. len ( )
482500 } ;
@@ -1288,6 +1306,21 @@ mod tests {
12881306 assert_eq ! ( shape, Vec :: <usize >:: new( ) ) ;
12891307 }
12901308
1309+ #[ test]
1310+ fn count_multi_axis_2d_no_ax ( ) {
1311+ // Arrange
1312+ let axes = vec ! [ ] ;
1313+ let missing = None ;
1314+ let arr = ndarray:: Array :: from_shape_vec ( ( 2 , 3 ) , ( 0 ..6 ) . collect ( ) )
1315+ . unwrap ( )
1316+ . into_dyn ( ) ;
1317+ // Act
1318+ let ( counts, shape) = count_array_multi_axis ( arr. view ( ) , & axes, missing) ;
1319+ // Assert
1320+ assert_eq ! ( counts, vec![ 1 , 1 , 1 , 1 , 1 , 1 ] ) ;
1321+ assert_eq ! ( shape, arr. shape( ) . to_vec( ) ) ;
1322+ }
1323+
12911324 #[ test]
12921325 fn count_multi_axis_2d_1ax_missing ( ) {
12931326 // Arrange
0 commit comments