Skip to content

Commit 91a5f0a

Browse files
committed
fix[array]: sum with initial value to fix op assoc
Signed-off-by: Joe Isaacs <[email protected]>
1 parent afc488d commit 91a5f0a

File tree

3 files changed

+20
-44
lines changed

3 files changed

+20
-44
lines changed

fuzz/fuzz_targets/array_ops.rs

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -89,8 +89,6 @@ fuzz_target!(|fuzz_action: FuzzArrayAction| -> Corpus {
8989
current_array = cast_result;
9090
}
9191
Action::Sum => {
92-
println!("sum {}", current_array.display_tree());
93-
println!("sum {}", current_array.display_values());
9492
let sum_result = sum(&current_array).vortex_unwrap();
9593
assert_scalar_eq(&expected.scalar(), &sum_result, i).unwrap();
9694
}

vortex-array/src/arrays/chunked/compute/sum.rs

Lines changed: 8 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -4,13 +4,12 @@
44
use std::ops::AddAssign;
55

66
use num_traits::PrimInt;
7-
use vortex_dtype::Nullability::Nullable;
8-
use vortex_dtype::{DType, DecimalDType, NativePType, match_each_native_ptype};
7+
use vortex_dtype::{DType, NativePType, match_each_native_ptype};
98
use vortex_error::{VortexExpect, VortexResult, vortex_bail, vortex_err};
10-
use vortex_scalar::{DecimalScalar, DecimalValue, Scalar};
9+
use vortex_scalar::Scalar;
1110

1211
use crate::arrays::{ChunkedArray, ChunkedVTable};
13-
use crate::compute::{SumKernel, SumKernelAdapter, sum};
12+
use crate::compute::{SumKernel, SumKernelAdapter, sum, sum_with_initial};
1413
use crate::stats::Stat;
1514
use crate::{ArrayRef, register_kernel};
1615

@@ -21,14 +20,7 @@ impl SumKernel for ChunkedVTable {
2120
.ok_or_else(|| vortex_err!("Sum not supported for dtype {}", array.dtype()))?;
2221

2322
match sum_dtype {
24-
DType::Decimal(decimal_dtype, _) => sum_decimal(
25-
array.chunks(),
26-
decimal_dtype,
27-
initial_value
28-
.as_decimal()
29-
.decimal_value()
30-
.vortex_expect("cannot be null"),
31-
),
23+
DType::Decimal(..) => sum_decimal(array.chunks(), initial_value),
3224
DType::Primitive(sum_ptype, _) => {
3325
let scalar_value = match_each_native_ptype!(
3426
sum_ptype,
@@ -82,36 +74,14 @@ fn sum_float<T: NativePType + AddAssign>(
8274
Ok(Some(result))
8375
}
8476

85-
fn sum_decimal(
86-
chunks: &[ArrayRef],
87-
result_decimal_type: DecimalDType,
88-
initial_value: DecimalValue,
89-
) -> VortexResult<Scalar> {
90-
let mut result = initial_value;
91-
92-
let null = || Scalar::null(DType::Decimal(result_decimal_type, Nullable));
77+
fn sum_decimal(chunks: &[ArrayRef], initial_value: &Scalar) -> VortexResult<Scalar> {
78+
let mut result = initial_value.clone();
9379

9480
for chunk in chunks {
95-
let chunk_sum = sum(chunk)?;
96-
97-
let chunk_decimal = DecimalScalar::try_from(&chunk_sum)?;
98-
let Some(r) = chunk_decimal
99-
.decimal_value()
100-
// TODO(joe): added a precision capped checked_add.
101-
.and_then(|c_sum| result.checked_add(&c_sum))
102-
.filter(|sum_value| {
103-
sum_value
104-
.fits_in_precision(result_decimal_type)
105-
.unwrap_or(false)
106-
})
107-
else {
108-
// null if any chunk is null or the sum overflows
109-
return Ok(null());
110-
};
111-
result = r;
81+
result = sum_with_initial(chunk, result)?;
11282
}
11383

114-
Ok(Scalar::decimal(result, result_decimal_type, Nullable))
84+
Ok(result)
11585
}
11686

11787
#[cfg(test)]

vortex-array/src/arrays/constant/compute/sum.rs

Lines changed: 12 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
// SPDX-License-Identifier: Apache-2.0
22
// SPDX-FileCopyrightText: Copyright the Vortex contributors
33

4+
use arrow_array::ArrowNativeTypeOp;
45
use num_traits::{CheckedAdd, CheckedMul, ToPrimitive};
56
use vortex_dtype::{DType, DecimalDType, NativePType, Nullability, i256, match_each_native_ptype};
67
use vortex_error::{VortexExpect, VortexResult, vortex_bail, vortex_err};
@@ -128,7 +129,7 @@ where
128129
let initial = initial_value
129130
.as_primitive()
130131
.as_::<T>()
131-
.unwrap_or_else(T::zero);
132+
.vortex_expect("cannot be null");
132133
Ok(initial.checked_add(&array_sum))
133134
}
134135

@@ -137,13 +138,20 @@ fn sum_float(
137138
array_len: usize,
138139
initial_value: &Scalar,
139140
) -> VortexResult<Option<f64>> {
140-
let v = primitive_scalar.as_::<f64>();
141+
let v = primitive_scalar
142+
.as_::<f64>()
143+
.vortex_expect("cannot be null");
141144
let array_len = array_len
142145
.to_f64()
143146
.ok_or_else(|| vortex_err!("array_len must fit the sum type"))?;
144147

145-
let array_sum = v.map(|v| v * array_len).unwrap_or(0.0);
146-
let initial = initial_value.as_primitive().as_::<f64>().unwrap_or(0.0);
148+
let Ok(array_sum) = v.mul_checked(array_len) else {
149+
return Ok(None);
150+
};
151+
let initial = initial_value
152+
.as_primitive()
153+
.as_::<f64>()
154+
.vortex_expect("cannot be null");
147155
Ok(Some(initial + array_sum))
148156
}
149157

0 commit comments

Comments
 (0)