Skip to content

Commit ccee917

Browse files
committed
fix[array]: sum with initial value to fix op assoc
Signed-off-by: Joe Isaacs <[email protected]>
1 parent 91a5f0a commit ccee917

File tree

3 files changed

+47
-31
lines changed

3 files changed

+47
-31
lines changed

vortex-array/src/arrays/chunked/compute/sum.rs

Lines changed: 6 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,7 @@
11
// SPDX-License-Identifier: Apache-2.0
22
// SPDX-FileCopyrightText: Copyright the Vortex contributors
33

4-
use std::ops::AddAssign;
5-
6-
use num_traits::PrimInt;
4+
use num_traits::CheckedAdd;
75
use vortex_dtype::{DType, NativePType, match_each_native_ptype};
86
use vortex_error::{VortexExpect, VortexResult, vortex_bail, vortex_err};
97
use vortex_scalar::Scalar;
@@ -26,7 +24,7 @@ impl SumKernel for ChunkedVTable {
2624
sum_ptype,
2725
unsigned: |T| { sum_int::<u64>(array.chunks(), initial_value.as_primitive().as_::<u64>().vortex_expect("cannot be null"))?.into() },
2826
signed: |T| { sum_int::<i64>(array.chunks(), initial_value.as_primitive().as_::<i64>().vortex_expect("cannot be null"))?.into() },
29-
floating: |T| { sum_float::<f64>(array.chunks(), initial_value.as_primitive().as_::<f64>().vortex_expect("cannot be null"))?.into() }
27+
floating: |T| { sum_float(array.chunks(), initial_value.as_primitive().as_::<f64>().vortex_expect("cannot be null"))?.into() }
3028
);
3129

