Skip to content

Commit 60d7520

Browse files
authored
fix: Decimal sum doesn't panic but returns null on overflow (#5564)
Signed-off-by: Robert Kruszewski <[email protected]>
1 parent c6ed47f commit 60d7520

File tree

3 files changed

+86
-62
lines changed

3 files changed

+86
-62
lines changed

vortex-array/src/arrays/decimal/compute/sum.rs

Lines changed: 80 additions & 56 deletions
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,17 @@
11
// SPDX-License-Identifier: Apache-2.0
22
// SPDX-FileCopyrightText: Copyright the Vortex contributors
33

4-
use arrow_schema::DECIMAL256_MAX_PRECISION;
4+
use itertools::Itertools;
55
use num_traits::AsPrimitive;
6-
use vortex_dtype::DecimalDType;
6+
use num_traits::CheckedAdd;
7+
use vortex_buffer::BitBuffer;
8+
use vortex_buffer::Buffer;
79
use vortex_dtype::DecimalType;
810
use vortex_dtype::Nullability::Nullable;
911
use vortex_dtype::match_each_decimal_value_type;
1012
use vortex_error::VortexExpect;
1113
use vortex_error::VortexResult;
1214
use vortex_error::vortex_bail;
13-
use vortex_error::vortex_err;
1415
use vortex_mask::Mask;
1516
use vortex_scalar::DecimalScalar;
1617
use vortex_scalar::DecimalValue;
@@ -21,47 +22,20 @@ use crate::arrays::DecimalVTable;
2122
use crate::compute::SumKernel;
2223
use crate::compute::SumKernelAdapter;
2324
use crate::register_kernel;
24-
25-
// Its safe to use `AsPrimitive` here because we always cast up.
26-
macro_rules! sum_decimal {
27-
($ty:ty, $values:expr, $initial:expr) => {{
28-
let mut sum: $ty = $initial;
29-
for v in $values.iter() {
30-
let v: $ty = (*v).as_();
31-
sum = num_traits::CheckedAdd::checked_add(&sum, &v)
32-
.ok_or_else(|| vortex_err!("Overflow when summing decimal {sum:?} + {v:?}"))?
33-
}
34-
sum
35-
}};
36-
($ty:ty, $values:expr, $validity:expr, $initial:expr) => {{
37-
use itertools::Itertools;
38-
39-
let mut sum: $ty = $initial;
40-
for (v, valid) in $values.iter().zip_eq($validity) {
41-
if valid {
42-
let v: $ty = (*v).as_();
43-
sum = num_traits::CheckedAdd::checked_add(&sum, &v)
44-
.ok_or_else(|| vortex_err!("Overflow when summing decimal {sum:?} + {v:?}"))?
45-
}
46-
}
47-
sum
48-
}};
49-
}
25+
use crate::stats::Stat;
5026

