@@ -104,52 +104,64 @@ impl ComputeFnVTable for Sum {
104104 // Short-circuit using array statistics.
105105 if let Some ( Precision :: Exact ( sum) ) = array. statistics ( ) . get ( Stat :: Sum ) {
106106 // For floats only use stats if accumulator is zero. otherwise we might have numerical stability issues.
107- if sum_dtype. is_float ( ) && accumulator == & Scalar :: zero_value ( sum_dtype. clone ( ) ) {
108- return Ok ( sum. into ( ) ) ;
109- } else if sum_dtype. is_int ( ) {
110- let sum_from_stat = accumulator
111- . as_primitive ( )
112- . checked_add ( & sum. as_primitive ( ) )
113- . map ( Scalar :: from) ;
114- return Ok ( sum_from_stat
115- . unwrap_or_else ( || Scalar :: null ( sum_dtype) )
116- . into ( ) ) ;
117- } else if sum_dtype. is_decimal ( ) {
118- let sum_from_stat = accumulator
119- . as_decimal ( )
120- . checked_binary_numeric ( & sum. as_decimal ( ) , NumericOperator :: Add )
121- . map ( Scalar :: from) ;
122- return Ok ( sum_from_stat
123- . unwrap_or_else ( || Scalar :: null ( sum_dtype) )
124- . into ( ) ) ;
107+ match sum_dtype {
108+ DType :: Primitive ( p, _) => {
109+ if p. is_float ( ) {
110+ return Ok ( sum. into ( ) ) ;
111+ }
112+ let sum_from_stat = accumulator
113+ . as_primitive ( )
114+ . checked_add ( & sum. as_primitive ( ) )
115+ . map ( Scalar :: from) ;
116+ return Ok ( sum_from_stat
117+ . unwrap_or_else ( || Scalar :: null ( sum_dtype) )
118+ . into ( ) ) ;
119+ }
120+ DType :: Decimal ( ..) => {
121+ let sum_from_stat = accumulator
122+ . as_decimal ( )
123+ . checked_binary_numeric ( & sum. as_decimal ( ) , NumericOperator :: Add )
124+ . map ( Scalar :: from) ;
125+ return Ok ( sum_from_stat
126+ . unwrap_or_else ( || Scalar :: null ( sum_dtype) )
127+ . into ( ) ) ;
128+ }
129+ _ => unreachable ! ( "Sum will always be a decimal or a primitive dtype" ) ,
125130 }
126131 }
127132
128133 let sum_scalar = sum_impl ( array, accumulator, kernels) ?;
129134
130135 // Update the statistics with the computed sum. Stored statistic shouldn't include the accumulator.
131- if sum_dtype. is_float ( ) && accumulator == & Scalar :: zero_value ( sum_dtype. clone ( ) ) {
132- array
133- . statistics ( )
134- . set ( Stat :: Sum , Precision :: Exact ( sum_scalar. value ( ) . clone ( ) ) ) ;
135- } else if sum_dtype. is_int ( )
136- && let Some ( less_accumulator) = sum_scalar
137- . as_primitive ( )
138- . checked_sub ( & accumulator. as_primitive ( ) )
139- {
140- array. statistics ( ) . set (
141- Stat :: Sum ,
142- Precision :: Exact ( Scalar :: from ( less_accumulator) . value ( ) . clone ( ) ) ,
143- ) ;
144- } else if sum_dtype. is_decimal ( )
145- && let Some ( less_accumulator) = sum_scalar
146- . as_decimal ( )
147- . checked_binary_numeric ( & accumulator. as_decimal ( ) , NumericOperator :: Sub )
148- {
149- array. statistics ( ) . set (
150- Stat :: Sum ,
151- Precision :: Exact ( Scalar :: from ( less_accumulator) . value ( ) . clone ( ) ) ,
152- ) ;
136+ match sum_dtype {
137+ DType :: Primitive ( p, _) => {
138+ if p. is_float ( ) && accumulator. is_zero ( ) {
139+ array
140+ . statistics ( )
141+ . set ( Stat :: Sum , Precision :: Exact ( sum_scalar. value ( ) . clone ( ) ) ) ;
142+ } else if p. is_int ( )
143+ && let Some ( less_accumulator) = sum_scalar
144+ . as_primitive ( )
145+ . checked_sub ( & accumulator. as_primitive ( ) )
146+ {
147+ array. statistics ( ) . set (
148+ Stat :: Sum ,
149+ Precision :: Exact ( Scalar :: from ( less_accumulator) . value ( ) . clone ( ) ) ,
150+ ) ;
151+ }
152+ }
153+ DType :: Decimal ( ..) => {
154+ if let Some ( less_accumulator) = sum_scalar
155+ . as_decimal ( )
156+ . checked_binary_numeric ( & accumulator. as_decimal ( ) , NumericOperator :: Sub )
157+ {
158+ array. statistics ( ) . set (
159+ Stat :: Sum ,
160+ Precision :: Exact ( Scalar :: from ( less_accumulator) . value ( ) . clone ( ) ) ,
161+ )
162+ }
163+ }
164+ _ => unreachable ! ( "Sum will always be a decimal or a primitive dtype" ) ,
153165 }
154166
155167 Ok ( sum_scalar. into ( ) )
0 commit comments