Skip to content

Commit 965891e

Browse files
authored
Compare kernel (#3144)
Migrate the CompareFn to the new kernels
1 parent 0a95b1b commit 965891e

File tree

67 files changed

+494
-377
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

67 files changed

+494
-377
lines changed

bench-vortex/src/bin/notimplemented.rs

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -159,7 +159,6 @@ fn compute_funcs(encodings: &[ArrayRef]) {
159159
table_builder.push_record(vec![
160160
"Encoding",
161161
"cast",
162-
"compare",
163162
"fill_forward",
164163
"fill_null",
165164
"scalar_at",
@@ -175,7 +174,6 @@ fn compute_funcs(encodings: &[ArrayRef]) {
175174
let mut impls = vec![id.as_ref()];
176175
for (j, func) in [
177176
encoding.cast_fn().is_some(),
178-
encoding.compare_fn().is_some(),
179177
encoding.fill_forward_fn().is_some(),
180178
encoding.fill_null_fn().is_some(),
181179
encoding.scalar_at_fn().is_some(),

encodings/alp/src/alp/array.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@ pub struct ALPArray {
2424
stats_set: ArrayStats,
2525
}
2626

27+
#[derive(Debug)]
2728
pub struct ALPEncoding;
2829
impl Encoding for ALPEncoding {
2930
type Array = ALPArray;

encodings/alp/src/alp/compute/compare.rs

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,8 @@
11
use std::fmt::Debug;
22

33
use vortex_array::arrays::ConstantArray;
4-
use vortex_array::compute::{CompareFn, Operator, compare};
5-
use vortex_array::{Array, ArrayRef};
4+
use vortex_array::compute::{CompareKernel, CompareKernelAdapter, Operator, compare};
5+
use vortex_array::{Array, ArrayRef, register_kernel};
66
use vortex_dtype::NativePType;
77
use vortex_error::{VortexResult, vortex_bail};
88
use vortex_scalar::{PrimitiveScalar, Scalar};
@@ -11,7 +11,7 @@ use crate::{ALPArray, ALPEncoding, ALPFloat, match_each_alp_float_ptype};
1111

1212
// TODO(joe): add fuzzing.
1313

14-
impl CompareFn<&ALPArray> for ALPEncoding {
14+
impl CompareKernel for ALPEncoding {
1515
fn compare(
1616
&self,
1717
lhs: &ALPArray,
@@ -42,6 +42,8 @@ impl CompareFn<&ALPArray> for ALPEncoding {
4242
}
4343
}
4444

45+
register_kernel!(CompareKernelAdapter(ALPEncoding).lift());
46+
4547
// We can compare a scalar to an ALPArray by encoding the scalar into the ALP domain and comparing
4648
// the encoded value to the encoded values in the ALPArray. There are fixups when the value doesn't
4749
// encode into the ALP domain.
Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,23 @@
1+
use vortex_array::compute::{FilterKernel, FilterKernelAdapter, filter};
2+
use vortex_array::{Array, ArrayRef, register_kernel};
3+
use vortex_error::VortexResult;
4+
use vortex_mask::Mask;
5+
6+
use crate::{ALPArray, ALPEncoding};
7+
8+
impl FilterKernel for ALPEncoding {
9+
fn filter(&self, array: &ALPArray, mask: &Mask) -> VortexResult<ArrayRef> {
10+
let patches = array
11+
.patches()
12+
.map(|p| p.filter(mask))
13+
.transpose()?
14+
.flatten();
15+
16+
Ok(
17+
ALPArray::try_new(filter(array.encoded(), mask)?, array.exponents(), patches)?
18+
.into_array(),
19+
)
20+
}
21+
}
22+
23+
register_kernel!(FilterKernelAdapter(ALPEncoding).lift());

encodings/alp/src/alp/compute/mod.rs

Lines changed: 4 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -1,18 +1,18 @@
11
mod compare;
2+
mod filter;
23

34
use std::fmt::Debug;
45

56
use vortex_array::arrays::ConstantArray;
67
use vortex_array::compute::{
7-
BetweenFn, BetweenOptions, CompareFn, FilterKernelAdapter, FilterKernelImpl, ScalarAtFn,
8-
SliceFn, StrictComparison, TakeFn, between, filter, scalar_at, slice, take,
8+
BetweenFn, BetweenOptions, ScalarAtFn, SliceFn, StrictComparison, TakeFn, between, scalar_at,
9+
slice, take,
910
};
1011
use vortex_array::variants::PrimitiveArrayTrait;
1112
use vortex_array::vtable::ComputeVTable;
12-
use vortex_array::{Array, ArrayRef, register_kernel};
13+
use vortex_array::{Array, ArrayRef};
1314
use vortex_dtype::{NativePType, Nullability};
1415
use vortex_error::VortexResult;
15-
use vortex_mask::Mask;
1616
use vortex_scalar::{Scalar, ScalarType};
1717

1818
use crate::{ALPArray, ALPEncoding, ALPFloat, match_each_alp_float_ptype};
@@ -22,10 +22,6 @@ impl ComputeVTable for ALPEncoding {
2222
Some(self)
2323
}
2424

25-
fn compare_fn(&self) -> Option<&dyn CompareFn<&dyn Array>> {
26-
Some(self)
27-
}
28-
2925
fn scalar_at_fn(&self) -> Option<&dyn ScalarAtFn<&dyn Array>> {
3026
Some(self)
3127
}
@@ -98,23 +94,6 @@ impl SliceFn<&ALPArray> for ALPEncoding {
9894
}
9995
}
10096

101-
impl FilterKernelImpl for ALPEncoding {
102-
fn filter(&self, array: &ALPArray, mask: &Mask) -> VortexResult<ArrayRef> {
103-
let patches = array
104-
.patches()
105-
.map(|p| p.filter(mask))
106-
.transpose()?
107-
.flatten();
108-
109-
Ok(
110-
ALPArray::try_new(filter(array.encoded(), mask)?, array.exponents(), patches)?
111-
.into_array(),
112-
)
113-
}
114-
}
115-
116-
register_kernel!(FilterKernelAdapter(ALPEncoding).lift());
117-
11897
impl BetweenFn<&ALPArray> for ALPEncoding {
11998
fn between(
12099
&self,

encodings/alp/src/alp_rd/array.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@ pub struct ALPRDArray {
2828
stats_set: ArrayStats,
2929
}
3030

31+
#[derive(Debug)]
3132
pub struct ALPRDEncoding;
3233
impl Encoding for ALPRDEncoding {
3334
type Array = ALPRDArray;

encodings/alp/src/alp_rd/compute/filter.rs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,11 @@
1-
use vortex_array::compute::{FilterKernelAdapter, FilterKernelImpl, filter};
1+
use vortex_array::compute::{FilterKernel, FilterKernelAdapter, filter};
22
use vortex_array::{Array, ArrayRef, register_kernel};
33
use vortex_error::VortexResult;
44
use vortex_mask::Mask;
55

66
use crate::{ALPRDArray, ALPRDEncoding};
77

8-
impl FilterKernelImpl for ALPRDEncoding {
8+
impl FilterKernel for ALPRDEncoding {
99
fn filter(&self, array: &ALPRDArray, mask: &Mask) -> VortexResult<ArrayRef> {
1010
let left_parts_exceptions = array
1111
.left_parts_patches()

encodings/bytebool/src/array.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@ pub struct ByteBoolArray {
2525

2626
try_from_array_ref!(ByteBoolArray);
2727

28+
#[derive(Debug)]
2829
pub struct ByteBoolEncoding;
2930
impl Encoding for ByteBoolEncoding {
3031
type Array = ByteBoolArray;

encodings/datetime-parts/src/array.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@ pub struct DateTimePartsArray {
2525
stats_set: ArrayStats,
2626
}
2727

28+
#[derive(Debug)]
2829
pub struct DateTimePartsEncoding;
2930
impl Encoding for DateTimePartsEncoding {
3031
type Array = DateTimePartsArray;

encodings/datetime-parts/src/compute/compare.rs

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,16 @@
11
use vortex_array::arrays::ConstantArray;
2-
use vortex_array::compute::{CompareFn, Operator, and, compare, or, try_cast};
3-
use vortex_array::{Array, ArrayRef};
2+
use vortex_array::compute::{
3+
CompareKernel, CompareKernelAdapter, Operator, and, compare, or, try_cast,
4+
};
5+
use vortex_array::{Array, ArrayRef, register_kernel};
46
use vortex_dtype::DType;
57
use vortex_dtype::datetime::TemporalMetadata;
68
use vortex_error::{VortexExpect as _, VortexResult};
79

810
use crate::array::{DateTimePartsArray, DateTimePartsEncoding};
911
use crate::timestamp;
1012

11-
impl CompareFn<&DateTimePartsArray> for DateTimePartsEncoding {
13+
impl CompareKernel for DateTimePartsEncoding {
1214
/// Compares two arrays and returns a new boolean array with the result of the comparison.
1315
/// Or, returns None if comparison is not supported.
1416
fn compare(
@@ -54,6 +56,8 @@ impl CompareFn<&DateTimePartsArray> for DateTimePartsEncoding {
5456
}
5557
}
5658

59+
register_kernel!(CompareKernelAdapter(DateTimePartsEncoding).lift());
60+
5761
fn compare_eq(
5862
lhs: &DateTimePartsArray,
5963
ts_parts: &timestamp::TimestampParts,

0 commit comments

Comments
 (0)