Skip to content

Commit 99d3fec

Browse files
authored
Port TakeFn to TakeKernel (#3239)
1 parent 9b9f045 commit 99d3fec

File tree

40 files changed

+474
-354
lines changed

40 files changed

+474
-354
lines changed

encodings/alp/src/alp/compute/mod.rs

Lines changed: 6 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -2,20 +2,16 @@ mod between;
22
mod compare;
33
mod nan_count;
44

5-
use vortex_array::compute::{TakeFn, take};
5+
use vortex_array::compute::{TakeKernel, TakeKernelAdapter, take};
66
use vortex_array::vtable::ComputeVTable;
7-
use vortex_array::{Array, ArrayRef};
7+
use vortex_array::{Array, ArrayRef, register_kernel};
88
use vortex_error::VortexResult;
99

1010
use crate::{ALPArray, ALPEncoding};
1111

12-
impl ComputeVTable for ALPEncoding {
13-
fn take_fn(&self) -> Option<&dyn TakeFn<&dyn Array>> {
14-
Some(self)
15-
}
16-
}
12+
impl ComputeVTable for ALPEncoding {}
1713

18-
impl TakeFn<&ALPArray> for ALPEncoding {
14+
impl TakeKernel for ALPEncoding {
1915
fn take(&self, array: &ALPArray, indices: &dyn Array) -> VortexResult<ArrayRef> {
2016
let taken_encoded = take(array.encoded(), indices)?;
2117
let taken_patches = array
@@ -34,3 +30,5 @@ impl TakeFn<&ALPArray> for ALPEncoding {
3430
Ok(ALPArray::try_new(taken_encoded, array.exponents(), taken_patches)?.into_array())
3531
}
3632
}
33+
34+
register_kernel!(TakeKernelAdapter(ALPEncoding).lift());
Lines changed: 1 addition & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,3 @@
1-
use vortex_array::Array;
2-
use vortex_array::compute::TakeFn;
31
use vortex_array::vtable::ComputeVTable;
42

53
use crate::ALPRDEncoding;
@@ -8,8 +6,4 @@ mod filter;
86
mod mask;
97
mod take;
108

11-
impl ComputeVTable for ALPRDEncoding {
12-
fn take_fn(&self) -> Option<&dyn TakeFn<&dyn Array>> {
13-
Some(self)
14-
}
15-
}
9+
impl ComputeVTable for ALPRDEncoding {}

encodings/alp/src/alp_rd/compute/take.rs

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,11 @@
1-
use vortex_array::compute::{TakeFn, fill_null, take};
2-
use vortex_array::{Array, ArrayRef};
1+
use vortex_array::compute::{TakeKernel, TakeKernelAdapter, fill_null, take};
2+
use vortex_array::{Array, ArrayRef, register_kernel};
33
use vortex_error::VortexResult;
44
use vortex_scalar::{Scalar, ScalarValue};
55

66
use crate::{ALPRDArray, ALPRDEncoding};
77

8-
impl TakeFn<&ALPRDArray> for ALPRDEncoding {
8+
impl TakeKernel for ALPRDEncoding {
99
fn take(&self, array: &ALPRDArray, indices: &dyn Array) -> VortexResult<ArrayRef> {
1010
let taken_left_parts = take(array.left_parts(), indices)?;
1111
let left_parts_exceptions = array
@@ -40,6 +40,8 @@ impl TakeFn<&ALPRDArray> for ALPRDEncoding {
4040
}
4141
}
4242

43+
register_kernel!(TakeKernelAdapter(ALPRDEncoding).lift());
44+
4345
#[cfg(test)]
4446
mod test {
4547
use rstest::rstest;

encodings/bytebool/src/compute.rs

Lines changed: 5 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
use num_traits::AsPrimitive;
2-
use vortex_array::compute::{MaskKernel, MaskKernelAdapter, TakeFn};
2+
use vortex_array::compute::{MaskKernel, MaskKernelAdapter, TakeKernel, TakeKernelAdapter};
33
use vortex_array::variants::PrimitiveArrayTrait;
44
use vortex_array::vtable::ComputeVTable;
55
use vortex_array::{Array, ArrayRef, ToCanonical, register_kernel};
@@ -9,11 +9,7 @@ use vortex_mask::Mask;
99

1010
use super::{ByteBoolArray, ByteBoolEncoding};
1111

12-
impl ComputeVTable for ByteBoolEncoding {
13-
fn take_fn(&self) -> Option<&dyn TakeFn<&dyn Array>> {
14-
Some(self)
15-
}
16-
}
12+
impl ComputeVTable for ByteBoolEncoding {}
1713

1814
impl MaskKernel for ByteBoolEncoding {
1915
fn mask(&self, array: &ByteBoolArray, mask: &Mask) -> VortexResult<ArrayRef> {
@@ -23,7 +19,7 @@ impl MaskKernel for ByteBoolEncoding {
2319

2420
register_kernel!(MaskKernelAdapter(ByteBoolEncoding).lift());
2521

26-
impl TakeFn<&ByteBoolArray> for ByteBoolEncoding {
22+
impl TakeKernel for ByteBoolEncoding {
2723
fn take(&self, array: &ByteBoolArray, indices: &dyn Array) -> VortexResult<ArrayRef> {
2824
let validity = array.validity_mask()?;
2925
let indices = indices.to_primitive()?;
@@ -69,6 +65,8 @@ impl TakeFn<&ByteBoolArray> for ByteBoolEncoding {
6965
}
7066
}
7167

68+
register_kernel!(TakeKernelAdapter(ByteBoolEncoding).lift());
69+
7270
#[cfg(test)]
7371
mod tests {
7472
use vortex_array::compute::conformance::mask::test_mask;

encodings/datetime-parts/src/compute/mod.rs

Lines changed: 1 addition & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -4,16 +4,8 @@ mod filter;
44
mod is_constant;
55
mod take;
66

7-
use vortex_array::Array;
8-
use vortex_array::compute::TakeFn;
97
use vortex_array::vtable::ComputeVTable;
108

119
use crate::DateTimePartsEncoding;
1210

13-
impl ComputeVTable for DateTimePartsEncoding {
14-
fn take_fn(&self) -> Option<&dyn TakeFn<&dyn Array>> {
15-
Some(self)
16-
}
17-
18-
// TODO(joe): implement `between_fn` this is used at lot.
19-
}
11+
impl ComputeVTable for DateTimePartsEncoding {}
Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,10 @@
1-
use vortex_array::compute::{TakeFn, take};
2-
use vortex_array::{Array, ArrayRef};
1+
use vortex_array::compute::{TakeKernel, TakeKernelAdapter, take};
2+
use vortex_array::{Array, ArrayRef, register_kernel};
33
use vortex_error::VortexResult;
44

55
use crate::{DateTimePartsArray, DateTimePartsEncoding};
66

7-
impl TakeFn<&DateTimePartsArray> for DateTimePartsEncoding {
7+
impl TakeKernel for DateTimePartsEncoding {
88
fn take(&self, array: &DateTimePartsArray, indices: &dyn Array) -> VortexResult<ArrayRef> {
99
Ok(DateTimePartsArray::try_new(
1010
array.dtype().clone(),
@@ -15,3 +15,5 @@ impl TakeFn<&DateTimePartsArray> for DateTimePartsEncoding {
1515
.into_array())
1616
}
1717
}
18+
19+
register_kernel!(TakeKernelAdapter(DateTimePartsEncoding).lift());

encodings/dict/src/compute/mod.rs

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -6,27 +6,27 @@ mod is_sorted;
66
mod like;
77
mod min_max;
88

9-
use vortex_array::compute::{FilterKernel, FilterKernelAdapter, TakeFn, filter, take};
9+
use vortex_array::compute::{
10+
FilterKernel, FilterKernelAdapter, TakeKernel, TakeKernelAdapter, filter, take,
11+
};
1012
use vortex_array::vtable::ComputeVTable;
1113
use vortex_array::{Array, ArrayRef, register_kernel};
1214
use vortex_error::VortexResult;
1315
use vortex_mask::Mask;
1416

1517
use crate::{DictArray, DictEncoding};
1618

17-
impl ComputeVTable for DictEncoding {
18-
fn take_fn(&self) -> Option<&dyn TakeFn<&dyn Array>> {
19-
Some(self)
20-
}
21-
}
19+
impl ComputeVTable for DictEncoding {}
2220

23-
impl TakeFn<&DictArray> for DictEncoding {
21+
impl TakeKernel for DictEncoding {
2422
fn take(&self, array: &DictArray, indices: &dyn Array) -> VortexResult<ArrayRef> {
2523
let codes = take(array.codes(), indices)?;
2624
DictArray::try_new(codes, array.values().clone()).map(|a| a.into_array())
2725
}
2826
}
2927

28+
register_kernel!(TakeKernelAdapter(DictEncoding).lift());
29+
3030
impl FilterKernel for DictEncoding {
3131
fn filter(&self, array: &DictArray, mask: &Mask) -> VortexResult<ArrayRef> {
3232
let codes = filter(array.codes(), mask)?;

encodings/fastlanes/src/bitpacking/compute/mod.rs

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
use vortex_array::Array;
2-
use vortex_array::compute::{SearchSortedFn, TakeFn};
2+
use vortex_array::compute::SearchSortedFn;
33
use vortex_array::vtable::ComputeVTable;
44

55
use crate::BitPackedEncoding;
@@ -14,10 +14,6 @@ impl ComputeVTable for BitPackedEncoding {
1414
fn search_sorted_fn(&self) -> Option<&dyn SearchSortedFn<&dyn Array>> {
1515
Some(self)
1616
}
17-
18-
fn take_fn(&self) -> Option<&dyn TakeFn<&dyn Array>> {
19-
Some(self)
20-
}
2117
}
2218

2319
fn chunked_indices<F: FnMut(usize, &[usize])>(

encodings/fastlanes/src/bitpacking/compute/take.rs

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3,10 +3,10 @@ use std::mem::MaybeUninit;
33

44
use fastlanes::BitPacking;
55
use vortex_array::arrays::PrimitiveArray;
6-
use vortex_array::compute::{TakeFn, take};
6+
use vortex_array::compute::{TakeKernel, TakeKernelAdapter, take};
77
use vortex_array::validity::Validity;
88
use vortex_array::variants::PrimitiveArrayTrait;
9-
use vortex_array::{Array, ArrayRef, ToCanonical};
9+
use vortex_array::{Array, ArrayRef, ToCanonical, register_kernel};
1010
use vortex_buffer::{Buffer, BufferMut};
1111
use vortex_dtype::{
1212
NativePType, PType, match_each_integer_ptype, match_each_unsigned_integer_ptype,
@@ -21,7 +21,7 @@ use crate::{BitPackedArray, BitPackedEncoding, unpack_single_primitive};
2121
// see https://github.com/spiraldb/vortex/pull/190#issue-2223752833
2222
pub(super) const UNPACK_CHUNK_THRESHOLD: usize = 8;
2323

24-
impl TakeFn<&BitPackedArray> for BitPackedEncoding {
24+
impl TakeKernel for BitPackedEncoding {
2525
fn take(&self, array: &BitPackedArray, indices: &dyn Array) -> VortexResult<ArrayRef> {
2626
// If the indices are large enough, it's faster to flatten and take the primitive array.
2727
if indices.len() * UNPACK_CHUNK_THRESHOLD > array.len() {
@@ -44,6 +44,8 @@ impl TakeFn<&BitPackedArray> for BitPackedEncoding {
4444
}
4545
}
4646

47+
register_kernel!(TakeKernelAdapter(BitPackedEncoding).lift());
48+
4749
fn take_primitive<T: NativePType + BitPacking, I: NativePType>(
4850
array: &BitPackedArray,
4951
indices: &PrimitiveArray,

encodings/fastlanes/src/for/compute/mod.rs

Lines changed: 5 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -3,8 +3,8 @@ mod is_constant;
33

44
use num_traits::WrappingSub;
55
use vortex_array::compute::{
6-
FilterKernel, FilterKernelAdapter, SearchResult, SearchSortedFn, SearchSortedSide, TakeFn,
7-
filter, search_sorted, take,
6+
FilterKernel, FilterKernelAdapter, SearchResult, SearchSortedFn, SearchSortedSide, TakeKernel,
7+
TakeKernelAdapter, filter, search_sorted, take,
88
};
99
use vortex_array::variants::PrimitiveArrayTrait;
1010
use vortex_array::vtable::ComputeVTable;
@@ -20,13 +20,9 @@ impl ComputeVTable for FoREncoding {
2020
fn search_sorted_fn(&self) -> Option<&dyn SearchSortedFn<&dyn Array>> {
2121
Some(self)
2222
}
23-
24-
fn take_fn(&self) -> Option<&dyn TakeFn<&dyn Array>> {
25-
Some(self)
26-
}
2723
}
2824

29-
impl TakeFn<&FoRArray> for FoREncoding {
25+
impl TakeKernel for FoREncoding {
3026
fn take(&self, array: &FoRArray, indices: &dyn Array) -> VortexResult<ArrayRef> {
3127
FoRArray::try_new(
3228
take(array.encoded(), indices)?,
@@ -36,6 +32,8 @@ impl TakeFn<&FoRArray> for FoREncoding {
3632
}
3733
}
3834

35+
register_kernel!(TakeKernelAdapter(FoREncoding).lift());
36+
3937
impl FilterKernel for FoREncoding {
4038
fn filter(&self, array: &FoRArray, mask: &Mask) -> VortexResult<ArrayRef> {
4139
FoRArray::try_new(

0 commit comments

Comments
 (0)