5127
impl SumKernel for DecimalVTable {
5228
#[expect(
5329
clippy::cognitive_complexity,
5430
reason = "complexity from nested match_each_* macros"
5531
)]
5632
fn sum(&self, array: &DecimalArray, accumulator: &Scalar) -> VortexResult<Scalar> {
57-
let decimal_dtype = array.decimal_dtype();
58-
59-
// Both Spark and DataFusion use this heuristic.
60-
// - https://github.com/apache/spark/blob/fcf636d9eb8d645c24be3db2d599aba2d7e2955a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Sum.scala#L66
61-
// - https://github.com/apache/datafusion/blob/4153adf2c0f6e317ef476febfdc834208bd46622/datafusion/functions-aggregate/src/sum.rs#L188
62-
let new_precision = u8::min(DECIMAL256_MAX_PRECISION, decimal_dtype.precision() + 10);
63-
let new_scale = decimal_dtype.scale();
64-
let return_dtype = DecimalDType::new(new_precision, new_scale);
33+
let return_dtype = Stat::Sum
34+
.dtype(array.dtype())
35+
.vortex_expect("sum for decimals exists");
36+
let return_decimal_dtype = return_dtype
37+
.as_decimal_opt()
38+
.vortex_expect("must be decimal");
6539

6640
// Extract the initial value as a DecimalValue
6741
let initial_decimal = DecimalScalar::try_from(accumulator)
@@ -74,44 +48,79 @@ impl SumKernel for DecimalVTable {
7448
vortex_bail!("invalid state, all-null array should be checked by top-level sum fn")
7549
}
7650
Mask::AllTrue(_) => {
77-
let values_type = DecimalType::smallest_decimal_value_type(&return_dtype);
51+
let values_type = DecimalType::smallest_decimal_value_type(return_decimal_dtype);
7852
match_each_decimal_value_type!(array.values_type(), |I| {
7953
match_each_decimal_value_type!(values_type, |O| {
8054
let initial_val: O = initial_decimal
8155
.cast()
8256
.vortex_expect("cannot fail to cast initial value");
83-
Ok(Scalar::decimal(
84-
DecimalValue::from(sum_decimal!(O, array.buffer::<I>(), initial_val)),
85-
return_dtype,
86-
Nullable,
87-
))
57+
if let Some(sum) = sum_decimal(array.buffer::<I>(), initial_val) {
58+
Ok(Scalar::decimal(
59+
DecimalValue::from(sum),
60+
*return_decimal_dtype,
61+
Nullable,
62+
))
63+
} else {
64+
Ok(Scalar::null(return_dtype))
65+
}
8866
})
8967
})
9068
}
9169
Mask::Values(mask_values) => {
92-
let values_type = DecimalType::smallest_decimal_value_type(&return_dtype);
70+
let values_type = DecimalType::smallest_decimal_value_type(return_decimal_dtype);
9371
match_each_decimal_value_type!(array.values_type(), |I| {
9472
match_each_decimal_value_type!(values_type, |O| {
9573
let initial_val: O = initial_decimal
9674
.cast()
9775
.vortex_expect("cannot fail to cast initial value");
98-
Ok(Scalar::decimal(
99-
DecimalValue::from(sum_decimal!(
100-
O,
101-
array.buffer::<I>(),
102-
mask_values.bit_buffer(),
103-
initial_val
104-
)),
105-
return_dtype,
106-
Nullable,
107-
))
76+
77+
if let Some(sum) = sum_decimal_with_validity(
78+
array.buffer::<I>(),
79+
mask_values.bit_buffer(),
80+
initial_val,
81+
) {
82+
Ok(Scalar::decimal(
83+
DecimalValue::from(sum),
84+
*return_decimal_dtype,
85+
Nullable,
86+
))
87+
} else {
88+
Ok(Scalar::null(return_dtype))
89+
}
10890
})
10991
})
11092
}
11193
}
11294
}
11395
}
11496

97+
fn sum_decimal<T: AsPrimitive<I>, I: Copy + CheckedAdd + 'static>(
98+
values: Buffer<T>,
99+
initial: I,
100+
) -> Option<I> {
101+
let mut sum = initial;
102+
for v in values.iter() {
103+
let v: I = v.as_();
104+
sum = CheckedAdd::checked_add(&sum, &v)?;
105+
}
106+
Some(sum)
107+
}
108+
109+
fn sum_decimal_with_validity<T: AsPrimitive<I>, I: Copy + CheckedAdd + 'static>(
110+
values: Buffer<T>,
111+
validity: &BitBuffer,
112+
initial: I,
113+
) -> Option<I> {
114+
let mut sum = initial;
115+
for (v, valid) in values.iter().zip_eq(validity) {
116+
if valid {
117+
let v: I = v.as_();
118+
sum = CheckedAdd::checked_add(&sum, &v)?;
119+
}
120+
}
121+
Some(sum)
122+
}
123+
115124
register_kernel!(SumKernelAdapter(DecimalVTable).lift());
116125

