Skip to content

Commit efdf70e

Browse files
committed
fix: Decimal sum doesn't panic but returns null on overflow
Signed-off-by: Robert Kruszewski <[email protected]>
1 parent c6ed47f commit efdf70e

File tree

2 files changed

+78
-48
lines changed
  • vortex-array/src/arrays/decimal/compute
  • vortex-dtype/src/decimal

2 files changed

+78
-48
lines changed

vortex-array/src/arrays/decimal/compute/sum.rs

Lines changed: 74 additions & 46 deletions
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,20 @@
11
// SPDX-License-Identifier: Apache-2.0
22
// SPDX-FileCopyrightText: Copyright the Vortex contributors
33

4-
use arrow_schema::DECIMAL256_MAX_PRECISION;
4+
use itertools::Itertools;
55
use num_traits::AsPrimitive;
6+
use num_traits::CheckedAdd;
7+
use vortex_buffer::BitBuffer;
8+
use vortex_buffer::Buffer;
9+
use vortex_dtype::DType;
610
use vortex_dtype::DecimalDType;
711
use vortex_dtype::DecimalType;
12+
use vortex_dtype::MAX_PRECISION;
813
use vortex_dtype::Nullability::Nullable;
914
use vortex_dtype::match_each_decimal_value_type;
1015
use vortex_error::VortexExpect;
1116
use vortex_error::VortexResult;
1217
use vortex_error::vortex_bail;
13-
use vortex_error::vortex_err;
1418
use vortex_mask::Mask;
1519
use vortex_scalar::DecimalScalar;
1620
use vortex_scalar::DecimalValue;
@@ -22,32 +26,6 @@ use crate::compute::SumKernel;
2226
use crate::compute::SumKernelAdapter;
2327
use crate::register_kernel;
2428

25-
// Its safe to use `AsPrimitive` here because we always cast up.
26-
macro_rules! sum_decimal {
27-
($ty:ty, $values:expr, $initial:expr) => {{
28-
let mut sum: $ty = $initial;
29-
for v in $values.iter() {
30-
let v: $ty = (*v).as_();
31-
sum = num_traits::CheckedAdd::checked_add(&sum, &v)
32-
.ok_or_else(|| vortex_err!("Overflow when summing decimal {sum:?} + {v:?}"))?
33-
}
34-
sum
35-
}};
36-
($ty:ty, $values:expr, $validity:expr, $initial:expr) => {{
37-
use itertools::Itertools;
38-
39-
let mut sum: $ty = $initial;
40-
for (v, valid) in $values.iter().zip_eq($validity) {
41-
if valid {
42-
let v: $ty = (*v).as_();
43-
sum = num_traits::CheckedAdd::checked_add(&sum, &v)
44-
.ok_or_else(|| vortex_err!("Overflow when summing decimal {sum:?} + {v:?}"))?
45-
}
46-
}
47-
sum
48-
}};
49-
}
50-
5129
impl SumKernel for DecimalVTable {
5230
#[expect(
5331
clippy::cognitive_complexity,
@@ -59,7 +37,7 @@ impl SumKernel for DecimalVTable {
5937
// Both Spark and DataFusion use this heuristic.
6038
// - https://github.com/apache/spark/blob/fcf636d9eb8d645c24be3db2d599aba2d7e2955a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Sum.scala#L66
6139
// - https://github.com/apache/datafusion/blob/4153adf2c0f6e317ef476febfdc834208bd46622/datafusion/functions-aggregate/src/sum.rs#L188
62-
let new_precision = u8::min(DECIMAL256_MAX_PRECISION, decimal_dtype.precision() + 10);
40+
let new_precision = u8::min(MAX_PRECISION, decimal_dtype.precision() + 10);
6341
let new_scale = decimal_dtype.scale();
6442
let return_dtype = DecimalDType::new(new_precision, new_scale);
6543

@@ -80,11 +58,15 @@ impl SumKernel for DecimalVTable {
8058
let initial_val: O = initial_decimal
8159
.cast()
8260
.vortex_expect("cannot fail to cast initial value");
83-
Ok(Scalar::decimal(
84-
DecimalValue::from(sum_decimal!(O, array.buffer::<I>(), initial_val)),
85-
return_dtype,
86-
Nullable,
87-
))
61+
if let Some(sum) = sum_decimal(array.buffer::<I>(), initial_val) {
62+
Ok(Scalar::decimal(
63+
DecimalValue::from(sum),
64+
return_dtype,
65+
Nullable,
66+
))
67+
} else {
68+
Ok(Scalar::null(DType::Decimal(return_dtype, Nullable)))
69+
}
8870
})
8971
})
9072
}
@@ -95,23 +77,54 @@ impl SumKernel for DecimalVTable {
9577
let initial_val: O = initial_decimal
9678
.cast()
9779
.vortex_expect("cannot fail to cast initial value");
98-
Ok(Scalar::decimal(
99-
DecimalValue::from(sum_decimal!(
100-
O,
101-
array.buffer::<I>(),
102-
mask_values.bit_buffer(),
103-
initial_val
104-
)),
105-
return_dtype,
106-
Nullable,
107-
))
80+
81+
if let Some(sum) = sum_decimal_with_validity(
82+
array.buffer::<I>(),
83+
mask_values.bit_buffer(),
84+
initial_val,
85+
) {
86+
Ok(Scalar::decimal(
87+
DecimalValue::from(sum),
88+
return_dtype,
89+
Nullable,
90+
))
91+
} else {
92+
Ok(Scalar::null(DType::Decimal(return_dtype, Nullable)))
93+
}
10894
})
10995
})
11096
}
11197
}
11298
}
11399
}
114100

