Skip to content

Commit 8632cd5

Browse files
committed
Implement sum over multiple axes
1 parent df7d6d1 commit 8632cd5

File tree

8 files changed

+235
-82
lines changed

8 files changed

+235
-82
lines changed

benches/byte_order.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@ fn get_test_request_data() -> RequestData {
1414
offset: None,
1515
size: None,
1616
shape: None,
17-
axis: None,
17+
axes: reductionist::models::ReductionAxes::All,
1818
order: None,
1919
selection: None,
2020
compression: None,

benches/operations.rs

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@ fn get_test_request_data() -> RequestData {
1919
offset: None,
2020
size: None,
2121
shape: None,
22-
axis: None,
22+
axes: reductionist::models::ReductionAxes::All,
2323
order: None,
2424
selection: None,
2525
compression: None,
@@ -35,7 +35,7 @@ fn criterion_benchmark(c: &mut Criterion) {
3535
let size = size_k * 1024;
3636
let data: Vec<i64> = (0_i64..size).map(|i| i % 256).collect::<Vec<i64>>();
3737
let data: Vec<u8> = data.as_bytes().into();
38-
let missings = vec![
38+
let missing_types = vec![
3939
None,
4040
Some(Missing::MissingValue(42.into())),
4141
Some(Missing::MissingValues(vec![42.into()])),
@@ -51,7 +51,7 @@ fn criterion_benchmark(c: &mut Criterion) {
5151
("sum", Box::new(operations::Sum::execute)),
5252
];
5353
for (op_name, execute) in operations {
54-
for missing in missings.clone() {
54+
for missing in missing_types.clone() {
5555
let name = format!("{}({}, {:?})", op_name, size, missing);
5656
c.bench_function(&name, |b| {
5757
b.iter(|| {

scripts/client.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,7 @@ def get_args() -> argparse.Namespace:
3636
parser.add_argument("--offset", type=int)
3737
parser.add_argument("--size", type=int)
3838
parser.add_argument("--shape", type=str)
39-
parser.add_argument("--axis", type=int)
39+
parser.add_argument("--axes", type=str)
4040
parser.add_argument("--order", default="C") #, choices=["C", "F"]) allow invalid for testing
4141
parser.add_argument("--selection", type=str)
4242
parser.add_argument("--compression", type=str)
@@ -73,8 +73,8 @@ def build_request_data(args: argparse.Namespace) -> dict:
7373
request_data["byte_order"] = args.byte_order
7474
if args.shape:
7575
request_data["shape"] = json.loads(args.shape)
76-
if args.axis is not None:
77-
request_data["axis"] = args.axis
76+
if args.axes is not None:
77+
request_data["axes"] = json.loads(args.axes)
7878
if args.selection:
7979
request_data["selection"] = json.loads(args.selection)
8080
if args.compression:

scripts/parallel-client.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,7 @@ def get_args() -> argparse.Namespace:
4848
parser.add_argument("--offset", type=int)
4949
parser.add_argument("--size", type=int)
5050
parser.add_argument("--shape", type=str)
51-
parser.add_argument("--axis", type=int)
51+
parser.add_argument("--axes", type=str)
5252
parser.add_argument("--order", default="C") #, choices=["C", "F"]) allow invalid for testing
5353
parser.add_argument("--selection", type=str)
5454
parser.add_argument("--compression", type=str)
@@ -91,8 +91,8 @@ def build_request_data(args: argparse.Namespace) -> dict:
9191
request_data["byte_order"] = args.byte_order
9292
if args.shape:
9393
request_data["shape"] = json.loads(args.shape)
94-
if args.axis is not None:
95-
request_data["axis"] = args.axis
94+
if args.axes is not None:
95+
request_data["axes"] = json.loads(args.axes)
9696
if args.selection:
9797
request_data["selection"] = json.loads(args.selection)
9898
if args.compression:

src/models.rs

Lines changed: 56 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -107,6 +107,16 @@ pub enum Filter {
107107
Shuffle { element_size: usize },
108108
}
109109

110+
/// Axes over which to perform the reduction
111+
#[derive(Debug, PartialEq, Default, Deserialize)]
112+
#[serde(rename_all = "lowercase", untagged)]
113+
pub enum ReductionAxes {
114+
#[default]
115+
All,
116+
One(usize),
117+
Multi(Vec<usize>),
118+
}
119+
110120
/// Request data for operations
111121
#[derive(Debug, Deserialize, PartialEq, Validate)]
112122
#[serde(deny_unknown_fields)]
@@ -137,7 +147,9 @@ pub struct RequestData {
137147
)]
138148
pub shape: Option<Vec<usize>>,
139149
/// Axis over which to perform the reduction operation
140-
pub axis: Option<usize>,
150+
// pub axis: Option<usize>,
151+
#[serde(default)]
152+
pub axes: ReductionAxes,
141153
/// Order of the multi-dimensional array
142154
pub order: Option<Order>,
143155
/// Subset of the data to operate on
@@ -237,20 +249,34 @@ fn validate_request_data(request_data: &RequestData) -> Result<(), ValidationErr
237249
_ => (),
238250
};
239251
// Check axis is compatible with shape
240-
match (&request_data.shape, &request_data.axis) {
241-
(Some(shape), Some(axis)) => {
252+
match (&request_data.shape, &request_data.axes) {
253+
(Some(shape), ReductionAxes::One(axis)) => {
242254
if *axis > shape.len() - 1 {
243255
return Err(ValidationError::new("Axis must be within shape"));
244256
}
245257
}
246-
(None, Some(_)) => {
258+
(Some(shape), ReductionAxes::Multi(axes)) => {
259+
if axes.len() >= shape.len() {
260+
return Err(ValidationError::new(
261+
"Number of reduction axes must be less than length of shape",
262+
));
263+
}
264+
for ax in axes {
265+
if *ax > shape.len() - 1 {
266+
return Err(ValidationError::new("All axes must be within shape"));
267+
}
268+
}
269+
}
270+
(None, ReductionAxes::One(_) | ReductionAxes::Multi(_)) => {
247271
return Err(ValidationError::new("Axis requires shape to be specified"));
248272
}
249-
_ => (),
273+
(_, ReductionAxes::All) => (),
250274
};
275+
// Validate missing specification
251276
if let Some(missing) = &request_data.missing {
252277
missing.validate(request_data.dtype)?;
253-
};
278+
}
279+
254280
Ok(())
255281
}
256282

@@ -346,21 +372,24 @@ mod tests {
346372
Token::U32(8),
347373
Token::Str("shape"),
348374
Token::Some,
349-
Token::Seq { len: Some(2) },
375+
Token::Seq { len: Some(3) },
350376
Token::U32(2),
351377
Token::U32(5),
378+
Token::U32(1),
352379
Token::SeqEnd,
353-
Token::Str("axis"),
354-
Token::Some,
380+
Token::Str("axes"),
381+
Token::Seq { len: Some(2) },
355382
Token::U32(1),
383+
Token::U32(2),
384+
Token::SeqEnd,
356385
Token::Str("order"),
357386
Token::Some,
358387
Token::Enum { name: "Order" },
359388
Token::Str("C"),
360389
Token::Unit,
361390
Token::Str("selection"),
362391
Token::Some,
363-
Token::Seq { len: Some(2) },
392+
Token::Seq { len: Some(3) },
364393
Token::Seq { len: Some(3) },
365394
Token::U32(1),
366395
Token::U32(2),
@@ -371,6 +400,11 @@ mod tests {
371400
Token::U32(5),
372401
Token::U32(6),
373402
Token::SeqEnd,
403+
Token::Seq { len: Some(3) },
404+
Token::U32(1),
405+
Token::U32(1),
406+
Token::U32(1),
407+
Token::SeqEnd,
374408
Token::SeqEnd,
375409
Token::Str("compression"),
376410
Token::Some,
@@ -590,7 +624,7 @@ mod tests {
590624

591625
#[test]
592626
fn test_selection_end_lt_start() {
593-
// Numpy sementics: start >= end yields an empty array
627+
// Numpy semantics: start >= end yields an empty array
594628
let mut request_data = test_utils::get_test_request_data();
595629
request_data.shape = Some(vec![1]);
596630
request_data.selection = Some(vec![Slice::new(1, 0, 1)]);
@@ -635,7 +669,7 @@ mod tests {
635669

636670
#[test]
637671
fn test_selection_start_gt_shape() {
638-
// Numpy sementics: start > length yields an empty array
672+
// Numpy semantics: start > length yields an empty array
639673
let mut request_data = test_utils::get_test_request_data();
640674
request_data.shape = Some(vec![4]);
641675
request_data.selection = Some(vec![Slice::new(5, 5, 1)]);
@@ -644,7 +678,7 @@ mod tests {
644678

645679
#[test]
646680
fn test_selection_start_lt_negative_shape() {
647-
// Numpy sementics: start < -length gets clamped to zero
681+
// Numpy semantics: start < -length gets clamped to zero
648682
let mut request_data = test_utils::get_test_request_data();
649683
request_data.shape = Some(vec![4]);
650684
request_data.selection = Some(vec![Slice::new(-5, 5, 1)]);
@@ -681,15 +715,15 @@ mod tests {
681715
#[should_panic(expected = "Axis requires shape to be specified")]
682716
fn test_axis_without_shape() {
683717
let mut request_data = test_utils::get_test_request_data();
684-
request_data.axis = Some(1);
718+
request_data.axes = ReductionAxes::One(1);
685719
request_data.validate().unwrap()
686720
}
687721

688722
#[test]
689723
#[should_panic(expected = "Axis must be within shape")]
690724
fn test_axis_gt_shape() {
691725
let mut request_data = test_utils::get_test_request_data();
692-
request_data.axis = Some(2);
726+
request_data.axes = ReductionAxes::One(2);
693727
request_data.shape = Some(vec![2, 5]);
694728
request_data.validate().unwrap()
695729
}
@@ -756,6 +790,7 @@ mod tests {
756790
fn test_missing_invalid_value_for_dtype() {
757791
let mut request_data = test_utils::get_test_request_data();
758792
request_data.missing = Some(Missing::MissingValue(i64::max_value().into()));
793+
println!("{:?}", request_data.validate());
759794
request_data.validate().unwrap()
760795
}
761796

@@ -766,7 +801,7 @@ mod tests {
766801
Token::Str("foo"),
767802
Token::StructEnd
768803
],
769-
"unknown field `foo`, expected one of `source`, `bucket`, `object`, `dtype`, `byte_order`, `offset`, `size`, `shape`, `axis`, `order`, `selection`, `compression`, `filters`, `missing`"
804+
"unknown field `foo`, expected one of `source`, `bucket`, `object`, `dtype`, `byte_order`, `offset`, `size`, `shape`, `axes`, `order`, `selection`, `compression`, `filters`, `missing`"
770805
)
771806
}
772807

@@ -789,10 +824,10 @@ mod tests {
789824
"byte_order": "little",
790825
"offset": 4,
791826
"size": 8,
792-
"shape": [2, 5],
793-
"axis": 1,
827+
"shape": [2, 5, 1],
828+
"axes": [1, 2],
794829
"order": "C",
795-
"selection": [[1, 2, 3], [4, 5, 6]],
830+
"selection": [[1, 2, 3], [4, 5, 6], [1, 1, 1]],
796831
"compression": {"id": "gzip"},
797832
"filters": [{"id": "shuffle", "element_size": 4}],
798833
"missing": {"missing_value": 42}
@@ -812,7 +847,7 @@ mod tests {
812847
"offset": 4,
813848
"size": 8,
814849
"shape": [2, 5, 10],
815-
"axis": 2,
850+
"axes": 2,
816851
"order": "F",
817852
"selection": [[1, 2, 3], [4, 5, 6], [7, 8, 9]],
818853
"compression": {"id": "zlib"},
@@ -824,7 +859,7 @@ mod tests {
824859
expected.dtype = DType::Float64;
825860
expected.byte_order = Some(ByteOrder::Big);
826861
expected.shape = Some(vec![2, 5, 10]);
827-
expected.axis = Some(2);
862+
expected.axes = ReductionAxes::One(2);
828863
expected.order = Some(Order::F);
829864
expected.selection = Some(vec![
830865
Slice::new(1, 2, 3),

0 commit comments

Comments
 (0)