Skip to content

Commit b8ca242

Browse files
committed
Implement count over multiple axes
1 parent 24c5e49 commit b8ca242

File tree

2 files changed

+173
-15
lines changed

2 files changed

+173
-15
lines changed

scripts/client.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -117,7 +117,9 @@ def display(response, verbose=False):
117117
dtype = response.headers['x-activestorage-dtype']
118118
shape = json.loads(response.headers['x-activestorage-shape'])
119119
counts = json.loads(response.headers['x-activestorage-count'])
120-
counts = np.array(counts).reshape(shape)
120+
counts = np.array(counts)
121+
if len(counts) > 1:
122+
counts = counts.reshape(shape)
121123
result = np.frombuffer(response.content, dtype=dtype).reshape(shape)
122124
if verbose:
123125
sep = "\n" if len(counts.shape) > 1 else " "

src/operations.rs

Lines changed: 170 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,43 @@ fn count_non_missing<T: Element>(
4646
Ok(array.iter().copied().filter(filter).count())
4747
}
4848

49+
/// Counts the number of non-missing elements along
50+
/// one or more axes of the provided array
51+
fn count_array_multi_axis<T: Element>(
52+
array: ndarray::ArrayView<T, ndarray::IxDyn>,
53+
axes: &[usize],
54+
missing: Option<Missing<T>>,
55+
) -> (Vec<i64>, Vec<usize>) {
56+
// Count non-missing over first axis
57+
let mut result = array
58+
.fold_axis(Axis(axes[0]), 0, |running_count, val| {
59+
if let Some(missing) = &missing {
60+
if !missing.is_missing(val) {
61+
running_count + 1
62+
} else {
63+
*running_count
64+
}
65+
} else {
66+
running_count + 1
67+
}
68+
})
69+
.into_dyn();
70+
// Sum counts over remaining axes
71+
if let Some(remaining_axes) = axes.get(1..) {
72+
for (n, axis) in remaining_axes.iter().enumerate() {
73+
result = result
74+
.fold_axis(Axis(axis - n - 1), 0, |total_count, count| {
75+
total_count + count
76+
})
77+
.into_dyn();
78+
}
79+
}
80+
81+
// Convert result to owned vec
82+
let counts = result.iter().copied().collect::<Vec<i64>>();
83+
(counts, result.shape().into())
84+
}
85+
4986
/// Return the number of selected elements in the array.
5087
pub struct Count {}
5188

@@ -57,22 +94,69 @@ impl NumOperation for Count {
5794
let array = array::build_array::<T>(request_data, &mut data)?;
5895
let slice_info = array::build_slice_info::<T>(&request_data.selection, array.shape());
5996
let sliced = array.slice(slice_info);
60-
let count = if let Some(missing) = &request_data.missing {
61-
let missing = Missing::<T>::try_from(missing)?;
62-
count_non_missing(&sliced, &missing)?
97+
98+
let typed_missing: Option<Missing<T>> = if let Some(missing) = &request_data.missing {
99+
let m = Missing::try_from(missing)?;
100+
Some(m)
63101
} else {
64-
sliced.len()
102+
None
65103
};
66-
let count = i64::try_from(count)?;
67-
let body = count.to_ne_bytes();
68-
// Need to copy to provide ownership to caller.
69-
let body = Bytes::copy_from_slice(&body);
70-
Ok(models::Response::new(
71-
body,
72-
models::DType::Int64,
73-
vec![],
74-
vec![count],
75-
))
104+
105+
match &request_data.axis {
106+
ReductionAxes::All => {
107+
let count = if let Some(missing) = &request_data.missing {
108+
let missing = Missing::<T>::try_from(missing)?;
109+
count_non_missing(&sliced, &missing)?
110+
} else {
111+
sliced.len()
112+
};
113+
let count = i64::try_from(count)?;
114+
let body = count.to_ne_bytes();
115+
// Need to copy to provide ownership to caller.
116+
let body = Bytes::copy_from_slice(&body);
117+
Ok(models::Response::new(
118+
body,
119+
models::DType::Int64,
120+
vec![],
121+
vec![count],
122+
))
123+
}
124+
ReductionAxes::One(axis) => {
125+
let result = sliced.fold_axis(Axis(*axis), 0, |count, val| {
126+
if let Some(missing) = &typed_missing {
127+
if !missing.is_missing(val) {
128+
count + 1
129+
} else {
130+
*count
131+
}
132+
} else {
133+
count + 1
134+
}
135+
});
136+
let counts = result.iter().copied().collect::<Vec<i64>>();
137+
let body = counts.as_bytes();
138+
// Need to copy to provide ownership to caller.
139+
let body = Bytes::copy_from_slice(body);
140+
Ok(models::Response::new(
141+
body,
142+
models::DType::Int64,
143+
result.shape().into(),
144+
counts,
145+
))
146+
}
147+
ReductionAxes::Multi(axes) => {
148+
let (counts, shape) = count_array_multi_axis(sliced.view(), axes, typed_missing);
149+
let body = counts.as_bytes();
150+
// Need to copy to provide ownership to caller.
151+
let body = Bytes::copy_from_slice(body);
152+
Ok(models::Response::new(
153+
body,
154+
models::DType::Int64,
155+
shape,
156+
counts,
157+
))
158+
}
159+
}
76160
}
77161
}
78162

