11// SPDX-License-Identifier: Apache-2.0
22// SPDX-FileCopyrightText: Copyright the Vortex contributors
33
4- use arrow_schema :: DECIMAL256_MAX_PRECISION ;
4+ use itertools :: Itertools ;
55use num_traits:: AsPrimitive ;
6+ use num_traits:: CheckedAdd ;
7+ use vortex_buffer:: BitBuffer ;
8+ use vortex_buffer:: Buffer ;
9+ use vortex_dtype:: DType ;
610use vortex_dtype:: DecimalDType ;
711use vortex_dtype:: DecimalType ;
12+ use vortex_dtype:: MAX_PRECISION ;
813use vortex_dtype:: Nullability :: Nullable ;
914use vortex_dtype:: match_each_decimal_value_type;
1015use vortex_error:: VortexExpect ;
1116use vortex_error:: VortexResult ;
1217use vortex_error:: vortex_bail;
13- use vortex_error:: vortex_err;
1418use vortex_mask:: Mask ;
1519use vortex_scalar:: DecimalScalar ;
1620use vortex_scalar:: DecimalValue ;
@@ -22,32 +26,6 @@ use crate::compute::SumKernel;
2226use crate :: compute:: SumKernelAdapter ;
2327use crate :: register_kernel;
2428
25- // Its safe to use `AsPrimitive` here because we always cast up.
26- macro_rules! sum_decimal {
27- ( $ty: ty, $values: expr, $initial: expr) => { {
28- let mut sum: $ty = $initial;
29- for v in $values. iter( ) {
30- let v: $ty = ( * v) . as_( ) ;
31- sum = num_traits:: CheckedAdd :: checked_add( & sum, & v)
32- . ok_or_else( || vortex_err!( "Overflow when summing decimal {sum:?} + {v:?}" ) ) ?
33- }
34- sum
35- } } ;
36- ( $ty: ty, $values: expr, $validity: expr, $initial: expr) => { {
37- use itertools:: Itertools ;
38-
39- let mut sum: $ty = $initial;
40- for ( v, valid) in $values. iter( ) . zip_eq( $validity) {
41- if valid {
42- let v: $ty = ( * v) . as_( ) ;
43- sum = num_traits:: CheckedAdd :: checked_add( & sum, & v)
44- . ok_or_else( || vortex_err!( "Overflow when summing decimal {sum:?} + {v:?}" ) ) ?
45- }
46- }
47- sum
48- } } ;
49- }
50-
5129impl SumKernel for DecimalVTable {
5230 #[ expect(
5331 clippy:: cognitive_complexity,
@@ -59,7 +37,7 @@ impl SumKernel for DecimalVTable {
5937 // Both Spark and DataFusion use this heuristic.
6038 // - https://github.com/apache/spark/blob/fcf636d9eb8d645c24be3db2d599aba2d7e2955a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Sum.scala#L66
6139 // - https://github.com/apache/datafusion/blob/4153adf2c0f6e317ef476febfdc834208bd46622/datafusion/functions-aggregate/src/sum.rs#L188
62- let new_precision = u8:: min ( DECIMAL256_MAX_PRECISION , decimal_dtype. precision ( ) + 10 ) ;
40+ let new_precision = u8:: min ( MAX_PRECISION , decimal_dtype. precision ( ) + 10 ) ;
6341 let new_scale = decimal_dtype. scale ( ) ;
6442 let return_dtype = DecimalDType :: new ( new_precision, new_scale) ;
6543
@@ -80,11 +58,15 @@ impl SumKernel for DecimalVTable {
8058 let initial_val: O = initial_decimal
8159 . cast( )
8260 . vortex_expect( "cannot fail to cast initial value" ) ;
83- Ok ( Scalar :: decimal(
84- DecimalValue :: from( sum_decimal!( O , array. buffer:: <I >( ) , initial_val) ) ,
85- return_dtype,
86- Nullable ,
87- ) )
61+ if let Some ( sum) = sum_decimal( array. buffer:: <I >( ) , initial_val) {
62+ Ok ( Scalar :: decimal(
63+ DecimalValue :: from( sum) ,
64+ return_dtype,
65+ Nullable ,
66+ ) )
67+ } else {
68+ Ok ( Scalar :: null( DType :: Decimal ( return_dtype, Nullable ) ) )
69+ }
8870 } )
8971 } )
9072 }
@@ -95,23 +77,54 @@ impl SumKernel for DecimalVTable {
9577 let initial_val: O = initial_decimal
9678 . cast( )
9779 . vortex_expect( "cannot fail to cast initial value" ) ;
98- Ok ( Scalar :: decimal(
99- DecimalValue :: from( sum_decimal!(
100- O ,
101- array. buffer:: <I >( ) ,
102- mask_values. bit_buffer( ) ,
103- initial_val
104- ) ) ,
105- return_dtype,
106- Nullable ,
107- ) )
80+
81+ if let Some ( sum) = sum_decimal_with_validity(
82+ array. buffer:: <I >( ) ,
83+ mask_values. bit_buffer( ) ,
84+ initial_val,
85+ ) {
86+ Ok ( Scalar :: decimal(
87+ DecimalValue :: from( sum) ,
88+ return_dtype,
89+ Nullable ,
90+ ) )
91+ } else {
92+ Ok ( Scalar :: null( DType :: Decimal ( return_dtype, Nullable ) ) )
93+ }
10894 } )
10995 } )
11096 }
11197 }
11298 }
11399}
114100
101+ fn sum_decimal < T : AsPrimitive < I > , I : Copy + CheckedAdd + ' static > (
102+ values : Buffer < T > ,
103+ initial : I ,
104+ ) -> Option < I > {
105+ let mut sum = initial;
106+ for v in values. iter ( ) {
107+ let v: I = v. as_ ( ) ;
108+ sum = CheckedAdd :: checked_add ( & sum, & v) ?;
109+ }
110+ Some ( sum)
111+ }
112+
113+ fn sum_decimal_with_validity < T : AsPrimitive < I > , I : Copy + CheckedAdd + ' static > (
114+ values : Buffer < T > ,
115+ validity : & BitBuffer ,
116+ initial : I ,
117+ ) -> Option < I > {
118+ let mut sum = initial;
119+ for ( v, valid) in values. iter ( ) . zip_eq ( validity) {
120+ if valid {
121+ let v: I = v. as_ ( ) ;
122+ sum = CheckedAdd :: checked_add ( & sum, & v) ?;
123+ }
124+ }
125+ Some ( sum)
126+ }
127+
115128register_kernel ! ( SumKernelAdapter ( DecimalVTable ) . lift( ) ) ;
116129
117130#[ cfg( test) ]
@@ -120,9 +133,11 @@ mod tests {
120133 use vortex_dtype:: DType ;
121134 use vortex_dtype:: DecimalDType ;
122135 use vortex_dtype:: Nullability ;
136+ use vortex_error:: VortexUnwrap ;
123137 use vortex_scalar:: DecimalValue ;
124138 use vortex_scalar:: Scalar ;
125139 use vortex_scalar:: ScalarValue ;
140+ use vortex_scalar:: i256;
126141
127142 use crate :: arrays:: DecimalArray ;
128143 use crate :: compute:: sum;
@@ -327,8 +342,6 @@ mod tests {
327342
328343 #[ test]
329344 fn test_sum_i128_to_i256_boundary ( ) {
330- use vortex_scalar:: i256;
331-
332345 // Test the boundary between i128 and i256 accumulation
333346 let large_i128 = i128:: MAX / 10 ;
334347 let decimal = DecimalArray :: new (
@@ -351,4 +364,19 @@ mod tests {
351364
352365 assert_eq ! ( result, expected) ;
353366 }
367+
368+ #[ test]
369+ fn test_i256_overflow ( ) {
370+ let decimal_dtype = DecimalDType :: new ( 76 , 0 ) ;
371+ let decimal = DecimalArray :: new (
372+ buffer ! [ i256:: MAX , i256:: MAX , i256:: MAX ] ,
373+ decimal_dtype,
374+ Validity :: AllValid ,
375+ ) ;
376+
377+ assert_eq ! (
378+ sum( decimal. as_ref( ) ) . vortex_unwrap( ) ,
379+ Scalar :: null( DType :: Decimal ( decimal_dtype, Nullability :: Nullable ) )
380+ ) ;
381+ }
354382}
0 commit comments