Skip to content

Commit c1e07fa

Browse files
committed
fixes
Signed-off-by: Robert Kruszewski <[email protected]>
1 parent 882e7ed commit c1e07fa

File tree

2 files changed

+90
-40
lines changed

2 files changed

+90
-40
lines changed

vortex-array/src/compute/sum.rs

Lines changed: 52 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -104,52 +104,64 @@ impl ComputeFnVTable for Sum {
104104
// Short-circuit using array statistics.
105105
if let Some(Precision::Exact(sum)) = array.statistics().get(Stat::Sum) {
106106
// For floats only use stats if accumulator is zero. otherwise we might have numerical stability issues.
107-
if sum_dtype.is_float() && accumulator == &Scalar::zero_value(sum_dtype.clone()) {
108-
return Ok(sum.into());
109-
} else if sum_dtype.is_int() {
110-
let sum_from_stat = accumulator
111-
.as_primitive()
112-
.checked_add(&sum.as_primitive())
113-
.map(Scalar::from);
114-
return Ok(sum_from_stat
115-
.unwrap_or_else(|| Scalar::null(sum_dtype))
116-
.into());
117-
} else if sum_dtype.is_decimal() {
118-
let sum_from_stat = accumulator
119-
.as_decimal()
120-
.checked_binary_numeric(&sum.as_decimal(), NumericOperator::Add)
121-
.map(Scalar::from);
122-
return Ok(sum_from_stat
123-
.unwrap_or_else(|| Scalar::null(sum_dtype))
124-
.into());
107+
match sum_dtype {
108+
DType::Primitive(p, _) => {
109+
if p.is_float() {
110+
return Ok(sum.into());
111+
}
112+
let sum_from_stat = accumulator
113+
.as_primitive()
114+
.checked_add(&sum.as_primitive())
115+
.map(Scalar::from);
116+
return Ok(sum_from_stat
117+
.unwrap_or_else(|| Scalar::null(sum_dtype))
118+
.into());
119+
}
120+
DType::Decimal(..) => {
121+
let sum_from_stat = accumulator
122+
.as_decimal()
123+
.checked_binary_numeric(&sum.as_decimal(), NumericOperator::Add)
124+
.map(Scalar::from);
125+
return Ok(sum_from_stat
126+
.unwrap_or_else(|| Scalar::null(sum_dtype))
127+
.into());
128+
}
129+
_ => unreachable!("Sum will always be a decimal or a primitive dtype"),
125130
}
126131
}
127132

128133
let sum_scalar = sum_impl(array, accumulator, kernels)?;
129134

130135
// Update the statistics with the computed sum. Stored statistic shouldn't include the accumulator.
131-
if sum_dtype.is_float() && accumulator == &Scalar::zero_value(sum_dtype.clone()) {
132-
array
133-
.statistics()
134-
.set(Stat::Sum, Precision::Exact(sum_scalar.value().clone()));
135-
} else if sum_dtype.is_int()
136-
&& let Some(less_accumulator) = sum_scalar
137-
.as_primitive()
138-
.checked_sub(&accumulator.as_primitive())
139-
{
140-
array.statistics().set(
141-
Stat::Sum,
142-
Precision::Exact(Scalar::from(less_accumulator).value().clone()),
143-
);
144-
} else if sum_dtype.is_decimal()
145-
&& let Some(less_accumulator) = sum_scalar
146-
.as_decimal()
147-
.checked_binary_numeric(&accumulator.as_decimal(), NumericOperator::Sub)
148-
{
149-
array.statistics().set(
150-
Stat::Sum,
151-
Precision::Exact(Scalar::from(less_accumulator).value().clone()),
152-
);
136+
match sum_dtype {
137+
DType::Primitive(p, _) => {
138+
if p.is_float() && accumulator.is_zero() {
139+
array
140+
.statistics()
141+
.set(Stat::Sum, Precision::Exact(sum_scalar.value().clone()));
142+
} else if p.is_int()
143+
&& let Some(less_accumulator) = sum_scalar
144+
.as_primitive()
145+
.checked_sub(&accumulator.as_primitive())
146+
{
147+
array.statistics().set(
148+
Stat::Sum,
149+
Precision::Exact(Scalar::from(less_accumulator).value().clone()),
150+
);
151+
}
152+
}
153+
DType::Decimal(..) => {
154+
if let Some(less_accumulator) = sum_scalar
155+
.as_decimal()
156+
.checked_binary_numeric(&accumulator.as_decimal(), NumericOperator::Sub)
157+
{
158+
array.statistics().set(
159+
Stat::Sum,
160+
Precision::Exact(Scalar::from(less_accumulator).value().clone()),
161+
)
162+
}
163+
}
164+
_ => unreachable!("Sum will always be a decimal or a primitive dtype"),
153165
}
154166

155167
Ok(sum_scalar.into())

vortex-scalar/src/scalar.rs

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -239,6 +239,44 @@ impl Scalar {
239239
}
240240
}
241241

242+
/// Returns true if the scalar is a zero value i.e., equal to a scalar returned from the ` zero_value ` method.
243+
pub fn is_zero(&self) -> bool {
244+
match self.dtype() {
245+
DType::Null => true,
246+
DType::Bool(_) => self.as_bool().value() == Some(false),
247+
DType::Primitive(pt, _) => self.as_primitive().pvalue() == Some(PValue::zero(*pt)),
248+
DType::Decimal(..) => {
249+
self.as_decimal().decimal_value() == Some(DecimalValue::from(0i8))
250+
}
251+
DType::Utf8(_) => self
252+
.as_utf8()
253+
.value()
254+
.map(|v| v.is_empty())
255+
.unwrap_or(false),
256+
DType::Binary(_) => self
257+
.as_binary()
258+
.value()
259+
.map(|v| v.is_empty())
260+
.unwrap_or(false),
261+
DType::Struct(..) => self
262+
.as_struct()
263+
.fields()
264+
.map(|mut sf| sf.all(|f| f.is_zero()))
265+
.unwrap_or(false),
266+
DType::List(..) => self
267+
.as_list()
268+
.elements()
269+
.map(|vals| vals.is_empty())
270+
.unwrap_or(false),
271+
DType::FixedSizeList(..) => self
272+
.as_list()
273+
.elements()
274+
.map(|vals| vals.iter().all(|f| f.is_zero()))
275+
.unwrap_or(false),
276+
DType::Extension(..) => self.as_extension().storage().is_zero(),
277+
}
278+
}
279+
242280
/// Creates a "default" scalar value for the given data type.
243281
///
244282
/// For nullable types, returns null. For non-nullable types, returns an appropriate zero/empty

0 commit comments

Comments
 (0)