Skip to content

Commit 768e27c

Browse files
authored
Port SumFn to SumKernel (#3149)
1 parent 2b1337c commit 768e27c

File tree

12 files changed

+230
-151
lines changed

12 files changed

+230
-151
lines changed

vortex-array/src/arrays/bool/compute/mod.rs

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@ use crate::Array;
22
use crate::arrays::BoolEncoding;
33
use crate::compute::{
44
CastFn, FillForwardFn, FillNullFn, InvertFn, IsConstantFn, IsSortedFn, MaskFn, MinMaxFn,
5-
ScalarAtFn, SliceFn, SumFn, TakeFn, ToArrowFn, UncompressedSizeFn,
5+
ScalarAtFn, SliceFn, TakeFn, ToArrowFn, UncompressedSizeFn,
66
};
77
use crate::vtable::ComputeVTable;
88

@@ -64,10 +64,6 @@ impl ComputeVTable for BoolEncoding {
6464
Some(self)
6565
}
6666

67-
fn sum_fn(&self) -> Option<&dyn SumFn<&dyn Array>> {
68-
Some(self)
69-
}
70-
7167
fn take_fn(&self) -> Option<&dyn TakeFn<&dyn Array>> {
7268
Some(self)
7369
}

vortex-array/src/arrays/bool/compute/sum.rs

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -4,11 +4,11 @@ use vortex_error::VortexResult;
44
use vortex_mask::AllOr;
55
use vortex_scalar::Scalar;
66

7-
use crate::Array;
87
use crate::arrays::{BoolArray, BoolEncoding};
9-
use crate::compute::SumFn;
8+
use crate::compute::{SumKernel, SumKernelAdapter};
9+
use crate::{Array, register_kernel};
1010

11-
impl SumFn<&BoolArray> for BoolEncoding {
11+
impl SumKernel for BoolEncoding {
1212
fn sum(&self, array: &BoolArray) -> VortexResult<Scalar> {
1313
let true_count: Option<u64> = match array.validity_mask()?.boolean_buffer() {
1414
AllOr::All => {
@@ -29,3 +29,5 @@ impl SumFn<&BoolArray> for BoolEncoding {
2929
Ok(Scalar::from(true_count))
3030
}
3131
}
32+
33+
register_kernel!(SumKernelAdapter(BoolEncoding).lift());

vortex-array/src/arrays/chunked/compute/mod.rs

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@ use crate::arrays::ChunkedEncoding;
55
use crate::arrays::chunked::ChunkedArray;
66
use crate::compute::{
77
CastFn, FillNullFn, InvertFn, IsConstantFn, IsSortedFn, MaskFn, MinMaxFn, ScalarAtFn, SliceFn,
8-
SumFn, TakeFn, UncompressedSizeFn, try_cast,
8+
TakeFn, UncompressedSizeFn, try_cast,
99
};
1010
use crate::vtable::ComputeVTable;
1111
use crate::{Array, ArrayRef};
@@ -70,10 +70,6 @@ impl ComputeVTable for ChunkedEncoding {
7070
fn uncompressed_size_fn(&self) -> Option<&dyn UncompressedSizeFn<&dyn Array>> {
7171
Some(self)
7272
}
73-
74-
fn sum_fn(&self) -> Option<&dyn SumFn<&dyn Array>> {
75-
Some(self)
76-
}
7773
}
7874

7975
impl CastFn<&ChunkedArray> for ChunkedEncoding {

vortex-array/src/arrays/chunked/compute/sum.rs

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -4,11 +4,11 @@ use vortex_error::{VortexExpect, VortexResult, vortex_err};
44
use vortex_scalar::{FromPrimitiveOrF16, Scalar};
55

66
use crate::arrays::{ChunkedArray, ChunkedEncoding};
7-
use crate::compute::{SumFn, sum};
7+
use crate::compute::{SumKernel, SumKernelAdapter, sum};
88
use crate::stats::Stat;
9-
use crate::{Array, ArrayRef};
9+
use crate::{Array, ArrayRef, register_kernel};
1010

11-
impl SumFn<&ChunkedArray> for ChunkedEncoding {
11+
impl SumKernel for ChunkedEncoding {
1212
fn sum(&self, array: &ChunkedArray) -> VortexResult<Scalar> {
1313
let sum_dtype = Stat::Sum
1414
.dtype(array.dtype())
@@ -26,6 +26,8 @@ impl SumFn<&ChunkedArray> for ChunkedEncoding {
2626
}
2727
}
2828

29+
register_kernel!(SumKernelAdapter(ChunkedEncoding).lift());
30+
2931
fn sum_int<T: NativePType + PrimInt + FromPrimitiveOrF16>(
3032
chunks: &[ArrayRef],
3133
) -> VortexResult<Option<T>> {

vortex-array/src/arrays/constant/compute/mod.rs

Lines changed: 4 additions & 55 deletions
Original file line numberDiff line numberDiff line change
@@ -5,19 +5,17 @@ mod compare;
55
mod filter;
66
mod invert;
77
mod search_sorted;
8+
mod sum;
89
mod take;
910

10-
use num_traits::{CheckedMul, ToPrimitive};
11-
use vortex_dtype::{NativePType, PType, match_each_native_ptype};
12-
use vortex_error::{VortexExpect, VortexResult, vortex_err};
13-
use vortex_scalar::{FromPrimitiveOrF16, PrimitiveScalar, Scalar};
11+
use vortex_error::VortexResult;
12+
use vortex_scalar::Scalar;
1413

1514
use crate::arrays::ConstantEncoding;
1615
use crate::arrays::constant::ConstantArray;
1716
use crate::compute::{
18-
CastFn, InvertFn, ScalarAtFn, SearchSortedFn, SliceFn, SumFn, TakeFn, UncompressedSizeFn,
17+
CastFn, InvertFn, ScalarAtFn, SearchSortedFn, SliceFn, TakeFn, UncompressedSizeFn,
1918
};
20-
use crate::stats::Stat;
2119
use crate::vtable::ComputeVTable;
2220
use crate::{Array, ArrayRef};
2321

@@ -42,10 +40,6 @@ impl ComputeVTable for ConstantEncoding {
4240
Some(self)
4341
}
4442

45-
fn sum_fn(&self) -> Option<&dyn SumFn<&dyn Array>> {
46-
Some(self)
47-
}
48-
4943
fn take_fn(&self) -> Option<&dyn TakeFn<&dyn Array>> {
5044
Some(self)
5145
}
@@ -79,51 +73,6 @@ impl UncompressedSizeFn<&ConstantArray> for ConstantEncoding {
7973
}
8074
}
8175

82-
impl SumFn<&ConstantArray> for ConstantEncoding {
83-
fn sum(&self, array: &ConstantArray) -> VortexResult<Scalar> {
84-
let sum_dtype = Stat::Sum
85-
.dtype(array.dtype())
86-
.ok_or_else(|| vortex_err!("Sum not supported for dtype {}", array.dtype()))?;
87-
let sum_ptype = PType::try_from(&sum_dtype).vortex_expect("sum dtype must be primitive");
88-
89-
let scalar = array.scalar();
90-
91-
let scalar_value = match_each_native_ptype!(
92-
sum_ptype,
93-
unsigned: |$T| { sum_integral::<u64>(scalar.as_primitive(), array.len())?.into() }
94-
signed: |$T| { sum_integral::<i64>(scalar.as_primitive(), array.len())?.into() }
95-
floating: |$T| { sum_float(scalar.as_primitive(), array.len())?.into() }
96-
);
97-
98-
Ok(Scalar::new(sum_dtype, scalar_value))
99-
}
100-
}
101-
102-
fn sum_integral<T>(
103-
primitive_scalar: PrimitiveScalar<'_>,
104-
array_len: usize,
105-
) -> VortexResult<Option<T>>
106-
where
107-
T: FromPrimitiveOrF16 + NativePType + CheckedMul,
108-
Scalar: From<Option<T>>,
109-
{
110-
let v = primitive_scalar.as_::<T>()?;
111-
let array_len =
112-
T::from(array_len).ok_or_else(|| vortex_err!("array_len must fit the sum type"))?;
113-
let sum = v.and_then(|v| v.checked_mul(&array_len));
114-
115-
Ok(sum)
116-
}
117-
118-
fn sum_float(primitive_scalar: PrimitiveScalar<'_>, array_len: usize) -> VortexResult<Option<f64>> {
119-
let v = primitive_scalar.as_::<f64>()?;
120-
let array_len = array_len
121-
.to_f64()
122-
.ok_or_else(|| vortex_err!("array_len must fit the sum type"))?;
123-
124-
Ok(v.map(|v| v * array_len))
125-
}
126-
12776
#[cfg(test)]
12877
mod test {
12978
use vortex_dtype::half::f16;
Lines changed: 56 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,56 @@
1+
use num_traits::{CheckedMul, ToPrimitive};
2+
use vortex_dtype::{NativePType, PType, match_each_native_ptype};
3+
use vortex_error::{VortexExpect, VortexResult, vortex_err};
4+
use vortex_scalar::{FromPrimitiveOrF16, PrimitiveScalar, Scalar};
5+
6+
use crate::arrays::{ConstantArray, ConstantEncoding};
7+
use crate::compute::{SumKernel, SumKernelAdapter};
8+
use crate::stats::Stat;
9+
use crate::{Array, register_kernel};
10+
11+
impl SumKernel for ConstantEncoding {
12+
fn sum(&self, array: &ConstantArray) -> VortexResult<Scalar> {
13+
let sum_dtype = Stat::Sum
14+
.dtype(array.dtype())
15+
.ok_or_else(|| vortex_err!("Sum not supported for dtype {}", array.dtype()))?;
16+
let sum_ptype = PType::try_from(&sum_dtype).vortex_expect("sum dtype must be primitive");
17+
18+
let scalar = array.scalar();
19+
20+
let scalar_value = match_each_native_ptype!(
21+
sum_ptype,
22+
unsigned: |$T| { sum_integral::<u64>(scalar.as_primitive(), array.len())?.into() }
23+
signed: |$T| { sum_integral::<i64>(scalar.as_primitive(), array.len())?.into() }
24+
floating: |$T| { sum_float(scalar.as_primitive(), array.len())?.into() }
25+
);
26+
27+
Ok(Scalar::new(sum_dtype, scalar_value))
28+
}
29+
}
30+
31+
fn sum_integral<T>(
32+
primitive_scalar: PrimitiveScalar<'_>,
33+
array_len: usize,
34+
) -> VortexResult<Option<T>>
35+
where
36+
T: FromPrimitiveOrF16 + NativePType + CheckedMul,
37+
Scalar: From<Option<T>>,
38+
{
39+
let v = primitive_scalar.as_::<T>()?;
40+
let array_len =
41+
T::from(array_len).ok_or_else(|| vortex_err!("array_len must fit the sum type"))?;
42+
let sum = v.and_then(|v| v.checked_mul(&array_len));
43+
44+
Ok(sum)
45+
}
46+
47+
fn sum_float(primitive_scalar: PrimitiveScalar<'_>, array_len: usize) -> VortexResult<Option<f64>> {
48+
let v = primitive_scalar.as_::<f64>()?;
49+
let array_len = array_len
50+
.to_f64()
51+
.ok_or_else(|| vortex_err!("array_len must fit the sum type"))?;
52+
53+
Ok(v.map(|v| v * array_len))
54+
}
55+
56+
register_kernel!(SumKernelAdapter(ConstantEncoding).lift());

vortex-array/src/arrays/extension/compute/mod.rs

Lines changed: 6 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -9,9 +9,9 @@ use crate::arrays::ExtensionEncoding;
99
use crate::arrays::extension::ExtensionArray;
1010
use crate::compute::{
1111
CastFn, FilterKernel, FilterKernelAdapter, IsConstantFn, IsConstantOpts, IsSortedFn, MinMaxFn,
12-
MinMaxResult, ScalarAtFn, SliceFn, SumFn, TakeFn, ToArrowFn, UncompressedSizeFn, filter,
13-
is_constant_opts, is_sorted, is_strict_sorted, min_max, scalar_at, slice, sum, take,
14-
uncompressed_size,
12+
MinMaxResult, ScalarAtFn, SliceFn, SumKernel, SumKernelAdapter, TakeFn, ToArrowFn,
13+
UncompressedSizeFn, filter, is_constant_opts, is_sorted, is_strict_sorted, min_max, scalar_at,
14+
slice, sum, take, uncompressed_size,
1515
};
1616
use crate::variants::ExtensionArrayTrait;
1717
use crate::vtable::ComputeVTable;
@@ -41,10 +41,6 @@ impl ComputeVTable for ExtensionEncoding {
4141
Some(self)
4242
}
4343

44-
fn sum_fn(&self) -> Option<&dyn SumFn<&dyn Array>> {
45-
Some(self)
46-
}
47-
4844
fn take_fn(&self) -> Option<&dyn TakeFn<&dyn Array>> {
4945
Some(self)
5046
}
@@ -92,12 +88,14 @@ impl SliceFn<&ExtensionArray> for ExtensionEncoding {
9288
}
9389
}
9490

95-
impl SumFn<&ExtensionArray> for ExtensionEncoding {
91+
impl SumKernel for ExtensionEncoding {
9692
fn sum(&self, array: &ExtensionArray) -> VortexResult<Scalar> {
9793
sum(array.storage())
9894
}
9995
}
10096

97+
register_kernel!(SumKernelAdapter(ExtensionEncoding).lift());
98+
10199
impl TakeFn<&ExtensionArray> for ExtensionEncoding {
102100
fn take(&self, array: &ExtensionArray, indices: &dyn Array) -> VortexResult<ArrayRef> {
103101
Ok(

vortex-array/src/arrays/primitive/compute/mod.rs

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@ use crate::Array;
22
use crate::arrays::PrimitiveEncoding;
33
use crate::compute::{
44
CastFn, FillForwardFn, FillNullFn, IsConstantFn, IsSortedFn, MaskFn, MinMaxFn, ScalarAtFn,
5-
SearchSortedFn, SearchSortedUsizeFn, SliceFn, SumFn, TakeFn, ToArrowFn, UncompressedSizeFn,
5+
SearchSortedFn, SearchSortedUsizeFn, SliceFn, TakeFn, ToArrowFn, UncompressedSizeFn,
66
};
77
use crate::vtable::ComputeVTable;
88

@@ -70,10 +70,6 @@ impl ComputeVTable for PrimitiveEncoding {
7070
Some(self)
7171
}
7272

73-
fn sum_fn(&self) -> Option<&dyn SumFn<&dyn Array>> {
74-
Some(self)
75-
}
76-
7773
fn take_fn(&self) -> Option<&dyn TakeFn<&dyn Array>> {
7874
Some(self)
7975
}

vortex-array/src/arrays/primitive/compute/sum.rs

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -6,13 +6,13 @@ use vortex_error::{VortexExpect, VortexResult};
66
use vortex_mask::AllOr;
77
use vortex_scalar::Scalar;
88

9-
use crate::Array;
109
use crate::arrays::{PrimitiveArray, PrimitiveEncoding};
11-
use crate::compute::SumFn;
10+
use crate::compute::{SumKernel, SumKernelAdapter};
1211
use crate::stats::Stat;
1312
use crate::variants::PrimitiveArrayTrait;
13+
use crate::{Array, register_kernel};
1414

15-
impl SumFn<&PrimitiveArray> for PrimitiveEncoding {
15+
impl SumKernel for PrimitiveEncoding {
1616
fn sum(&self, array: &PrimitiveArray) -> VortexResult<Scalar> {
1717
Ok(match array.validity_mask()?.boolean_buffer() {
1818
AllOr::All => {
@@ -53,6 +53,8 @@ impl SumFn<&PrimitiveArray> for PrimitiveEncoding {
5353
}
5454
}
5555

56+
register_kernel!(SumKernelAdapter(PrimitiveEncoding).lift());
57+
5658
fn sum_integer<T: NativePType + ToPrimitive, R: NativePType + CheckedAdd>(
5759
values: &[T],
5860
) -> Option<R> {

vortex-array/src/compute/mod.rs

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -261,10 +261,10 @@ impl Output {
261261
}
262262
}
263263

264-
pub fn into_scalar(self) -> Option<Scalar> {
265-
match self {
266-
Output::Scalar(scalar) => Some(scalar),
267-
_ => None,
264+
pub fn unwrap_scalar(self) -> VortexResult<Scalar> {
265+
match &self {
266+
Output::Array(_) => vortex_bail!("Expected array output, got Array"),
267+
Output::Scalar(scalar) => Ok(scalar.clone()),
268268
}
269269
}
270270

0 commit comments

Comments
 (0)