Skip to content

Commit ecc9eaf

Browse files
committed
rename
Signed-off-by: Joe Isaacs <[email protected]>
1 parent ccee917 commit ecc9eaf

File tree

7 files changed

+68
-123
lines changed

7 files changed

+68
-123
lines changed

vortex-array/src/arrays/bool/compute/sum.rs

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@ use crate::compute::{SumKernel, SumKernelAdapter};
1212
use crate::register_kernel;
1313

1414
impl SumKernel for BoolVTable {
15-
fn sum(&self, array: &BoolArray, initial_value: &Scalar) -> VortexResult<Scalar> {
15+
fn sum(&self, array: &BoolArray, accumulator: &Scalar) -> VortexResult<Scalar> {
1616
let true_count: Option<u64> = match array.validity_mask().bit_buffer() {
1717
AllOr::All => {
1818
// All-valid
@@ -27,13 +27,12 @@ impl SumKernel for BoolVTable {
2727
}
2828
};
2929

30-
// Add initial_value to true_count
31-
let initial_u64 = initial_value
30+
let accumulator = accumulator
3231
.as_primitive()
3332
.as_::<u64>()
3433
.vortex_expect("cannot be null");
3534
Ok(Scalar::from(
36-
true_count.and_then(|tc| tc.checked_add(initial_u64)),
35+
true_count.and_then(|tc| accumulator.checked_add(tc)),
3736
))
3837
}
3938
}

vortex-array/src/arrays/chunked/compute/sum.rs

Lines changed: 10 additions & 68 deletions
Original file line numberDiff line numberDiff line change
@@ -1,84 +1,26 @@
11
// SPDX-License-Identifier: Apache-2.0
22
// SPDX-FileCopyrightText: Copyright the Vortex contributors
33

4-
use num_traits::CheckedAdd;
5-
use vortex_dtype::{DType, NativePType, match_each_native_ptype};
6-
use vortex_error::{VortexExpect, VortexResult, vortex_bail, vortex_err};
4+
use vortex_error::VortexResult;
75
use vortex_scalar::Scalar;
86

97
use crate::arrays::{ChunkedArray, ChunkedVTable};
10-
use crate::compute::{SumKernel, SumKernelAdapter, sum, sum_with_initial};
11-
use crate::stats::Stat;
12-
use crate::{ArrayRef, register_kernel};
8+
use crate::compute::{SumKernel, SumKernelAdapter, sum_with_accumulator};
9+
use crate::register_kernel;
1310

1411
impl SumKernel for ChunkedVTable {
15-
fn sum(&self, array: &ChunkedArray, initial_value: &Scalar) -> VortexResult<Scalar> {
16-
let sum_dtype = Stat::Sum
17-
.dtype(array.dtype())
18-
.ok_or_else(|| vortex_err!("Sum not supported for dtype {}", array.dtype()))?;
19-
20-
match sum_dtype {
21-
DType::Decimal(..) => sum_decimal(array.chunks(), initial_value),
22-
DType::Primitive(sum_ptype, _) => {
23-
let scalar_value = match_each_native_ptype!(
24-
sum_ptype,
25-
unsigned: |T| { sum_int::<u64>(array.chunks(), initial_value.as_primitive().as_::<u64>().vortex_expect("cannot be null"))?.into() },
26-
signed: |T| { sum_int::<i64>(array.chunks(), initial_value.as_primitive().as_::<i64>().vortex_expect("cannot be null"))?.into() },
27-
floating: |T| { sum_float(array.chunks(), initial_value.as_primitive().as_::<f64>().vortex_expect("cannot be null"))?.into() }
28-
);
29-
30-
Ok(Scalar::new(sum_dtype, scalar_value))
31-
}
32-
_ => {
33-
vortex_bail!("Sum not supported for dtype {}", sum_dtype);
34-
}
35-
}
12+
fn sum(&self, array: &ChunkedArray, accumulator: &Scalar) -> VortexResult<Scalar> {
13+
array
14+
.chunks
15+
.iter()
16+
.try_fold(accumulator.clone(), |result, chunk| {
17+
sum_with_accumulator(chunk, &result)
18+
})
3619
}
3720
}
3821

3922
register_kernel!(SumKernelAdapter(ChunkedVTable).lift());
4023

41-
fn sum_int<T: NativePType + CheckedAdd>(
42-
chunks: &[ArrayRef],
43-
initial_value: T,
44-
) -> VortexResult<Option<T>> {
45-
let mut result: T = initial_value;
46-
for chunk in chunks {
47-
let chunk_sum = sum(chunk)?;
48-
let Some(chunk_sum) = chunk_sum
49-
.as_primitive()
50-
.as_::<T>()
51-
.and_then(|chunk_sum| result.checked_add(&chunk_sum))
52-
else {
53-
// Bail out on null or overflow
54-
return Ok(None);
55-
};
56-
result = chunk_sum;
57-
}
58-
Ok(Some(result))
59-
}
60-
61-
fn sum_float(chunks: &[ArrayRef], initial_value: f64) -> VortexResult<Option<f64>> {
62-
let mut result = initial_value;
63-
for chunk in chunks {
64-
let Some(chunk_sum) = sum(chunk)?.as_primitive().as_::<f64>() else {
65-
return Ok(None);
66-
};
67-
result += chunk_sum;
68-
}
69-
Ok(Some(result))
70-
}
71-
72-
fn sum_decimal(chunks: &[ArrayRef], initial_value: &Scalar) -> VortexResult<Scalar> {
73-
let mut result = initial_value.clone();
74-
75-
for chunk in chunks {
76-
result = sum_with_initial(chunk, &result)?;
77-
}
78-
79-
Ok(result)
80-
}
81-
8224
#[cfg(test)]
8325
mod tests {
8426
use vortex_buffer::buffer;

vortex-array/src/arrays/constant/compute/sum.rs

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -13,13 +13,13 @@ use crate::register_kernel;
1313
use crate::stats::Stat;
1414

1515
impl SumKernel for ConstantVTable {
16-
fn sum(&self, array: &ConstantArray, initial_value: &Scalar) -> VortexResult<Scalar> {
16+
fn sum(&self, array: &ConstantArray, accumulator: &Scalar) -> VortexResult<Scalar> {
1717
// Compute the expected dtype of the sum.
1818
let sum_dtype = Stat::Sum
1919
.dtype(array.dtype())
2020
.ok_or_else(|| vortex_err!("Sum not supported for dtype {}", array.dtype()))?;
2121

22-
let sum_value = sum_scalar(array.scalar(), array.len(), initial_value)?;
22+
let sum_value = sum_scalar(array.scalar(), array.len(), accumulator)?;
2323
Ok(Scalar::new(sum_dtype, sum_value))
2424
}
2525
}
@@ -59,7 +59,7 @@ fn sum_decimal(
5959
decimal_scalar: DecimalScalar,
6060
array_len: usize,
6161
decimal_dtype: DecimalDType,
62-
initial_value: &Scalar,
62+
accumulator: &Scalar,
6363
) -> VortexResult<ScalarValue> {
6464
let result_dtype = Stat::Sum
6565
.dtype(&DType::Decimal(decimal_dtype, Nullability::Nullable))
@@ -85,8 +85,8 @@ fn sum_decimal(
8585
.then_some(result)
8686
});
8787

88-
// Add initial_value to array_sum
89-
let initial_decimal = DecimalScalar::try_from(initial_value)?;
88+
// Add accumulator to array_sum
89+
let initial_decimal = DecimalScalar::try_from(accumulator)?;
9090
let initial_dec_value = initial_decimal
9191
.decimal_value()
9292
.unwrap_or(DecimalValue::I256(i256::ZERO));
@@ -113,7 +113,7 @@ fn sum_decimal(
113113
fn sum_integral<T>(
114114
primitive_scalar: PrimitiveScalar<'_>,
115115
array_len: usize,
116-
initial_value: &Scalar,
116+
accumulator: &Scalar,
117117
) -> VortexResult<Option<T>>
118118
where
119119
T: NativePType + CheckedMul + CheckedAdd,
@@ -126,7 +126,7 @@ where
126126
return Ok(None);
127127
};
128128

129-
let initial = initial_value
129+
let initial = accumulator
130130
.as_primitive()
131131
.as_::<T>()
132132
.vortex_expect("cannot be null");
@@ -136,7 +136,7 @@ where
136136
fn sum_float(
137137
primitive_scalar: PrimitiveScalar<'_>,
138138
array_len: usize,
139-
initial_value: &Scalar,
139+
accumulator: &Scalar,
140140
) -> VortexResult<Option<f64>> {
141141
let v = primitive_scalar
142142
.as_::<f64>()
@@ -148,7 +148,7 @@ fn sum_float(
148148
let Ok(array_sum) = v.mul_checked(array_len) else {
149149
return Ok(None);
150150
};
151-
let initial = initial_value
151+
let initial = accumulator
152152
.as_primitive()
153153
.as_::<f64>()
154154
.vortex_expect("cannot be null");

vortex-array/src/arrays/decimal/compute/sum.rs

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,7 @@ macro_rules! sum_decimal {
4141

4242
impl SumKernel for DecimalVTable {
4343
#[allow(clippy::cognitive_complexity)]
44-
fn sum(&self, array: &DecimalArray, initial_value: &Scalar) -> VortexResult<Scalar> {
44+
fn sum(&self, array: &DecimalArray, accumulator: &Scalar) -> VortexResult<Scalar> {
4545
let decimal_dtype = array.decimal_dtype();
4646

4747
// Both Spark and DataFusion use this heuristic.
@@ -52,7 +52,7 @@ impl SumKernel for DecimalVTable {
5252
let return_dtype = DecimalDType::new(new_precision, new_scale);
5353

5454
// Extract the initial value as a DecimalValue
55-
let initial_decimal = DecimalScalar::try_from(initial_value)
55+
let initial_decimal = DecimalScalar::try_from(accumulator)
5656
.vortex_expect("must be a decimal")
5757
.decimal_value()
5858
.vortex_expect("cannot be null");
@@ -65,7 +65,9 @@ impl SumKernel for DecimalVTable {
6565
let values_type = DecimalType::smallest_decimal_value_type(&return_dtype);
6666
match_each_decimal_value_type!(array.values_type(), |I| {
6767
match_each_decimal_value_type!(values_type, |O| {
68-
let initial_val: O = initial_decimal.cast().unwrap_or_else(O::default);
68+
let initial_val: O = initial_decimal
69+
.cast()
70+
.vortex_expect("cannot fail to cast initial value");
6971
Ok(Scalar::decimal(
7072
DecimalValue::from(sum_decimal!(O, array.buffer::<I>(), initial_val)),
7173
return_dtype,
@@ -78,7 +80,9 @@ impl SumKernel for DecimalVTable {
7880
let values_type = DecimalType::smallest_decimal_value_type(&return_dtype);
7981
match_each_decimal_value_type!(array.values_type(), |I| {
8082
match_each_decimal_value_type!(values_type, |O| {
81-
let initial_val: O = initial_decimal.cast().unwrap_or_else(O::default);
83+
let initial_val: O = initial_decimal
84+
.cast()
85+
.vortex_expect("cannot fail to cast initial value");
8286
Ok(Scalar::decimal(
8387
DecimalValue::from(sum_decimal!(
8488
O,

vortex-array/src/arrays/extension/compute/sum.rs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,8 +9,8 @@ use crate::compute::{self, SumKernel, SumKernelAdapter};
99
use crate::register_kernel;
1010

1111
impl SumKernel for ExtensionVTable {
12-
fn sum(&self, array: &ExtensionArray, initial_value: &Scalar) -> VortexResult<Scalar> {
13-
compute::sum_with_initial(array.storage(), initial_value)
12+
fn sum(&self, array: &ExtensionArray, accumulator: &Scalar) -> VortexResult<Scalar> {
13+
compute::sum_with_accumulator(array.storage(), accumulator)
1414
}
1515
}
1616

vortex-array/src/arrays/primitive/compute/sum.rs

Lines changed: 17 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -14,33 +14,33 @@ use crate::compute::{SumKernel, SumKernelAdapter};
1414
use crate::register_kernel;
1515

1616
impl SumKernel for PrimitiveVTable {
17-
fn sum(&self, array: &PrimitiveArray, initial_value: &Scalar) -> VortexResult<Scalar> {
17+
fn sum(&self, array: &PrimitiveArray, accumulator: &Scalar) -> VortexResult<Scalar> {
1818
let array_sum_scalar = match array.validity_mask().bit_buffer() {
1919
AllOr::All => {
2020
// All-valid
2121
match_each_native_ptype!(
2222
array.ptype(),
23-
unsigned: |T| { sum_integer::<_, u64>(array.as_slice::<T>(), initial_value.as_primitive().as_::<u64>().vortex_expect("cannot be null")).into() },
24-
signed: |T| { sum_integer::<_, i64>(array.as_slice::<T>(), initial_value.as_primitive().as_::<i64>().vortex_expect("cannot be null")).into() },
25-
floating: |T| { Some(sum_float(array.as_slice::<T>(), initial_value.as_primitive().as_::<f64>().vortex_expect("cannot be null"))).into() }
23+
unsigned: |T| { sum_integer::<_, u64>(array.as_slice::<T>(), accumulator.as_primitive().as_::<u64>().vortex_expect("cannot be null")).into() },
24+
signed: |T| { sum_integer::<_, i64>(array.as_slice::<T>(), accumulator.as_primitive().as_::<i64>().vortex_expect("cannot be null")).into() },
25+
floating: |T| { Some(sum_float(array.as_slice::<T>(), accumulator.as_primitive().as_::<f64>().vortex_expect("cannot be null"))).into() }
2626
)
2727
}
2828
AllOr::None => {
29-
// All-invalid, return initial_value
30-
return Ok(initial_value.clone());
29+
// All-invalid, return accumulator
30+
return Ok(accumulator.clone());
3131
}
3232
AllOr::Some(validity_mask) => {
3333
// Some-valid
3434
match_each_native_ptype!(
3535
array.ptype(),
3636
unsigned: |T| {
37-
sum_integer_with_validity::<_, u64>(array.as_slice::<T>(), validity_mask, initial_value.as_primitive().as_::<u64>().vortex_expect("cannot be null")).into()
37+
sum_integer_with_validity::<_, u64>(array.as_slice::<T>(), validity_mask, accumulator.as_primitive().as_::<u64>().vortex_expect("cannot be null")).into()
3838
},
3939
signed: |T| {
40-
sum_integer_with_validity::<_, i64>(array.as_slice::<T>(), validity_mask, initial_value.as_primitive().as_::<i64>().vortex_expect("cannot be null")).into()
40+
sum_integer_with_validity::<_, i64>(array.as_slice::<T>(), validity_mask, accumulator.as_primitive().as_::<i64>().vortex_expect("cannot be null")).into()
4141
},
4242
floating: |T| {
43-
Some(sum_float_with_validity(array.as_slice::<T>(), validity_mask, initial_value.as_primitive().as_::<f64>().vortex_expect("cannot be null"))).into()
43+
Some(sum_float_with_validity(array.as_slice::<T>(), validity_mask, accumulator.as_primitive().as_::<f64>().vortex_expect("cannot be null"))).into()
4444
}
4545
)
4646
}
@@ -54,9 +54,9 @@ register_kernel!(SumKernelAdapter(PrimitiveVTable).lift());
5454

5555
fn sum_integer<T: NativePType + ToPrimitive, R: NativePType + CheckedAdd>(
5656
values: &[T],
57-
initial_value: R,
57+
accumulator: R,
5858
) -> Option<R> {
59-
let mut sum = initial_value;
59+
let mut sum = accumulator;
6060
for &x in values {
6161
sum = sum.checked_add(&R::from(x)?)?;
6262
}
@@ -66,9 +66,9 @@ fn sum_integer<T: NativePType + ToPrimitive, R: NativePType + CheckedAdd>(
6666
fn sum_integer_with_validity<T: NativePType + ToPrimitive, R: NativePType + CheckedAdd>(
6767
values: &[T],
6868
validity: &BitBuffer,
69-
initial_value: R,
69+
accumulator: R,
7070
) -> Option<R> {
71-
let mut sum: R = initial_value;
71+
let mut sum: R = accumulator;
7272
for (&x, valid) in values.iter().zip_eq(validity.iter()) {
7373
if valid {
7474
sum = sum.checked_add(&R::from(x)?)?;
@@ -77,8 +77,8 @@ fn sum_integer_with_validity<T: NativePType + ToPrimitive, R: NativePType + Chec
7777
Some(sum)
7878
}
7979

80-
fn sum_float<T: NativePType + Float>(values: &[T], initial_value: f64) -> f64 {
81-
let mut sum = initial_value;
80+
fn sum_float<T: NativePType + Float>(values: &[T], accumulator: f64) -> f64 {
81+
let mut sum = accumulator;
8282
for &x in values {
8383
sum += x.to_f64().vortex_expect("Failed to cast value to f64");
8484
}
@@ -88,9 +88,9 @@ fn sum_float<T: NativePType + Float>(values: &[T], initial_value: f64) -> f64 {
8888
fn sum_float_with_validity<T: NativePType + Float>(
8989
array: &[T],
9090
validity: &BitBuffer,
91-
initial_value: f64,
91+
accumulator: f64,
9292
) -> f64 {
93-
let mut sum = initial_value;
93+
let mut sum = accumulator;
9494
for (&x, valid) in array.iter().zip_eq(validity.iter()) {
9595
if valid {
9696
sum += x.to_f64().vortex_expect("Failed to cast value to f64");

0 commit comments

Comments
 (0)