@@ -136,6 +136,8 @@ pub struct RequestData {
136136 custom = "validate_shape"
137137 ) ]
138138 pub shape : Option < Vec < usize > > ,
139+ /// Axis over which to perform the reduction operation
140+ pub axis : Option < usize > ,
139141 /// Order of the multi-dimensional array
140142 pub order : Option < Order > ,
141143 /// Subset of the data to operate on
@@ -222,6 +224,7 @@ fn validate_request_data(request_data: &RequestData) -> Result<(), ValidationErr
222224 validate_raw_size ( * size, request_data. dtype , & request_data. shape ) ?;
223225 }
224226 } ;
227+ // Check selection is compatible with shape
225228 match ( & request_data. shape , & request_data. selection ) {
226229 ( Some ( shape) , Some ( selection) ) => {
227230 validate_shape_selection ( shape, selection) ?;
@@ -233,6 +236,18 @@ fn validate_request_data(request_data: &RequestData) -> Result<(), ValidationErr
233236 }
234237 _ => ( ) ,
235238 } ;
239+ // Check axis is compatible with shape
240+ match ( & request_data. shape , & request_data. axis ) {
241+ ( Some ( shape) , Some ( axis) ) => {
242+ if * axis > shape. len ( ) - 1 {
243+ return Err ( ValidationError :: new ( "Axis must be within shape" ) ) ;
244+ }
245+ }
246+ ( None , Some ( _) ) => {
247+ return Err ( ValidationError :: new ( "Axis requires shape to be specified" ) ) ;
248+ }
249+ _ => ( ) ,
250+ } ;
236251 if let Some ( missing) = & request_data. missing {
237252 missing. validate ( request_data. dtype ) ?;
238253 } ;
@@ -335,6 +350,9 @@ mod tests {
335350 Token :: U32 ( 2 ) ,
336351 Token :: U32 ( 5 ) ,
337352 Token :: SeqEnd ,
353+ Token :: Str ( "axis" ) ,
354+ Token :: Some ,
355+ Token :: U32 ( 1 ) ,
338356 Token :: Str ( "order" ) ,
339357 Token :: Some ,
340358 Token :: Enum { name : "Order" } ,
@@ -659,6 +677,23 @@ mod tests {
659677 request_data. validate ( ) . unwrap ( )
660678 }
661679
680+ #[ test]
681+ #[ should_panic( expected = "Axis requires shape to be specified" ) ]
682+ fn test_axis_without_shape ( ) {
683+ let mut request_data = test_utils:: get_test_request_data ( ) ;
684+ request_data. axis = Some ( 1 ) ;
685+ request_data. validate ( ) . unwrap ( )
686+ }
687+
688+ #[ test]
689+ #[ should_panic( expected = "Axis must be within shape" ) ]
690+ fn test_axis_gt_shape ( ) {
691+ let mut request_data = test_utils:: get_test_request_data ( ) ;
692+ request_data. axis = Some ( 2 ) ;
693+ request_data. shape = Some ( vec ! [ 2 , 5 ] ) ;
694+ request_data. validate ( ) . unwrap ( )
695+ }
696+
662697 #[ test]
663698 fn test_invalid_compression ( ) {
664699 assert_de_tokens_error :: < RequestData > (
@@ -731,7 +766,7 @@ mod tests {
731766 Token :: Str ( "foo" ) ,
732767 Token :: StructEnd
733768 ] ,
734- "unknown field `foo`, expected one of `source`, `bucket`, `object`, `dtype`, `byte_order`, `offset`, `size`, `shape`, `order`, `selection`, `compression`, `filters`, `missing`"
769+ "unknown field `foo`, expected one of `source`, `bucket`, `object`, `dtype`, `byte_order`, `offset`, `size`, `shape`, `axis`, ` order`, `selection`, `compression`, `filters`, `missing`"
735770 )
736771 }
737772
@@ -755,6 +790,7 @@ mod tests {
755790 "offset": 4,
756791 "size": 8,
757792 "shape": [2, 5],
793+ "axis": 1,
758794 "order": "C",
759795 "selection": [[1, 2, 3], [4, 5, 6]],
760796 "compression": {"id": "gzip"},
@@ -776,6 +812,7 @@ mod tests {
776812 "offset": 4,
777813 "size": 8,
778814 "shape": [2, 5, 10],
815+ "axis": 2,
779816 "order": "F",
780817 "selection": [[1, 2, 3], [4, 5, 6], [7, 8, 9]],
781818 "compression": {"id": "zlib"},
@@ -787,6 +824,7 @@ mod tests {
787824 expected. dtype = DType :: Float64 ;
788825 expected. byte_order = Some ( ByteOrder :: Big ) ;
789826 expected. shape = Some ( vec ! [ 2 , 5 , 10 ] ) ;
827+ expected. axis = Some ( 2 ) ;
790828 expected. order = Some ( Order :: F ) ;
791829 expected. selection = Some ( vec ! [
792830 Slice :: new( 1 , 2 , 3 ) ,
0 commit comments