Skip to content

Commit cb4553c

Browse files
committed
fix[array]: handle empty/all_invalid sum correctly
Signed-off-by: Joe Isaacs <[email protected]>
1 parent 8ca7f69 commit cb4553c

File tree

3 files changed

+38
-50
lines changed

3 files changed

+38
-50
lines changed

vortex-array/src/arrays/chunked/compute/sum.rs

Lines changed: 28 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -42,37 +42,31 @@ register_kernel!(SumKernelAdapter(ChunkedVTable).lift());
4242
fn sum_int<T: NativePType + PrimInt + FromPrimitiveOrF16>(
4343
chunks: &[ArrayRef],
4444
) -> VortexResult<Option<T>> {
45-
let mut result: Option<T> = None;
45+
let mut result: T = T::zero();
4646
for chunk in chunks {
4747
let chunk_sum = sum(chunk)?;
48-
49-
let Some(chunk_sum) = chunk_sum.as_primitive().as_::<T>() else {
50-
// Skip missing null chunk
51-
continue;
48+
let Some(chunk_sum) = chunk_sum
49+
.as_primitive()
50+
.as_::<T>()
51+
.and_then(|chunk_sum| result.checked_add(&chunk_sum))
52+
else {
53+
// Bail out on null or overflow
54+
return Ok(None);
5255
};
53-
54-
result = Some(match result {
55-
None => chunk_sum,
56-
Some(result) => {
57-
let Some(chunk_result) = result.checked_add(&chunk_sum) else {
58-
// Bail out on overflow
59-
return Ok(None);
60-
};
61-
chunk_result
62-
}
63-
});
56+
result = chunk_sum;
6457
}
65-
Ok(result)
58+
Ok(Some(result))
6659
}
6760

68-
fn sum_float(chunks: &[ArrayRef]) -> VortexResult<f64> {
61+
fn sum_float(chunks: &[ArrayRef]) -> VortexResult<Option<f64>> {
6962
let mut result = 0f64;
7063
for chunk in chunks {
71-
if let Some(chunk_sum) = sum(chunk)?.as_primitive().as_::<f64>() {
72-
result += chunk_sum;
64+
let Some(chunk_sum) = sum(chunk)?.as_primitive().as_::<f64>() else {
65+
return Ok(None);
7366
};
67+
result += chunk_sum;
7468
}
75-
Ok(result)
69+
Ok(Some(result))
7670
}
7771

7872
fn sum_decimal(chunks: &[ArrayRef], result_decimal_type: DecimalDType) -> VortexResult<Scalar> {
@@ -84,21 +78,19 @@ fn sum_decimal(chunks: &[ArrayRef], result_decimal_type: DecimalDType) -> Vortex
8478
let chunk_sum = sum(chunk)?;
8579

8680
let chunk_decimal = DecimalScalar::try_from(&chunk_sum)?;
87-
let Some(chunk_value) = chunk_decimal.decimal_value() else {
88-
// skips all null chunks
89-
continue;
90-
};
91-
92-
// Perform checked addition with current result
93-
let Some(r) = result.checked_add(&chunk_value).filter(|sum_value| {
94-
sum_value
95-
.fits_in_precision(result_decimal_type)
96-
.unwrap_or(false)
97-
}) else {
98-
// Overflow
81+
let Some(r) = chunk_decimal
82+
.decimal_value()
83+
// TODO(joe): added a precision capped checked_add.
84+
.and_then(|c_sum| result.checked_add(&c_sum))
85+
.filter(|sum_value| {
86+
sum_value
87+
.fits_in_precision(result_decimal_type)
88+
.unwrap_or(false)
89+
})
90+
else {
91+
// null if any chunk is null or the sum overflows
9992
return Ok(null());
10093
};
101-
10294
result = r;
10395
}
10496

@@ -146,18 +138,17 @@ mod tests {
146138
}
147139

148140
#[test]
149-
fn test_sum_chunked_floats_all_nulls() {
141+
fn test_sum_chunked_floats_all_nulls_is_zero() {
150142
// Create chunks with all nulls
151143
let chunk1 = PrimitiveArray::from_option_iter::<f32, _>(vec![None, None, None]);
152144
let chunk2 = PrimitiveArray::from_option_iter::<f32, _>(vec![None, None]);
153145

154146
let dtype = chunk1.dtype().clone();
155147
let chunked =
156148
ChunkedArray::try_new(vec![chunk1.into_array(), chunk2.into_array()], dtype).unwrap();
157-
158149
// Compute sum - should return null for all nulls
159150
let result = sum(chunked.as_ref()).unwrap();
160-
assert!(result.as_primitive().as_::<f64>().is_none());
151+
assert_eq!(result, Scalar::primitive(0f64, Nullability::Nullable));
161152
}
162153

163154
#[test]

vortex-array/src/arrays/primitive/compute/sum.rs

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
// SPDX-FileCopyrightText: Copyright the Vortex contributors
33

44
use itertools::Itertools;
5-
use num_traits::{CheckedAdd, Float, ToPrimitive};
5+
use num_traits::{CheckedAdd, Float, ToPrimitive, Zero};
66
use vortex_buffer::BitBuffer;
77
use vortex_dtype::{NativePType, match_each_native_ptype};
88
use vortex_error::{VortexExpect, VortexResult};
@@ -28,11 +28,12 @@ impl SumKernel for PrimitiveVTable {
2828
}
2929
AllOr::None => {
3030
// All-invalid
31-
return Ok(Scalar::null(
32-
Stat::Sum
33-
.dtype(array.dtype())
34-
.vortex_expect("Sum dtype must be defined for primitive type"),
35-
));
31+
let sum_dtype = Stat::Sum
32+
.dtype(array.dtype())
33+
.vortex_expect("Sum dtype must be defined for primitive type");
34+
return Ok(match_each_native_ptype!(sum_dtype.as_ptype(), |P| {
35+
Scalar::primitive(P::zero(), sum_dtype.nullability())
36+
}));
3637
}
3738
AllOr::Some(validity_mask) => {
3839
// Some-valid

vortex-array/src/compute/sum.rs

Lines changed: 3 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@ use std::sync::LazyLock;
55

66
use arcref::ArcRef;
77
use vortex_dtype::DType;
8+
use vortex_dtype::Nullability::NonNullable;
89
use vortex_error::{VortexResult, vortex_err, vortex_panic};
910
use vortex_scalar::Scalar;
1011

@@ -124,13 +125,8 @@ pub fn sum_impl(
124125
sum_dtype: DType,
125126
kernels: &[ArcRef<dyn Kernel>],
126127
) -> VortexResult<Scalar> {
127-
if array.is_empty() {
128-
return Ok(Scalar::default_value(sum_dtype));
129-
}
130-
131-
// Sum of all null is null.
132-
if array.all_invalid() {
133-
return Ok(Scalar::null(sum_dtype));
128+
if array.is_empty() || array.all_invalid() {
129+
return Scalar::default_value(sum_dtype.with_nullability(NonNullable)).cast(&sum_dtype);
134130
}
135131

136132
// Try to find a sum kernel

0 commit comments

Comments
 (0)