Skip to content

Commit 1a1f89e

Browse files
committed
types
Signed-off-by: Robert Kruszewski <[email protected]>
1 parent 4a6d4c5 commit 1a1f89e

File tree

1 file changed

+45
-16
lines changed
  • vortex-array/src/compute

1 file changed

+45
-16
lines changed

vortex-array/src/compute/sum.rs

Lines changed: 45 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@ use vortex_dtype::DType;
99
use vortex_error::{
1010
VortexError, VortexResult, vortex_bail, vortex_ensure, vortex_err, vortex_panic,
1111
};
12-
use vortex_scalar::Scalar;
12+
use vortex_scalar::{NumericOperator, Scalar};
1313

1414
use crate::Array;
1515
use crate::compute::{ComputeFn, ComputeFnVTable, InvocationArgs, Kernel, Output};
@@ -103,26 +103,55 @@ impl ComputeFnVTable for Sum {
103103

104104
// Short-circuit using array statistics.
105105
if let Some(Precision::Exact(sum)) = array.statistics().get(Stat::Sum) {
106-
let sum_from_stat = sum
107-
.as_primitive()
108-
.checked_add(&accumulator.as_primitive())
109-
.map(|s| Scalar::from(s));
110-
return Ok(sum_from_stat
111-
.unwrap_or_else(|| Scalar::null(sum_dtype))
112-
.into());
106+
// For floats only use stats if accumulator is zero. otherwise we might have numerical stability issues.
107+
if sum_dtype.is_float() && accumulator == &Scalar::zero_value(sum_dtype.clone()) {
108+
return Ok(sum.into());
109+
} else if sum_dtype.is_int() {
110+
let sum_from_stat = accumulator
111+
.as_primitive()
112+
.checked_add(&sum.as_primitive())
113+
.map(|s| Scalar::from(s));
114+
return Ok(sum_from_stat
115+
.unwrap_or_else(|| Scalar::null(sum_dtype))
116+
.into());
117+
} else if sum_dtype.is_decimal() {
118+
let sum_from_stat = accumulator
119+
.as_decimal()
120+
.checked_binary_numeric(&sum.as_decimal(), NumericOperator::Add)
121+
.map(|s| Scalar::from(s));
122+
return Ok(sum_from_stat
123+
.unwrap_or_else(|| Scalar::null(sum_dtype))
124+
.into());
125+
}
113126
}
114127

115128
let sum_scalar = sum_impl(array, accumulator, kernels)?;
116129

117130
// Update the statistics with the computed sum. Stored statistic shouldn't include the accumulator.
118-
if let Some(less_accumulator) = sum_scalar
119-
.as_primitive()
120-
.checked_sub(&accumulator.as_primitive())
121-
{
122-
array.statistics().set(
123-
Stat::Sum,
124-
Precision::Exact(Scalar::from(less_accumulator).value().clone()),
125-
);
131+
if sum_dtype.is_float() && accumulator == &Scalar::zero_value(sum_dtype.clone()) {
132+
array
133+
.statistics()
134+
.set(Stat::Sum, Precision::Exact(sum_scalar.value().clone()));
135+
} else if sum_dtype.is_int() {
136+
if let Some(less_accumulator) = sum_scalar
137+
.as_primitive()
138+
.checked_sub(&accumulator.as_primitive())
139+
{
140+
array.statistics().set(
141+
Stat::Sum,
142+
Precision::Exact(Scalar::from(less_accumulator).value().clone()),
143+
);
144+
}
145+
} else if sum_dtype.is_decimal() {
146+
if let Some(less_accumulator) = sum_scalar
147+
.as_decimal()
148+
.checked_binary_numeric(&accumulator.as_decimal(), NumericOperator::Sub)
149+
{
150+
array.statistics().set(
151+
Stat::Sum,
152+
Precision::Exact(Scalar::from(less_accumulator).value().clone()),
153+
);
154+
}
126155
}
127156

128157
Ok(sum_scalar.into())

0 commit comments

Comments
 (0)