Skip to content

Commit 25095f7

Browse files
authored
Fix: ConstantArray float sum behaves the same way as primitive array (#5513)
fix #5508 Signed-off-by: Robert Kruszewski <[email protected]>
1 parent eba8252 commit 25095f7

File tree

2 files changed

+33
-21
lines changed

2 files changed

+33
-21
lines changed

vortex-array/src/arrays/constant/compute/sum.rs

Lines changed: 25 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,7 @@
11
// SPDX-License-Identifier: Apache-2.0
22
// SPDX-FileCopyrightText: Copyright the Vortex contributors
33

4-
use arrow_array::ArrowNativeTypeOp;
5-
use num_traits::{CheckedAdd, CheckedMul, ToPrimitive};
4+
use num_traits::{CheckedAdd, CheckedMul};
65
use vortex_dtype::{DType, DecimalDType, NativePType, Nullability, i256, match_each_native_ptype};
76
use vortex_error::{VortexExpect, VortexResult, vortex_bail, vortex_err};
87
use vortex_scalar::{DecimalScalar, DecimalValue, PrimitiveScalar, Scalar, ScalarValue};
@@ -138,21 +137,20 @@ fn sum_float(
138137
array_len: usize,
139138
accumulator: &Scalar,
140139
) -> VortexResult<Option<f64>> {
141-
let v = primitive_scalar
142-
.as_::<f64>()
143-
.vortex_expect("cannot be null");
144-
let array_len = array_len
145-
.to_f64()
146-
.ok_or_else(|| vortex_err!("array_len must fit the sum type"))?;
147-
148-
let Ok(array_sum) = v.mul_checked(array_len) else {
149-
return Ok(None);
150-
};
151140
let initial = accumulator
152141
.as_primitive()
153142
.as_::<f64>()
154143
.vortex_expect("cannot be null");
155-
Ok(Some(initial + array_sum))
144+
let v = primitive_scalar
145+
.as_::<f64>()
146+
.vortex_expect("cannot be null");
147+
148+
// Preserve numerical behaviour of summation of floats by using a loop instead of simplifying to multiplication.
149+
let mut sum = initial;
150+
for _ in 0..array_len {
151+
sum += v;
152+
}
153+
Ok(Some(sum))
156154
}
157155

158156
register_kernel!(SumKernelAdapter(ConstantVTable).lift());
@@ -161,10 +159,11 @@ register_kernel!(SumKernelAdapter(ConstantVTable).lift());
161159
mod tests {
162160
use vortex_dtype::Nullability::Nullable;
163161
use vortex_dtype::{DType, DecimalDType, Nullability, PType, i256};
162+
use vortex_error::VortexUnwrap;
164163
use vortex_scalar::{DecimalValue, Scalar};
165164

166165
use crate::arrays::ConstantArray;
167-
use crate::compute::sum;
166+
use crate::compute::{sum, sum_with_accumulator};
168167
use crate::stats::Stat;
169168
use crate::{Array, IntoArray};
170169

@@ -269,4 +268,16 @@ mod tests {
269268
Some(DecimalValue::I256(i256::from_i128(99_999_999_900)))
270269
);
271270
}
271+
272+
#[test]
273+
fn test_sum_float_non_multiply() {
274+
let acc = -2048669276050936500000000000f64;
275+
let array = ConstantArray::new(6.1811675e16f64, 25);
276+
let sum =
277+
sum_with_accumulator(array.as_ref(), &Scalar::primitive(acc, Nullable)).vortex_unwrap();
278+
assert_eq!(
279+
sum,
280+
Scalar::primitive(-2048669274505641600000000000f64, Nullable)
281+
);
282+
}
272283
}

vortex-array/src/compute/sum.rs

Lines changed: 8 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -108,14 +108,15 @@ impl ComputeFnVTable for Sum {
108108
DType::Primitive(p, _) => {
109109
if p.is_float() && accumulator.is_zero() {
110110
return Ok(sum.into());
111+
} else if p.is_int() {
112+
let sum_from_stat = accumulator
113+
.as_primitive()
114+
.checked_add(&sum.as_primitive())
115+
.map(Scalar::from);
116+
return Ok(sum_from_stat
117+
.unwrap_or_else(|| Scalar::null(sum_dtype))
118+
.into());
111119
}
112-
let sum_from_stat = accumulator
113-
.as_primitive()
114-
.checked_add(&sum.as_primitive())
115-
.map(Scalar::from);
116-
return Ok(sum_from_stat
117-
.unwrap_or_else(|| Scalar::null(sum_dtype))
118-
.into());
119120
}
120121
DType::Decimal(..) => {
121122
let sum_from_stat = accumulator

0 commit comments

Comments
 (0)