Skip to content

Commit 09c967f

Browse files
committed
Implement max over multiple axes
1 parent 456c74e commit 09c967f

File tree

1 file changed

+197
-43
lines changed

1 file changed

+197
-43
lines changed

src/operations.rs

Lines changed: 197 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
//! Each operation is implemented as a struct that implements the
44
//! [Operation](crate::operation::Operation) trait.
55
6-
use std::cmp::min_by;
6+
use std::cmp::{max_by, min_by};
77

88
use crate::array;
99
use crate::error::ActiveStorageError;
@@ -13,7 +13,6 @@ use crate::types::Missing;
1313

1414
use axum::body::Bytes;
1515
use ndarray::{ArrayView, Axis};
16-
use ndarray_stats::{errors::MinMaxError, QuantileExt};
1716
// Bring trait into scope to use as_bytes method.
1817
use zerocopy::AsBytes;
1918

@@ -80,6 +79,72 @@ impl NumOperation for Count {
8079
/// Return the maximum of selected elements in the array.
8180
pub struct Max {}
8281

82+
fn max_element_pairwise<T: Element>(x: &&T, y: &&T) -> std::cmp::Ordering {
83+
// TODO: How to handle NaN correctly?
84+
// Numpy seems to behave as follows:
85+
//
86+
// np.min([np.nan, 1]) == np.nan
87+
// np.max([np.nan, 1]) == np.nan
88+
// np.nan != np.nan
89+
// np.min([np.nan, 1]) != np.max([np.nan, 1])
90+
//
91+
// There are also separate np.nan{min,max} functions
92+
// which ignore nans instead.
93+
//
94+
// Which behaviour do we want to follow?
95+
//
96+
// Panic for now (TODO: Make this a user-facing error response instead)
97+
x.partial_cmp(y)
98+
// .unwrap_or(std::cmp::Ordering::Less)
99+
.unwrap_or_else(|| panic!("unexpected undefined order error for min"))
100+
}
101+
102+
/// Performs a max over one or more axes of the provided array
103+
fn max_array_multi_axis<T: Element>(
104+
array: ndarray::ArrayView<T, ndarray::IxDyn>,
105+
axes: &[usize],
106+
missing: Option<Missing<T>>,
107+
) -> (Vec<T>, Vec<i64>, Vec<usize>) {
108+
// Find maximum over first axis and count elements operated on
109+
let init = T::min_value();
110+
let mut result = array
111+
.fold_axis(Axis(axes[0]), (init, 0), |(running_max, count), val| {
112+
if let Some(missing) = &missing {
113+
if !missing.is_missing(val) {
114+
let new_max = max_by(running_max, val, max_element_pairwise);
115+
(*new_max, count + 1)
116+
} else {
117+
(*running_max, *count)
118+
}
119+
} else {
120+
let new_max = max_by(running_max, val, max_element_pairwise);
121+
(*new_max, count + 1)
122+
}
123+
})
124+
.into_dyn();
125+
// Find max over remaining axes (where total count is now sum of counts)
126+
if let Some(remaining_axes) = axes.get(1..) {
127+
for (n, axis) in remaining_axes.iter().enumerate() {
128+
result = result
129+
.fold_axis(
130+
Axis(axis - n - 1),
131+
(init, 0),
132+
|(global_max, total_count), (running_max, count)| {
133+
let new_max = max_by(global_max, running_max, max_element_pairwise);
134+
(*new_max, total_count + count)
135+
},
136+
)
137+
.into_dyn();
138+
}
139+
}
140+
141+
// Result is array of (max, count) tuples so separate them here
142+
let maxes = result.iter().map(|(max, _)| *max).collect::<Vec<T>>();
143+
let counts = result.iter().map(|(_, count)| *count).collect::<Vec<i64>>();
144+
145+
(maxes, counts, result.shape().into())
146+
}
147+
83148
impl NumOperation for Max {
84149
fn execute_t<T: Element>(
85150
request_data: &models::RequestData,
@@ -88,44 +153,74 @@ impl NumOperation for Max {
88153
let array = array::build_array::<T>(request_data, &mut data)?;
89154
let slice_info = array::build_slice_info::<T>(&request_data.selection, array.shape());
90155
let sliced = array.slice(slice_info);
91-
let (max, count) = if let Some(missing) = &request_data.missing {
92-
let missing = Missing::<T>::try_from(missing)?;
93-
// Use a fold to simultaneously max and count the non-missing data.
94-
// TODO: separate float impl?
95-
// TODO: inifinite/NaN
96-
let (max, count) = sliced
97-
.iter()
98-
.copied()
99-
.filter(missing_filter(&missing))
100-
.fold((None, 0), |(a, count), b| {
101-
let max = match (a, b) {
102-
(None, b) => Some(b), //FIXME: if b.is_finite() { Some(b) } else { None },
103-
(Some(a), b) => Some(std::cmp::max_by(a, b, |x, y| {
104-
x.partial_cmp(y).unwrap_or(std::cmp::Ordering::Greater)
105-
})),
106-
};
107-
(max, count + 1)
108-
});
109-
let max = max.ok_or(ActiveStorageError::EmptyArray { operation: "max" })?;
110-
(max, count)
156+
157+
let typed_missing: Option<Missing<T>> = if let Some(missing) = &request_data.missing {
158+
let m = Missing::try_from(missing)?;
159+
Some(m)
111160
} else {
112-
let max = *sliced.max().map_err(|err| match err {
113-
MinMaxError::EmptyInput => ActiveStorageError::EmptyArray { operation: "max" },
114-
MinMaxError::UndefinedOrder => panic!("unexpected undefined order error for max"),
115-
})?;
116-
let count = sliced.len();
117-
(max, count)
161+
None
118162
};
119-
let count = i64::try_from(count)?;
120-
let body = max.as_bytes();
121-
// Need to copy to provide ownership to caller.
122-
let body = Bytes::copy_from_slice(body);
123-
Ok(models::Response::new(
124-
body,
125-
request_data.dtype,
126-
vec![],
127-
vec![count],
128-
))
163+
164+
match &request_data.axis {
165+
ReductionAxes::One(axis) => {
166+
let init = T::min_value();
167+
let result =
168+
sliced.fold_axis(Axis(*axis), (init, 0), |(running_max, count), val| {
169+
if let Some(missing) = &typed_missing {
170+
if !missing.is_missing(val) {
171+
(*max_by(running_max, val, max_element_pairwise), count + 1)
172+
} else {
173+
(*running_max, *count)
174+
}
175+
} else {
176+
(*max_by(running_max, val, max_element_pairwise), count + 1)
177+
}
178+
});
179+
let maxes = result.iter().map(|(max, _)| *max).collect::<Vec<T>>();
180+
let counts = result.iter().map(|(_, count)| *count).collect::<Vec<i64>>();
181+
let body = maxes.as_bytes();
182+
let body = Bytes::copy_from_slice(body);
183+
Ok(models::Response::new(
184+
body,
185+
request_data.dtype,
186+
result.shape().into(),
187+
counts,
188+
))
189+
}
190+
ReductionAxes::Multi(axes) => {
191+
let (maxes, counts, shape) = max_array_multi_axis(sliced, axes, typed_missing);
192+
let body = Bytes::copy_from_slice(maxes.as_bytes());
193+
Ok(models::Response::new(
194+
body,
195+
request_data.dtype,
196+
shape,
197+
counts,
198+
))
199+
}
200+
ReductionAxes::All => {
201+
let init = T::min_value();
202+
let (max, count) = sliced.fold((init, 0_i64), |(running_max, count), val| {
203+
if let Some(missing) = &typed_missing {
204+
if !missing.is_missing(val) {
205+
(*max_by(&running_max, val, max_element_pairwise), count + 1)
206+
} else {
207+
(running_max, count)
208+
}
209+
} else {
210+
(*max_by(&running_max, val, max_element_pairwise), count + 1)
211+
}
212+
});
213+
214+
let body = max.as_bytes();
215+
let body = Bytes::copy_from_slice(body);
216+
Ok(models::Response::new(
217+
body,
218+
request_data.dtype,
219+
vec![],
220+
vec![count],
221+
))
222+
}
223+
}
129224
}
130225
}
131226

@@ -146,7 +241,7 @@ fn min_element_pairwise<T: Element>(x: &&T, y: &&T) -> std::cmp::Ordering {
146241
//
147242
// Which behaviour do we want to follow?
148243
//
149-
// Panic is probably the best option for now...
244+
// Panic for now (TODO: Make this a user-facing error response instead)
150245
x.partial_cmp(y)
151246
// .unwrap_or(std::cmp::Ordering::Less)
152247
.unwrap_or_else(|| panic!("unexpected undefined order error for min"))
@@ -164,13 +259,13 @@ fn min_array_multi_axis<T: Element>(
164259
.fold_axis(Axis(axes[0]), (init, 0), |(running_min, count), val| {
165260
if let Some(missing) = &missing {
166261
if !missing.is_missing(val) {
167-
let new_min = std::cmp::min_by(running_min, val, min_element_pairwise);
262+
let new_min = min_by(running_min, val, min_element_pairwise);
168263
(*new_min, count + 1)
169264
} else {
170265
(*running_min, *count)
171266
}
172267
} else {
173-
let new_min = std::cmp::min_by(running_min, val, min_element_pairwise);
268+
let new_min = min_by(running_min, val, min_element_pairwise);
174269
(*new_min, count + 1)
175270
}
176271
})
@@ -184,8 +279,7 @@ fn min_array_multi_axis<T: Element>(
184279
(init, 0),
185280
|(global_min, total_count), (running_min, count)| {
186281
// (*global_min.min(running_min), total_count + count)
187-
let new_min =
188-
std::cmp::min_by(global_min, running_min, min_element_pairwise);
282+
let new_min = min_by(global_min, running_min, min_element_pairwise);
189283
(*new_min, total_count + count)
190284
},
191285
)
@@ -1004,4 +1098,64 @@ mod tests {
10041098
assert_eq!(counts, vec![6, 5]);
10051099
assert_eq!(shape, vec![2]);
10061100
}
1101+
1102+
#[test]
1103+
#[should_panic(expected = "assertion failed: axis.index() < self.ndim()")]
1104+
fn max_multi_axis_2d_wrong_axis() {
1105+
// Arrange
1106+
let array = ndarray::Array::from_shape_vec((2, 2), (0..4).collect())
1107+
.unwrap()
1108+
.into_dyn();
1109+
let axes = vec![2];
1110+
// Act
1111+
let _ = max_array_multi_axis(array.view(), &axes, None);
1112+
}
1113+
1114+
#[test]
1115+
fn max_multi_axis_2d_2ax() {
1116+
// Arrange
1117+
let axes = vec![0, 1];
1118+
let missing = None;
1119+
let arr = ndarray::Array::from_shape_vec((2, 3), (0..6).collect())
1120+
.unwrap()
1121+
.into_dyn();
1122+
// Act
1123+
let (result, counts, shape) = max_array_multi_axis(arr.view(), &axes, missing);
1124+
// Assert
1125+
assert_eq!(result, vec![5]);
1126+
assert_eq!(counts, vec![6]);
1127+
assert_eq!(shape, Vec::<usize>::new());
1128+
}
1129+
1130+
#[test]
1131+
fn max_multi_axis_2d_1ax_missing() {
1132+
// Arrange
1133+
let axes = vec![1];
1134+
let missing = Missing::MissingValue(0);
1135+
let arr = ndarray::Array::from_shape_vec((2, 3), (0..6).collect())
1136+
.unwrap()
1137+
.into_dyn();
1138+
// Act
1139+
let (result, counts, shape) = max_array_multi_axis(arr.view(), &axes, Some(missing));
1140+
// Assert
1141+
assert_eq!(result, vec![2, 5]);
1142+
assert_eq!(counts, vec![2, 3]);
1143+
assert_eq!(shape, vec![2]);
1144+
}
1145+
1146+
#[test]
1147+
fn max_multi_axis_4d_3ax_missing() {
1148+
// Arrange
1149+
let arr = ndarray::Array::from_shape_vec((2, 3, 2, 1), (0..12).collect())
1150+
.unwrap()
1151+
.into_dyn();
1152+
let axes = vec![0, 1, 3];
1153+
let missing = Missing::MissingValue(10);
1154+
// Act
1155+
let (result, counts, shape) = max_array_multi_axis(arr.view(), &axes, Some(missing));
1156+
// Assert
1157+
assert_eq!(result, vec![8, 11]);
1158+
assert_eq!(counts, vec![5, 6]);
1159+
assert_eq!(shape, vec![2]);
1160+
}
10071161
}

0 commit comments

Comments
 (0)