22// SPDX-FileCopyrightText: Copyright the Vortex contributors
33
44use num_traits:: PrimInt ;
5- use vortex_dtype:: { NativePType , PType , match_each_native_ptype} ;
6- use vortex_error:: { VortexExpect , VortexResult , vortex_err} ;
7- use vortex_scalar:: { FromPrimitiveOrF16 , Scalar } ;
5+ use vortex_dtype:: Nullability :: Nullable ;
6+ use vortex_dtype:: { DType , DecimalDType , NativePType , match_each_native_ptype} ;
7+ use vortex_error:: { VortexResult , vortex_bail, vortex_err} ;
8+ use vortex_scalar:: { DecimalScalar , DecimalValue , FromPrimitiveOrF16 , Scalar , i256} ;
89
910use crate :: arrays:: { ChunkedArray , ChunkedVTable } ;
1011use crate :: compute:: { SumKernel , SumKernelAdapter , sum} ;
@@ -16,16 +17,23 @@ impl SumKernel for ChunkedVTable {
1617 let sum_dtype = Stat :: Sum
1718 . dtype ( array. dtype ( ) )
1819 . ok_or_else ( || vortex_err ! ( "Sum not supported for dtype {}" , array. dtype( ) ) ) ?;
19- let sum_ptype = PType :: try_from ( & sum_dtype) . vortex_expect ( "sum dtype must be primitive" ) ;
2020
21- let scalar_value = match_each_native_ptype ! (
22- sum_ptype,
23- unsigned: |T | { sum_int:: <u64 >( array. chunks( ) ) ?. into( ) } ,
24- signed: |T | { sum_int:: <i64 >( array. chunks( ) ) ?. into( ) } ,
25- floating: |T | { sum_float( array. chunks( ) ) ?. into( ) }
26- ) ;
21+ match sum_dtype {
22+ DType :: Decimal ( decimal_dtype, _) => sum_decimal ( array. chunks ( ) , decimal_dtype) ,
23+ DType :: Primitive ( sum_ptype, _) => {
24+ let scalar_value = match_each_native_ptype ! (
25+ sum_ptype,
26+ unsigned: |T | { sum_int:: <u64 >( array. chunks( ) ) ?. into( ) } ,
27+ signed: |T | { sum_int:: <i64 >( array. chunks( ) ) ?. into( ) } ,
28+ floating: |T | { sum_float( array. chunks( ) ) ?. into( ) }
29+ ) ;
2730
28- Ok ( Scalar :: new ( sum_dtype, scalar_value) )
31+ Ok ( Scalar :: new ( sum_dtype, scalar_value) )
32+ }
33+ _ => {
34+ vortex_bail ! ( "Sum not supported for dtype {}" , sum_dtype) ;
35+ }
36+ }
2937 }
3038}
3139
@@ -39,7 +47,7 @@ fn sum_int<T: NativePType + PrimInt + FromPrimitiveOrF16>(
3947 let chunk_sum = sum ( chunk) ?;
4048
4149 let Some ( chunk_sum) = chunk_sum. as_primitive ( ) . as_ :: < T > ( ) else {
42- // Bail out on overflow
50+ // Bail out missing statistic
4351 return Ok ( None ) ;
4452 } ;
4553
@@ -63,14 +71,46 @@ fn sum_float(chunks: &[ArrayRef]) -> VortexResult<f64> {
6371 Ok ( result)
6472}
6573
74+ fn sum_decimal ( chunks : & [ ArrayRef ] , result_decimal_type : DecimalDType ) -> VortexResult < Scalar > {
75+ let mut result = DecimalValue :: I256 ( i256:: ZERO ) ;
76+
77+ let null = || Scalar :: null ( DType :: Decimal ( result_decimal_type, Nullable ) ) ;
78+
79+ for chunk in chunks {
80+ let chunk_sum = sum ( chunk) ?;
81+
82+ let chunk_decimal = DecimalScalar :: try_from ( & chunk_sum) ?;
83+ let Some ( chunk_value) = chunk_decimal. decimal_value ( ) else {
84+ // skips all null chunks
85+ continue ;
86+ } ;
87+
88+ // Perform checked addition with current result
89+ let Some ( r) = result. checked_add ( & chunk_value) . filter ( |sum_value| {
90+ sum_value
91+ . fits_in_precision ( result_decimal_type)
92+ . unwrap_or ( false )
93+ } ) else {
94+ // Overflow
95+ return Ok ( null ( ) ) ;
96+ } ;
97+
98+ result = r;
99+ }
100+
101+ Ok ( Scalar :: decimal ( result, result_decimal_type, Nullable ) )
102+ }
103+
66104#[ cfg( test) ]
67105mod tests {
68- use vortex_dtype:: Nullability ;
69- use vortex_scalar:: Scalar ;
106+ use vortex_buffer:: buffer;
107+ use vortex_dtype:: { DType , DecimalDType , Nullability } ;
108+ use vortex_scalar:: { DecimalValue , Scalar , i256} ;
70109
71110 use crate :: array:: IntoArray ;
72- use crate :: arrays:: { ChunkedArray , ConstantArray , PrimitiveArray } ;
111+ use crate :: arrays:: { ChunkedArray , ConstantArray , DecimalArray , PrimitiveArray } ;
73112 use crate :: compute:: sum;
113+ use crate :: validity:: Validity ;
74114
75115 #[ test]
76116 fn test_sum_chunked_floats_with_nulls ( ) {
@@ -138,4 +178,117 @@ mod tests {
138178 let result = sum ( chunked. as_ref ( ) ) . unwrap ( ) ;
139179 assert_eq ! ( result. as_primitive( ) . as_:: <f64 >( ) , Some ( 36.0 ) ) ;
140180 }
181+
182+ #[ test]
183+ fn test_sum_chunked_decimals ( ) {
184+ // Create decimal chunks with precision=10, scale=2
185+ let decimal_dtype = DecimalDType :: new ( 10 , 2 ) ;
186+ let chunk1 = DecimalArray :: new (
187+ buffer ! [ 100i32 , 100i32 , 100i32 , 100i32 , 100i32 ] ,
188+ decimal_dtype,
189+ Validity :: AllValid ,
190+ ) ;
191+ let chunk2 = DecimalArray :: new (
192+ buffer ! [ 200i32 , 200i32 , 200i32 ] ,
193+ decimal_dtype,
194+ Validity :: AllValid ,
195+ ) ;
196+ let chunk3 = DecimalArray :: new ( buffer ! [ 300i32 , 300i32 ] , decimal_dtype, Validity :: AllValid ) ;
197+
198+ let dtype = chunk1. dtype ( ) . clone ( ) ;
199+ let chunked = ChunkedArray :: try_new (
200+ vec ! [
201+ chunk1. into_array( ) ,
202+ chunk2. into_array( ) ,
203+ chunk3. into_array( ) ,
204+ ] ,
205+ dtype,
206+ )
207+ . unwrap ( ) ;
208+
209+ // Compute sum: 5*100 + 3*200 + 2*300 = 500 + 600 + 600 = 1700 (represents 17.00)
210+ let result = sum ( chunked. as_ref ( ) ) . unwrap ( ) ;
211+ let decimal_result = result. as_decimal ( ) ;
212+ assert_eq ! (
213+ decimal_result. decimal_value( ) ,
214+ Some ( DecimalValue :: I256 ( i256:: from_i128( 1700 ) ) )
215+ ) ;
216+ }
217+
218+ #[ test]
219+ fn test_sum_chunked_decimals_with_nulls ( ) {
220+ let decimal_dtype = DecimalDType :: new ( 10 , 2 ) ;
221+
222+ // Create chunks with some nulls - all must have same nullability
223+ let chunk1 = DecimalArray :: new (
224+ buffer ! [ 100i32 , 100i32 , 100i32 ] ,
225+ decimal_dtype,
226+ Validity :: AllValid ,
227+ ) ;
228+ let chunk2 = DecimalArray :: new (
229+ buffer ! [ 0i32 , 0i32 ] ,
230+ decimal_dtype,
231+ Validity :: from_iter ( [ false , false ] ) ,
232+ ) ;
233+ let chunk3 = DecimalArray :: new ( buffer ! [ 200i32 , 200i32 ] , decimal_dtype, Validity :: AllValid ) ;
234+
235+ let dtype = chunk1. dtype ( ) . clone ( ) ;
236+ let chunked = ChunkedArray :: try_new (
237+ vec ! [
238+ chunk1. into_array( ) ,
239+ chunk2. into_array( ) ,
240+ chunk3. into_array( ) ,
241+ ] ,
242+ dtype,
243+ )
244+ . unwrap ( ) ;
245+
246+ // Compute sum: 3*100 + 2*200 = 300 + 400 = 700 (nulls ignored)
247+ let result = sum ( chunked. as_ref ( ) ) . unwrap ( ) ;
248+ let decimal_result = result. as_decimal ( ) ;
249+ assert_eq ! (
250+ decimal_result. decimal_value( ) ,
251+ Some ( DecimalValue :: I256 ( i256:: from_i128( 700 ) ) )
252+ ) ;
253+ }
254+
255+ #[ test]
256+ fn test_sum_chunked_decimals_large ( ) {
257+ // Create decimals with precision 3 (max value 999)
258+ // Sum will be 500 + 600 = 1100, which fits in result precision 13 (3+10)
259+ let decimal_dtype = DecimalDType :: new ( 3 , 0 ) ;
260+ let chunk1 = ConstantArray :: new (
261+ Scalar :: decimal (
262+ DecimalValue :: I16 ( 500 ) ,
263+ decimal_dtype,
264+ Nullability :: NonNullable ,
265+ ) ,
266+ 1 ,
267+ ) ;
268+ let chunk2 = ConstantArray :: new (
269+ Scalar :: decimal (
270+ DecimalValue :: I16 ( 600 ) ,
271+ decimal_dtype,
272+ Nullability :: NonNullable ,
273+ ) ,
274+ 1 ,
275+ ) ;
276+
277+ let dtype = chunk1. dtype ( ) . clone ( ) ;
278+ let chunked =
279+ ChunkedArray :: try_new ( vec ! [ chunk1. into_array( ) , chunk2. into_array( ) ] , dtype) . unwrap ( ) ;
280+
281+ // Compute sum: 500 + 600 = 1100
282+ // Result should have precision 13 (3+10), scale 0
283+ let result = sum ( chunked. as_ref ( ) ) . unwrap ( ) ;
284+ let decimal_result = result. as_decimal ( ) ;
285+ assert_eq ! (
286+ decimal_result. decimal_value( ) ,
287+ Some ( DecimalValue :: I256 ( i256:: from_i128( 1100 ) ) )
288+ ) ;
289+ assert_eq ! (
290+ result. dtype( ) ,
291+ & DType :: Decimal ( DecimalDType :: new( 13 , 0 ) , Nullability :: Nullable )
292+ ) ;
293+ }
141294}
0 commit comments