Skip to content

Commit fcf2d98

Browse files
fix[array]: sum with initial value to fix op assoc (#5278)
Signed-off-by: Joe Isaacs <[email protected]>
1 parent 5f998ed commit fcf2d98

File tree

7 files changed

+221
-170
lines changed

7 files changed

+221
-170
lines changed

vortex-array/src/arrays/bool/compute/sum.rs

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33

44
use std::ops::BitAnd;
55

6-
use vortex_error::VortexResult;
6+
use vortex_error::{VortexExpect, VortexResult};
77
use vortex_mask::AllOr;
88
use vortex_scalar::Scalar;
99

@@ -12,7 +12,7 @@ use crate::compute::{SumKernel, SumKernelAdapter};
1212
use crate::register_kernel;
1313

1414
impl SumKernel for BoolVTable {
15-
fn sum(&self, array: &BoolArray) -> VortexResult<Scalar> {
15+
fn sum(&self, array: &BoolArray, accumulator: &Scalar) -> VortexResult<Scalar> {
1616
let true_count: Option<u64> = match array.validity_mask().bit_buffer() {
1717
AllOr::All => {
1818
// All-valid
@@ -26,7 +26,14 @@ impl SumKernel for BoolVTable {
2626
Some(array.bit_buffer().bitand(validity_mask).true_count() as u64)
2727
}
2828
};
29-
Ok(Scalar::from(true_count))
29+
30+
let accumulator = accumulator
31+
.as_primitive()
32+
.as_::<u64>()
33+
.vortex_expect("cannot be null");
34+
Ok(Scalar::from(
35+
true_count.and_then(|tc| accumulator.checked_add(tc)),
36+
))
3037
}
3138
}
3239

vortex-array/src/arrays/chunked/compute/sum.rs

Lines changed: 11 additions & 85 deletions
Original file line numberDiff line numberDiff line change
@@ -1,100 +1,26 @@
11
// SPDX-License-Identifier: Apache-2.0
22
// SPDX-FileCopyrightText: Copyright the Vortex contributors
33

4-
use num_traits::PrimInt;
5-
use vortex_dtype::Nullability::Nullable;
6-
use vortex_dtype::{DType, DecimalDType, NativePType, i256, match_each_native_ptype};
7-
use vortex_error::{VortexResult, vortex_bail, vortex_err};
8-
use vortex_scalar::{DecimalScalar, DecimalValue, Scalar};
4+
use vortex_error::VortexResult;
5+
use vortex_scalar::Scalar;
96

107
use crate::arrays::{ChunkedArray, ChunkedVTable};
11-
use crate::compute::{SumKernel, SumKernelAdapter, sum};
12-
use crate::stats::Stat;
13-
use crate::{ArrayRef, register_kernel};
8+
use crate::compute::{SumKernel, SumKernelAdapter, sum_with_accumulator};
9+
use crate::register_kernel;
1410

1511
impl SumKernel for ChunkedVTable {
16-
fn sum(&self, array: &ChunkedArray) -> VortexResult<Scalar> {
17-
let sum_dtype = Stat::Sum
18-
.dtype(array.dtype())
19-
.ok_or_else(|| vortex_err!("Sum not supported for dtype {}", array.dtype()))?;
20-
21-
match sum_dtype {
22-
DType::Decimal(decimal_dtype, _) => sum_decimal(array.chunks(), decimal_dtype),
23-
DType::Primitive(sum_ptype, _) => {
24-
let scalar_value = match_each_native_ptype!(
25-
sum_ptype,
26-
unsigned: |T| { sum_int::<u64>(array.chunks())?.into() },
27-
signed: |T| { sum_int::<i64>(array.chunks())?.into() },
28-
floating: |T| { sum_float(array.chunks())?.into() }
29-
);
30-
31-
Ok(Scalar::new(sum_dtype, scalar_value))
32-
}
33-
_ => {
34-
vortex_bail!("Sum not supported for dtype {}", sum_dtype);
35-
}
36-
}
12+
fn sum(&self, array: &ChunkedArray, accumulator: &Scalar) -> VortexResult<Scalar> {
13+
array
14+
.chunks
15+
.iter()
16+
.try_fold(accumulator.clone(), |result, chunk| {
17+
sum_with_accumulator(chunk, &result)
18+
})
3719
}
3820
}
3921

4022
register_kernel!(SumKernelAdapter(ChunkedVTable).lift());
4123

42-
fn sum_int<T: NativePType + PrimInt>(chunks: &[ArrayRef]) -> VortexResult<Option<T>> {
43-
let mut result: T = T::zero();
44-
for chunk in chunks {
45-
let chunk_sum = sum(chunk)?;
46-
let Some(chunk_sum) = chunk_sum
47-
.as_primitive()
48-
.as_::<T>()
49-
.and_then(|chunk_sum| result.checked_add(&chunk_sum))
50-
else {
51-
// Bail out on null or overflow
52-
return Ok(None);
53-
};
54-
result = chunk_sum;
55-
}
56-
Ok(Some(result))
57-
}
58-
59-
fn sum_float(chunks: &[ArrayRef]) -> VortexResult<Option<f64>> {
60-
let mut result = 0f64;
61-
for chunk in chunks {
62-
let Some(chunk_sum) = sum(chunk)?.as_primitive().as_::<f64>() else {
63-
return Ok(None);
64-
};
65-
result += chunk_sum;
66-
}
67-
Ok(Some(result))
68-
}
69-
70-
fn sum_decimal(chunks: &[ArrayRef], result_decimal_type: DecimalDType) -> VortexResult<Scalar> {
71-
let mut result = DecimalValue::I256(i256::ZERO);
72-
73-
let null = || Scalar::null(DType::Decimal(result_decimal_type, Nullable));
74-
75-
for chunk in chunks {
76-
let chunk_sum = sum(chunk)?;
77-
78-
let chunk_decimal = DecimalScalar::try_from(&chunk_sum)?;
79-
let Some(r) = chunk_decimal
80-
.decimal_value()
81-
// TODO(joe): added a precision capped checked_add.
82-
.and_then(|c_sum| result.checked_add(&c_sum))
83-
.filter(|sum_value| {
84-
sum_value
85-
.fits_in_precision(result_decimal_type)
86-
.unwrap_or(false)
87-
})
88-
else {
89-
// null if any chunk is null or the sum overflows
90-
return Ok(null());
91-
};
92-
result = r;
93-
}
94-
95-
Ok(Scalar::decimal(result, result_decimal_type, Nullable))
96-
}
97-
9824
#[cfg(test)]
9925
mod tests {
10026
use vortex_buffer::buffer;

vortex-array/src/arrays/constant/compute/sum.rs

Lines changed: 79 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,8 @@
11
// SPDX-License-Identifier: Apache-2.0
22
// SPDX-FileCopyrightText: Copyright the Vortex contributors
33

4-
use num_traits::{CheckedMul, ToPrimitive};
4+
use arrow_array::ArrowNativeTypeOp;
5+
use num_traits::{CheckedAdd, CheckedMul, ToPrimitive};
56
use vortex_dtype::{DType, DecimalDType, NativePType, Nullability, i256, match_each_native_ptype};
67
use vortex_error::{VortexExpect, VortexResult, vortex_bail, vortex_err};
78
use vortex_scalar::{DecimalScalar, DecimalValue, PrimitiveScalar, Scalar, ScalarValue};
@@ -12,32 +13,44 @@ use crate::register_kernel;
1213
use crate::stats::Stat;
1314

1415
impl SumKernel for ConstantVTable {
15-
fn sum(&self, array: &ConstantArray) -> VortexResult<Scalar> {
16+
fn sum(&self, array: &ConstantArray, accumulator: &Scalar) -> VortexResult<Scalar> {
1617
// Compute the expected dtype of the sum.
1718
let sum_dtype = Stat::Sum
1819
.dtype(array.dtype())
1920
.ok_or_else(|| vortex_err!("Sum not supported for dtype {}", array.dtype()))?;
2021

21-
let sum_value = sum_scalar(array.scalar(), array.len())?;
22+
let sum_value = sum_scalar(array.scalar(), array.len(), accumulator)?;
2223
Ok(Scalar::new(sum_dtype, sum_value))
2324
}
2425
}
2526

26-
fn sum_scalar(scalar: &Scalar, len: usize) -> VortexResult<ScalarValue> {
27+
fn sum_scalar(scalar: &Scalar, len: usize, accumulator: &Scalar) -> VortexResult<ScalarValue> {
2728
match scalar.dtype() {
28-
DType::Bool(_) => Ok(ScalarValue::from(match scalar.as_bool().value() {
29-
None => unreachable!("Handled before reaching this point"),
30-
Some(false) => 0u64,
31-
Some(true) => len as u64,
32-
})),
33-
DType::Primitive(ptype, _) => Ok(match_each_native_ptype!(
34-
ptype,
35-
unsigned: |T| { sum_integral::<u64>(scalar.as_primitive(), len)?.into() },
36-
signed: |T| { sum_integral::<i64>(scalar.as_primitive(), len)?.into() },
37-
floating: |T| { sum_float(scalar.as_primitive(), len)?.into() }
38-
)),
39-
DType::Decimal(decimal_dtype, _) => sum_decimal(scalar.as_decimal(), len, *decimal_dtype),
40-
DType::Extension(_) => sum_scalar(&scalar.as_extension().storage(), len),
29+
DType::Bool(_) => {
30+
let count = match scalar.as_bool().value() {
31+
None => unreachable!("Handled before reaching this point"),
32+
Some(false) => 0u64,
33+
Some(true) => len as u64,
34+
};
35+
let accumulator = accumulator
36+
.as_primitive()
37+
.as_::<u64>()
38+
.vortex_expect("cannot be null");
39+
Ok(ScalarValue::from(accumulator.checked_add(count)))
40+
}
41+
DType::Primitive(ptype, _) => {
42+
let result = match_each_native_ptype!(
43+
ptype,
44+
unsigned: |T| { sum_integral::<u64>(scalar.as_primitive(), len, accumulator)?.into() },
45+
signed: |T| { sum_integral::<i64>(scalar.as_primitive(), len, accumulator)?.into() },
46+
floating: |T| { sum_float(scalar.as_primitive(), len, accumulator)?.into() }
47+
);
48+
Ok(result)
49+
}
50+
DType::Decimal(decimal_dtype, _) => {
51+
sum_decimal(scalar.as_decimal(), len, *decimal_dtype, accumulator)
52+
}
53+
DType::Extension(_) => sum_scalar(&scalar.as_extension().storage(), len, accumulator),
4154
dtype => vortex_bail!("Unsupported dtype for sum: {}", dtype),
4255
}
4356
}
@@ -46,6 +59,7 @@ fn sum_decimal(
4659
decimal_scalar: DecimalScalar,
4760
array_len: usize,
4861
decimal_dtype: DecimalDType,
62+
accumulator: &Scalar,
4963
) -> VortexResult<ScalarValue> {
5064
let result_dtype = Stat::Sum
5165
.dtype(&DType::Decimal(decimal_dtype, Nullability::Nullable))
@@ -63,43 +77,82 @@ fn sum_decimal(
6377
let len_value = DecimalValue::I256(i256::from_i128(array_len as i128));
6478

6579
// Multiply value * len
66-
let sum = value.checked_mul(&len_value).and_then(|result| {
80+
let array_sum = value.checked_mul(&len_value).and_then(|result| {
6781
// Check if result fits in the precision
6882
result
6983
.fits_in_precision(*result_decimal_type)
7084
.unwrap_or(false)
7185
.then_some(result)
7286
});
7387

74-
match sum {
75-
Some(result_value) => Ok(ScalarValue::from(result_value)),
88+
// Add accumulator to array_sum
89+
let initial_decimal = DecimalScalar::try_from(accumulator)?;
90+
let initial_dec_value = initial_decimal
91+
.decimal_value()
92+
.unwrap_or(DecimalValue::I256(i256::ZERO));
93+
94+
match array_sum {
95+
Some(array_sum_value) => {
96+
let total = array_sum_value
97+
.checked_add(&initial_dec_value)
98+
.and_then(|result| {
99+
result
100+
.fits_in_precision(*result_decimal_type)
101+
.unwrap_or(false)
102+
.then_some(result)
103+
});
104+
match total {
105+
Some(result_value) => Ok(ScalarValue::from(result_value)),
106+
None => Ok(ScalarValue::null()), // Overflow
107+
}
108+
}
76109
None => Ok(ScalarValue::null()), // Overflow
77110
}
78111
}
79112

80113
fn sum_integral<T>(
81114
primitive_scalar: PrimitiveScalar<'_>,
82115
array_len: usize,
116+
accumulator: &Scalar,
83117
) -> VortexResult<Option<T>>
84118
where
85-
T: NativePType + CheckedMul,
119+
T: NativePType + CheckedMul + CheckedAdd,
86120
Scalar: From<Option<T>>,
87121
{
88122
let v = primitive_scalar.as_::<T>();
89123
let array_len =
90124
T::from(array_len).ok_or_else(|| vortex_err!("array_len must fit the sum type"))?;
91-
let sum = v.and_then(|v| v.checked_mul(&array_len));
125+
let Some(array_sum) = v.and_then(|v| v.checked_mul(&array_len)) else {
126+
return Ok(None);
127+
};
92128

93-
Ok(sum)
129+
let initial = accumulator
130+
.as_primitive()
131+
.as_::<T>()
132+
.vortex_expect("cannot be null");
133+
Ok(initial.checked_add(&array_sum))
94134
}
95135

96-
fn sum_float(primitive_scalar: PrimitiveScalar<'_>, array_len: usize) -> VortexResult<Option<f64>> {
97-
let v = primitive_scalar.as_::<f64>();
136+
fn sum_float(
137+
primitive_scalar: PrimitiveScalar<'_>,
138+
array_len: usize,
139+
accumulator: &Scalar,
140+
) -> VortexResult<Option<f64>> {
141+
let v = primitive_scalar
142+
.as_::<f64>()
143+
.vortex_expect("cannot be null");
98144
let array_len = array_len
99145
.to_f64()
100146
.ok_or_else(|| vortex_err!("array_len must fit the sum type"))?;
101147

102-
Ok(v.map(|v| v * array_len))
148+
let Ok(array_sum) = v.mul_checked(array_len) else {
149+
return Ok(None);
150+
};
151+
let initial = accumulator
152+
.as_primitive()
153+
.as_::<f64>()
154+
.vortex_expect("cannot be null");
155+
Ok(Some(initial + array_sum))
103156
}
104157

105158
register_kernel!(SumKernelAdapter(ConstantVTable).lift());

0 commit comments

Comments
 (0)