@@ -46,6 +46,43 @@ fn count_non_missing<T: Element>(
4646 Ok ( array. iter ( ) . copied ( ) . filter ( filter) . count ( ) )
4747}
4848
49+ /// Counts the number of non-missing elements along
50+ /// one or more axes of the provided array
51+ fn count_array_multi_axis < T : Element > (
52+ array : ndarray:: ArrayView < T , ndarray:: IxDyn > ,
53+ axes : & [ usize ] ,
54+ missing : Option < Missing < T > > ,
55+ ) -> ( Vec < i64 > , Vec < usize > ) {
56+ // Count non-missing over first axis
57+ let mut result = array
58+ . fold_axis ( Axis ( axes[ 0 ] ) , 0 , |running_count, val| {
59+ if let Some ( missing) = & missing {
60+ if !missing. is_missing ( val) {
61+ running_count + 1
62+ } else {
63+ * running_count
64+ }
65+ } else {
66+ running_count + 1
67+ }
68+ } )
69+ . into_dyn ( ) ;
70+ // Sum counts over remaining axes
71+ if let Some ( remaining_axes) = axes. get ( 1 ..) {
72+ for ( n, axis) in remaining_axes. iter ( ) . enumerate ( ) {
73+ result = result
74+ . fold_axis ( Axis ( axis - n - 1 ) , 0 , |total_count, count| {
75+ total_count + count
76+ } )
77+ . into_dyn ( ) ;
78+ }
79+ }
80+
81+ // Convert result to owned vec
82+ let counts = result. iter ( ) . copied ( ) . collect :: < Vec < i64 > > ( ) ;
83+ ( counts, result. shape ( ) . into ( ) )
84+ }
85+
4986/// Return the number of selected elements in the array.
5087pub struct Count { }
5188
@@ -57,22 +94,69 @@ impl NumOperation for Count {
5794 let array = array:: build_array :: < T > ( request_data, & mut data) ?;
5895 let slice_info = array:: build_slice_info :: < T > ( & request_data. selection , array. shape ( ) ) ;
5996 let sliced = array. slice ( slice_info) ;
60- let count = if let Some ( missing) = & request_data. missing {
61- let missing = Missing :: < T > :: try_from ( missing) ?;
62- count_non_missing ( & sliced, & missing) ?
97+
98+ let typed_missing: Option < Missing < T > > = if let Some ( missing) = & request_data. missing {
99+ let m = Missing :: try_from ( missing) ?;
100+ Some ( m)
63101 } else {
64- sliced . len ( )
102+ None
65103 } ;
66- let count = i64:: try_from ( count) ?;
67- let body = count. to_ne_bytes ( ) ;
68- // Need to copy to provide ownership to caller.
69- let body = Bytes :: copy_from_slice ( & body) ;
70- Ok ( models:: Response :: new (
71- body,
72- models:: DType :: Int64 ,
73- vec ! [ ] ,
74- vec ! [ count] ,
75- ) )
104+
105+ match & request_data. axis {
106+ ReductionAxes :: All => {
107+ let count = if let Some ( missing) = & request_data. missing {
108+ let missing = Missing :: < T > :: try_from ( missing) ?;
109+ count_non_missing ( & sliced, & missing) ?
110+ } else {
111+ sliced. len ( )
112+ } ;
113+ let count = i64:: try_from ( count) ?;
114+ let body = count. to_ne_bytes ( ) ;
115+ // Need to copy to provide ownership to caller.
116+ let body = Bytes :: copy_from_slice ( & body) ;
117+ Ok ( models:: Response :: new (
118+ body,
119+ models:: DType :: Int64 ,
120+ vec ! [ ] ,
121+ vec ! [ count] ,
122+ ) )
123+ }
124+ ReductionAxes :: One ( axis) => {
125+ let result = sliced. fold_axis ( Axis ( * axis) , 0 , |count, val| {
126+ if let Some ( missing) = & typed_missing {
127+ if !missing. is_missing ( val) {
128+ count + 1
129+ } else {
130+ * count
131+ }
132+ } else {
133+ count + 1
134+ }
135+ } ) ;
136+ let counts = result. iter ( ) . copied ( ) . collect :: < Vec < i64 > > ( ) ;
137+ let body = counts. as_bytes ( ) ;
138+ // Need to copy to provide ownership to caller.
139+ let body = Bytes :: copy_from_slice ( body) ;
140+ Ok ( models:: Response :: new (
141+ body,
142+ models:: DType :: Int64 ,
143+ result. shape ( ) . into ( ) ,
144+ counts,
145+ ) )
146+ }
147+ ReductionAxes :: Multi ( axes) => {
148+ let ( counts, shape) = count_array_multi_axis ( sliced. view ( ) , axes, typed_missing) ;
149+ let body = counts. as_bytes ( ) ;
150+ // Need to copy to provide ownership to caller.
151+ let body = Bytes :: copy_from_slice ( body) ;
152+ Ok ( models:: Response :: new (
153+ body,
154+ models:: DType :: Int64 ,
155+ shape,
156+ counts,
157+ ) )
158+ }
159+ }
76160 }
77161}
78162
@@ -1176,4 +1260,76 @@ mod tests {
11761260 assert_eq ! ( counts, vec![ 5 , 6 ] ) ;
11771261 assert_eq ! ( shape, vec![ 2 ] ) ;
11781262 }
1263+
1264+ #[ test]
1265+ #[ should_panic( expected = "assertion failed: axis.index() < self.ndim()" ) ]
1266+ fn count_multi_axis_2d_wrong_axis ( ) {
1267+ // Arrange
1268+ let array = ndarray:: Array :: from_shape_vec ( ( 2 , 2 ) , ( 0 ..4 ) . collect ( ) )
1269+ . unwrap ( )
1270+ . into_dyn ( ) ;
1271+ let axes = vec ! [ 2 ] ;
1272+ // Act
1273+ let _ = count_array_multi_axis ( array. view ( ) , & axes, None ) ;
1274+ }
1275+
1276+ #[ test]
1277+ fn count_multi_axis_2d_2ax ( ) {
1278+ // Arrange
1279+ let axes = vec ! [ 0 , 1 ] ;
1280+ let missing = None ;
1281+ let arr = ndarray:: Array :: from_shape_vec ( ( 2 , 3 ) , ( 0 ..6 ) . collect ( ) )
1282+ . unwrap ( )
1283+ . into_dyn ( ) ;
1284+ // Act
1285+ let ( counts, shape) = count_array_multi_axis ( arr. view ( ) , & axes, missing) ;
1286+ // Assert
1287+ assert_eq ! ( counts, vec![ 6 ] ) ;
1288+ assert_eq ! ( shape, Vec :: <usize >:: new( ) ) ;
1289+ }
1290+
1291+ #[ test]
1292+ fn count_multi_axis_2d_1ax_missing ( ) {
1293+ // Arrange
1294+ let axes = vec ! [ 1 ] ;
1295+ let missing = Missing :: MissingValue ( 0 ) ;
1296+ let arr = ndarray:: Array :: from_shape_vec ( ( 2 , 3 ) , ( 0 ..6 ) . collect ( ) )
1297+ . unwrap ( )
1298+ . into_dyn ( ) ;
1299+ // Act
1300+ let ( counts, shape) = count_array_multi_axis ( arr. view ( ) , & axes, Some ( missing) ) ;
1301+ // Assert
1302+ assert_eq ! ( counts, vec![ 2 , 3 ] ) ;
1303+ assert_eq ! ( shape, vec![ 2 ] ) ;
1304+ }
1305+
1306+ #[ test]
1307+ fn count_multi_axis_4d_3ax_multi_missing ( ) {
1308+ // Arrange
1309+ let arr = ndarray:: Array :: from_shape_vec ( ( 2 , 3 , 2 , 1 ) , ( 0 ..12 ) . collect ( ) )
1310+ . unwrap ( )
1311+ . into_dyn ( ) ;
1312+ let axes = vec ! [ 0 , 1 , 3 ] ;
1313+ let missing = Missing :: MissingValues ( vec ! [ 9 , 10 , 11 ] ) ;
1314+ // Act
1315+ let ( counts, shape) = count_array_multi_axis ( arr. view ( ) , & axes, Some ( missing) ) ;
1316+ // Assert
1317+ assert_eq ! ( counts, vec![ 5 , 4 ] ) ;
1318+ assert_eq ! ( shape, vec![ 2 ] ) ;
1319+ }
1320+
1321+ #[ test]
1322+ fn count_multi_axis_4d_3ax_missing ( ) {
1323+ // Arrange
1324+ let arr = ndarray:: Array :: from_shape_vec ( ( 2 , 3 , 2 , 1 ) , ( 0 ..12 ) . collect ( ) )
1325+ . unwrap ( )
1326+ . into_dyn ( ) ;
1327+ let axes = vec ! [ 0 , 1 , 3 ] ;
1328+ let missing = Missing :: MissingValue ( 10 ) ;
1329+ // Act
1330+ let ( counts, shape) = count_array_multi_axis ( arr. view ( ) , & axes, Some ( missing) ) ;
1331+ // Assert
1332+ assert_eq ! ( counts, vec![ 5 , 6 ] ) ;
1333+ assert_eq ! ( shape, vec![ 2 ] ) ;
1334+ }
11791335}
0 commit comments