@@ -1176,4 +1260,76 @@ mod tests {
11761260
assert_eq!(counts, vec![5, 6]);
11771261
assert_eq!(shape, vec![2]);
11781262
}
1263+
1264+
#[test]
1265+
#[should_panic(expected = "assertion failed: axis.index() < self.ndim()")]
1266+
fn count_multi_axis_2d_wrong_axis() {
1267+
// Arrange
1268+
let array = ndarray::Array::from_shape_vec((2, 2), (0..4).collect())
1269+
.unwrap()
1270+
.into_dyn();
1271+
let axes = vec![2];
1272+
// Act
1273+
let _ = count_array_multi_axis(array.view(), &axes, None);
1274+
}
1275+
1276+
#[test]
1277+
fn count_multi_axis_2d_2ax() {
1278+
// Arrange
1279+
let axes = vec![0, 1];
1280+
let missing = None;
1281+
let arr = ndarray::Array::from_shape_vec((2, 3), (0..6).collect())
1282+
.unwrap()
1283+
.into_dyn();
1284+
// Act
1285+
let (counts, shape) = count_array_multi_axis(arr.view(), &axes, missing);
1286+
// Assert
1287+
assert_eq!(counts, vec![6]);
1288+
assert_eq!(shape, Vec::<usize>::new());
1289+
}
1290+
1291+
#[test]
1292+
fn count_multi_axis_2d_1ax_missing() {
1293+
// Arrange
1294+
let axes = vec![1];
1295+
let missing = Missing::MissingValue(0);
1296+
let arr = ndarray::Array::from_shape_vec((2, 3), (0..6).collect())
1297+
.unwrap()
1298+
.into_dyn();
1299+
// Act
1300+
let (counts, shape) = count_array_multi_axis(arr.view(), &axes, Some(missing));
1301+
// Assert
1302+
assert_eq!(counts, vec![2, 3]);
1303+
assert_eq!(shape, vec![2]);
1304+
}
1305+
1306+
#[test]
1307+
fn count_multi_axis_4d_3ax_multi_missing() {
1308+
// Arrange
1309+
let arr = ndarray::Array::from_shape_vec((2, 3, 2, 1), (0..12).collect())
1310+
.unwrap()
1311+
.into_dyn();
1312+
let axes = vec![0, 1, 3];
1313+
let missing = Missing::MissingValues(vec![9, 10, 11]);
1314+
// Act
1315+
let (counts, shape) = count_array_multi_axis(arr.view(), &axes, Some(missing));
1316+
// Assert
1317+
assert_eq!(counts, vec![5, 4]);
1318+
assert_eq!(shape, vec![2]);
1319+
}
1320+
1321+
#[test]
1322+
fn count_multi_axis_4d_3ax_missing() {
1323+
// Arrange
1324+
let arr = ndarray::Array::from_shape_vec((2, 3, 2, 1), (0..12).collect())
1325+
.unwrap()
1326+
.into_dyn();
1327+
let axes = vec![0, 1, 3];
1328+
let missing = Missing::MissingValue(10);
1329+
// Act
1330+
let (counts, shape) = count_array_multi_axis(arr.view(), &axes, Some(missing));
1331+
// Assert
1332+
assert_eq!(counts, vec![5, 6]);
1333+
assert_eq!(shape, vec![2]);
1334+
}
11791335
}

0 commit comments

Comments
 (0)