101+
fn sum_decimal<T: AsPrimitive<I>, I: Copy + CheckedAdd + 'static>(
102+
values: Buffer<T>,
103+
initial: I,
104+
) -> Option<I> {
105+
let mut sum = initial;
106+
for v in values.iter() {
107+
let v: I = v.as_();
108+
sum = CheckedAdd::checked_add(&sum, &v)?;
109+
}
110+
Some(sum)
111+
}
112+
113+
fn sum_decimal_with_validity<T: AsPrimitive<I>, I: Copy + CheckedAdd + 'static>(
114+
values: Buffer<T>,
115+
validity: &BitBuffer,
116+
initial: I,
117+
) -> Option<I> {
118+
let mut sum = initial;
119+
for (v, valid) in values.iter().zip_eq(validity) {
120+
if valid {
121+
let v: I = v.as_();
122+
sum = CheckedAdd::checked_add(&sum, &v)?;
123+
}
124+
}
125+
Some(sum)
126+
}
127+
115128
register_kernel!(SumKernelAdapter(DecimalVTable).lift());
116129

117130
#[cfg(test)]
@@ -120,9 +133,11 @@ mod tests {
120133
use vortex_dtype::DType;
121134
use vortex_dtype::DecimalDType;
122135
use vortex_dtype::Nullability;
136+
use vortex_error::VortexUnwrap;
123137
use vortex_scalar::DecimalValue;
124138
use vortex_scalar::Scalar;
125139
use vortex_scalar::ScalarValue;
140+
use vortex_scalar::i256;
126141

127142
use crate::arrays::DecimalArray;
128143
use crate::compute::sum;
@@ -327,8 +342,6 @@ mod tests {
327342

328343
#[test]
329344
fn test_sum_i128_to_i256_boundary() {
330-
use vortex_scalar::i256;
331-
332345
// Test the boundary between i128 and i256 accumulation
333346
let large_i128 = i128::MAX / 10;
334347
let decimal = DecimalArray::new(
@@ -351,4 +364,19 @@ mod tests {
351364

352365
assert_eq!(result, expected);
353366
}
367+
368+
#[test]
369+
fn test_i256_overflow() {
370+
let decimal_dtype = DecimalDType::new(76, 0);
371+
let decimal = DecimalArray::new(
372+
buffer![i256::MAX, i256::MAX, i256::MAX],
373+
decimal_dtype,
374+
Validity::AllValid,
375+
);
376+
377+
assert_eq!(
378+
sum(decimal.as_ref()).vortex_unwrap(),
379+
Scalar::null(DType::Decimal(decimal_dtype, Nullability::Nullable))
380+
);
381+
}
354382
}

vortex-dtype/src/decimal/mod.rs

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -23,8 +23,10 @@ use vortex_error::vortex_panic;
2323
use crate::DType;
2424
use crate::i256;
2525

26-
const MAX_PRECISION: u8 = <i256 as NativeDecimalType>::MAX_PRECISION;
27-
const MAX_SCALE: i8 = <i256 as NativeDecimalType>::MAX_SCALE;
26+
/// The maximum precision allowed for a decimal type.
27+
pub const MAX_PRECISION: u8 = <i256 as NativeDecimalType>::MAX_PRECISION;
28+
/// The maximum scale allowed for a decimal type.
29+
pub const MAX_SCALE: i8 = <i256 as NativeDecimalType>::MAX_SCALE;
2830

2931
/// Parameters that define the precision and scale of a decimal type.
3032
///

0 commit comments

Comments
 (0)