@@ -107,6 +107,16 @@ pub enum Filter {
107107 Shuffle { element_size : usize } ,
108108}
109109
110+ /// Axes over which to perform the reduction
111+ #[ derive( Debug , PartialEq , Default , Deserialize ) ]
112+ #[ serde( rename_all = "lowercase" , untagged) ]
113+ pub enum ReductionAxes {
114+ #[ default]
115+ All ,
116+ One ( usize ) ,
117+ Multi ( Vec < usize > ) ,
118+ }
119+
110120/// Request data for operations
111121#[ derive( Debug , Deserialize , PartialEq , Validate ) ]
112122#[ serde( deny_unknown_fields) ]
@@ -137,7 +147,9 @@ pub struct RequestData {
137147 ) ]
138148 pub shape : Option < Vec < usize > > ,
139149 /// Axis over which to perform the reduction operation
140- pub axis : Option < usize > ,
150+ // pub axis: Option<usize>,
151+ #[ serde( default ) ]
152+ pub axes : ReductionAxes ,
141153 /// Order of the multi-dimensional array
142154 pub order : Option < Order > ,
143155 /// Subset of the data to operate on
@@ -237,20 +249,34 @@ fn validate_request_data(request_data: &RequestData) -> Result<(), ValidationErr
237249 _ => ( ) ,
238250 } ;
239251 // Check axis is compatible with shape
240- match ( & request_data. shape , & request_data. axis ) {
241- ( Some ( shape) , Some ( axis) ) => {
252+ match ( & request_data. shape , & request_data. axes ) {
253+ ( Some ( shape) , ReductionAxes :: One ( axis) ) => {
242254 if * axis > shape. len ( ) - 1 {
243255 return Err ( ValidationError :: new ( "Axis must be within shape" ) ) ;
244256 }
245257 }
246- ( None , Some ( _) ) => {
258+ ( Some ( shape) , ReductionAxes :: Multi ( axes) ) => {
259+ if axes. len ( ) >= shape. len ( ) {
260+ return Err ( ValidationError :: new (
261+ "Number of reduction axes must be less than length of shape" ,
262+ ) ) ;
263+ }
264+ for ax in axes {
265+ if * ax > shape. len ( ) - 1 {
266+ return Err ( ValidationError :: new ( "All axes must be within shape" ) ) ;
267+ }
268+ }
269+ }
270+ ( None , ReductionAxes :: One ( _) | ReductionAxes :: Multi ( _) ) => {
247271 return Err ( ValidationError :: new ( "Axis requires shape to be specified" ) ) ;
248272 }
249- _ => ( ) ,
273+ ( _ , ReductionAxes :: All ) => ( ) ,
250274 } ;
275+ // Validate missing specification
251276 if let Some ( missing) = & request_data. missing {
252277 missing. validate ( request_data. dtype ) ?;
253- } ;
278+ }
279+
254280 Ok ( ( ) )
255281}
256282
@@ -346,21 +372,24 @@ mod tests {
346372 Token :: U32 ( 8 ) ,
347373 Token :: Str ( "shape" ) ,
348374 Token :: Some ,
349- Token :: Seq { len : Some ( 2 ) } ,
375+ Token :: Seq { len : Some ( 3 ) } ,
350376 Token :: U32 ( 2 ) ,
351377 Token :: U32 ( 5 ) ,
378+ Token :: U32 ( 1 ) ,
352379 Token :: SeqEnd ,
353- Token :: Str ( "axis " ) ,
354- Token :: Some ,
380+ Token :: Str ( "axes " ) ,
381+ Token :: Seq { len : Some ( 2 ) } ,
355382 Token :: U32 ( 1 ) ,
383+ Token :: U32 ( 2 ) ,
384+ Token :: SeqEnd ,
356385 Token :: Str ( "order" ) ,
357386 Token :: Some ,
358387 Token :: Enum { name : "Order" } ,
359388 Token :: Str ( "C" ) ,
360389 Token :: Unit ,
361390 Token :: Str ( "selection" ) ,
362391 Token :: Some ,
363- Token :: Seq { len : Some ( 2 ) } ,
392+ Token :: Seq { len : Some ( 3 ) } ,
364393 Token :: Seq { len : Some ( 3 ) } ,
365394 Token :: U32 ( 1 ) ,
366395 Token :: U32 ( 2 ) ,
@@ -371,6 +400,11 @@ mod tests {
371400 Token :: U32 ( 5 ) ,
372401 Token :: U32 ( 6 ) ,
373402 Token :: SeqEnd ,
403+ Token :: Seq { len : Some ( 3 ) } ,
404+ Token :: U32 ( 1 ) ,
405+ Token :: U32 ( 1 ) ,
406+ Token :: U32 ( 1 ) ,
407+ Token :: SeqEnd ,
374408 Token :: SeqEnd ,
375409 Token :: Str ( "compression" ) ,
376410 Token :: Some ,
@@ -590,7 +624,7 @@ mod tests {
590624
591625 #[ test]
592626 fn test_selection_end_lt_start ( ) {
593- // Numpy sementics : start >= end yields an empty array
627+ // Numpy semantics : start >= end yields an empty array
594628 let mut request_data = test_utils:: get_test_request_data ( ) ;
595629 request_data. shape = Some ( vec ! [ 1 ] ) ;
596630 request_data. selection = Some ( vec ! [ Slice :: new( 1 , 0 , 1 ) ] ) ;
@@ -635,7 +669,7 @@ mod tests {
635669
636670 #[ test]
637671 fn test_selection_start_gt_shape ( ) {
638- // Numpy sementics : start > length yields an empty array
672+ // Numpy semantics : start > length yields an empty array
639673 let mut request_data = test_utils:: get_test_request_data ( ) ;
640674 request_data. shape = Some ( vec ! [ 4 ] ) ;
641675 request_data. selection = Some ( vec ! [ Slice :: new( 5 , 5 , 1 ) ] ) ;
@@ -644,7 +678,7 @@ mod tests {
644678
645679 #[ test]
646680 fn test_selection_start_lt_negative_shape ( ) {
647- // Numpy sementics : start < -length gets clamped to zero
681+ // Numpy semantics : start < -length gets clamped to zero
648682 let mut request_data = test_utils:: get_test_request_data ( ) ;
649683 request_data. shape = Some ( vec ! [ 4 ] ) ;
650684 request_data. selection = Some ( vec ! [ Slice :: new( -5 , 5 , 1 ) ] ) ;
@@ -681,15 +715,15 @@ mod tests {
681715 #[ should_panic( expected = "Axis requires shape to be specified" ) ]
682716 fn test_axis_without_shape ( ) {
683717 let mut request_data = test_utils:: get_test_request_data ( ) ;
684- request_data. axis = Some ( 1 ) ;
718+ request_data. axes = ReductionAxes :: One ( 1 ) ;
685719 request_data. validate ( ) . unwrap ( )
686720 }
687721
688722 #[ test]
689723 #[ should_panic( expected = "Axis must be within shape" ) ]
690724 fn test_axis_gt_shape ( ) {
691725 let mut request_data = test_utils:: get_test_request_data ( ) ;
692- request_data. axis = Some ( 2 ) ;
726+ request_data. axes = ReductionAxes :: One ( 2 ) ;
693727 request_data. shape = Some ( vec ! [ 2 , 5 ] ) ;
694728 request_data. validate ( ) . unwrap ( )
695729 }
@@ -756,6 +790,7 @@ mod tests {
756790 fn test_missing_invalid_value_for_dtype ( ) {
757791 let mut request_data = test_utils:: get_test_request_data ( ) ;
758792 request_data. missing = Some ( Missing :: MissingValue ( i64:: max_value ( ) . into ( ) ) ) ;
793+ println ! ( "{:?}" , request_data. validate( ) ) ;
759794 request_data. validate ( ) . unwrap ( )
760795 }
761796
@@ -766,7 +801,7 @@ mod tests {
766801 Token :: Str ( "foo" ) ,
767802 Token :: StructEnd
768803 ] ,
769- "unknown field `foo`, expected one of `source`, `bucket`, `object`, `dtype`, `byte_order`, `offset`, `size`, `shape`, `axis `, `order`, `selection`, `compression`, `filters`, `missing`"
804+ "unknown field `foo`, expected one of `source`, `bucket`, `object`, `dtype`, `byte_order`, `offset`, `size`, `shape`, `axes `, `order`, `selection`, `compression`, `filters`, `missing`"
770805 )
771806 }
772807
@@ -789,10 +824,10 @@ mod tests {
789824 "byte_order": "little",
790825 "offset": 4,
791826 "size": 8,
792- "shape": [2, 5],
793- "axis ": 1 ,
827+ "shape": [2, 5, 1 ],
828+ "axes ": [1, 2] ,
794829 "order": "C",
795- "selection": [[1, 2, 3], [4, 5, 6]],
830+ "selection": [[1, 2, 3], [4, 5, 6], [1, 1, 1] ],
796831 "compression": {"id": "gzip"},
797832 "filters": [{"id": "shuffle", "element_size": 4}],
798833 "missing": {"missing_value": 42}
@@ -812,7 +847,7 @@ mod tests {
812847 "offset": 4,
813848 "size": 8,
814849 "shape": [2, 5, 10],
815- "axis ": 2,
850+ "axes ": 2,
816851 "order": "F",
817852 "selection": [[1, 2, 3], [4, 5, 6], [7, 8, 9]],
818853 "compression": {"id": "zlib"},
@@ -824,7 +859,7 @@ mod tests {
824859 expected. dtype = DType :: Float64 ;
825860 expected. byte_order = Some ( ByteOrder :: Big ) ;
826861 expected. shape = Some ( vec ! [ 2 , 5 , 10 ] ) ;
827- expected. axis = Some ( 2 ) ;
862+ expected. axes = ReductionAxes :: One ( 2 ) ;
828863 expected. order = Some ( Order :: F ) ;
829864 expected. selection = Some ( vec ! [
830865 Slice :: new( 1 , 2 , 3 ) ,
0 commit comments