11// SPDX-License-Identifier: Apache-2.0
22// SPDX-FileCopyrightText: Copyright the Vortex contributors
33
4- use num_traits:: { CheckedMul , ToPrimitive } ;
4+ use arrow_array:: ArrowNativeTypeOp ;
5+ use num_traits:: { CheckedAdd , CheckedMul , ToPrimitive } ;
56use vortex_dtype:: { DType , DecimalDType , NativePType , Nullability , i256, match_each_native_ptype} ;
67use vortex_error:: { VortexExpect , VortexResult , vortex_bail, vortex_err} ;
78use vortex_scalar:: { DecimalScalar , DecimalValue , PrimitiveScalar , Scalar , ScalarValue } ;
@@ -12,32 +13,44 @@ use crate::register_kernel;
1213use crate :: stats:: Stat ;
1314
1415impl SumKernel for ConstantVTable {
15- fn sum ( & self , array : & ConstantArray ) -> VortexResult < Scalar > {
16+ fn sum ( & self , array : & ConstantArray , accumulator : & Scalar ) -> VortexResult < Scalar > {
1617 // Compute the expected dtype of the sum.
1718 let sum_dtype = Stat :: Sum
1819 . dtype ( array. dtype ( ) )
1920 . ok_or_else ( || vortex_err ! ( "Sum not supported for dtype {}" , array. dtype( ) ) ) ?;
2021
21- let sum_value = sum_scalar ( array. scalar ( ) , array. len ( ) ) ?;
22+ let sum_value = sum_scalar ( array. scalar ( ) , array. len ( ) , accumulator ) ?;
2223 Ok ( Scalar :: new ( sum_dtype, sum_value) )
2324 }
2425}
2526
26- fn sum_scalar ( scalar : & Scalar , len : usize ) -> VortexResult < ScalarValue > {
27+ fn sum_scalar ( scalar : & Scalar , len : usize , accumulator : & Scalar ) -> VortexResult < ScalarValue > {
2728 match scalar. dtype ( ) {
28- DType :: Bool ( _) => Ok ( ScalarValue :: from ( match scalar. as_bool ( ) . value ( ) {
29- None => unreachable ! ( "Handled before reaching this point" ) ,
30- Some ( false ) => 0u64 ,
31- Some ( true ) => len as u64 ,
32- } ) ) ,
33- DType :: Primitive ( ptype, _) => Ok ( match_each_native_ptype ! (
34- ptype,
35- unsigned: |T | { sum_integral:: <u64 >( scalar. as_primitive( ) , len) ?. into( ) } ,
36- signed: |T | { sum_integral:: <i64 >( scalar. as_primitive( ) , len) ?. into( ) } ,
37- floating: |T | { sum_float( scalar. as_primitive( ) , len) ?. into( ) }
38- ) ) ,
39- DType :: Decimal ( decimal_dtype, _) => sum_decimal ( scalar. as_decimal ( ) , len, * decimal_dtype) ,
40- DType :: Extension ( _) => sum_scalar ( & scalar. as_extension ( ) . storage ( ) , len) ,
29+ DType :: Bool ( _) => {
30+ let count = match scalar. as_bool ( ) . value ( ) {
31+ None => unreachable ! ( "Handled before reaching this point" ) ,
32+ Some ( false ) => 0u64 ,
33+ Some ( true ) => len as u64 ,
34+ } ;
35+ let accumulator = accumulator
36+ . as_primitive ( )
37+ . as_ :: < u64 > ( )
38+ . vortex_expect ( "cannot be null" ) ;
39+ Ok ( ScalarValue :: from ( accumulator. checked_add ( count) ) )
40+ }
41+ DType :: Primitive ( ptype, _) => {
42+ let result = match_each_native_ptype ! (
43+ ptype,
44+ unsigned: |T | { sum_integral:: <u64 >( scalar. as_primitive( ) , len, accumulator) ?. into( ) } ,
45+ signed: |T | { sum_integral:: <i64 >( scalar. as_primitive( ) , len, accumulator) ?. into( ) } ,
46+ floating: |T | { sum_float( scalar. as_primitive( ) , len, accumulator) ?. into( ) }
47+ ) ;
48+ Ok ( result)
49+ }
50+ DType :: Decimal ( decimal_dtype, _) => {
51+ sum_decimal ( scalar. as_decimal ( ) , len, * decimal_dtype, accumulator)
52+ }
53+ DType :: Extension ( _) => sum_scalar ( & scalar. as_extension ( ) . storage ( ) , len, accumulator) ,
4154 dtype => vortex_bail ! ( "Unsupported dtype for sum: {}" , dtype) ,
4255 }
4356}
@@ -46,6 +59,7 @@ fn sum_decimal(
4659 decimal_scalar : DecimalScalar ,
4760 array_len : usize ,
4861 decimal_dtype : DecimalDType ,
62+ accumulator : & Scalar ,
4963) -> VortexResult < ScalarValue > {
5064 let result_dtype = Stat :: Sum
5165 . dtype ( & DType :: Decimal ( decimal_dtype, Nullability :: Nullable ) )
@@ -63,43 +77,82 @@ fn sum_decimal(
6377 let len_value = DecimalValue :: I256 ( i256:: from_i128 ( array_len as i128 ) ) ;
6478
6579 // Multiply value * len
66- let sum = value. checked_mul ( & len_value) . and_then ( |result| {
80+ let array_sum = value. checked_mul ( & len_value) . and_then ( |result| {
6781 // Check if result fits in the precision
6882 result
6983 . fits_in_precision ( * result_decimal_type)
7084 . unwrap_or ( false )
7185 . then_some ( result)
7286 } ) ;
7387
74- match sum {
75- Some ( result_value) => Ok ( ScalarValue :: from ( result_value) ) ,
88+ // Add accumulator to array_sum
89+ let initial_decimal = DecimalScalar :: try_from ( accumulator) ?;
90+ let initial_dec_value = initial_decimal
91+ . decimal_value ( )
92+ . unwrap_or ( DecimalValue :: I256 ( i256:: ZERO ) ) ;
93+
94+ match array_sum {
95+ Some ( array_sum_value) => {
96+ let total = array_sum_value
97+ . checked_add ( & initial_dec_value)
98+ . and_then ( |result| {
99+ result
100+ . fits_in_precision ( * result_decimal_type)
101+ . unwrap_or ( false )
102+ . then_some ( result)
103+ } ) ;
104+ match total {
105+ Some ( result_value) => Ok ( ScalarValue :: from ( result_value) ) ,
106+ None => Ok ( ScalarValue :: null ( ) ) , // Overflow
107+ }
108+ }
76109 None => Ok ( ScalarValue :: null ( ) ) , // Overflow
77110 }
78111}
79112
80113fn sum_integral < T > (
81114 primitive_scalar : PrimitiveScalar < ' _ > ,
82115 array_len : usize ,
116+ accumulator : & Scalar ,
83117) -> VortexResult < Option < T > >
84118where
85- T : NativePType + CheckedMul ,
119+ T : NativePType + CheckedMul + CheckedAdd ,
86120 Scalar : From < Option < T > > ,
87121{
88122 let v = primitive_scalar. as_ :: < T > ( ) ;
89123 let array_len =
90124 T :: from ( array_len) . ok_or_else ( || vortex_err ! ( "array_len must fit the sum type" ) ) ?;
91- let sum = v. and_then ( |v| v. checked_mul ( & array_len) ) ;
125+ let Some ( array_sum) = v. and_then ( |v| v. checked_mul ( & array_len) ) else {
126+ return Ok ( None ) ;
127+ } ;
92128
93- Ok ( sum)
129+ let initial = accumulator
130+ . as_primitive ( )
131+ . as_ :: < T > ( )
132+ . vortex_expect ( "cannot be null" ) ;
133+ Ok ( initial. checked_add ( & array_sum) )
94134}
95135
96- fn sum_float ( primitive_scalar : PrimitiveScalar < ' _ > , array_len : usize ) -> VortexResult < Option < f64 > > {
97- let v = primitive_scalar. as_ :: < f64 > ( ) ;
136+ fn sum_float (
137+ primitive_scalar : PrimitiveScalar < ' _ > ,
138+ array_len : usize ,
139+ accumulator : & Scalar ,
140+ ) -> VortexResult < Option < f64 > > {
141+ let v = primitive_scalar
142+ . as_ :: < f64 > ( )
143+ . vortex_expect ( "cannot be null" ) ;
98144 let array_len = array_len
99145 . to_f64 ( )
100146 . ok_or_else ( || vortex_err ! ( "array_len must fit the sum type" ) ) ?;
101147
102- Ok ( v. map ( |v| v * array_len) )
148+ let Ok ( array_sum) = v. mul_checked ( array_len) else {
149+ return Ok ( None ) ;
150+ } ;
151+ let initial = accumulator
152+ . as_primitive ( )
153+ . as_ :: < f64 > ( )
154+ . vortex_expect ( "cannot be null" ) ;
155+ Ok ( Some ( initial + array_sum) )
103156}
104157
105158register_kernel ! ( SumKernelAdapter ( ConstantVTable ) . lift( ) ) ;
0 commit comments