Skip to content

Commit 79b5719

Browse files
committed
Wire up filter kernel
Signed-off-by: Nicholas Gates <[email protected]>
1 parent 222cd71 commit 79b5719

File tree

12 files changed

+103
-79
lines changed

12 files changed

+103
-79
lines changed

encodings/decimal-byte-parts/src/decimal_byte_parts/rules.rs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
// SPDX-License-Identifier: Apache-2.0
22
// SPDX-FileCopyrightText: Copyright the Vortex contributors
33

4+
use vortex_array::Array;
45
use vortex_array::ArrayRef;
56
use vortex_array::IntoArray;
67
use vortex_array::arrays::FilterArray;
@@ -38,8 +39,7 @@ impl ArrayParentReduceRule<DecimalBytePartsVTable> for DecimalBytePartsFilterPus
3839
return Ok(None);
3940
}
4041

41-
let new_msp =
42-
FilterArray::new(child.msp.clone(), parent.filter_mask().clone()).into_array();
42+
let new_msp = child.msp.filter(parent.filter_mask().clone())?;
4343
let new_child =
4444
DecimalBytePartsArray::try_new(new_msp, *child.decimal_dtype())?.into_array();
4545
Ok(Some(new_child))

encodings/fastlanes/src/bitpacking/mod.rs

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,6 @@ pub use array::bitpack_decompress;
88
pub use array::unpack_iter;
99

1010
mod compute;
11-
mod kernels;
1211

1312
mod vtable;
1413
pub use vtable::BitPackedVTable;

encodings/fastlanes/src/bitpacking/kernels/filter.rs renamed to encodings/fastlanes/src/bitpacking/vtable/kernels/filter.rs

Lines changed: 43 additions & 47 deletions
Original file line numberDiff line numberDiff line change
@@ -2,36 +2,37 @@
22
// SPDX-FileCopyrightText: Copyright the Vortex contributors
33

44
use std::mem::MaybeUninit;
5+
use std::sync::Arc;
56

67
use fastlanes::BitPacking;
78
use vortex_array::ExecutionCtx;
8-
use vortex_array::IntoArray;
9-
use vortex_array::VectorExecutor;
109
use vortex_array::arrays::FilterArray;
1110
use vortex_array::arrays::FilterVTable;
1211
use vortex_array::kernel::ExecuteParentKernel;
12+
use vortex_array::kernel::ParentKernelSet;
1313
use vortex_array::matchers::Exact;
14-
use vortex_array::patches::patch_pvector;
15-
use vortex_buffer::Buffer;
1614
use vortex_buffer::BufferMut;
1715
use vortex_compute::filter::Filter;
1816
use vortex_dtype::NativePType;
1917
use vortex_dtype::PType;
2018
use vortex_dtype::UnsignedPType;
2119
use vortex_dtype::match_each_integer_ptype;
22-
use vortex_error::VortexExpect;
2320
use vortex_error::VortexResult;
2421
use vortex_mask::Mask;
22+
use vortex_mask::MaskValues;
2523
use vortex_vector::Vector;
26-
use vortex_vector::VectorMut;
2724
use vortex_vector::VectorMutOps;
2825
use vortex_vector::primitive::PVector;
26+
use vortex_vector::primitive::PVectorMut;
2927
use vortex_vector::primitive::PrimitiveVector;
3028

3129
use crate::BitPackedArray;
3230
use crate::BitPackedVTable;
33-
use crate::bitpacking::kernels::UNPACK_CHUNK_THRESHOLD;
34-
use crate::bitpacking::kernels::chunked_indices;
31+
use crate::bitpacking::vtable::kernels::UNPACK_CHUNK_THRESHOLD;
32+
use crate::bitpacking::vtable::kernels::chunked_indices;
33+
34+
pub(crate) const PARENT_KERNELS: ParentKernelSet<BitPackedVTable> =
35+
ParentKernelSet::new(&[ParentKernelSet::lift(&BitPackingFilterKernel)]);
3536

3637
/// The threshold over which it is faster to fully unpack the entire [`BitPackedArray`] and then
3738
/// filter the result than to unpack only specific bitpacked values into the output buffer.
@@ -48,6 +49,7 @@ pub const fn unpack_then_filter_threshold<T>() -> f64 {
4849
}
4950
}
5051

52+
/// Kernel to execute filtering directly on a bit-packed array.
5153
#[derive(Debug)]
5254
struct BitPackingFilterKernel;
5355

