Skip to content

Commit 2e619d6

Browse files
committed
Handle empty axis list in count operation
1 parent f211d47 commit 2e619d6

File tree

1 file changed

+56
-23
lines changed

1 file changed

+56
-23
lines changed

src/operations.rs

Lines changed: 56 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -41,9 +41,9 @@ fn missing_filter<'a, T: Element>(missing: &'a Missing<T>) -> Box<dyn Fn(&T) ->
4141
fn count_non_missing<T: Element>(
4242
array: &ArrayView<T, ndarray::Dim<ndarray::IxDynImpl>>,
4343
missing: &Missing<T>,
44-
) -> Result<usize, ActiveStorageError> {
44+
) -> usize {
4545
let filter = missing_filter(missing);
46-
Ok(array.iter().copied().filter(filter).count())
46+
array.iter().copied().filter(filter).count()
4747
}
4848

4949
/// Counts the number of non-missing elements along
@@ -53,33 +53,52 @@ fn count_array_multi_axis<T: Element>(
5353
axes: &[usize],
5454
missing: Option<Missing<T>>,
5555
) -> (Vec<i64>, Vec<usize>) {
56-
// Count non-missing over first axis
57-
let mut result = array
58-
.fold_axis(Axis(axes[0]), 0, |running_count, val| {
56+
let result = if axes.is_empty() {
57+
// Emulate numpy semantics of axis = () being
58+
// equivalent to a 'reduction over no axes'
59+
array.map(|val| {
5960
if let Some(missing) = &missing {
6061
if !missing.is_missing(val) {
61-
running_count + 1
62+
1
6263
} else {
63-
*running_count
64+
0
6465
}
6566
} else {
66-
running_count + 1
67+
1
6768
}
6869
})
69-
.into_dyn();
70-
// Sum counts over remaining axes
71-
if let Some(remaining_axes) = axes.get(1..) {
72-
for (n, axis) in remaining_axes.iter().enumerate() {
73-
result = result
74-
.fold_axis(Axis(axis - n - 1), 0, |total_count, count| {
75-
total_count + count
76-
})
77-
.into_dyn();
70+
} else {
71+
// Should never panic here due to axis.is_empty() branch above
72+
let first_axis = axes.first().expect("axes list to be non-empty");
73+
// Count non-missing over first axis
74+
let mut result = array
75+
.fold_axis(Axis(*first_axis), 0, |running_count, val| {
76+
if let Some(missing) = &missing {
77+
if !missing.is_missing(val) {
78+
running_count + 1
79+
} else {
80+
*running_count
81+
}
82+
} else {
83+
running_count + 1
84+
}
85+
})
86+
.into_dyn();
87+
// Sum counts over remaining axes
88+
if let Some(remaining_axes) = axes.get(1..) {
89+
for (n, axis) in remaining_axes.iter().enumerate() {
90+
result = result
91+
.fold_axis(Axis(axis - n - 1), 0, |total_count, count| {
92+
total_count + count
93+
})
94+
.into_dyn();
95+
}
7896
}
79-
}
97+
result
98+
};
8099

81100
// Convert result to owned vec
82-
let counts = result.iter().copied().collect::<Vec<i64>>();
101+
let counts = result.iter().copied().collect();
83102
(counts, result.shape().into())
84103
}
85104

@@ -104,9 +123,8 @@ impl NumOperation for Count {
104123

105124
match &request_data.axis {
106125
ReductionAxes::All => {
107-
let count = if let Some(missing) = &request_data.missing {
108-
let missing = Missing::<T>::try_from(missing)?;
109-
count_non_missing(&sliced, &missing)?
126+
let count = if let Some(missing) = typed_missing {
127+
count_non_missing(&sliced, &missing)
110128
} else {
111129
sliced.len()
112130
};
@@ -476,7 +494,7 @@ impl NumOperation for Select {
476494
let sliced = array.slice(slice_info);
477495
let count = if let Some(missing) = &request_data.missing {
478496
let missing = Missing::<T>::try_from(missing)?;
479-
count_non_missing(&sliced, &missing)?
497+
count_non_missing(&sliced, &missing)
480498
} else {
481499
sliced.len()
482500
};
@@ -1288,6 +1306,21 @@ mod tests {
12881306
assert_eq!(shape, Vec::<usize>::new());
12891307
}
12901308

1309+
#[test]
1310+
fn count_multi_axis_2d_no_ax() {
1311+
// Arrange
1312+
let axes = vec![];
1313+
let missing = None;
1314+
let arr = ndarray::Array::from_shape_vec((2, 3), (0..6).collect())
1315+
.unwrap()
1316+
.into_dyn();
1317+
// Act
1318+
let (counts, shape) = count_array_multi_axis(arr.view(), &axes, missing);
1319+
// Assert
1320+
assert_eq!(counts, vec![1, 1, 1, 1, 1, 1]);
1321+
assert_eq!(shape, arr.shape().to_vec());
1322+
}
1323+
12911324
#[test]
12921325
fn count_multi_axis_2d_1ax_missing() {
12931326
// Arrange

0 commit comments

Comments
 (0)