117126
#[cfg(test)]
@@ -120,9 +129,11 @@ mod tests {
120129
use vortex_dtype::DType;
121130
use vortex_dtype::DecimalDType;
122131
use vortex_dtype::Nullability;
132+
use vortex_error::VortexUnwrap;
123133
use vortex_scalar::DecimalValue;
124134
use vortex_scalar::Scalar;
125135
use vortex_scalar::ScalarValue;
136+
use vortex_scalar::i256;
126137

127138
use crate::arrays::DecimalArray;
128139
use crate::compute::sum;
@@ -327,8 +338,6 @@ mod tests {
327338

328339
#[test]
329340
fn test_sum_i128_to_i256_boundary() {
330-
use vortex_scalar::i256;
331-
332341
// Test the boundary between i128 and i256 accumulation
333342
let large_i128 = i128::MAX / 10;
334343
let decimal = DecimalArray::new(
@@ -351,4 +360,19 @@ mod tests {
351360

352361
assert_eq!(result, expected);
353362
}
363+
364+
#[test]
365+
fn test_i256_overflow() {
366+
let decimal_dtype = DecimalDType::new(76, 0);
367+
let decimal = DecimalArray::new(
368+
buffer![i256::MAX, i256::MAX, i256::MAX],
369+
decimal_dtype,
370+
Validity::AllValid,
371+
);
372+
373+
assert_eq!(
374+
sum(decimal.as_ref()).vortex_unwrap(),
375+
Scalar::null(DType::Decimal(decimal_dtype, Nullability::Nullable))
376+
);
377+
}
354378
}

vortex-array/src/stats/mod.rs

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -20,11 +20,10 @@ use num_enum::TryFromPrimitive;
2020
pub use stats_set::*;
2121
use vortex_dtype::DType;
2222
use vortex_dtype::DecimalDType;
23-
use vortex_dtype::NativeDecimalType;
23+
use vortex_dtype::MAX_PRECISION;
2424
use vortex_dtype::Nullability::NonNullable;
2525
use vortex_dtype::Nullability::Nullable;
2626
use vortex_dtype::PType;
27-
use vortex_dtype::i256;
2827

2928
mod array;
3029
mod bound;
@@ -221,8 +220,7 @@ impl Stat {
221220
// Both Spark and DataFusion use this heuristic.
222221
// - https://github.com/apache/spark/blob/fcf636d9eb8d645c24be3db2d599aba2d7e2955a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Sum.scala#L66
223222
// - https://github.com/apache/datafusion/blob/4153adf2c0f6e317ef476febfdc834208bd46622/datafusion/functions-aggregate/src/sum.rs#L188
224-
let precision =
225-
u8::min(i256::MAX_PRECISION, decimal_dtype.precision() + 10);
223+
let precision = u8::min(MAX_PRECISION, decimal_dtype.precision() + 10);
226224
DType::Decimal(
227225
DecimalDType::new(precision, decimal_dtype.scale()),
228226
Nullable,

vortex-dtype/src/decimal/mod.rs

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -23,8 +23,10 @@ use vortex_error::vortex_panic;
2323
use crate::DType;
2424
use crate::i256;
2525

26-
const MAX_PRECISION: u8 = <i256 as NativeDecimalType>::MAX_PRECISION;
27-
const MAX_SCALE: i8 = <i256 as NativeDecimalType>::MAX_SCALE;
26+
/// The maximum precision allowed for a decimal type.
27+
pub const MAX_PRECISION: u8 = <i256 as NativeDecimalType>::MAX_PRECISION;
28+
/// The maximum scale allowed for a decimal type.
29+
pub const MAX_SCALE: i8 = <i256 as NativeDecimalType>::MAX_SCALE;
2830

2931
/// Parameters that define the precision and scale of a decimal type.
3032
///

0 commit comments

Comments
 (0)