11// SPDX-License-Identifier: Apache-2.0
22// SPDX-FileCopyrightText: Copyright the Vortex contributors
33
4- use num_traits:: { CheckedMul , ToPrimitive } ;
4+ use num_traits:: { CheckedAdd , CheckedMul , ToPrimitive } ;
55use vortex_dtype:: { DType , DecimalDType , NativePType , Nullability , i256, match_each_native_ptype} ;
66use vortex_error:: { VortexExpect , VortexResult , vortex_bail, vortex_err} ;
77use vortex_scalar:: { DecimalScalar , DecimalValue , PrimitiveScalar , Scalar , ScalarValue } ;
@@ -12,32 +12,44 @@ use crate::register_kernel;
1212use crate :: stats:: Stat ;
1313
1414impl SumKernel for ConstantVTable {
15- fn sum ( & self , array : & ConstantArray ) -> VortexResult < Scalar > {
15+ fn sum ( & self , array : & ConstantArray , initial_value : & Scalar ) -> VortexResult < Scalar > {
1616 // Compute the expected dtype of the sum.
1717 let sum_dtype = Stat :: Sum
1818 . dtype ( array. dtype ( ) )
1919 . ok_or_else ( || vortex_err ! ( "Sum not supported for dtype {}" , array. dtype( ) ) ) ?;
2020
21- let sum_value = sum_scalar ( array. scalar ( ) , array. len ( ) ) ?;
21+ let sum_value = sum_scalar ( array. scalar ( ) , array. len ( ) , initial_value ) ?;
2222 Ok ( Scalar :: new ( sum_dtype, sum_value) )
2323 }
2424}
2525
26- fn sum_scalar ( scalar : & Scalar , len : usize ) -> VortexResult < ScalarValue > {
26+ fn sum_scalar ( scalar : & Scalar , len : usize , acc : & Scalar ) -> VortexResult < ScalarValue > {
2727 match scalar. dtype ( ) {
28- DType :: Bool ( _) => Ok ( ScalarValue :: from ( match scalar. as_bool ( ) . value ( ) {
29- None => unreachable ! ( "Handled before reaching this point" ) ,
30- Some ( false ) => 0u64 ,
31- Some ( true ) => len as u64 ,
32- } ) ) ,
33- DType :: Primitive ( ptype, _) => Ok ( match_each_native_ptype ! (
34- ptype,
35- unsigned: |T | { sum_integral:: <u64 >( scalar. as_primitive( ) , len) ?. into( ) } ,
36- signed: |T | { sum_integral:: <i64 >( scalar. as_primitive( ) , len) ?. into( ) } ,
37- floating: |T | { sum_float( scalar. as_primitive( ) , len) ?. into( ) }
38- ) ) ,
39- DType :: Decimal ( decimal_dtype, _) => sum_decimal ( scalar. as_decimal ( ) , len, * decimal_dtype) ,
40- DType :: Extension ( _) => sum_scalar ( & scalar. as_extension ( ) . storage ( ) , len) ,
28+ DType :: Bool ( _) => {
29+ let count = match scalar. as_bool ( ) . value ( ) {
30+ None => unreachable ! ( "Handled before reaching this point" ) ,
31+ Some ( false ) => 0u64 ,
32+ Some ( true ) => len as u64 ,
33+ } ;
34+ let initial_u64 = acc
35+ . as_primitive ( )
36+ . as_ :: < u64 > ( )
37+ . vortex_expect ( "cannot be null" ) ;
38+ Ok ( ScalarValue :: from ( initial_u64. checked_add ( count) ) )
39+ }
40+ DType :: Primitive ( ptype, _) => {
41+ let result = match_each_native_ptype ! (
42+ ptype,
43+ unsigned: |T | { sum_integral:: <u64 >( scalar. as_primitive( ) , len, acc) ?. into( ) } ,
44+ signed: |T | { sum_integral:: <i64 >( scalar. as_primitive( ) , len, acc) ?. into( ) } ,
45+ floating: |T | { sum_float( scalar. as_primitive( ) , len, acc) ?. into( ) }
46+ ) ;
47+ Ok ( result)
48+ }
49+ DType :: Decimal ( decimal_dtype, _) => {
50+ sum_decimal ( scalar. as_decimal ( ) , len, * decimal_dtype, acc)
51+ }
52+ DType :: Extension ( _) => sum_scalar ( & scalar. as_extension ( ) . storage ( ) , len, acc) ,
4153 dtype => vortex_bail ! ( "Unsupported dtype for sum: {}" , dtype) ,
4254 }
4355}
@@ -46,6 +58,7 @@ fn sum_decimal(
4658 decimal_scalar : DecimalScalar ,
4759 array_len : usize ,
4860 decimal_dtype : DecimalDType ,
61+ initial_value : & Scalar ,
4962) -> VortexResult < ScalarValue > {
5063 let result_dtype = Stat :: Sum
5164 . dtype ( & DType :: Decimal ( decimal_dtype, Nullability :: Nullable ) )
@@ -63,43 +76,75 @@ fn sum_decimal(
6376 let len_value = DecimalValue :: I256 ( i256:: from_i128 ( array_len as i128 ) ) ;
6477
6578 // Multiply value * len
66- let sum = value. checked_mul ( & len_value) . and_then ( |result| {
79+ let array_sum = value. checked_mul ( & len_value) . and_then ( |result| {
6780 // Check if result fits in the precision
6881 result
6982 . fits_in_precision ( * result_decimal_type)
7083 . unwrap_or ( false )
7184 . then_some ( result)
7285 } ) ;
7386
74- match sum {
75- Some ( result_value) => Ok ( ScalarValue :: from ( result_value) ) ,
87+ // Add initial_value to array_sum
88+ let initial_decimal = DecimalScalar :: try_from ( initial_value) ?;
89+ let initial_dec_value = initial_decimal
90+ . decimal_value ( )
91+ . unwrap_or ( DecimalValue :: I256 ( i256:: ZERO ) ) ;
92+
93+ match array_sum {
94+ Some ( array_sum_value) => {
95+ let total = array_sum_value
96+ . checked_add ( & initial_dec_value)
97+ . and_then ( |result| {
98+ result
99+ . fits_in_precision ( * result_decimal_type)
100+ . unwrap_or ( false )
101+ . then_some ( result)
102+ } ) ;
103+ match total {
104+ Some ( result_value) => Ok ( ScalarValue :: from ( result_value) ) ,
105+ None => Ok ( ScalarValue :: null ( ) ) , // Overflow
106+ }
107+ }
76108 None => Ok ( ScalarValue :: null ( ) ) , // Overflow
77109 }
78110}
79111
80112fn sum_integral < T > (
81113 primitive_scalar : PrimitiveScalar < ' _ > ,
82114 array_len : usize ,
115+ initial_value : & Scalar ,
83116) -> VortexResult < Option < T > >
84117where
85- T : NativePType + CheckedMul ,
118+ T : NativePType + CheckedMul + CheckedAdd ,
86119 Scalar : From < Option < T > > ,
87120{
88121 let v = primitive_scalar. as_ :: < T > ( ) ;
89122 let array_len =
90123 T :: from ( array_len) . ok_or_else ( || vortex_err ! ( "array_len must fit the sum type" ) ) ?;
91- let sum = v. and_then ( |v| v. checked_mul ( & array_len) ) ;
124+ let Some ( array_sum) = v. and_then ( |v| v. checked_mul ( & array_len) ) else {
125+ return Ok ( None ) ;
126+ } ;
92127
93- Ok ( sum)
128+ let initial = initial_value
129+ . as_primitive ( )
130+ . as_ :: < T > ( )
131+ . unwrap_or_else ( T :: zero) ;
132+ Ok ( initial. checked_add ( & array_sum) )
94133}
95134
96- fn sum_float ( primitive_scalar : PrimitiveScalar < ' _ > , array_len : usize ) -> VortexResult < Option < f64 > > {
135+ fn sum_float (
136+ primitive_scalar : PrimitiveScalar < ' _ > ,
137+ array_len : usize ,
138+ initial_value : & Scalar ,
139+ ) -> VortexResult < Option < f64 > > {
97140 let v = primitive_scalar. as_ :: < f64 > ( ) ;
98141 let array_len = array_len
99142 . to_f64 ( )
100143 . ok_or_else ( || vortex_err ! ( "array_len must fit the sum type" ) ) ?;
101144
102- Ok ( v. map ( |v| v * array_len) )
145+ let array_sum = v. map ( |v| v * array_len) . unwrap_or ( 0.0 ) ;
146+ let initial = initial_value. as_primitive ( ) . as_ :: < f64 > ( ) . unwrap_or ( 0.0 ) ;
147+ Ok ( Some ( initial + array_sum) )
103148}
104149
105150register_kernel ! ( SumKernelAdapter ( ConstantVTable ) . lift( ) ) ;
0 commit comments