Skip to content

Commit afc488d

Browse files
committed
fix[array]: sum with initial value to fix op assoc
Signed-off-by: Joe Isaacs <[email protected]>
1 parent 32d0834 commit afc488d

File tree

8 files changed

+216
-97
lines changed

8 files changed

+216
-97
lines changed

fuzz/fuzz_targets/array_ops.rs

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -89,6 +89,8 @@ fuzz_target!(|fuzz_action: FuzzArrayAction| -> Corpus {
8989
current_array = cast_result;
9090
}
9191
Action::Sum => {
92+
println!("sum {}", current_array.display_tree());
93+
println!("sum {}", current_array.display_values());
9294
let sum_result = sum(&current_array).vortex_unwrap();
9395
assert_scalar_eq(&expected.scalar(), &sum_result, i).unwrap();
9496
}

vortex-array/src/arrays/bool/compute/sum.rs

Lines changed: 11 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33

44
use std::ops::BitAnd;
55

6-
use vortex_error::VortexResult;
6+
use vortex_error::{VortexExpect, VortexResult};
77
use vortex_mask::AllOr;
88
use vortex_scalar::Scalar;
99

@@ -12,7 +12,7 @@ use crate::compute::{SumKernel, SumKernelAdapter};
1212
use crate::register_kernel;
1313

1414
impl SumKernel for BoolVTable {
15-
fn sum(&self, array: &BoolArray) -> VortexResult<Scalar> {
15+
fn sum(&self, array: &BoolArray, initial_value: &Scalar) -> VortexResult<Scalar> {
1616
let true_count: Option<u64> = match array.validity_mask().bit_buffer() {
1717
AllOr::All => {
1818
// All-valid
@@ -26,7 +26,15 @@ impl SumKernel for BoolVTable {
2626
Some(array.bit_buffer().bitand(validity_mask).true_count() as u64)
2727
}
2828
};
29-
Ok(Scalar::from(true_count))
29+
30+
// Add initial_value to true_count
31+
let initial_u64 = initial_value
32+
.as_primitive()
33+
.as_::<u64>()
34+
.vortex_expect("cannot be null");
35+
Ok(Scalar::from(
36+
true_count.and_then(|tc| tc.checked_add(initial_u64)),
37+
))
3038
}
3139
}
3240

vortex-array/src/arrays/chunked/compute/sum.rs

Lines changed: 33 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,12 @@
11
// SPDX-License-Identifier: Apache-2.0
22
// SPDX-FileCopyrightText: Copyright the Vortex contributors
33

4+
use std::ops::AddAssign;
5+
46
use num_traits::PrimInt;
57
use vortex_dtype::Nullability::Nullable;
6-
use vortex_dtype::{DType, DecimalDType, NativePType, i256, match_each_native_ptype};
7-
use vortex_error::{VortexResult, vortex_bail, vortex_err};
8+
use vortex_dtype::{DType, DecimalDType, NativePType, match_each_native_ptype};
9+
use vortex_error::{VortexExpect, VortexResult, vortex_bail, vortex_err};
810
use vortex_scalar::{DecimalScalar, DecimalValue, Scalar};
911

1012
use crate::arrays::{ChunkedArray, ChunkedVTable};
@@ -13,19 +15,26 @@ use crate::stats::Stat;
1315
use crate::{ArrayRef, register_kernel};
1416

1517
impl SumKernel for ChunkedVTable {
16-
fn sum(&self, array: &ChunkedArray) -> VortexResult<Scalar> {
18+
fn sum(&self, array: &ChunkedArray, initial_value: &Scalar) -> VortexResult<Scalar> {
1719
let sum_dtype = Stat::Sum
1820
.dtype(array.dtype())
1921
.ok_or_else(|| vortex_err!("Sum not supported for dtype {}", array.dtype()))?;
2022

2123
match sum_dtype {
22-
DType::Decimal(decimal_dtype, _) => sum_decimal(array.chunks(), decimal_dtype),
24+
DType::Decimal(decimal_dtype, _) => sum_decimal(
25+
array.chunks(),
26+
decimal_dtype,
27+
initial_value
28+
.as_decimal()
29+
.decimal_value()
30+
.vortex_expect("cannot be null"),
31+
),
2332
DType::Primitive(sum_ptype, _) => {
2433
let scalar_value = match_each_native_ptype!(
2534
sum_ptype,
26-
unsigned: |T| { sum_int::<u64>(array.chunks())?.into() },
27-
signed: |T| { sum_int::<i64>(array.chunks())?.into() },
28-
floating: |T| { sum_float(array.chunks())?.into() }
35+
unsigned: |T| { sum_int::<u64>(array.chunks(), initial_value.as_primitive().as_::<u64>().vortex_expect("cannot be null"))?.into() },
36+
signed: |T| { sum_int::<i64>(array.chunks(), initial_value.as_primitive().as_::<i64>().vortex_expect("cannot be null"))?.into() },
37+
floating: |T| { sum_float::<f64>(array.chunks(), initial_value.as_primitive().as_::<f64>().vortex_expect("cannot be null"))?.into() }
2938
);
3039

3140
Ok(Scalar::new(sum_dtype, scalar_value))
@@ -39,8 +48,11 @@ impl SumKernel for ChunkedVTable {
3948

4049
register_kernel!(SumKernelAdapter(ChunkedVTable).lift());
4150

42-
fn sum_int<T: NativePType + PrimInt>(chunks: &[ArrayRef]) -> VortexResult<Option<T>> {
43-
let mut result: T = T::zero();
51+
fn sum_int<T: NativePType + PrimInt>(
52+
chunks: &[ArrayRef],
53+
initial_value: T,
54+
) -> VortexResult<Option<T>> {
55+
let mut result: T = initial_value;
4456
for chunk in chunks {
4557
let chunk_sum = sum(chunk)?;
4658
let Some(chunk_sum) = chunk_sum
@@ -56,19 +68,26 @@ fn sum_int<T: NativePType + PrimInt>(chunks: &[ArrayRef]) -> VortexResult<Option
5668
Ok(Some(result))
5769
}
5870

59-
fn sum_float(chunks: &[ArrayRef]) -> VortexResult<Option<f64>> {
60-
let mut result = 0f64;
71+
fn sum_float<T: NativePType + AddAssign>(
72+
chunks: &[ArrayRef],
73+
initial_value: T,
74+
) -> VortexResult<Option<T>> {
75+
let mut result = initial_value;
6176
for chunk in chunks {
62-
let Some(chunk_sum) = sum(chunk)?.as_primitive().as_::<f64>() else {
77+
let Some(chunk_sum) = sum(chunk)?.as_primitive().as_::<T>() else {
6378
return Ok(None);
6479
};
6580
result += chunk_sum;
6681
}
6782
Ok(Some(result))
6883
}
6984

70-
fn sum_decimal(chunks: &[ArrayRef], result_decimal_type: DecimalDType) -> VortexResult<Scalar> {
71-
let mut result = DecimalValue::I256(i256::ZERO);
85+
fn sum_decimal(
86+
chunks: &[ArrayRef],
87+
result_decimal_type: DecimalDType,
88+
initial_value: DecimalValue,
89+
) -> VortexResult<Scalar> {
90+
let mut result = initial_value;
7291

7392
let null = || Scalar::null(DType::Decimal(result_decimal_type, Nullable));
7493

vortex-array/src/arrays/constant/compute/sum.rs

Lines changed: 70 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
// SPDX-License-Identifier: Apache-2.0
22
// SPDX-FileCopyrightText: Copyright the Vortex contributors
33

4-
use num_traits::{CheckedMul, ToPrimitive};
4+
use num_traits::{CheckedAdd, CheckedMul, ToPrimitive};
55
use vortex_dtype::{DType, DecimalDType, NativePType, Nullability, i256, match_each_native_ptype};
66
use vortex_error::{VortexExpect, VortexResult, vortex_bail, vortex_err};
77
use vortex_scalar::{DecimalScalar, DecimalValue, PrimitiveScalar, Scalar, ScalarValue};
@@ -12,32 +12,44 @@ use crate::register_kernel;
1212
use crate::stats::Stat;
1313

1414
impl SumKernel for ConstantVTable {
15-
fn sum(&self, array: &ConstantArray) -> VortexResult<Scalar> {
15+
fn sum(&self, array: &ConstantArray, initial_value: &Scalar) -> VortexResult<Scalar> {
1616
// Compute the expected dtype of the sum.
1717
let sum_dtype = Stat::Sum
1818
.dtype(array.dtype())
1919
.ok_or_else(|| vortex_err!("Sum not supported for dtype {}", array.dtype()))?;
2020

21-
let sum_value = sum_scalar(array.scalar(), array.len())?;
21+
let sum_value = sum_scalar(array.scalar(), array.len(), initial_value)?;
2222
Ok(Scalar::new(sum_dtype, sum_value))
2323
}
2424
}
2525

26-
fn sum_scalar(scalar: &Scalar, len: usize) -> VortexResult<ScalarValue> {
26+
fn sum_scalar(scalar: &Scalar, len: usize, acc: &Scalar) -> VortexResult<ScalarValue> {
2727
match scalar.dtype() {
28-
DType::Bool(_) => Ok(ScalarValue::from(match scalar.as_bool().value() {
29-
None => unreachable!("Handled before reaching this point"),
30-
Some(false) => 0u64,
31-
Some(true) => len as u64,
32-
})),
33-
DType::Primitive(ptype, _) => Ok(match_each_native_ptype!(
34-
ptype,
35-
unsigned: |T| { sum_integral::<u64>(scalar.as_primitive(), len)?.into() },
36-
signed: |T| { sum_integral::<i64>(scalar.as_primitive(), len)?.into() },
37-
floating: |T| { sum_float(scalar.as_primitive(), len)?.into() }
38-
)),
39-
DType::Decimal(decimal_dtype, _) => sum_decimal(scalar.as_decimal(), len, *decimal_dtype),
40-
DType::Extension(_) => sum_scalar(&scalar.as_extension().storage(), len),
28+
DType::Bool(_) => {
29+
let count = match scalar.as_bool().value() {
30+
None => unreachable!("Handled before reaching this point"),
31+
Some(false) => 0u64,
32+
Some(true) => len as u64,
33+
};
34+
let initial_u64 = acc
35+
.as_primitive()
36+
.as_::<u64>()
37+
.vortex_expect("cannot be null");
38+
Ok(ScalarValue::from(initial_u64.checked_add(count)))
39+
}
40+
DType::Primitive(ptype, _) => {
41+
let result = match_each_native_ptype!(
42+
ptype,
43+
unsigned: |T| { sum_integral::<u64>(scalar.as_primitive(), len, acc)?.into() },
44+
signed: |T| { sum_integral::<i64>(scalar.as_primitive(), len, acc)?.into() },
45+
floating: |T| { sum_float(scalar.as_primitive(), len, acc)?.into() }
46+
);
47+
Ok(result)
48+
}
49+
DType::Decimal(decimal_dtype, _) => {
50+
sum_decimal(scalar.as_decimal(), len, *decimal_dtype, acc)
51+
}
52+
DType::Extension(_) => sum_scalar(&scalar.as_extension().storage(), len, acc),
4153
dtype => vortex_bail!("Unsupported dtype for sum: {}", dtype),
4254
}
4355
}
@@ -46,6 +58,7 @@ fn sum_decimal(
4658
decimal_scalar: DecimalScalar,
4759
array_len: usize,
4860
decimal_dtype: DecimalDType,
61+
initial_value: &Scalar,
4962
) -> VortexResult<ScalarValue> {
5063
let result_dtype = Stat::Sum
5164
.dtype(&DType::Decimal(decimal_dtype, Nullability::Nullable))
@@ -63,43 +76,75 @@ fn sum_decimal(
6376
let len_value = DecimalValue::I256(i256::from_i128(array_len as i128));
6477

6578
// Multiply value * len
66-
let sum = value.checked_mul(&len_value).and_then(|result| {
79+
let array_sum = value.checked_mul(&len_value).and_then(|result| {
6780
// Check if result fits in the precision
6881
result
6982
.fits_in_precision(*result_decimal_type)
7083
.unwrap_or(false)
7184
.then_some(result)
7285
});
7386

74-
match sum {
75-
Some(result_value) => Ok(ScalarValue::from(result_value)),
87+
// Add initial_value to array_sum
88+
let initial_decimal = DecimalScalar::try_from(initial_value)?;
89+
let initial_dec_value = initial_decimal
90+
.decimal_value()
91+
.unwrap_or(DecimalValue::I256(i256::ZERO));
92+
93+
match array_sum {
94+
Some(array_sum_value) => {
95+
let total = array_sum_value
96+
.checked_add(&initial_dec_value)
97+
.and_then(|result| {
98+
result
99+
.fits_in_precision(*result_decimal_type)
100+
.unwrap_or(false)
101+
.then_some(result)
102+
});
103+
match total {
104+
Some(result_value) => Ok(ScalarValue::from(result_value)),
105+
None => Ok(ScalarValue::null()), // Overflow
106+
}
107+
}
76108
None => Ok(ScalarValue::null()), // Overflow
77109
}
78110
}
79111

80112
fn sum_integral<T>(
81113
primitive_scalar: PrimitiveScalar<'_>,
82114
array_len: usize,
115+
initial_value: &Scalar,
83116
) -> VortexResult<Option<T>>
84117
where
85-
T: NativePType + CheckedMul,
118+
T: NativePType + CheckedMul + CheckedAdd,
86119
Scalar: From<Option<T>>,
87120
{
88121
let v = primitive_scalar.as_::<T>();
89122
let array_len =
90123
T::from(array_len).ok_or_else(|| vortex_err!("array_len must fit the sum type"))?;
91-
let sum = v.and_then(|v| v.checked_mul(&array_len));
124+
let Some(array_sum) = v.and_then(|v| v.checked_mul(&array_len)) else {
125+
return Ok(None);
126+
};
92127

93-
Ok(sum)
128+
let initial = initial_value
129+
.as_primitive()
130+
.as_::<T>()
131+
.unwrap_or_else(T::zero);
132+
Ok(initial.checked_add(&array_sum))
94133
}
95134

96-
fn sum_float(primitive_scalar: PrimitiveScalar<'_>, array_len: usize) -> VortexResult<Option<f64>> {
135+
fn sum_float(
136+
primitive_scalar: PrimitiveScalar<'_>,
137+
array_len: usize,
138+
initial_value: &Scalar,
139+
) -> VortexResult<Option<f64>> {
97140
let v = primitive_scalar.as_::<f64>();
98141
let array_len = array_len
99142
.to_f64()
100143
.ok_or_else(|| vortex_err!("array_len must fit the sum type"))?;
101144

102-
Ok(v.map(|v| v * array_len))
145+
let array_sum = v.map(|v| v * array_len).unwrap_or(0.0);
146+
let initial = initial_value.as_primitive().as_::<f64>().unwrap_or(0.0);
147+
Ok(Some(initial + array_sum))
103148
}
104149

105150
register_kernel!(SumKernelAdapter(ConstantVTable).lift());

vortex-array/src/arrays/decimal/compute/sum.rs

Lines changed: 22 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -5,32 +5,34 @@ use arrow_schema::DECIMAL256_MAX_PRECISION;
55
use num_traits::AsPrimitive;
66
use vortex_dtype::Nullability::Nullable;
77
use vortex_dtype::{DecimalDType, DecimalType, match_each_decimal_value_type};
8-
use vortex_error::{VortexResult, vortex_bail};
8+
use vortex_error::{VortexExpect, VortexResult, vortex_bail, vortex_err};
99
use vortex_mask::Mask;
10-
use vortex_scalar::{DecimalValue, Scalar};
10+
use vortex_scalar::{DecimalScalar, DecimalValue, Scalar};
1111

1212
use crate::arrays::{DecimalArray, DecimalVTable};
1313
use crate::compute::{SumKernel, SumKernelAdapter};
1414
use crate::register_kernel;
1515

1616
// Its safe to use `AsPrimitive` here because we always cast up.
1717
macro_rules! sum_decimal {
18-
($ty:ty, $values:expr) => {{
19-
let mut sum: $ty = <$ty>::default();
18+
($ty:ty, $values:expr, $initial:expr) => {{
19+
let mut sum: $ty = $initial;
2020
for v in $values.iter() {
2121
let v: $ty = (*v).as_();
22-
sum += v;
22+
sum = num_traits::CheckedAdd::checked_add(&sum, &v)
23+
.ok_or_else(|| vortex_err!("Overflow when summing decimal {sum:?} + {v:?}"))?
2324
}
2425
sum
2526
}};
26-
($ty:ty, $values:expr, $validity:expr) => {{
27+
($ty:ty, $values:expr, $validity:expr, $initial:expr) => {{
2728
use itertools::Itertools;
2829

29-
let mut sum: $ty = <$ty>::default();
30+
let mut sum: $ty = $initial;
3031
for (v, valid) in $values.iter().zip_eq($validity) {
3132
if valid {
3233
let v: $ty = (*v).as_();
33-
sum += v;
34+
sum = num_traits::CheckedAdd::checked_add(&sum, &v)
35+
.ok_or_else(|| vortex_err!("Overflow when summing decimal {sum:?} + {v:?}"))?
3436
}
3537
}
3638
sum
@@ -39,7 +41,7 @@ macro_rules! sum_decimal {
3941

4042
impl SumKernel for DecimalVTable {
4143
#[allow(clippy::cognitive_complexity)]
42-
fn sum(&self, array: &DecimalArray) -> VortexResult<Scalar> {
44+
fn sum(&self, array: &DecimalArray, initial_value: &Scalar) -> VortexResult<Scalar> {
4345
let decimal_dtype = array.decimal_dtype();
4446

4547
// Both Spark and DataFusion use this heuristic.
@@ -49,6 +51,12 @@ impl SumKernel for DecimalVTable {
4951
let new_scale = decimal_dtype.scale();
5052
let return_dtype = DecimalDType::new(new_precision, new_scale);
5153

54+
// Extract the initial value as a DecimalValue
55+
let initial_decimal = DecimalScalar::try_from(initial_value)
56+
.vortex_expect("must be a decimal")
57+
.decimal_value()
58+
.vortex_expect("cannot be null");
59+
5260
match array.validity_mask() {
5361
Mask::AllFalse(_) => {
5462
vortex_bail!("invalid state, all-null array should be checked by top-level sum fn")
@@ -57,8 +65,9 @@ impl SumKernel for DecimalVTable {
5765
let values_type = DecimalType::smallest_decimal_value_type(&return_dtype);
5866
match_each_decimal_value_type!(array.values_type(), |I| {
5967
match_each_decimal_value_type!(values_type, |O| {
68+
let initial_val: O = initial_decimal.cast().unwrap_or_else(O::default);
6069
Ok(Scalar::decimal(
61-
DecimalValue::from(sum_decimal!(O, array.buffer::<I>())),
70+
DecimalValue::from(sum_decimal!(O, array.buffer::<I>(), initial_val)),
6271
return_dtype,
6372
Nullable,
6473
))
@@ -69,11 +78,13 @@ impl SumKernel for DecimalVTable {
6978
let values_type = DecimalType::smallest_decimal_value_type(&return_dtype);
7079
match_each_decimal_value_type!(array.values_type(), |I| {
7180
match_each_decimal_value_type!(values_type, |O| {
81+
let initial_val: O = initial_decimal.cast().unwrap_or_else(O::default);
7282
Ok(Scalar::decimal(
7383
DecimalValue::from(sum_decimal!(
7484
O,
7585
array.buffer::<I>(),
76-
mask_values.bit_buffer()
86+
mask_values.bit_buffer(),
87+
initial_val
7788
)),
7889
return_dtype,
7990
Nullable,

0 commit comments

Comments
 (0)