Skip to content

Commit 3165d29

Browse files
fix: Sum from stats has to account for accumulator (#5445)
fix: #5434 Signed-off-by: Robert Kruszewski <[email protected]> --------- Signed-off-by: Robert Kruszewski <[email protected]> Co-authored-by: Joe Isaacs <[email protected]>
1 parent a2fe4d3 commit 3165d29

File tree

2 files changed

+124
-9
lines changed

2 files changed

+124
-9
lines changed

vortex-array/src/compute/sum.rs

Lines changed: 86 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -4,11 +4,12 @@
44
use std::sync::LazyLock;
55

66
use arcref::ArcRef;
7+
use num_traits::{CheckedAdd, CheckedSub};
78
use vortex_dtype::DType;
89
use vortex_error::{
910
VortexError, VortexResult, vortex_bail, vortex_ensure, vortex_err, vortex_panic,
1011
};
11-
use vortex_scalar::Scalar;
12+
use vortex_scalar::{NumericOperator, Scalar};
1213

1314
use crate::Array;
1415
use crate::compute::{ComputeFn, ComputeFnVTable, InvocationArgs, Kernel, Output};
@@ -102,15 +103,66 @@ impl ComputeFnVTable for Sum {
102103

103104
// Short-circuit using array statistics.
104105
if let Some(Precision::Exact(sum)) = array.statistics().get(Stat::Sum) {
105-
return Ok(sum.into());
106+
// For floats only use stats if accumulator is zero. otherwise we might have numerical stability issues.
107+
match sum_dtype {
108+
DType::Primitive(p, _) => {
109+
if p.is_float() && accumulator.is_zero() {
110+
return Ok(sum.into());
111+
}
112+
let sum_from_stat = accumulator
113+
.as_primitive()
114+
.checked_add(&sum.as_primitive())
115+
.map(Scalar::from);
116+
return Ok(sum_from_stat
117+
.unwrap_or_else(|| Scalar::null(sum_dtype))
118+
.into());
119+
}
120+
DType::Decimal(..) => {
121+
let sum_from_stat = accumulator
122+
.as_decimal()
123+
.checked_binary_numeric(&sum.as_decimal(), NumericOperator::Add)
124+
.map(Scalar::from);
125+
return Ok(sum_from_stat
126+
.unwrap_or_else(|| Scalar::null(sum_dtype))
127+
.into());
128+
}
129+
_ => unreachable!("Sum will always be a decimal or a primitive dtype"),
130+
}
106131
}
107132

108133
let sum_scalar = sum_impl(array, accumulator, kernels)?;
109134

110-
// Update the statistics with the computed sum.
111-
array
112-
.statistics()
113-
.set(Stat::Sum, Precision::Exact(sum_scalar.value().clone()));
135+
// Update the statistics with the computed sum. Stored statistic shouldn't include the accumulator.
136+
match sum_dtype {
137+
DType::Primitive(p, _) => {
138+
if p.is_float() && accumulator.is_zero() {
139+
array
140+
.statistics()
141+
.set(Stat::Sum, Precision::Exact(sum_scalar.value().clone()));
142+
} else if p.is_int()
143+
&& let Some(less_accumulator) = sum_scalar
144+
.as_primitive()
145+
.checked_sub(&accumulator.as_primitive())
146+
{
147+
array.statistics().set(
148+
Stat::Sum,
149+
Precision::Exact(Scalar::from(less_accumulator).value().clone()),
150+
);
151+
}
152+
}
153+
DType::Decimal(..) => {
154+
if let Some(less_accumulator) = sum_scalar
155+
.as_decimal()
156+
.checked_binary_numeric(&accumulator.as_decimal(), NumericOperator::Sub)
157+
{
158+
array.statistics().set(
159+
Stat::Sum,
160+
Precision::Exact(Scalar::from(less_accumulator).value().clone()),
161+
)
162+
}
163+
}
164+
_ => unreachable!("Sum will always be a decimal or a primitive dtype"),
165+
}
114166

115167
Ok(sum_scalar.into())
116168
}
@@ -206,12 +258,13 @@ pub fn sum_impl(
206258
#[cfg(test)]
207259
mod test {
208260
use vortex_buffer::buffer;
209-
use vortex_dtype::Nullability;
261+
use vortex_dtype::{DType, Nullability, PType};
262+
use vortex_error::VortexUnwrap;
210263
use vortex_scalar::Scalar;
211264

212265
use crate::IntoArray as _;
213-
use crate::arrays::{BoolArray, PrimitiveArray};
214-
use crate::compute::sum;
266+
use crate::arrays::{BoolArray, ChunkedArray, PrimitiveArray};
267+
use crate::compute::{sum, sum_with_accumulator};
215268

216269
#[test]
217270
fn sum_all_invalid() {
@@ -247,4 +300,28 @@ mod test {
247300
let result = sum(array.as_ref()).unwrap();
248301
assert_eq!(result.as_primitive().as_::<i32>(), Some(2));
249302
}
303+
304+
#[test]
305+
fn sum_stats() {
306+
let array = ChunkedArray::try_new(
307+
vec![
308+
PrimitiveArray::from_iter([1, 1, 1]).into_array(),
309+
PrimitiveArray::from_iter([2, 2, 2]).into_array(),
310+
],
311+
DType::Primitive(PType::I32, Nullability::NonNullable),
312+
)
313+
.vortex_unwrap();
314+
// compute sum with accumulator to populate stats
315+
sum_with_accumulator(
316+
array.as_ref(),
317+
&Scalar::primitive(2i64, Nullability::Nullable),
318+
)
319+
.unwrap();
320+
321+
let sum_without_acc = sum(array.as_ref()).unwrap();
322+
assert_eq!(
323+
sum_without_acc,
324+
Scalar::primitive(9i64, Nullability::Nullable)
325+
);
326+
}
250327
}

vortex-scalar/src/scalar.rs

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -239,6 +239,44 @@ impl Scalar {
239239
}
240240
}
241241

242+
/// Returns true if the scalar is a zero value i.e., equal to a scalar returned from the ` zero_value ` method.
243+
pub fn is_zero(&self) -> bool {
244+
match self.dtype() {
245+
DType::Null => true,
246+
DType::Bool(_) => self.as_bool().value() == Some(false),
247+
DType::Primitive(pt, _) => self.as_primitive().pvalue() == Some(PValue::zero(*pt)),
248+
DType::Decimal(..) => {
249+
self.as_decimal().decimal_value() == Some(DecimalValue::from(0i8))
250+
}
251+
DType::Utf8(_) => self
252+
.as_utf8()
253+
.value()
254+
.map(|v| v.is_empty())
255+
.unwrap_or(false),
256+
DType::Binary(_) => self
257+
.as_binary()
258+
.value()
259+
.map(|v| v.is_empty())
260+
.unwrap_or(false),
261+
DType::Struct(..) => self
262+
.as_struct()
263+
.fields()
264+
.map(|mut sf| sf.all(|f| f.is_zero()))
265+
.unwrap_or(false),
266+
DType::List(..) => self
267+
.as_list()
268+
.elements()
269+
.map(|vals| vals.is_empty())
270+
.unwrap_or(false),
271+
DType::FixedSizeList(..) => self
272+
.as_list()
273+
.elements()
274+
.map(|vals| vals.iter().all(|f| f.is_zero()))
275+
.unwrap_or(false),
276+
DType::Extension(..) => self.as_extension().storage().is_zero(),
277+
}
278+
}
279+
242280
/// Creates a "default" scalar value for the given data type.
243281
///
244282
/// For nullable types, returns null. For non-nullable types, returns an appropriate zero/empty

0 commit comments

Comments
 (0)