Skip to content

Commit b0c75b5

Browse files
committed
Improve validation of multi-axis reduction requests
1 parent 3d9625d commit b0c75b5

File tree

1 file changed

+36
-0
lines changed

1 file changed

+36
-0
lines changed

src/models.rs

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -255,18 +255,36 @@ fn validate_request_data(request_data: &RequestData) -> Result<(), ValidationErr
255255
}
256256
}
257257
(Some(shape), ReductionAxes::Multi(axes)) => {
258+
// Check we've not been given too many axes
258259
if axes.len() >= shape.len() {
259260
return Err(ValidationError::new(
260261
"Number of reduction axes must be less than length of shape - to reduce over all axes omit the axis field completely",
261262
));
262263
}
264+
// Check axes are ordered correctly
265+
// NOTE(sd109): We could mutate request data to sort the axes
266+
// but it's also trivial to do on the Python client side
267+
let mut sorted_axes = axes.clone();
268+
sorted_axes.sort();
269+
if &sorted_axes != axes {
270+
return Err(ValidationError::new(
271+
"Reduction axes must be provided in ascending order",
272+
));
273+
}
274+
// Check axes are valid for given shape
263275
for ax in axes {
264276
if *ax > shape.len() - 1 {
265277
return Err(ValidationError::new(
266278
"All reduction axes must be within shape",
267279
));
268280
}
269281
}
282+
// Check we've not been given duplicate axes
283+
for ax in axes {
284+
if axes.iter().filter(|val| *val == ax).count() != 1 {
285+
return Err(ValidationError::new("Reduction axes contains duplicates"));
286+
}
287+
}
270288
}
271289
(None, ReductionAxes::One(_) | ReductionAxes::Multi(_)) => {
272290
return Err(ValidationError::new("Axis requires shape to be specified"));
@@ -730,6 +748,24 @@ mod tests {
730748
request_data.validate().unwrap()
731749
}
732750

751+
#[test]
752+
#[should_panic(expected = "Reduction axes must be provided in ascending order")]
753+
fn test_axis_unsorted() {
754+
let mut request_data = test_utils::get_test_request_data();
755+
request_data.axis = ReductionAxes::Multi(vec![1, 0]);
756+
request_data.shape = Some(vec![2, 5, 1]);
757+
request_data.validate().unwrap()
758+
}
759+
760+
#[test]
761+
#[should_panic(expected = "Reduction axes contains duplicates")]
762+
fn test_axis_duplicated() {
763+
let mut request_data = test_utils::get_test_request_data();
764+
request_data.axis = ReductionAxes::Multi(vec![1, 1]);
765+
request_data.shape = Some(vec![2, 5, 1]);
766+
request_data.validate().unwrap()
767+
}
768+
733769
#[test]
734770
fn test_invalid_compression() {
735771
assert_de_tokens_error::<RequestData>(

0 commit comments

Comments
 (0)