Skip to content

Commit 9059d9d

Browse files
committed
fix[array]: handle empty/all_invalid sum correctly
Signed-off-by: Joe Isaacs <[email protected]>
1 parent cb4553c commit 9059d9d

File tree

2 files changed

+21
-28
lines changed

2 files changed

+21
-28
lines changed

vortex-array/src/arrays/constant/compute/sum.rs

Lines changed: 19 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -108,7 +108,8 @@ register_kernel!(SumKernelAdapter(ConstantVTable).lift());
108108

109109
#[cfg(test)]
110110
mod tests {
111-
use vortex_dtype::{DType, DecimalDType, Nullability, PType};
111+
use vortex_dtype::Nullability::Nullable;
112+
use vortex_dtype::{DType, DecimalDType, Nullability, PType, i256};
112113
use vortex_scalar::{DecimalValue, Scalar};
113114

114115
use crate::arrays::ConstantArray;
@@ -132,13 +133,10 @@ mod tests {
132133

133134
#[test]
134135
fn test_sum_nullable_value() {
135-
let array = ConstantArray::new(
136-
Scalar::null(DType::Primitive(PType::U32, Nullability::Nullable)),
137-
10,
138-
)
139-
.into_array();
136+
let array = ConstantArray::new(Scalar::null(DType::Primitive(PType::U32, Nullable)), 10)
137+
.into_array();
140138
let result = sum(&array).unwrap();
141-
assert!(result.is_null());
139+
assert_eq!(result, Scalar::primitive(0u64, Nullable));
142140
}
143141

144142
#[test]
@@ -157,10 +155,9 @@ mod tests {
157155

158156
#[test]
159157
fn test_sum_bool_null() {
160-
let array =
161-
ConstantArray::new(Scalar::null(DType::Bool(Nullability::Nullable)), 10).into_array();
158+
let array = ConstantArray::new(Scalar::null(DType::Bool(Nullable)), 10).into_array();
162159
let result = sum(&array).unwrap();
163-
assert!(result.is_null());
160+
assert_eq!(result, Scalar::primitive(0u64, Nullable));
164161
}
165162

166163
#[test]
@@ -180,22 +177,26 @@ mod tests {
180177

181178
assert_eq!(
182179
result.as_decimal().decimal_value(),
183-
Some(DecimalValue::I256(vortex_scalar::i256::from_i128(500)))
180+
Some(DecimalValue::I256(i256::from_i128(500)))
184181
);
185182
assert_eq!(result.dtype(), &Stat::Sum.dtype(array.dtype()).unwrap());
186183
}
187184

188185
#[test]
189186
fn test_sum_decimal_null() {
190187
let decimal_dtype = DecimalDType::new(10, 2);
191-
let array = ConstantArray::new(
192-
Scalar::null(DType::Decimal(decimal_dtype, Nullability::Nullable)),
193-
10,
194-
)
195-
.into_array();
188+
let array = ConstantArray::new(Scalar::null(DType::Decimal(decimal_dtype, Nullable)), 10)
189+
.into_array();
196190

197191
let result = sum(&array).unwrap();
198-
assert!(result.is_null());
192+
assert_eq!(
193+
result,
194+
Scalar::decimal(
195+
DecimalValue::I256(i256::ZERO),
196+
DecimalDType::new(20, 2),
197+
Nullable
198+
)
199+
);
199200
}
200201

201202
#[test]
@@ -214,9 +215,7 @@ mod tests {
214215
let result = sum(&array).unwrap();
215216
assert_eq!(
216217
result.as_decimal().decimal_value(),
217-
Some(DecimalValue::I256(vortex_scalar::i256::from_i128(
218-
99_999_999_900
219-
)))
218+
Some(DecimalValue::I256(i256::from_i128(99_999_999_900)))
220219
);
221220
}
222221
}

vortex-array/src/compute/sum.rs

Lines changed: 2 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -169,20 +169,14 @@ mod test {
169169
fn sum_all_invalid() {
170170
let array = PrimitiveArray::from_option_iter::<i32, _>([None, None, None]);
171171
let result = sum(array.as_ref()).unwrap();
172-
assert_eq!(
173-
result,
174-
Scalar::null(DType::Primitive(PType::I64, Nullability::Nullable))
175-
);
172+
assert_eq!(result, Scalar::primitive(0i64, Nullability::Nullable));
176173
}
177174

178175
#[test]
179176
fn sum_all_invalid_float() {
180177
let array = PrimitiveArray::from_option_iter::<f32, _>([None, None, None]);
181178
let result = sum(array.as_ref()).unwrap();
182-
assert_eq!(
183-
result,
184-
Scalar::null(DType::Primitive(PType::F64, Nullability::Nullable))
185-
);
179+
assert_eq!(result, Scalar::primitive(0f64, Nullability::Nullable));
186180
}
187181

188182
#[test]

0 commit comments

Comments
 (0)