|
4 | 4 | use std::sync::LazyLock; |
5 | 5 |
|
6 | 6 | use arcref::ArcRef; |
| 7 | +use num_traits::{CheckedAdd, CheckedSub}; |
7 | 8 | use vortex_dtype::DType; |
8 | 9 | use vortex_error::{ |
9 | 10 | VortexError, VortexResult, vortex_bail, vortex_ensure, vortex_err, vortex_panic, |
@@ -102,15 +103,27 @@ impl ComputeFnVTable for Sum { |
102 | 103 |
|
103 | 104 | // Short-circuit using array statistics. |
104 | 105 | if let Some(Precision::Exact(sum)) = array.statistics().get(Stat::Sum) { |
105 | | - return Ok(sum.into()); |
| 106 | + let sum_from_stat = sum |
| 107 | + .as_primitive() |
| 108 | + .checked_add(&accumulator.as_primitive()) |
| 109 | + .map(|s| Scalar::from(s)); |
| 110 | + return Ok(sum_from_stat |
| 111 | + .unwrap_or_else(|| Scalar::null(sum_dtype)) |
| 112 | + .into()); |
106 | 113 | } |
107 | 114 |
|
108 | 115 | let sum_scalar = sum_impl(array, accumulator, kernels)?; |
109 | 116 |
|
110 | | - // Update the statistics with the computed sum. |
111 | | - array |
112 | | - .statistics() |
113 | | - .set(Stat::Sum, Precision::Exact(sum_scalar.value().clone())); |
| 117 | + // Update the statistics with the computed sum. Stored statistic shouldn't include the accumulator. |
| 118 | + if let Some(less_accumulator) = sum_scalar |
| 119 | + .as_primitive() |
| 120 | + .checked_sub(&accumulator.as_primitive()) |
| 121 | + { |
| 122 | + array.statistics().set( |
| 123 | + Stat::Sum, |
| 124 | + Precision::Exact(Scalar::from(less_accumulator).value().clone()), |
| 125 | + ); |
| 126 | + } |
114 | 127 |
|
115 | 128 | Ok(sum_scalar.into()) |
116 | 129 | } |
@@ -206,12 +219,13 @@ pub fn sum_impl( |
206 | 219 | #[cfg(test)] |
207 | 220 | mod test { |
208 | 221 | use vortex_buffer::buffer; |
209 | | - use vortex_dtype::Nullability; |
| 222 | + use vortex_dtype::{DType, Nullability, PType}; |
| 223 | + use vortex_error::VortexUnwrap; |
210 | 224 | use vortex_scalar::Scalar; |
211 | 225 |
|
212 | 226 | use crate::IntoArray as _; |
213 | | - use crate::arrays::{BoolArray, PrimitiveArray}; |
214 | | - use crate::compute::sum; |
| 227 | + use crate::arrays::{BoolArray, ChunkedArray, PrimitiveArray}; |
| 228 | + use crate::compute::{sum, sum_with_accumulator}; |
215 | 229 |
|
216 | 230 | #[test] |
217 | 231 | fn sum_all_invalid() { |
@@ -247,4 +261,28 @@ mod test { |
247 | 261 | let result = sum(array.as_ref()).unwrap(); |
248 | 262 | assert_eq!(result.as_primitive().as_::<i32>(), Some(2)); |
249 | 263 | } |
| 264 | + |
| 265 | + #[test] |
| 266 | + fn sum_stats() { |
| 267 | + let array = ChunkedArray::try_new( |
| 268 | + vec![ |
| 269 | + PrimitiveArray::from_iter([1, 1, 1]).into_array(), |
| 270 | + PrimitiveArray::from_iter([2, 2, 2]).into_array(), |
| 271 | + ], |
| 272 | + DType::Primitive(PType::I32, Nullability::NonNullable), |
| 273 | + ) |
| 274 | + .vortex_unwrap(); |
| 275 | + // comptue sum with accumulator to populate stats |
| 276 | + sum_with_accumulator( |
| 277 | + array.as_ref(), |
| 278 | + &Scalar::primitive(2i64, Nullability::Nullable), |
| 279 | + ) |
| 280 | + .unwrap(); |
| 281 | + |
| 282 | + let sum_with_acc = sum(array.as_ref()).unwrap(); |
| 283 | + assert_eq!( |
| 284 | + sum_with_acc, |
| 285 | + Scalar::primitive(11i64, Nullability::Nullable) |
| 286 | + ); |
| 287 | + } |
250 | 288 | } |
0 commit comments