Skip to content

Commit 8c8b81b

Browse files
committed
Implement min over multiple axes
1 parent 8632cd5 commit 8c8b81b

File tree

2 files changed

+208
-41
lines changed

2 files changed

+208
-41
lines changed

src/operation.rs

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@ pub trait Element:
1313
+ num_traits::FromPrimitive
1414
+ num_traits::ToBytes
1515
+ num_traits::Zero
16+
+ num_traits::Bounded
1617
+ std::convert::From<u16>
1718
+ std::fmt::Debug
1819
+ std::iter::Sum
@@ -34,6 +35,7 @@ impl<T> Element for T where
3435
+ num_traits::One
3536
+ num_traits::ToBytes
3637
+ num_traits::Zero
38+
+ num_traits::Bounded
3739
+ std::convert::From<u16>
3840
+ std::fmt::Debug
3941
+ std::iter::Sum

src/operations.rs

Lines changed: 206 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,8 @@
33
//! Each operation is implemented as a struct that implements the
44
//! [Operation](crate::operation::Operation) trait.
55
6+
use std::cmp::min_by;
7+
68
use crate::array;
79
use crate::error::ActiveStorageError;
810
use crate::models::{self, ReductionAxes};
@@ -130,6 +132,74 @@ impl NumOperation for Max {
130132
/// Return the minimum of selected elements in the array.
131133
pub struct Min {}
132134

135+
fn min_element_pairwise<T: Element>(x: &&T, y: &&T) -> std::cmp::Ordering {
136+
// TODO: How to handle NaN correctly?
137+
// Numpy seems to behave as follows:
138+
//
139+
// np.min([np.nan, 1]) == np.nan
140+
// np.max([np.nan, 1]) == np.nan
141+
// np.nan != np.nan
142+
// np.min([np.nan, 1]) != np.max([np.nan, 1])
143+
//
144+
// There are also separate np.nan{min,max} functions
145+
// which ignore nans instead.
146+
//
147+
// Which behaviour do we want to follow?
148+
//
149+
// Panic is probably the best option for now...
150+
x.partial_cmp(y)
151+
// .unwrap_or(std::cmp::Ordering::Less)
152+
.unwrap_or_else(|| panic!("unexpected undefined order error for min"))
153+
}
154+
155+
/// Finds the minimum value over one or more axes of the provided array
156+
fn min_array_multi_axis<T: Element>(
157+
array: ndarray::ArrayView<T, ndarray::IxDyn>,
158+
axes: &[usize],
159+
missing: Option<Missing<T>>,
160+
) -> (Vec<T>, Vec<i64>, Vec<usize>) {
161+
// Find minimum over first axis and count elements operated on
162+
let init = T::max_value();
163+
let mut result = array
164+
.fold_axis(Axis(axes[0]), (init, 0), |(running_min, count), val| {
165+
if let Some(missing) = &missing {
166+
if !missing.is_missing(val) {
167+
let new_min = std::cmp::min_by(running_min, val, min_element_pairwise);
168+
(*new_min, count + 1)
169+
} else {
170+
(*running_min, *count)
171+
}
172+
} else {
173+
let new_min = std::cmp::min_by(running_min, val, min_element_pairwise);
174+
(*new_min, count + 1)
175+
}
176+
})
177+
.into_dyn();
178+
// Find min over remaining axes (where total count is now sum of counts)
179+
if let Some(remaining_axes) = axes.get(1..) {
180+
for (n, axis) in remaining_axes.iter().enumerate() {
181+
result = result
182+
.fold_axis(
183+
Axis(axis - n - 1),
184+
(init, 0),
185+
|(global_min, total_count), (running_min, count)| {
186+
// (*global_min.min(running_min), total_count + count)
187+
let new_min =
188+
std::cmp::min_by(global_min, running_min, min_element_pairwise);
189+
(*new_min, total_count + count)
190+
},
191+
)
192+
.into_dyn();
193+
}
194+
}
195+
196+
// Result is array of (mins, count) tuples so separate them here
197+
let mins = result.iter().map(|(min, _)| *min).collect::<Vec<T>>();
198+
let counts = result.iter().map(|(_, count)| *count).collect::<Vec<i64>>();
199+
200+
(mins, counts, result.shape().into())
201+
}
202+
133203
impl NumOperation for Min {
134204
fn execute_t<T: Element>(
135205
request_data: &models::RequestData,
@@ -138,44 +208,80 @@ impl NumOperation for Min {
138208
let array = array::build_array::<T>(request_data, &mut data)?;
139209
let slice_info = array::build_slice_info::<T>(&request_data.selection, array.shape());
140210
let sliced = array.slice(slice_info);
141-
let (min, count) = if let Some(missing) = &request_data.missing {
142-
let missing = Missing::<T>::try_from(missing)?;
143-
// Use a fold to simultaneously min and count the non-missing data.
144-
// TODO: separate float impl?
145-
// TODO: inifinite/NaN
146-
let (min, count) = sliced
147-
.iter()
148-
.copied()
149-
.filter(missing_filter(&missing))
150-
.fold((None, 0), |(a, count), b| {
151-
let min = match (a, b) {
152-
(None, b) => Some(b), //FIXME: if b.is_finite() { Some(b) } else { None },
153-
(Some(a), b) => Some(std::cmp::min_by(a, b, |x, y| {
154-
x.partial_cmp(y).unwrap_or(std::cmp::Ordering::Less)
155-
})),
156-
};
157-
(min, count + 1)
158-
});
159-
let min = min.ok_or(ActiveStorageError::EmptyArray { operation: "min" })?;
160-
(min, count)
211+
212+
// Convert Missing<Dtype> to Missing<T: Element>
213+
let typed_missing: Option<Missing<T>> = if let Some(missing) = &request_data.missing {
214+
let m = Missing::try_from(missing)?;
215+
Some(m)
161216
} else {
162-
let min = *sliced.min().map_err(|err| match err {
163-
MinMaxError::EmptyInput => ActiveStorageError::EmptyArray { operation: "min" },
164-
MinMaxError::UndefinedOrder => panic!("unexpected undefined order error for min"),
165-
})?;
166-
let count = sliced.len();
167-
(min, count)
217+
None
168218
};
169-
let count = i64::try_from(count)?;
170-
let body = min.as_bytes();
171-
// Need to copy to provide ownership to caller.
172-
let body = Bytes::copy_from_slice(body);
173-
Ok(models::Response::new(
174-
body,
175-
request_data.dtype,
176-
vec![],
177-
vec![count],
178-
))
219+
220+
// Use ndarray::fold, ndarray::fold_axis or dispatch to specialised
221+
// multi-axis function depending on whether we're performing reduction
222+
// over all axes or only a subset
223+
match &request_data.axes {
224+
ReductionAxes::One(axis) => {
225+
let init = T::max_value();
226+
let result =
227+
sliced.fold_axis(Axis(*axis), (init, 0), |(running_min, count), val| {
228+
if let Some(missing) = &typed_missing {
229+
if !missing.is_missing(val) {
230+
(*min_by(running_min, val, min_element_pairwise), count + 1)
231+
} else {
232+
(*running_min, *count)
233+
}
234+
} else {
235+
(*min_by(running_min, val, min_element_pairwise), count + 1)
236+
}
237+
});
238+
// Unpack the result tuples into separate vectors
239+
let mins = result.iter().map(|(min, _)| *min).collect::<Vec<T>>();
240+
let counts = result.iter().map(|(_, count)| *count).collect::<Vec<i64>>();
241+
let body = mins.as_bytes();
242+
let body = Bytes::copy_from_slice(body);
243+
Ok(models::Response::new(
244+
body,
245+
request_data.dtype,
246+
result.shape().into(),
247+
counts,
248+
))
249+
}
250+
ReductionAxes::Multi(axes) => {
251+
let (mins, counts, shape) = min_array_multi_axis(sliced, axes, typed_missing);
252+
let body = Bytes::copy_from_slice(mins.as_bytes());
253+
Ok(models::Response::new(
254+
body,
255+
request_data.dtype,
256+
shape,
257+
counts,
258+
))
259+
}
260+
ReductionAxes::All => {
261+
let init = T::max_value();
262+
let (min, count) = sliced.fold((init, 0_i64), |(running_min, count), val| {
263+
if let Some(missing) = &typed_missing {
264+
if !missing.is_missing(val) {
265+
(*min_by(&running_min, val, min_element_pairwise), count + 1)
266+
} else {
267+
(running_min, count)
268+
}
269+
} else {
270+
(*min_by(&running_min, val, min_element_pairwise), count + 1)
271+
}
272+
});
273+
274+
let body = min.as_bytes();
275+
// Need to copy to provide ownership to caller.
276+
let body = Bytes::copy_from_slice(body);
277+
Ok(models::Response::new(
278+
body,
279+
request_data.dtype,
280+
vec![],
281+
vec![count],
282+
))
283+
}
284+
}
179285
}
180286
}
181287

@@ -220,6 +326,7 @@ impl NumOperation for Select {
220326
/// Return the sum of selected elements in the array.
221327
pub struct Sum {}
222328

329+
/// Performs a sum over one or more axes of the provided array
223330
fn sum_array_multi_axis<T: Element>(
224331
array: ndarray::ArrayView<T, ndarray::IxDyn>,
225332
axes: &[usize],
@@ -551,6 +658,7 @@ mod tests {
551658
}
552659

553660
#[test]
661+
#[should_panic(expected = "unexpected undefined order error for min")]
554662
fn min_f32_1d_nan_missing_value() {
555663
let mut request_data = test_utils::get_test_request_data();
556664
request_data.dtype = models::DType::Float32;
@@ -567,6 +675,7 @@ mod tests {
567675
}
568676

569677
#[test]
678+
#[should_panic(expected = "unexpected undefined order error for min")]
570679
fn min_f32_1d_nan_first_missing_value() {
571680
let mut request_data = test_utils::get_test_request_data();
572681
request_data.dtype = models::DType::Float32;
@@ -783,7 +892,7 @@ mod tests {
783892

784893
#[test]
785894
#[should_panic(expected = "assertion failed: axis.index() < self.ndim()")]
786-
fn test_sum_multi_axis_2d_wrong_axis() {
895+
fn sum_multi_axis_2d_wrong_axis() {
787896
let array = ndarray::Array::from_shape_vec((2, 2), (0..4).collect())
788897
.unwrap()
789898
.into_dyn();
@@ -792,7 +901,7 @@ mod tests {
792901
}
793902

794903
#[test]
795-
fn test_sum_multi_axis_2d_2ax() {
904+
fn sum_multi_axis_2d_2ax() {
796905
let array = ndarray::Array::from_shape_vec((2, 2), (0..4).collect())
797906
.unwrap()
798907
.into_dyn();
@@ -804,7 +913,7 @@ mod tests {
804913
}
805914

806915
#[test]
807-
fn test_sum_multi_axis_2d_2ax_missing() {
916+
fn sum_multi_axis_2d_2ax_missing() {
808917
let array = ndarray::Array::from_shape_vec((2, 2), (0..4).collect())
809918
.unwrap()
810919
.into_dyn();
@@ -817,7 +926,7 @@ mod tests {
817926
}
818927

819928
#[test]
820-
fn test_sum_multi_axis_4d_1ax() {
929+
fn sum_multi_axis_4d_1ax() {
821930
let array = ndarray::Array::from_shape_vec((2, 3, 2, 1), (0..12).collect())
822931
.unwrap()
823932
.into_dyn();
@@ -829,7 +938,7 @@ mod tests {
829938
}
830939

831940
#[test]
832-
fn test_sum_multi_axis_4d_3ax() {
941+
fn sum_multi_axis_4d_3ax() {
833942
let array = ndarray::Array::from_shape_vec((2, 3, 2, 1), (0..12).collect())
834943
.unwrap()
835944
.into_dyn();
@@ -839,4 +948,60 @@ mod tests {
839948
assert_eq!(count, vec![6, 6]);
840949
assert_eq!(shape, vec![2]);
841950
}
951+
952+
#[test]
953+
#[should_panic(expected = "assertion failed: axis.index() < self.ndim()")]
954+
fn min_multi_axis_2d_wrong_axis() {
955+
let array = ndarray::Array::from_shape_vec((2, 2), (0..4).collect())
956+
.unwrap()
957+
.into_dyn();
958+
let axes = vec![2];
959+
let _ = min_array_multi_axis(array.view(), &axes, None);
960+
}
961+
962+
#[test]
963+
fn min_multi_axis_2d_2ax() {
964+
// Arrrange
965+
let axes = vec![0, 1];
966+
let missing = None;
967+
let arr = ndarray::Array::from_shape_vec((2, 3), (0..6).collect())
968+
.unwrap()
969+
.into_dyn();
970+
// Act
971+
let (result, counts, shape) = min_array_multi_axis(arr.view(), &axes, missing);
972+
// Assert
973+
assert_eq!(result, vec![0]);
974+
assert_eq!(counts, vec![6]);
975+
assert_eq!(shape, Vec::<usize>::new());
976+
}
977+
978+
#[test]
979+
fn min_multi_axis_2d_1ax_missing() {
980+
// Arrange
981+
let axes = vec![1];
982+
let missing = Missing::MissingValue(0);
983+
let arr = ndarray::Array::from_shape_vec((2, 3), (0..6).collect())
984+
.unwrap()
985+
.into_dyn();
986+
// Act
987+
let (result, counts, shape) = min_array_multi_axis(arr.view(), &axes, Some(missing));
988+
// Assert
989+
assert_eq!(result, vec![1, 3]);
990+
assert_eq!(counts, vec![2, 3]);
991+
assert_eq!(shape, vec![2]);
992+
}
993+
994+
#[test]
995+
fn min_multi_axis_4d_3ax_missing() {
996+
let arr = ndarray::Array::from_shape_vec((2, 3, 2, 1), (0..12).collect())
997+
.unwrap()
998+
.into_dyn();
999+
let axes = vec![0, 1, 3];
1000+
let missing = Missing::MissingValue(1);
1001+
let (result, counts, shape) = min_array_multi_axis(arr.view(), &axes, Some(missing));
1002+
1003+
assert_eq!(result, vec![0, 3]);
1004+
assert_eq!(counts, vec![6, 5]);
1005+
assert_eq!(shape, vec![2]);
1006+
}
8421007
}

0 commit comments

Comments
 (0)