Skip to content

Commit e0742b7

Browse files
committed
Add axis field to request data model
1 parent ff4f335 commit e0742b7

File tree

4 files changed

+43
-1
lines changed

4 files changed

+43
-1
lines changed

benches/byte_order.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@ fn get_test_request_data() -> RequestData {
1414
offset: None,
1515
size: None,
1616
shape: None,
17+
axis: None,
1718
order: None,
1819
selection: None,
1920
compression: None,

benches/operations.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@ fn get_test_request_data() -> RequestData {
1919
offset: None,
2020
size: None,
2121
shape: None,
22+
axis: None,
2223
order: None,
2324
selection: None,
2425
compression: None,

src/models.rs

Lines changed: 39 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -136,6 +136,8 @@ pub struct RequestData {
136136
custom = "validate_shape"
137137
)]
138138
pub shape: Option<Vec<usize>>,
139+
/// Axis over which to perform the reduction operation
140+
pub axis: Option<usize>,
139141
/// Order of the multi-dimensional array
140142
pub order: Option<Order>,
141143
/// Subset of the data to operate on
@@ -222,6 +224,7 @@ fn validate_request_data(request_data: &RequestData) -> Result<(), ValidationErr
222224
validate_raw_size(*size, request_data.dtype, &request_data.shape)?;
223225
}
224226
};
227+
// Check selection is compatible with shape
225228
match (&request_data.shape, &request_data.selection) {
226229
(Some(shape), Some(selection)) => {
227230
validate_shape_selection(shape, selection)?;
@@ -233,6 +236,18 @@ fn validate_request_data(request_data: &RequestData) -> Result<(), ValidationErr
233236
}
234237
_ => (),
235238
};
239+
// Check axis is compatible with shape
240+
match (&request_data.shape, &request_data.axis) {
241+
(Some(shape), Some(axis)) => {
242+
if *axis > shape.len() - 1 {
243+
return Err(ValidationError::new("Axis must be within shape"));
244+
}
245+
}
246+
(None, Some(_)) => {
247+
return Err(ValidationError::new("Axis requires shape to be specified"));
248+
}
249+
_ => (),
250+
};
236251
if let Some(missing) = &request_data.missing {
237252
missing.validate(request_data.dtype)?;
238253
};
@@ -335,6 +350,9 @@ mod tests {
335350
Token::U32(2),
336351
Token::U32(5),
337352
Token::SeqEnd,
353+
Token::Str("axis"),
354+
Token::Some,
355+
Token::U32(1),
338356
Token::Str("order"),
339357
Token::Some,
340358
Token::Enum { name: "Order" },
@@ -659,6 +677,23 @@ mod tests {
659677
request_data.validate().unwrap()
660678
}
661679

680+
#[test]
681+
#[should_panic(expected = "Axis requires shape to be specified")]
682+
fn test_axis_without_shape() {
683+
let mut request_data = test_utils::get_test_request_data();
684+
request_data.axis = Some(1);
685+
request_data.validate().unwrap()
686+
}
687+
688+
#[test]
689+
#[should_panic(expected = "Axis must be within shape")]
690+
fn test_axis_gt_shape() {
691+
let mut request_data = test_utils::get_test_request_data();
692+
request_data.axis = Some(2);
693+
request_data.shape = Some(vec![2, 5]);
694+
request_data.validate().unwrap()
695+
}
696+
662697
#[test]
663698
fn test_invalid_compression() {
664699
assert_de_tokens_error::<RequestData>(
@@ -731,7 +766,7 @@ mod tests {
731766
Token::Str("foo"),
732767
Token::StructEnd
733768
],
734-
"unknown field `foo`, expected one of `source`, `bucket`, `object`, `dtype`, `byte_order`, `offset`, `size`, `shape`, `order`, `selection`, `compression`, `filters`, `missing`"
769+
"unknown field `foo`, expected one of `source`, `bucket`, `object`, `dtype`, `byte_order`, `offset`, `size`, `shape`, `axis`, `order`, `selection`, `compression`, `filters`, `missing`"
735770
)
736771
}
737772

@@ -755,6 +790,7 @@ mod tests {
755790
"offset": 4,
756791
"size": 8,
757792
"shape": [2, 5],
793+
"axis": 1,
758794
"order": "C",
759795
"selection": [[1, 2, 3], [4, 5, 6]],
760796
"compression": {"id": "gzip"},
@@ -776,6 +812,7 @@ mod tests {
776812
"offset": 4,
777813
"size": 8,
778814
"shape": [2, 5, 10],
815+
"axis": 2,
779816
"order": "F",
780817
"selection": [[1, 2, 3], [4, 5, 6], [7, 8, 9]],
781818
"compression": {"id": "zlib"},
@@ -787,6 +824,7 @@ mod tests {
787824
expected.dtype = DType::Float64;
788825
expected.byte_order = Some(ByteOrder::Big);
789826
expected.shape = Some(vec![2, 5, 10]);
827+
expected.axis = Some(2);
790828
expected.order = Some(Order::F);
791829
expected.selection = Some(vec![
792830
Slice::new(1, 2, 3),

src/test_utils.rs

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@ pub(crate) fn get_test_request_data() -> RequestData {
1313
byte_order: None,
1414
offset: None,
1515
size: None,
16+
axis: None,
1617
shape: None,
1718
order: None,
1819
selection: None,
@@ -32,6 +33,7 @@ pub(crate) fn get_test_request_data_optional() -> RequestData {
3233
byte_order: Some(ByteOrder::Little),
3334
offset: Some(4),
3435
size: Some(8),
36+
axis: Some(1),
3537
shape: Some(vec![2, 5]),
3638
order: Some(Order::C),
3739
selection: Some(vec![Slice::new(1, 2, 3), Slice::new(4, 5, 6)]),

0 commit comments

Comments
 (0)