Skip to content

Commit 2492c9b

Browse files
fix[array]: handle empty/all_invalid sum correctly (#5180)
empty/all_invalid arrays have a sum of 0. Signed-off-by: Joe Isaacs <[email protected]>
1 parent 4d2d47d commit 2492c9b

File tree

4 files changed

+60
-79
lines changed

4 files changed

+60
-79
lines changed

vortex-array/src/arrays/chunked/compute/sum.rs

Lines changed: 28 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -42,37 +42,31 @@ register_kernel!(SumKernelAdapter(ChunkedVTable).lift());
4242
fn sum_int<T: NativePType + PrimInt + FromPrimitiveOrF16>(
4343
chunks: &[ArrayRef],
4444
) -> VortexResult<Option<T>> {
45-
let mut result: Option<T> = None;
45+
let mut result: T = T::zero();
4646
for chunk in chunks {
4747
let chunk_sum = sum(chunk)?;
48-
49-
let Some(chunk_sum) = chunk_sum.as_primitive().as_::<T>() else {
50-
// Skip missing null chunk
51-
continue;
48+
let Some(chunk_sum) = chunk_sum
49+
.as_primitive()
50+
.as_::<T>()
51+
.and_then(|chunk_sum| result.checked_add(&chunk_sum))
52+
else {
53+
// Bail out on null or overflow
54+
return Ok(None);
5255
};
53-
54-
result = Some(match result {
55-
None => chunk_sum,
56-
Some(result) => {
57-
let Some(chunk_result) = result.checked_add(&chunk_sum) else {
58-
// Bail out on overflow
59-
return Ok(None);
60-
};
61-
chunk_result
62-
}
63-
});
56+
result = chunk_sum;
6457
}
65-
Ok(result)
58+
Ok(Some(result))
6659
}
6760

68-
fn sum_float(chunks: &[ArrayRef]) -> VortexResult<f64> {
61+
fn sum_float(chunks: &[ArrayRef]) -> VortexResult<Option<f64>> {
6962
let mut result = 0f64;
7063
for chunk in chunks {
71-
if let Some(chunk_sum) = sum(chunk)?.as_primitive().as_::<f64>() {
72-
result += chunk_sum;
64+
let Some(chunk_sum) = sum(chunk)?.as_primitive().as_::<f64>() else {
65+
return Ok(None);
7366
};
67+
result += chunk_sum;
7468
}
75-
Ok(result)
69+
Ok(Some(result))
7670
}
7771

7872
fn sum_decimal(chunks: &[ArrayRef], result_decimal_type: DecimalDType) -> VortexResult<Scalar> {
@@ -84,21 +78,19 @@ fn sum_decimal(chunks: &[ArrayRef], result_decimal_type: DecimalDType) -> Vortex
8478
let chunk_sum = sum(chunk)?;
8579

8680
let chunk_decimal = DecimalScalar::try_from(&chunk_sum)?;
87-
let Some(chunk_value) = chunk_decimal.decimal_value() else {
88-
// skips all null chunks
89-
continue;
90-
};
91-
92-
// Perform checked addition with current result
93-
let Some(r) = result.checked_add(&chunk_value).filter(|sum_value| {
94-
sum_value
95-
.fits_in_precision(result_decimal_type)
96-
.unwrap_or(false)
97-
}) else {
98-
// Overflow
81+
let Some(r) = chunk_decimal
82+
.decimal_value()
83+
// TODO(joe): added a precision capped checked_add.
84+
.and_then(|c_sum| result.checked_add(&c_sum))
85+
.filter(|sum_value| {
86+
sum_value
87+
.fits_in_precision(result_decimal_type)
88+
.unwrap_or(false)
89+
})
90+
else {
91+
// null if any chunk is null or the sum overflows
9992
return Ok(null());
10093
};
101-
10294
result = r;
10395
}
10496

@@ -146,18 +138,17 @@ mod tests {
146138
}
147139

148140
#[test]
149-
fn test_sum_chunked_floats_all_nulls() {
141+
fn test_sum_chunked_floats_all_nulls_is_zero() {
150142
// Create chunks with all nulls
151143
let chunk1 = PrimitiveArray::from_option_iter::<f32, _>(vec![None, None, None]);
152144
let chunk2 = PrimitiveArray::from_option_iter::<f32, _>(vec![None, None]);
153145

154146
let dtype = chunk1.dtype().clone();
155147
let chunked =
156148
ChunkedArray::try_new(vec![chunk1.into_array(), chunk2.into_array()], dtype).unwrap();
157-
158149
// Compute sum - should return null for all nulls
159150
let result = sum(chunked.as_ref()).unwrap();
160-
assert!(result.as_primitive().as_::<f64>().is_none());
151+
assert_eq!(result, Scalar::primitive(0f64, Nullability::Nullable));
161152
}
162153

163154
#[test]

vortex-array/src/arrays/constant/compute/sum.rs

Lines changed: 19 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -108,7 +108,8 @@ register_kernel!(SumKernelAdapter(ConstantVTable).lift());
108108

109109
#[cfg(test)]
110110
mod tests {
111-
use vortex_dtype::{DType, DecimalDType, Nullability, PType};
111+
use vortex_dtype::Nullability::Nullable;
112+
use vortex_dtype::{DType, DecimalDType, Nullability, PType, i256};
112113
use vortex_scalar::{DecimalValue, Scalar};
113114

114115
use crate::arrays::ConstantArray;
@@ -132,13 +133,10 @@ mod tests {
132133

133134
#[test]
134135
fn test_sum_nullable_value() {
135-
let array = ConstantArray::new(
136-
Scalar::null(DType::Primitive(PType::U32, Nullability::Nullable)),
137-
10,
138-
)
139-
.into_array();
136+
let array = ConstantArray::new(Scalar::null(DType::Primitive(PType::U32, Nullable)), 10)
137+
.into_array();
140138
let result = sum(&array).unwrap();
141-
assert!(result.is_null());
139+
assert_eq!(result, Scalar::primitive(0u64, Nullable));
142140
}
143141

144142
#[test]
@@ -157,10 +155,9 @@ mod tests {
157155

158156
#[test]
159157
fn test_sum_bool_null() {
160-
let array =
161-
ConstantArray::new(Scalar::null(DType::Bool(Nullability::Nullable)), 10).into_array();
158+
let array = ConstantArray::new(Scalar::null(DType::Bool(Nullable)), 10).into_array();
162159
let result = sum(&array).unwrap();
163-
assert!(result.is_null());
160+
assert_eq!(result, Scalar::primitive(0u64, Nullable));
164161
}
165162

166163
#[test]
@@ -180,22 +177,26 @@ mod tests {
180177

181178
assert_eq!(
182179
result.as_decimal().decimal_value(),
183-
Some(DecimalValue::I256(vortex_scalar::i256::from_i128(500)))
180+
Some(DecimalValue::I256(i256::from_i128(500)))
184181
);
185182
assert_eq!(result.dtype(), &Stat::Sum.dtype(array.dtype()).unwrap());
186183
}
187184

188185
#[test]
189186
fn test_sum_decimal_null() {
190187
let decimal_dtype = DecimalDType::new(10, 2);
191-
let array = ConstantArray::new(
192-
Scalar::null(DType::Decimal(decimal_dtype, Nullability::Nullable)),
193-
10,
194-
)
195-
.into_array();
188+
let array = ConstantArray::new(Scalar::null(DType::Decimal(decimal_dtype, Nullable)), 10)
189+
.into_array();
196190

197191
let result = sum(&array).unwrap();
198-
assert!(result.is_null());
192+
assert_eq!(
193+
result,
194+
Scalar::decimal(
195+
DecimalValue::I256(i256::ZERO),
196+
DecimalDType::new(20, 2),
197+
Nullable
198+
)
199+
);
199200
}
200201

201202
#[test]
@@ -214,9 +215,7 @@ mod tests {
214215
let result = sum(&array).unwrap();
215216
assert_eq!(
216217
result.as_decimal().decimal_value(),
217-
Some(DecimalValue::I256(vortex_scalar::i256::from_i128(
218-
99_999_999_900
219-
)))
218+
Some(DecimalValue::I256(i256::from_i128(99_999_999_900)))
220219
);
221220
}
222221
}

vortex-array/src/arrays/primitive/compute/sum.rs

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
// SPDX-FileCopyrightText: Copyright the Vortex contributors
33

44
use itertools::Itertools;
5-
use num_traits::{CheckedAdd, Float, ToPrimitive};
5+
use num_traits::{CheckedAdd, Float, ToPrimitive, Zero};
66
use vortex_buffer::BitBuffer;
77
use vortex_dtype::{NativePType, match_each_native_ptype};
88
use vortex_error::{VortexExpect, VortexResult};
@@ -28,11 +28,12 @@ impl SumKernel for PrimitiveVTable {
2828
}
2929
AllOr::None => {
3030
// All-invalid
31-
return Ok(Scalar::null(
32-
Stat::Sum
33-
.dtype(array.dtype())
34-
.vortex_expect("Sum dtype must be defined for primitive type"),
35-
));
31+
let sum_dtype = Stat::Sum
32+
.dtype(array.dtype())
33+
.vortex_expect("Sum dtype must be defined for primitive type");
34+
return Ok(match_each_native_ptype!(sum_dtype.as_ptype(), |P| {
35+
Scalar::primitive(P::zero(), sum_dtype.nullability())
36+
}));
3637
}
3738
AllOr::Some(validity_mask) => {
3839
// Some-valid

vortex-array/src/compute/sum.rs

Lines changed: 6 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@ use std::sync::LazyLock;
55

66
use arcref::ArcRef;
77
use vortex_dtype::DType;
8+
use vortex_dtype::Nullability::NonNullable;
89
use vortex_error::{VortexResult, vortex_err, vortex_panic};
910
use vortex_scalar::Scalar;
1011

@@ -124,13 +125,8 @@ pub fn sum_impl(
124125
sum_dtype: DType,
125126
kernels: &[ArcRef<dyn Kernel>],
126127
) -> VortexResult<Scalar> {
127-
if array.is_empty() {
128-
return Ok(Scalar::default_value(sum_dtype));
129-
}
130-
131-
// Sum of all null is null.
132-
if array.all_invalid() {
133-
return Ok(Scalar::null(sum_dtype));
128+
if array.is_empty() || array.all_invalid() {
129+
return Scalar::default_value(sum_dtype.with_nullability(NonNullable)).cast(&sum_dtype);
134130
}
135131

136132
// Try to find a sum kernel
@@ -162,7 +158,7 @@ pub fn sum_impl(
162158
#[cfg(test)]
163159
mod test {
164160
use vortex_buffer::buffer;
165-
use vortex_dtype::{DType, Nullability, PType};
161+
use vortex_dtype::Nullability;
166162
use vortex_scalar::Scalar;
167163

168164
use crate::IntoArray as _;
@@ -173,20 +169,14 @@ mod test {
173169
fn sum_all_invalid() {
174170
let array = PrimitiveArray::from_option_iter::<i32, _>([None, None, None]);
175171
let result = sum(array.as_ref()).unwrap();
176-
assert_eq!(
177-
result,
178-
Scalar::null(DType::Primitive(PType::I64, Nullability::Nullable))
179-
);
172+
assert_eq!(result, Scalar::primitive(0i64, Nullability::Nullable));
180173
}
181174

182175
#[test]
183176
fn sum_all_invalid_float() {
184177
let array = PrimitiveArray::from_option_iter::<f32, _>([None, None, None]);
185178
let result = sum(array.as_ref()).unwrap();
186-
assert_eq!(
187-
result,
188-
Scalar::null(DType::Primitive(PType::F64, Nullability::Nullable))
189-
);
179+
assert_eq!(result, Scalar::primitive(0f64, Nullability::Nullable));
190180
}
191181

192182
#[test]

0 commit comments

Comments
 (0)