3230
Ok(Scalar::new(sum_dtype, scalar_value))
@@ -40,7 +38,7 @@ impl SumKernel for ChunkedVTable {
4038

4139
register_kernel!(SumKernelAdapter(ChunkedVTable).lift());
4240

43-
fn sum_int<T: NativePType + PrimInt>(
41+
fn sum_int<T: NativePType + CheckedAdd>(
4442
chunks: &[ArrayRef],
4543
initial_value: T,
4644
) -> VortexResult<Option<T>> {
@@ -60,13 +58,10 @@ fn sum_int<T: NativePType + PrimInt>(
6058
Ok(Some(result))
6159
}
6260

63-
fn sum_float<T: NativePType + AddAssign>(
64-
chunks: &[ArrayRef],
65-
initial_value: T,
66-
) -> VortexResult<Option<T>> {
61+
fn sum_float(chunks: &[ArrayRef], initial_value: f64) -> VortexResult<Option<f64>> {
6762
let mut result = initial_value;
6863
for chunk in chunks {
69-
let Some(chunk_sum) = sum(chunk)?.as_primitive().as_::<T>() else {
64+
let Some(chunk_sum) = sum(chunk)?.as_primitive().as_::<f64>() else {
7065
return Ok(None);
7166
};
7267
result += chunk_sum;
@@ -78,7 +73,7 @@ fn sum_decimal(chunks: &[ArrayRef], initial_value: &Scalar) -> VortexResult<Scal
7873
let mut result = initial_value.clone();
7974

8075
for chunk in chunks {
81-
result = sum_with_initial(chunk, result)?;
76+
result = sum_with_initial(chunk, &result)?;
8277
}
8378

8479
Ok(result)

vortex-array/src/arrays/extension/compute/sum.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@ use crate::register_kernel;
1010

1111
impl SumKernel for ExtensionVTable {
1212
fn sum(&self, array: &ExtensionArray, initial_value: &Scalar) -> VortexResult<Scalar> {
13-
compute::sum_with_initial(array.storage(), initial_value.clone())
13+
compute::sum_with_initial(array.storage(), initial_value)
1414
}
1515
}
1616

vortex-array/src/compute/sum.rs

Lines changed: 40 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -6,13 +6,11 @@ use std::sync::LazyLock;
66

77
use arcref::ArcRef;
88
use vortex_dtype::DType;
9-
use vortex_error::{VortexResult, vortex_err, vortex_panic};
9+
use vortex_error::{VortexError, VortexResult, vortex_bail, vortex_err, vortex_panic};
1010
use vortex_scalar::Scalar;
1111

1212
use crate::Array;
13-
use crate::compute::{
14-
ComputeFn, ComputeFnVTable, InvocationArgs, Kernel, Options, Output, UnaryArgs,
15-
};
13+
use crate::compute::{ComputeFn, ComputeFnVTable, InvocationArgs, Kernel, Options, Output};
1614
use crate::stats::{Precision, Stat, StatsProvider};
1715
use crate::vtable::VTable;
1816

@@ -45,11 +43,11 @@ pub(crate) fn warm_up_vtable() -> usize {
4543
/// If the sum is not supported for the array's dtype, an error will be raised.
4644
/// If the array is all-invalid, the sum will be the initial_value.
4745
/// The initial_value must have a dtype compatible with the sum result dtype.
48-
pub(crate) fn sum_with_initial(array: &dyn Array, initial_value: Scalar) -> VortexResult<Scalar> {
46+
pub(crate) fn sum_with_initial(array: &dyn Array, initial_value: &Scalar) -> VortexResult<Scalar> {
4947
SUM_FN
5048
.invoke(&InvocationArgs {
51-
inputs: &[array.into()],
52-
options: &SumOptions { initial_value },
49+
inputs: &[array.into(), initial_value.into()],
50+
options: &(),
5351
})?
5452
.unwrap_scalar()
5553
}
@@ -64,7 +62,30 @@ pub fn sum(array: &dyn Array) -> VortexResult<Scalar> {
6462
.dtype(array.dtype())
6563
.ok_or_else(|| vortex_err!("Sum not supported for dtype: {}", array.dtype()))?;
6664
let zero = Scalar::zero_value(sum_dtype);
67-
sum_with_initial(array, zero)
65+
sum_with_initial(array, &zero)
66+
}
67+
68+
/// For unary compute functions, it's useful to just have this short-cut.
69+
pub struct SumArgs<'a> {
70+
pub array: &'a dyn Array,
71+
pub accumulator: &'a Scalar,
72+
}
73+
74+
impl<'a> TryFrom<&InvocationArgs<'a>> for SumArgs<'a> {
75+
type Error = VortexError;
76+
77+
fn try_from(value: &InvocationArgs<'a>) -> Result<Self, Self::Error> {
78+
if value.inputs.len() != 2 {
79+
vortex_bail!("Expected 2 inputs, found {}", value.inputs.len());
80+
}
81+
let array = value.inputs[0]
82+
.array()
83+
.ok_or_else(|| vortex_err!("Expected input 0 to be an array"))?;
84+
let accumulator = value.inputs[1]
85+
.scalar()
86+
.ok_or_else(|| vortex_err!("Expected input 1 to be a scalar"))?;
87+
Ok(SumArgs { array, accumulator })
88+
}
6889
}
6990

7091
struct Sum;
@@ -75,8 +96,7 @@ impl ComputeFnVTable for Sum {
7596
args: &InvocationArgs,
7697
kernels: &[ArcRef<dyn Kernel>],
7798
) -> VortexResult<Output> {
78-
let UnaryArgs { array, options } = UnaryArgs::<SumOptions>::try_from(args)?;
79-
let initial_value = &options.initial_value;
99+
let SumArgs { array, accumulator } = args.try_into()?;
80100

81101
// Compute the expected dtype of the sum.
82102
let sum_dtype = self.return_dtype(args)?;
@@ -86,7 +106,7 @@ impl ComputeFnVTable for Sum {
86106
return Ok(sum.into());
87107
}
88108

89-
let sum_scalar = sum_impl(array, sum_dtype, initial_value, kernels)?;
109+
let sum_scalar = sum_impl(array, sum_dtype, accumulator, kernels)?;
90110

91111
// Update the statistics with the computed sum.
92112
array
@@ -97,7 +117,7 @@ impl ComputeFnVTable for Sum {
97117
}
98118

99119
fn return_dtype(&self, args: &InvocationArgs) -> VortexResult<DType> {
100-
let UnaryArgs { array, .. } = UnaryArgs::<SumOptions>::try_from(args)?;
120+
let SumArgs { array, .. } = args.try_into()?;
101121
Stat::Sum
102122
.dtype(array.dtype())
103123
.ok_or_else(|| vortex_err!("Sum not supported for dtype: {}", array.dtype()))
@@ -136,11 +156,14 @@ impl<V: VTable + SumKernel> SumKernelAdapter<V> {
136156

137157
impl<V: VTable + SumKernel> Kernel for SumKernelAdapter<V> {
138158
fn invoke(&self, args: &InvocationArgs) -> VortexResult<Option<Output>> {
139-
let UnaryArgs { array, options } = UnaryArgs::<SumOptions>::try_from(args)?;
159+
let SumArgs {
160+
array,
161+
accumulator: initial_value,
162+
} = args.try_into()?;
140163
let Some(array) = array.as_opt::<V>() else {
141164
return Ok(None);
142165
};
143-
Ok(Some(V::sum(&self.0, array, &options.initial_value)?.into()))
166+
Ok(Some(V::sum(&self.0, array, initial_value)?.into()))
144167
}
145168
}
146169

@@ -161,10 +184,8 @@ pub fn sum_impl(
161184

162185
// Try to find a sum kernel
163186
let args = InvocationArgs {
164-
inputs: &[array.into()],
165-
options: &SumOptions {
166-
initial_value: initial_value.clone(),
167-
},
187+
inputs: &[array.into(), initial_value.into()],
188+
options: &(),
168189
};
169190
for kernel in kernels {
170191
if let Some(output) = kernel.invoke(&args)? {
@@ -184,7 +205,7 @@ pub fn sum_impl(
184205
array.encoding_id()
185206
);
186207
}
187-
sum_with_initial(array.to_canonical().as_ref(), initial_value.clone())
208+
sum_with_initial(array.to_canonical().as_ref(), initial_value)
188209
}
189210

190211
#[cfg(test)]

0 commit comments

Comments
 (0)