@@ -63,50 +65,47 @@ impl ExecuteParentKernel<BitPackedVTable> for BitPackingFilterKernel {
6365
array: &BitPackedArray,
6466
parent: &FilterArray,
6567
_child_idx: usize,
66-
ctx: &mut ExecutionCtx,
68+
_ctx: &mut ExecutionCtx,
6769
) -> VortexResult<Option<Vector>> {
68-
let selection = parent.filter_mask();
69-
70-
let true_count = selection.true_count();
71-
if true_count == 0 {
72-
// Fast-path for an empty mask.
73-
return Ok(Some(VectorMut::with_capacity(array.dtype(), 0).freeze()));
74-
} else if true_count == selection.len() {
75-
// Fast-path for a full mask.
76-
return Ok(Some(array.to_array().execute(ctx)?));
77-
}
70+
let values = match parent.filter_mask() {
71+
Mask::AllTrue(_) | Mask::AllFalse(_) => {
72+
// No optimization for full or empty mask
73+
return Ok(None);
74+
}
75+
Mask::Values(values) => values,
76+
};
7877

7978
match_each_integer_ptype!(array.ptype(), |I| {
8079
// If the density is high enough, then we would rather decompress the whole array and then apply
8180
// a filter over decompressing values one by one.
82-
if selection.density() > unpack_then_filter_threshold::<I>() {
81+
if values.density() > unpack_then_filter_threshold::<I>() {
8382
return Ok(None);
8483
}
8584
});
8685

8786
let primitive_vector: PrimitiveVector = match array.ptype() {
88-
PType::U8 => filter_primitive::<u8>(array, selection)?.into(),
89-
PType::U16 => filter_primitive::<u16>(array, selection)?.into(),
90-
PType::U32 => filter_primitive::<u32>(array, selection)?.into(),
91-
PType::U64 => filter_primitive::<u64>(array, selection)?.into(),
87+
PType::U8 => filter_primitive::<u8>(array, values)?.into(),
88+
PType::U16 => filter_primitive::<u16>(array, values)?.into(),
89+
PType::U32 => filter_primitive::<u32>(array, values)?.into(),
90+
PType::U64 => filter_primitive::<u64>(array, values)?.into(),
9291

9392
// Since the fastlanes crate only supports unsigned integers, and since we know that all
9493
// numbers are going to be non-negative, we can safely "cast" to unsigned and back.
9594
PType::I8 => {
96-
let pvector = filter_primitive::<u8>(array, selection)?;
97-
pvector.cast_into::<i8>().into()
95+
let pvector = filter_primitive::<u8>(array, values)?;
96+
unsafe { pvector.transmute::<i8>() }.into()
9897
}
9998
PType::I16 => {
100-
let pvector = filter_primitive::<u16>(array, selection)?;
101-
pvector.cast_into::<i16>().into()
99+
let pvector = filter_primitive::<u16>(array, values)?;
100+
unsafe { pvector.transmute::<i16>() }.into()
102101
}
103102
PType::I32 => {
104-
let pvector = filter_primitive::<u32>(array, selection)?;
105-
pvector.cast_into::<i32>().into()
103+
let pvector = filter_primitive::<u32>(array, values)?;
104+
unsafe { pvector.transmute::<i32>() }.into()
106105
}
107106
PType::I64 => {
108-
let pvector = filter_primitive::<u64>(array, selection)?;
109-
pvector.cast_into::<i64>().into()
107+
let pvector = filter_primitive::<u64>(array, values)?;
108+
unsafe { pvector.transmute::<i64>() }.into()
110109
}
111110
other => {
112111
unreachable!("Unsupported ptype {other} for bitpacking, we also checked this above")
@@ -128,42 +127,39 @@ impl ExecuteParentKernel<BitPackedVTable> for BitPackingFilterKernel {
128127
/// elements is relatively slow.
129128
fn filter_primitive<U: UnsignedPType + BitPacking>(
130129
array: &BitPackedArray,
131-
selection: &Mask,
130+
selection: &Arc<MaskValues>,
132131
) -> VortexResult<PVector<U>> {
133-
let values = filter_with_indices(
134-
array,
135-
selection
136-
.values()
137-
.vortex_expect("AllTrue and AllFalse handled by filter fn")
138-
.indices(),
139-
);
140-
let validity = array.validity_mask().filter(selection);
132+
let values = filter_with_indices(array, selection.indices());
133+
let validity = array
134+
.validity_mask()
135+
.filter(&Mask::Values(selection.clone()))
136+
.into_mut();
141137

142138
debug_assert_eq!(
143139
values.len(),
144140
validity.len(),
145141
"`filter_with_indices` was somehow incorrect"
146142
);
147143

148-
let mut pvector = unsafe { PVector::new_unchecked(values, validity) };
144+
let mut pvector = unsafe { PVectorMut::new_unchecked(values, validity) };
149145

150146
// TODO(connor): We want a `PatchesArray` or patching compute functions instead of this.
151147
let patches = array
152148
.patches()
153-
.map(|patches| patches.filter(selection))
149+
.map(|patches| patches.filter(&Mask::Values(selection.clone())))
154150
.transpose()?
155151
.flatten();
156152
if let Some(patches) = patches {
157-
pvector = patch_pvector(pvector, &patches);
153+
pvector = patches.apply_to_pvector(pvector);
158154
}
159155

160-
Ok(pvector)
156+
Ok(pvector.freeze())
161157
}
162158

163159
fn filter_with_indices<T: NativePType + BitPacking>(
164160
array: &BitPackedArray,
165161
indices: &[usize],
166-
) -> Buffer<T> {
162+
) -> BufferMut<T> {
167163
let offset = array.offset() as usize;
168164
let bit_width = array.bit_width() as usize;
169165
let mut values = BufferMut::with_capacity(indices.len());
@@ -209,5 +205,5 @@ fn filter_with_indices<T: NativePType + BitPacking>(
209205
}
210206
});
211207

212-
values.freeze()
208+
values
213209
}

encodings/fastlanes/src/bitpacking/kernels/mod.rs renamed to encodings/fastlanes/src/bitpacking/vtable/kernels/mod.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
// SPDX-License-Identifier: Apache-2.0
22
// SPDX-FileCopyrightText: Copyright the Vortex contributors
33

4-
mod filter;
4+
pub(crate) mod filter;
55

66
/// Assuming the buffer is already allocated (which will happen at most once), then unpacking all
77
/// 1024 elements takes ~8.8x as long as unpacking a single element on an M2 Macbook Air.

encodings/fastlanes/src/bitpacking/vtable/mod.rs

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,10 +30,12 @@ use vortex_vector::VectorMutOps;
3030

3131
use crate::BitPackedArray;
3232
use crate::bitpack_decompress::unpack_to_primitive_vector;
33+
use crate::bitpacking::vtable::kernels::filter::PARENT_KERNELS;
3334

3435
mod array;
3536
mod canonical;
3637
mod encode;
38+
mod kernels;
3739
mod operations;
3840
mod validity;
3941
mod visitor;
@@ -246,6 +248,15 @@ impl VTable for BitPackedVTable {
246248
fn execute(array: &Self::Array, _ctx: &mut ExecutionCtx) -> VortexResult<Vector> {
247249
Ok(unpack_to_primitive_vector(array).freeze().into())
248250
}
251+
252+
fn execute_parent(
253+
array: &Self::Array,
254+
parent: &ArrayRef,
255+
child_idx: usize,
256+
ctx: &mut ExecutionCtx,
257+
) -> VortexResult<Option<Vector>> {
258+
PARENT_KERNELS.execute(array, parent, child_idx, ctx)
259+
}
249260
}
250261

251262
#[derive(Debug)]

encodings/fastlanes/src/for/vtable/rules.rs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
// SPDX-License-Identifier: Apache-2.0
22
// SPDX-FileCopyrightText: Copyright the Vortex contributors
33

4+
use vortex_array::Array;
45
use vortex_array::ArrayRef;
56
use vortex_array::IntoArray;
67
use vortex_array::arrays::FilterArray;
@@ -34,8 +35,7 @@ impl ArrayParentReduceRule<FoRVTable> for FoRFilterPushDownRule {
3435
) -> VortexResult<Option<ArrayRef>> {
3536
let new_array = unsafe {
3637
FoRArray::new_unchecked(
37-
FilterArray::new(child.encoded().clone(), parent.filter_mask().clone())
38-
.into_array(),
38+
child.encoded.filter(parent.filter_mask().clone())?,
3939
child.reference.clone(),
4040
)
4141
};

vortex-array/src/patches.rs

Lines changed: 16 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -30,8 +30,7 @@ use vortex_mask::MaskMut;
3030
use vortex_scalar::PValue;
3131
use vortex_scalar::Scalar;
3232
use vortex_utils::aliases::hash_map::HashMap;
33-
use vortex_vector::VectorOps;
34-
use vortex_vector::primitive::PVector;
33+
use vortex_vector::primitive::PVectorMut;
3534

3635
use crate::Array;
3736
use crate::ArrayRef;
@@ -824,6 +823,21 @@ impl Patches {
824823
}))
825824
}
826825

826+
/// Applies patches to a [`PVector<T>`], returning the patched vector.
827+
///
828+
/// This function modifies the elements buffer in-place at the positions specified by the patch
829+
/// indices. It also updates the validity mask to reflect the nullability of patch values.
830+
pub fn apply_to_pvector<T: NativePType>(&self, pvector: PVectorMut<T>) -> PVectorMut<T> {
831+
let (mut elements, mut validity) = pvector.into_parts();
832+
833+
// SAFETY: We maintain the invariant that elements and validity have the same length, and all
834+
// patch indices are valid after offset adjustment (guaranteed by `Patches`).
835+
unsafe { self.apply_to_buffer(elements.as_mut_slice(), &mut validity) };
836+
837+
// SAFETY: We have not modified the length of elements or validity.
838+
unsafe { PVectorMut::new_unchecked(elements, validity) }
839+
}
840+
827841
/// Apply patches to a mutable buffer and validity mask.
828842
///
829843
/// This method applies the patch values to the given buffer at the positions specified by the
@@ -885,20 +899,6 @@ impl Patches {
885899
}
886900
}
887901

