@@ -42,37 +42,31 @@ register_kernel!(SumKernelAdapter(ChunkedVTable).lift());
4242fn sum_int < T : NativePType + PrimInt + FromPrimitiveOrF16 > (
4343 chunks : & [ ArrayRef ] ,
4444) -> VortexResult < Option < T > > {
45- let mut result: Option < T > = None ;
45+ let mut result: T = T :: zero ( ) ;
4646 for chunk in chunks {
4747 let chunk_sum = sum ( chunk) ?;
48-
49- let Some ( chunk_sum) = chunk_sum. as_primitive ( ) . as_ :: < T > ( ) else {
50- // Skip missing null chunk
51- continue ;
48+ let Some ( chunk_sum) = chunk_sum
49+ . as_primitive ( )
50+ . as_ :: < T > ( )
51+ . and_then ( |chunk_sum| result. checked_add ( & chunk_sum) )
52+ else {
53+ // Bail out on null or overflow
54+ return Ok ( None ) ;
5255 } ;
53-
54- result = Some ( match result {
55- None => chunk_sum,
56- Some ( result) => {
57- let Some ( chunk_result) = result. checked_add ( & chunk_sum) else {
58- // Bail out on overflow
59- return Ok ( None ) ;
60- } ;
61- chunk_result
62- }
63- } ) ;
56+ result = chunk_sum;
6457 }
65- Ok ( result)
58+ Ok ( Some ( result) )
6659}
6760
68- fn sum_float ( chunks : & [ ArrayRef ] ) -> VortexResult < f64 > {
61+ fn sum_float ( chunks : & [ ArrayRef ] ) -> VortexResult < Option < f64 > > {
6962 let mut result = 0f64 ;
7063 for chunk in chunks {
71- if let Some ( chunk_sum) = sum ( chunk) ?. as_primitive ( ) . as_ :: < f64 > ( ) {
72- result += chunk_sum ;
64+ let Some ( chunk_sum) = sum ( chunk) ?. as_primitive ( ) . as_ :: < f64 > ( ) else {
65+ return Ok ( None ) ;
7366 } ;
67+ result += chunk_sum;
7468 }
75- Ok ( result)
69+ Ok ( Some ( result) )
7670}
7771
7872fn sum_decimal ( chunks : & [ ArrayRef ] , result_decimal_type : DecimalDType ) -> VortexResult < Scalar > {
@@ -84,21 +78,19 @@ fn sum_decimal(chunks: &[ArrayRef], result_decimal_type: DecimalDType) -> Vortex
8478 let chunk_sum = sum ( chunk) ?;
8579
8680 let chunk_decimal = DecimalScalar :: try_from ( & chunk_sum) ?;
87- let Some ( chunk_value) = chunk_decimal. decimal_value ( ) else {
88- // skips all null chunks
89- continue ;
90- } ;
91-
92- // Perform checked addition with current result
93- let Some ( r) = result. checked_add ( & chunk_value) . filter ( |sum_value| {
94- sum_value
95- . fits_in_precision ( result_decimal_type)
96- . unwrap_or ( false )
97- } ) else {
98- // Overflow
81+ let Some ( r) = chunk_decimal
82+ . decimal_value ( )
83+ // TODO(joe): added a precision capped checked_add.
84+ . and_then ( |c_sum| result. checked_add ( & c_sum) )
85+ . filter ( |sum_value| {
86+ sum_value
87+ . fits_in_precision ( result_decimal_type)
88+ . unwrap_or ( false )
89+ } )
90+ else {
91+ // null if any chunk is null or the sum overflows
9992 return Ok ( null ( ) ) ;
10093 } ;
101-
10294 result = r;
10395 }
10496
@@ -146,18 +138,17 @@ mod tests {
146138 }
147139
148140 #[ test]
149- fn test_sum_chunked_floats_all_nulls ( ) {
141+ fn test_sum_chunked_floats_all_nulls_is_zero ( ) {
150142 // Create chunks with all nulls
151143 let chunk1 = PrimitiveArray :: from_option_iter :: < f32 , _ > ( vec ! [ None , None , None ] ) ;
152144 let chunk2 = PrimitiveArray :: from_option_iter :: < f32 , _ > ( vec ! [ None , None ] ) ;
153145
154146 let dtype = chunk1. dtype ( ) . clone ( ) ;
155147 let chunked =
156148 ChunkedArray :: try_new ( vec ! [ chunk1. into_array( ) , chunk2. into_array( ) ] , dtype) . unwrap ( ) ;
157-
158149 // Compute sum - should return null for all nulls
159150 let result = sum ( chunked. as_ref ( ) ) . unwrap ( ) ;
160- assert ! ( result. as_primitive ( ) . as_ :: < f64 > ( ) . is_none ( ) ) ;
151+ assert_eq ! ( result, Scalar :: primitive ( 0f64 , Nullability :: Nullable ) ) ;
161152 }
162153
163154 #[ test]
0 commit comments