11// SPDX-License-Identifier: Apache-2.0
22// SPDX-FileCopyrightText: Copyright the Vortex contributors
33
4- use arrow_schema :: DECIMAL256_MAX_PRECISION ;
4+ use itertools :: Itertools ;
55use num_traits:: AsPrimitive ;
6- use vortex_dtype:: DecimalDType ;
6+ use num_traits:: CheckedAdd ;
7+ use vortex_buffer:: BitBuffer ;
8+ use vortex_buffer:: Buffer ;
79use vortex_dtype:: DecimalType ;
810use vortex_dtype:: Nullability :: Nullable ;
911use vortex_dtype:: match_each_decimal_value_type;
1012use vortex_error:: VortexExpect ;
1113use vortex_error:: VortexResult ;
1214use vortex_error:: vortex_bail;
13- use vortex_error:: vortex_err;
1415use vortex_mask:: Mask ;
1516use vortex_scalar:: DecimalScalar ;
1617use vortex_scalar:: DecimalValue ;
@@ -21,47 +22,20 @@ use crate::arrays::DecimalVTable;
2122use crate :: compute:: SumKernel ;
2223use crate :: compute:: SumKernelAdapter ;
2324use crate :: register_kernel;
24-
25- // Its safe to use `AsPrimitive` here because we always cast up.
26- macro_rules! sum_decimal {
27- ( $ty: ty, $values: expr, $initial: expr) => { {
28- let mut sum: $ty = $initial;
29- for v in $values. iter( ) {
30- let v: $ty = ( * v) . as_( ) ;
31- sum = num_traits:: CheckedAdd :: checked_add( & sum, & v)
32- . ok_or_else( || vortex_err!( "Overflow when summing decimal {sum:?} + {v:?}" ) ) ?
33- }
34- sum
35- } } ;
36- ( $ty: ty, $values: expr, $validity: expr, $initial: expr) => { {
37- use itertools:: Itertools ;
38-
39- let mut sum: $ty = $initial;
40- for ( v, valid) in $values. iter( ) . zip_eq( $validity) {
41- if valid {
42- let v: $ty = ( * v) . as_( ) ;
43- sum = num_traits:: CheckedAdd :: checked_add( & sum, & v)
44- . ok_or_else( || vortex_err!( "Overflow when summing decimal {sum:?} + {v:?}" ) ) ?
45- }
46- }
47- sum
48- } } ;
49- }
25+ use crate :: stats:: Stat ;
5026
5127impl SumKernel for DecimalVTable {
5228 #[ expect(
5329 clippy:: cognitive_complexity,
5430 reason = "complexity from nested match_each_* macros"
5531 ) ]
5632 fn sum ( & self , array : & DecimalArray , accumulator : & Scalar ) -> VortexResult < Scalar > {
57- let decimal_dtype = array. decimal_dtype ( ) ;
58-
59- // Both Spark and DataFusion use this heuristic.
60- // - https://github.com/apache/spark/blob/fcf636d9eb8d645c24be3db2d599aba2d7e2955a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Sum.scala#L66
61- // - https://github.com/apache/datafusion/blob/4153adf2c0f6e317ef476febfdc834208bd46622/datafusion/functions-aggregate/src/sum.rs#L188
62- let new_precision = u8:: min ( DECIMAL256_MAX_PRECISION , decimal_dtype. precision ( ) + 10 ) ;
63- let new_scale = decimal_dtype. scale ( ) ;
64- let return_dtype = DecimalDType :: new ( new_precision, new_scale) ;
33+ let return_dtype = Stat :: Sum
34+ . dtype ( array. dtype ( ) )
35+ . vortex_expect ( "sum for decimals exists" ) ;
36+ let return_decimal_dtype = return_dtype
37+ . as_decimal_opt ( )
38+ . vortex_expect ( "must be decimal" ) ;
6539
6640 // Extract the initial value as a DecimalValue
6741 let initial_decimal = DecimalScalar :: try_from ( accumulator)
@@ -74,44 +48,79 @@ impl SumKernel for DecimalVTable {
7448 vortex_bail ! ( "invalid state, all-null array should be checked by top-level sum fn" )
7549 }
7650 Mask :: AllTrue ( _) => {
77- let values_type = DecimalType :: smallest_decimal_value_type ( & return_dtype ) ;
51+ let values_type = DecimalType :: smallest_decimal_value_type ( return_decimal_dtype ) ;
7852 match_each_decimal_value_type ! ( array. values_type( ) , |I | {
7953 match_each_decimal_value_type!( values_type, |O | {
8054 let initial_val: O = initial_decimal
8155 . cast( )
8256 . vortex_expect( "cannot fail to cast initial value" ) ;
83- Ok ( Scalar :: decimal(
84- DecimalValue :: from( sum_decimal!( O , array. buffer:: <I >( ) , initial_val) ) ,
85- return_dtype,
86- Nullable ,
87- ) )
57+ if let Some ( sum) = sum_decimal( array. buffer:: <I >( ) , initial_val) {
58+ Ok ( Scalar :: decimal(
59+ DecimalValue :: from( sum) ,
60+ * return_decimal_dtype,
61+ Nullable ,
62+ ) )
63+ } else {
64+ Ok ( Scalar :: null( return_dtype) )
65+ }
8866 } )
8967 } )
9068 }
9169 Mask :: Values ( mask_values) => {
92- let values_type = DecimalType :: smallest_decimal_value_type ( & return_dtype ) ;
70+ let values_type = DecimalType :: smallest_decimal_value_type ( return_decimal_dtype ) ;
9371 match_each_decimal_value_type ! ( array. values_type( ) , |I | {
9472 match_each_decimal_value_type!( values_type, |O | {
9573 let initial_val: O = initial_decimal
9674 . cast( )
9775 . vortex_expect( "cannot fail to cast initial value" ) ;
98- Ok ( Scalar :: decimal(
99- DecimalValue :: from( sum_decimal!(
100- O ,
101- array. buffer:: <I >( ) ,
102- mask_values. bit_buffer( ) ,
103- initial_val
104- ) ) ,
105- return_dtype,
106- Nullable ,
107- ) )
76+
77+ if let Some ( sum) = sum_decimal_with_validity(
78+ array. buffer:: <I >( ) ,
79+ mask_values. bit_buffer( ) ,
80+ initial_val,
81+ ) {
82+ Ok ( Scalar :: decimal(
83+ DecimalValue :: from( sum) ,
84+ * return_decimal_dtype,
85+ Nullable ,
86+ ) )
87+ } else {
88+ Ok ( Scalar :: null( return_dtype) )
89+ }
10890 } )
10991 } )
11092 }
11193 }
11294 }
11395}
11496
97+ fn sum_decimal < T : AsPrimitive < I > , I : Copy + CheckedAdd + ' static > (
98+ values : Buffer < T > ,
99+ initial : I ,
100+ ) -> Option < I > {
101+ let mut sum = initial;
102+ for v in values. iter ( ) {
103+ let v: I = v. as_ ( ) ;
104+ sum = CheckedAdd :: checked_add ( & sum, & v) ?;
105+ }
106+ Some ( sum)
107+ }
108+
109+ fn sum_decimal_with_validity < T : AsPrimitive < I > , I : Copy + CheckedAdd + ' static > (
110+ values : Buffer < T > ,
111+ validity : & BitBuffer ,
112+ initial : I ,
113+ ) -> Option < I > {
114+ let mut sum = initial;
115+ for ( v, valid) in values. iter ( ) . zip_eq ( validity) {
116+ if valid {
117+ let v: I = v. as_ ( ) ;
118+ sum = CheckedAdd :: checked_add ( & sum, & v) ?;
119+ }
120+ }
121+ Some ( sum)
122+ }
123+
115124register_kernel ! ( SumKernelAdapter ( DecimalVTable ) . lift( ) ) ;
116125
117126#[ cfg( test) ]
@@ -120,9 +129,11 @@ mod tests {
120129 use vortex_dtype:: DType ;
121130 use vortex_dtype:: DecimalDType ;
122131 use vortex_dtype:: Nullability ;
132+ use vortex_error:: VortexUnwrap ;
123133 use vortex_scalar:: DecimalValue ;
124134 use vortex_scalar:: Scalar ;
125135 use vortex_scalar:: ScalarValue ;
136+ use vortex_scalar:: i256;
126137
127138 use crate :: arrays:: DecimalArray ;
128139 use crate :: compute:: sum;
@@ -327,8 +338,6 @@ mod tests {
327338
328339 #[ test]
329340 fn test_sum_i128_to_i256_boundary ( ) {
330- use vortex_scalar:: i256;
331-
332341 // Test the boundary between i128 and i256 accumulation
333342 let large_i128 = i128:: MAX / 10 ;
334343 let decimal = DecimalArray :: new (
@@ -351,4 +360,19 @@ mod tests {
351360
352361 assert_eq ! ( result, expected) ;
353362 }
363+
364+ #[ test]
365+ fn test_i256_overflow ( ) {
366+ let decimal_dtype = DecimalDType :: new ( 76 , 0 ) ;
367+ let decimal = DecimalArray :: new (
368+ buffer ! [ i256:: MAX , i256:: MAX , i256:: MAX ] ,
369+ decimal_dtype,
370+ Validity :: AllValid ,
371+ ) ;
372+
373+ assert_eq ! (
374+ sum( decimal. as_ref( ) ) . vortex_unwrap( ) ,
375+ Scalar :: null( DType :: Decimal ( decimal_dtype, Nullability :: Nullable ) )
376+ ) ;
377+ }
354378}
0 commit comments