888-
/// Applies patches to a [`PVector<T>`], returning the patched vector.
889-
///
890-
/// This function modifies the elements buffer in-place at the positions specified by the patch
891-
/// indices. It also updates the validity mask to reflect the nullability of patch values.
892-
pub fn patch_pvector<T: NativePType>(pvector: PVector<T>, patches: &Patches) -> PVector<T> {
893-
let (mut elements, mut validity) = pvector.into_mut().into_parts();
894-
895-
// SAFETY: We maintain the invariant that elements and validity have the same length, and all
896-
// patch indices are valid after offset adjustment (guaranteed by `Patches`).
897-
unsafe { patches.apply_to_buffer(elements.as_mut_slice(), &mut validity) };
898-
899-
PVector::new(elements.freeze(), validity.freeze())
900-
}
901-
902902
/// Helper function to apply patches to a buffer.
903903
///
904904
/// # Safety

vortex-buffer/src/buffer.rs

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -466,13 +466,19 @@ impl<T> Buffer<T> {
466466
}
467467
}
468468

469-
impl<T: Copy> Buffer<T> {
470-
/// Cast a `Buffer<T>` into a `Buffer<U>`.
469+
impl<T> Buffer<T> {
470+
/// Transmute a `Buffer<T>` into a `Buffer<U>`.
471+
///
472+
/// # Safety
473+
///
474+
/// The caller must ensure that all possible bit representations of type `T` are valid when
475+
/// interpreted as type `U`.
476+
/// See [`std::mem::transmute`] for more details.
471477
///
472478
/// # Panics
473479
///
474480
/// Panics if the type `U` does not have the same size and alignment as `T`.
475-
pub fn cast_into<U>(self) -> Buffer<U> {
481+
pub unsafe fn transmute<U>(self) -> Buffer<U> {
476482
assert_eq!(size_of::<T>(), size_of::<U>(), "Buffer type size mismatch");
477483
assert_eq!(
478484
align_of::<T>(),

vortex-compute/src/take/slice/avx2.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -72,7 +72,7 @@ pub unsafe fn take_avx2<V: NativePType, I: UnsignedPType>(
7272
let values = unsafe { std::mem::transmute::<&[V], &[$cast]>(buffer) };
7373

7474
let result = exec_take::<$cast, $indices, AVX2Gather>(values, indices);
75-
result.cast_into::<V>()
75+
result.transmute::<V>()
7676
}};
7777
}
7878

vortex-compute/src/take/slice/portable.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,7 @@ pub fn take_portable<T: NativePType, I: UnsignedPType>(buffer: &[T], indices: &[
4040
// make.
4141
let u16_slice: &[u16] =
4242
unsafe { std::slice::from_raw_parts(buffer.as_ptr() as *const u16, buffer.len()) };
43-
return take_with_indices(u16_slice, indices).cast_into::<T>();
43+
return take_with_indices(u16_slice, indices).transmute::<T>();
4444
}
4545

4646
match_each_native_simd_ptype!(T::PTYPE, |TC| {

0 commit comments

Comments
 (0)