@@ -46,6 +46,43 @@ fn count_non_missing<T: Element>(
4646    Ok ( array. iter ( ) . copied ( ) . filter ( filter) . count ( ) ) 
4747} 
4848
49+ /// Counts the number of non-missing elements along 
50+ /// one or more axes of the provided array 
51+ fn  count_array_multi_axis < T :  Element > ( 
52+     array :  ndarray:: ArrayView < T ,  ndarray:: IxDyn > , 
53+     axes :  & [ usize ] , 
54+     missing :  Option < Missing < T > > , 
55+ )  -> ( Vec < i64 > ,  Vec < usize > )  { 
56+     // Count non-missing over first axis 
57+     let  mut  result = array
58+         . fold_axis ( Axis ( axes[ 0 ] ) ,  0 ,  |running_count,  val| { 
59+             if  let  Some ( missing)  = & missing { 
60+                 if  !missing. is_missing ( val)  { 
61+                     running_count + 1 
62+                 }  else  { 
63+                     * running_count
64+                 } 
65+             }  else  { 
66+                 running_count + 1 
67+             } 
68+         } ) 
69+         . into_dyn ( ) ; 
70+     // Sum counts over remaining axes 
71+     if  let  Some ( remaining_axes)  = axes. get ( 1 ..)  { 
72+         for  ( n,  axis)  in  remaining_axes. iter ( ) . enumerate ( )  { 
73+             result = result
74+                 . fold_axis ( Axis ( axis - n - 1 ) ,  0 ,  |total_count,  count| { 
75+                     total_count + count
76+                 } ) 
77+                 . into_dyn ( ) ; 
78+         } 
79+     } 
80+ 
81+     // Convert result to owned vec 
82+     let  counts = result. iter ( ) . copied ( ) . collect :: < Vec < i64 > > ( ) ; 
83+     ( counts,  result. shape ( ) . into ( ) ) 
84+ } 
85+ 
4986/// Return the number of selected elements in the array. 
5087pub  struct  Count  { } 
5188
@@ -57,22 +94,69 @@ impl NumOperation for Count {
5794        let  array = array:: build_array :: < T > ( request_data,  & mut  data) ?; 
5895        let  slice_info = array:: build_slice_info :: < T > ( & request_data. selection ,  array. shape ( ) ) ; 
5996        let  sliced = array. slice ( slice_info) ; 
60-         let  count = if  let  Some ( missing)  = & request_data. missing  { 
61-             let  missing = Missing :: < T > :: try_from ( missing) ?; 
62-             count_non_missing ( & sliced,  & missing) ?
97+ 
98+         let  typed_missing:  Option < Missing < T > >  = if  let  Some ( missing)  = & request_data. missing  { 
99+             let  m = Missing :: try_from ( missing) ?; 
100+             Some ( m) 
63101        }  else  { 
64-             sliced . len ( ) 
102+             None 
65103        } ; 
66-         let  count = i64:: try_from ( count) ?; 
67-         let  body = count. to_ne_bytes ( ) ; 
68-         // Need to copy to provide ownership to caller. 
69-         let  body = Bytes :: copy_from_slice ( & body) ; 
70-         Ok ( models:: Response :: new ( 
71-             body, 
72-             models:: DType :: Int64 , 
73-             vec ! [ ] , 
74-             vec ! [ count] , 
75-         ) ) 
104+ 
105+         match  & request_data. axis  { 
106+             ReductionAxes :: All  => { 
107+                 let  count = if  let  Some ( missing)  = & request_data. missing  { 
108+                     let  missing = Missing :: < T > :: try_from ( missing) ?; 
109+                     count_non_missing ( & sliced,  & missing) ?
110+                 }  else  { 
111+                     sliced. len ( ) 
112+                 } ; 
113+                 let  count = i64:: try_from ( count) ?; 
114+                 let  body = count. to_ne_bytes ( ) ; 
115+                 // Need to copy to provide ownership to caller. 
116+                 let  body = Bytes :: copy_from_slice ( & body) ; 
117+                 Ok ( models:: Response :: new ( 
118+                     body, 
119+                     models:: DType :: Int64 , 
120+                     vec ! [ ] , 
121+                     vec ! [ count] , 
122+                 ) ) 
123+             } 
124+             ReductionAxes :: One ( axis)  => { 
125+                 let  result = sliced. fold_axis ( Axis ( * axis) ,  0 ,  |count,  val| { 
126+                     if  let  Some ( missing)  = & typed_missing { 
127+                         if  !missing. is_missing ( val)  { 
128+                             count + 1 
129+                         }  else  { 
130+                             * count
131+                         } 
132+                     }  else  { 
133+                         count + 1 
134+                     } 
135+                 } ) ; 
136+                 let  counts = result. iter ( ) . copied ( ) . collect :: < Vec < i64 > > ( ) ; 
137+                 let  body = counts. as_bytes ( ) ; 
138+                 // Need to copy to provide ownership to caller. 
139+                 let  body = Bytes :: copy_from_slice ( body) ; 
140+                 Ok ( models:: Response :: new ( 
141+                     body, 
142+                     models:: DType :: Int64 , 
143+                     result. shape ( ) . into ( ) , 
144+                     counts, 
145+                 ) ) 
146+             } 
147+             ReductionAxes :: Multi ( axes)  => { 
148+                 let  ( counts,  shape)  = count_array_multi_axis ( sliced. view ( ) ,  axes,  typed_missing) ; 
149+                 let  body = counts. as_bytes ( ) ; 
150+                 // Need to copy to provide ownership to caller. 
151+                 let  body = Bytes :: copy_from_slice ( body) ; 
152+                 Ok ( models:: Response :: new ( 
153+                     body, 
154+                     models:: DType :: Int64 , 
155+                     shape, 
156+                     counts, 
157+                 ) ) 
158+             } 
159+         } 
76160    } 
77161} 
78162
@@ -1176,4 +1260,76 @@ mod tests {
11761260        assert_eq ! ( counts,  vec![ 5 ,  6 ] ) ; 
11771261        assert_eq ! ( shape,  vec![ 2 ] ) ; 
11781262    } 
1263+ 
1264+     #[ test]  
1265+     #[ should_panic( expected = "assertion failed: axis.index() < self.ndim()" ) ]  
1266+     fn  count_multi_axis_2d_wrong_axis ( )  { 
1267+         // Arrange 
1268+         let  array = ndarray:: Array :: from_shape_vec ( ( 2 ,  2 ) ,  ( 0 ..4 ) . collect ( ) ) 
1269+             . unwrap ( ) 
1270+             . into_dyn ( ) ; 
1271+         let  axes = vec ! [ 2 ] ; 
1272+         // Act 
1273+         let  _ = count_array_multi_axis ( array. view ( ) ,  & axes,  None ) ; 
1274+     } 
1275+ 
1276+     #[ test]  
1277+     fn  count_multi_axis_2d_2ax ( )  { 
1278+         // Arrange 
1279+         let  axes = vec ! [ 0 ,  1 ] ; 
1280+         let  missing = None ; 
1281+         let  arr = ndarray:: Array :: from_shape_vec ( ( 2 ,  3 ) ,  ( 0 ..6 ) . collect ( ) ) 
1282+             . unwrap ( ) 
1283+             . into_dyn ( ) ; 
1284+         // Act 
1285+         let  ( counts,  shape)  = count_array_multi_axis ( arr. view ( ) ,  & axes,  missing) ; 
1286+         // Assert 
1287+         assert_eq ! ( counts,  vec![ 6 ] ) ; 
1288+         assert_eq ! ( shape,  Vec :: <usize >:: new( ) ) ; 
1289+     } 
1290+ 
1291+     #[ test]  
1292+     fn  count_multi_axis_2d_1ax_missing ( )  { 
1293+         // Arrange 
1294+         let  axes = vec ! [ 1 ] ; 
1295+         let  missing = Missing :: MissingValue ( 0 ) ; 
1296+         let  arr = ndarray:: Array :: from_shape_vec ( ( 2 ,  3 ) ,  ( 0 ..6 ) . collect ( ) ) 
1297+             . unwrap ( ) 
1298+             . into_dyn ( ) ; 
1299+         // Act 
1300+         let  ( counts,  shape)  = count_array_multi_axis ( arr. view ( ) ,  & axes,  Some ( missing) ) ; 
1301+         // Assert 
1302+         assert_eq ! ( counts,  vec![ 2 ,  3 ] ) ; 
1303+         assert_eq ! ( shape,  vec![ 2 ] ) ; 
1304+     } 
1305+ 
1306+     #[ test]  
1307+     fn  count_multi_axis_4d_3ax_multi_missing ( )  { 
1308+         // Arrange 
1309+         let  arr = ndarray:: Array :: from_shape_vec ( ( 2 ,  3 ,  2 ,  1 ) ,  ( 0 ..12 ) . collect ( ) ) 
1310+             . unwrap ( ) 
1311+             . into_dyn ( ) ; 
1312+         let  axes = vec ! [ 0 ,  1 ,  3 ] ; 
1313+         let  missing = Missing :: MissingValues ( vec ! [ 9 ,  10 ,  11 ] ) ; 
1314+         // Act 
1315+         let  ( counts,  shape)  = count_array_multi_axis ( arr. view ( ) ,  & axes,  Some ( missing) ) ; 
1316+         // Assert 
1317+         assert_eq ! ( counts,  vec![ 5 ,  4 ] ) ; 
1318+         assert_eq ! ( shape,  vec![ 2 ] ) ; 
1319+     } 
1320+ 
1321+     #[ test]  
1322+     fn  count_multi_axis_4d_3ax_missing ( )  { 
1323+         // Arrange 
1324+         let  arr = ndarray:: Array :: from_shape_vec ( ( 2 ,  3 ,  2 ,  1 ) ,  ( 0 ..12 ) . collect ( ) ) 
1325+             . unwrap ( ) 
1326+             . into_dyn ( ) ; 
1327+         let  axes = vec ! [ 0 ,  1 ,  3 ] ; 
1328+         let  missing = Missing :: MissingValue ( 10 ) ; 
1329+         // Act 
1330+         let  ( counts,  shape)  = count_array_multi_axis ( arr. view ( ) ,  & axes,  Some ( missing) ) ; 
1331+         // Assert 
1332+         assert_eq ! ( counts,  vec![ 5 ,  6 ] ) ; 
1333+         assert_eq ! ( shape,  vec![ 2 ] ) ; 
1334+     } 
11791335} 
0 commit comments