@@ -255,18 +255,36 @@ fn validate_request_data(request_data: &RequestData) -> Result<(), ValidationErr
255255 }
256256 }
257257 ( Some ( shape) , ReductionAxes :: Multi ( axes) ) => {
258+ // Check we've not been given too many axes
258259 if axes. len ( ) >= shape. len ( ) {
259260 return Err ( ValidationError :: new (
260261 "Number of reduction axes must be less than length of shape - to reduce over all axes omit the axis field completely" ,
261262 ) ) ;
262263 }
264+ // Check axes are ordered correctly
265+ // NOTE(sd109): We could mutate request data to sort the axes
266+ // but it's also trivial to do on the Python client side
267+ let mut sorted_axes = axes. clone ( ) ;
268+ sorted_axes. sort ( ) ;
269+ if & sorted_axes != axes {
270+ return Err ( ValidationError :: new (
271+ "Reduction axes must be provided in ascending order" ,
272+ ) ) ;
273+ }
274+ // Check axes are valid for given shape
263275 for ax in axes {
264276 if * ax > shape. len ( ) - 1 {
265277 return Err ( ValidationError :: new (
266278 "All reduction axes must be within shape" ,
267279 ) ) ;
268280 }
269281 }
282+ // Check we've not been given duplicate axes
283+ for ax in axes {
284+ if axes. iter ( ) . filter ( |val| * val == ax) . count ( ) != 1 {
285+ return Err ( ValidationError :: new ( "Reduction axes contains duplicates" ) ) ;
286+ }
287+ }
270288 }
271289 ( None , ReductionAxes :: One ( _) | ReductionAxes :: Multi ( _) ) => {
272290 return Err ( ValidationError :: new ( "Axis requires shape to be specified" ) ) ;
@@ -730,6 +748,24 @@ mod tests {
730748 request_data. validate ( ) . unwrap ( )
731749 }
732750
751+ #[ test]
752+ #[ should_panic( expected = "Reduction axes must be provided in ascending order" ) ]
753+ fn test_axis_unsorted ( ) {
754+ let mut request_data = test_utils:: get_test_request_data ( ) ;
755+ request_data. axis = ReductionAxes :: Multi ( vec ! [ 1 , 0 ] ) ;
756+ request_data. shape = Some ( vec ! [ 2 , 5 , 1 ] ) ;
757+ request_data. validate ( ) . unwrap ( )
758+ }
759+
760+ #[ test]
761+ #[ should_panic( expected = "Reduction axes contains duplicates" ) ]
762+ fn test_axis_duplicated ( ) {
763+ let mut request_data = test_utils:: get_test_request_data ( ) ;
764+ request_data. axis = ReductionAxes :: Multi ( vec ! [ 1 , 1 ] ) ;
765+ request_data. shape = Some ( vec ! [ 2 , 5 , 1 ] ) ;
766+ request_data. validate ( ) . unwrap ( )
767+ }
768+
733769 #[ test]
734770 fn test_invalid_compression ( ) {
735771 assert_de_tokens_error :: < RequestData > (
0 commit comments