|
4 | 4 | use std::sync::LazyLock; |
5 | 5 |
|
6 | 6 | use arcref::ArcRef; |
| 7 | +use num_traits::{CheckedAdd, CheckedSub}; |
7 | 8 | use vortex_dtype::DType; |
8 | 9 | use vortex_error::{ |
9 | 10 | VortexError, VortexResult, vortex_bail, vortex_ensure, vortex_err, vortex_panic, |
10 | 11 | }; |
11 | | -use vortex_scalar::Scalar; |
| 12 | +use vortex_scalar::{NumericOperator, Scalar}; |
12 | 13 |
|
13 | 14 | use crate::Array; |
14 | 15 | use crate::compute::{ComputeFn, ComputeFnVTable, InvocationArgs, Kernel, Output}; |
@@ -102,15 +103,66 @@ impl ComputeFnVTable for Sum { |
102 | 103 |
|
103 | 104 | // Short-circuit using array statistics. |
104 | 105 | if let Some(Precision::Exact(sum)) = array.statistics().get(Stat::Sum) { |
105 | | - return Ok(sum.into()); |
| 106 | + // For floats only use stats if accumulator is zero. otherwise we might have numerical stability issues. |
| 107 | + match sum_dtype { |
| 108 | + DType::Primitive(p, _) => { |
| 109 | + if p.is_float() && accumulator.is_zero() { |
| 110 | + return Ok(sum.into()); |
| 111 | + } |
| 112 | + let sum_from_stat = accumulator |
| 113 | + .as_primitive() |
| 114 | + .checked_add(&sum.as_primitive()) |
| 115 | + .map(Scalar::from); |
| 116 | + return Ok(sum_from_stat |
| 117 | + .unwrap_or_else(|| Scalar::null(sum_dtype)) |
| 118 | + .into()); |
| 119 | + } |
| 120 | + DType::Decimal(..) => { |
| 121 | + let sum_from_stat = accumulator |
| 122 | + .as_decimal() |
| 123 | + .checked_binary_numeric(&sum.as_decimal(), NumericOperator::Add) |
| 124 | + .map(Scalar::from); |
| 125 | + return Ok(sum_from_stat |
| 126 | + .unwrap_or_else(|| Scalar::null(sum_dtype)) |
| 127 | + .into()); |
| 128 | + } |
| 129 | + _ => unreachable!("Sum will always be a decimal or a primitive dtype"), |
| 130 | + } |
106 | 131 | } |
107 | 132 |
|
108 | 133 | let sum_scalar = sum_impl(array, accumulator, kernels)?; |
109 | 134 |
|
110 | | - // Update the statistics with the computed sum. |
111 | | - array |
112 | | - .statistics() |
113 | | - .set(Stat::Sum, Precision::Exact(sum_scalar.value().clone())); |
| 135 | + // Update the statistics with the computed sum. Stored statistic shouldn't include the accumulator. |
| 136 | + match sum_dtype { |
| 137 | + DType::Primitive(p, _) => { |
| 138 | + if p.is_float() && accumulator.is_zero() { |
| 139 | + array |
| 140 | + .statistics() |
| 141 | + .set(Stat::Sum, Precision::Exact(sum_scalar.value().clone())); |
| 142 | + } else if p.is_int() |
| 143 | + && let Some(less_accumulator) = sum_scalar |
| 144 | + .as_primitive() |
| 145 | + .checked_sub(&accumulator.as_primitive()) |
| 146 | + { |
| 147 | + array.statistics().set( |
| 148 | + Stat::Sum, |
| 149 | + Precision::Exact(Scalar::from(less_accumulator).value().clone()), |
| 150 | + ); |
| 151 | + } |
| 152 | + } |
| 153 | + DType::Decimal(..) => { |
| 154 | + if let Some(less_accumulator) = sum_scalar |
| 155 | + .as_decimal() |
| 156 | + .checked_binary_numeric(&accumulator.as_decimal(), NumericOperator::Sub) |
| 157 | + { |
| 158 | + array.statistics().set( |
| 159 | + Stat::Sum, |
| 160 | + Precision::Exact(Scalar::from(less_accumulator).value().clone()), |
| 161 | + ) |
| 162 | + } |
| 163 | + } |
| 164 | + _ => unreachable!("Sum will always be a decimal or a primitive dtype"), |
| 165 | + } |
114 | 166 |
|
115 | 167 | Ok(sum_scalar.into()) |
116 | 168 | } |
@@ -206,12 +258,13 @@ pub fn sum_impl( |
206 | 258 | #[cfg(test)] |
207 | 259 | mod test { |
208 | 260 | use vortex_buffer::buffer; |
209 | | - use vortex_dtype::Nullability; |
| 261 | + use vortex_dtype::{DType, Nullability, PType}; |
| 262 | + use vortex_error::VortexUnwrap; |
210 | 263 | use vortex_scalar::Scalar; |
211 | 264 |
|
212 | 265 | use crate::IntoArray as _; |
213 | | - use crate::arrays::{BoolArray, PrimitiveArray}; |
214 | | - use crate::compute::sum; |
| 266 | + use crate::arrays::{BoolArray, ChunkedArray, PrimitiveArray}; |
| 267 | + use crate::compute::{sum, sum_with_accumulator}; |
215 | 268 |
|
216 | 269 | #[test] |
217 | 270 | fn sum_all_invalid() { |
@@ -247,4 +300,28 @@ mod test { |
247 | 300 | let result = sum(array.as_ref()).unwrap(); |
248 | 301 | assert_eq!(result.as_primitive().as_::<i32>(), Some(2)); |
249 | 302 | } |
| 303 | + |
| 304 | + #[test] |
| 305 | + fn sum_stats() { |
| 306 | + let array = ChunkedArray::try_new( |
| 307 | + vec![ |
| 308 | + PrimitiveArray::from_iter([1, 1, 1]).into_array(), |
| 309 | + PrimitiveArray::from_iter([2, 2, 2]).into_array(), |
| 310 | + ], |
| 311 | + DType::Primitive(PType::I32, Nullability::NonNullable), |
| 312 | + ) |
| 313 | + .vortex_unwrap(); |
| 314 | + // compute sum with accumulator to populate stats |
| 315 | + sum_with_accumulator( |
| 316 | + array.as_ref(), |
| 317 | + &Scalar::primitive(2i64, Nullability::Nullable), |
| 318 | + ) |
| 319 | + .unwrap(); |
| 320 | + |
| 321 | + let sum_without_acc = sum(array.as_ref()).unwrap(); |
| 322 | + assert_eq!( |
| 323 | + sum_without_acc, |
| 324 | + Scalar::primitive(9i64, Nullability::Nullable) |
| 325 | + ); |
| 326 | + } |
250 | 327 | } |
0 commit comments