Skip to content

Commit 5d3de8a

Browse files
committed
fix: Sum from stats has to account for accumulator
Signed-off-by: Robert Kruszewski <[email protected]>
1 parent aafe376 commit 5d3de8a

File tree

1 file changed

+46
-8
lines changed
  • vortex-array/src/compute

1 file changed

+46
-8
lines changed

vortex-array/src/compute/sum.rs

Lines changed: 46 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
use std::sync::LazyLock;
55

66
use arcref::ArcRef;
7+
use num_traits::{CheckedAdd, CheckedSub};
78
use vortex_dtype::DType;
89
use vortex_error::{
910
VortexError, VortexResult, vortex_bail, vortex_ensure, vortex_err, vortex_panic,
@@ -102,15 +103,27 @@ impl ComputeFnVTable for Sum {
102103

103104
// Short-circuit using array statistics.
104105
if let Some(Precision::Exact(sum)) = array.statistics().get(Stat::Sum) {
105-
return Ok(sum.into());
106+
let sum_from_stat = sum
107+
.as_primitive()
108+
.checked_add(&accumulator.as_primitive())
109+
.map(|s| Scalar::from(s));
110+
return Ok(sum_from_stat
111+
.unwrap_or_else(|| Scalar::null(sum_dtype))
112+
.into());
106113
}
107114

108115
let sum_scalar = sum_impl(array, accumulator, kernels)?;
109116

110-
// Update the statistics with the computed sum.
111-
array
112-
.statistics()
113-
.set(Stat::Sum, Precision::Exact(sum_scalar.value().clone()));
117+
// Update the statistics with the computed sum. Stored statistic shouldn't include the accumulator.
118+
if let Some(less_accumulator) = sum_scalar
119+
.as_primitive()
120+
.checked_sub(&accumulator.as_primitive())
121+
{
122+
array.statistics().set(
123+
Stat::Sum,
124+
Precision::Exact(Scalar::from(less_accumulator).value().clone()),
125+
);
126+
}
114127

115128
Ok(sum_scalar.into())
116129
}
@@ -206,12 +219,13 @@ pub fn sum_impl(
206219
#[cfg(test)]
207220
mod test {
208221
use vortex_buffer::buffer;
209-
use vortex_dtype::Nullability;
222+
use vortex_dtype::{DType, Nullability, PType};
223+
use vortex_error::VortexUnwrap;
210224
use vortex_scalar::Scalar;
211225

212226
use crate::IntoArray as _;
213-
use crate::arrays::{BoolArray, PrimitiveArray};
214-
use crate::compute::sum;
227+
use crate::arrays::{BoolArray, ChunkedArray, PrimitiveArray};
228+
use crate::compute::{sum, sum_with_accumulator};
215229

216230
#[test]
217231
fn sum_all_invalid() {
@@ -247,4 +261,28 @@ mod test {
247261
let result = sum(array.as_ref()).unwrap();
248262
assert_eq!(result.as_primitive().as_::<i32>(), Some(2));
249263
}
264+
265+
#[test]
266+
fn sum_stats() {
267+
let array = ChunkedArray::try_new(
268+
vec![
269+
PrimitiveArray::from_iter([1, 1, 1]).into_array(),
270+
PrimitiveArray::from_iter([2, 2, 2]).into_array(),
271+
],
272+
DType::Primitive(PType::I32, Nullability::NonNullable),
273+
)
274+
.vortex_unwrap();
275+
// comptue sum with accumulator to populate stats
276+
sum_with_accumulator(
277+
array.as_ref(),
278+
&Scalar::primitive(2i64, Nullability::Nullable),
279+
)
280+
.unwrap();
281+
282+
let sum_with_acc = sum(array.as_ref()).unwrap();
283+
assert_eq!(
284+
sum_with_acc,
285+
Scalar::primitive(11i64, Nullability::Nullable)
286+
);
287+
}
250288
}

0 commit comments

Comments
 (0)