@@ -9,7 +9,7 @@ use vortex_dtype::DType;
99use vortex_error:: {
1010 VortexError , VortexResult , vortex_bail, vortex_ensure, vortex_err, vortex_panic,
1111} ;
12- use vortex_scalar:: Scalar ;
12+ use vortex_scalar:: { NumericOperator , Scalar } ;
1313
1414use crate :: Array ;
1515use crate :: compute:: { ComputeFn , ComputeFnVTable , InvocationArgs , Kernel , Output } ;
@@ -103,26 +103,55 @@ impl ComputeFnVTable for Sum {
103103
104104 // Short-circuit using array statistics.
105105 if let Some ( Precision :: Exact ( sum) ) = array. statistics ( ) . get ( Stat :: Sum ) {
106- let sum_from_stat = sum
107- . as_primitive ( )
108- . checked_add ( & accumulator. as_primitive ( ) )
109- . map ( |s| Scalar :: from ( s) ) ;
110- return Ok ( sum_from_stat
111- . unwrap_or_else ( || Scalar :: null ( sum_dtype) )
112- . into ( ) ) ;
106+ // For floats only use stats if accumulator is zero. otherwise we might have numerical stability issues.
107+ if sum_dtype. is_float ( ) && accumulator == & Scalar :: zero_value ( sum_dtype. clone ( ) ) {
108+ return Ok ( sum. into ( ) ) ;
109+ } else if sum_dtype. is_int ( ) {
110+ let sum_from_stat = accumulator
111+ . as_primitive ( )
112+ . checked_add ( & sum. as_primitive ( ) )
113+ . map ( |s| Scalar :: from ( s) ) ;
114+ return Ok ( sum_from_stat
115+ . unwrap_or_else ( || Scalar :: null ( sum_dtype) )
116+ . into ( ) ) ;
117+ } else if sum_dtype. is_decimal ( ) {
118+ let sum_from_stat = accumulator
119+ . as_decimal ( )
120+ . checked_binary_numeric ( & sum. as_decimal ( ) , NumericOperator :: Add )
121+ . map ( |s| Scalar :: from ( s) ) ;
122+ return Ok ( sum_from_stat
123+ . unwrap_or_else ( || Scalar :: null ( sum_dtype) )
124+ . into ( ) ) ;
125+ }
113126 }
114127
115128 let sum_scalar = sum_impl ( array, accumulator, kernels) ?;
116129
117130 // Update the statistics with the computed sum. Stored statistic shouldn't include the accumulator.
118- if let Some ( less_accumulator) = sum_scalar
119- . as_primitive ( )
120- . checked_sub ( & accumulator. as_primitive ( ) )
121- {
122- array. statistics ( ) . set (
123- Stat :: Sum ,
124- Precision :: Exact ( Scalar :: from ( less_accumulator) . value ( ) . clone ( ) ) ,
125- ) ;
131+ if sum_dtype. is_float ( ) && accumulator == & Scalar :: zero_value ( sum_dtype. clone ( ) ) {
132+ array
133+ . statistics ( )
134+ . set ( Stat :: Sum , Precision :: Exact ( sum_scalar. value ( ) . clone ( ) ) ) ;
135+ } else if sum_dtype. is_int ( ) {
136+ if let Some ( less_accumulator) = sum_scalar
137+ . as_primitive ( )
138+ . checked_sub ( & accumulator. as_primitive ( ) )
139+ {
140+ array. statistics ( ) . set (
141+ Stat :: Sum ,
142+ Precision :: Exact ( Scalar :: from ( less_accumulator) . value ( ) . clone ( ) ) ,
143+ ) ;
144+ }
145+ } else if sum_dtype. is_decimal ( ) {
146+ if let Some ( less_accumulator) = sum_scalar
147+ . as_decimal ( )
148+ . checked_binary_numeric ( & accumulator. as_decimal ( ) , NumericOperator :: Sub )
149+ {
150+ array. statistics ( ) . set (
151+ Stat :: Sum ,
152+ Precision :: Exact ( Scalar :: from ( less_accumulator) . value ( ) . clone ( ) ) ,
153+ ) ;
154+ }
126155 }
127156
128157 Ok ( sum_scalar. into ( ) )
0 commit comments