Skip to content

Commit 86a1f1a

Browse files
committed
More operator rules
Signed-off-by: Nicholas Gates <[email protected]>
1 parent 3b33d70 commit 86a1f1a

File tree

7 files changed

+143
-39
lines changed

7 files changed

+143
-39
lines changed

encodings/fastlanes/src/bitpacking/vtable/kernels/filter.rs

Lines changed: 24 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -22,9 +22,8 @@ use vortex_mask::Mask;
2222
use vortex_mask::MaskValues;
2323
use vortex_vector::Vector;
2424
use vortex_vector::VectorMutOps;
25-
use vortex_vector::primitive::PVector;
2625
use vortex_vector::primitive::PVectorMut;
27-
use vortex_vector::primitive::PrimitiveVector;
26+
use vortex_vector::primitive::PrimitiveVectorMut;
2827

2928
use crate::BitPackedArray;
3029
use crate::BitPackedVTable;
@@ -83,36 +82,46 @@ impl ExecuteParentKernel<BitPackedVTable> for BitPackingFilterKernel {
8382
}
8483
});
8584

86-
let primitive_vector: PrimitiveVector = match array.ptype() {
87-
PType::U8 => filter_primitive::<u8>(array, values)?.into(),
88-
PType::U16 => filter_primitive::<u16>(array, values)?.into(),
89-
PType::U32 => filter_primitive::<u32>(array, values)?.into(),
90-
PType::U64 => filter_primitive::<u64>(array, values)?.into(),
85+
let mut primitive_vector: PrimitiveVectorMut = match array.ptype() {
86+
PType::U8 => filter_primitive_without_patches::<u8>(array, values)?.into(),
87+
PType::U16 => filter_primitive_without_patches::<u16>(array, values)?.into(),
88+
PType::U32 => filter_primitive_without_patches::<u32>(array, values)?.into(),
89+
PType::U64 => filter_primitive_without_patches::<u64>(array, values)?.into(),
9190

9291
// Since the fastlanes crate only supports unsigned integers, and since we know that all
9392
// numbers are going to be non-negative, we can safely "cast" to unsigned and back.
9493
PType::I8 => {
95-
let pvector = filter_primitive::<u8>(array, values)?;
94+
let pvector = filter_primitive_without_patches::<u8>(array, values)?;
9695
unsafe { pvector.transmute::<i8>() }.into()
9796
}
9897
PType::I16 => {
99-
let pvector = filter_primitive::<u16>(array, values)?;
98+
let pvector = filter_primitive_without_patches::<u16>(array, values)?;
10099
unsafe { pvector.transmute::<i16>() }.into()
101100
}
102101
PType::I32 => {
103-
let pvector = filter_primitive::<u32>(array, values)?;
102+
let pvector = filter_primitive_without_patches::<u32>(array, values)?;
104103
unsafe { pvector.transmute::<i32>() }.into()
105104
}
106105
PType::I64 => {
107-
let pvector = filter_primitive::<u64>(array, values)?;
106+
let pvector = filter_primitive_without_patches::<u64>(array, values)?;
108107
unsafe { pvector.transmute::<i64>() }.into()
109108
}
110109
other => {
111110
unreachable!("Unsupported ptype {other} for bitpacking, we also checked this above")
112111
}
113112
};
114113

115-
Ok(Some(primitive_vector.into()))
114+
// TODO(connor): We want a `PatchesArray` or patching compute functions instead of this.
115+
let patches = array
116+
.patches()
117+
.map(|patches| patches.filter(&Mask::Values(values.clone())))
118+
.transpose()?
119+
.flatten();
120+
if let Some(patches) = patches {
121+
primitive_vector = patches.apply_to_primitive_vector(primitive_vector);
122+
}
123+
124+
Ok(Some(primitive_vector.freeze().into()))
116125
}
117126
}
118127

@@ -125,10 +134,10 @@ impl ExecuteParentKernel<BitPackedVTable> for BitPackingFilterKernel {
125134
/// This function fully decompresses the array for all but the most selective masks because the
126135
/// FastLanes decompression is so fast and the bookkeepping necessary to decompress individual
127136
/// elements is relatively slow.
128-
fn filter_primitive<U: UnsignedPType + BitPacking>(
137+
fn filter_primitive_without_patches<U: UnsignedPType + BitPacking>(
129138
array: &BitPackedArray,
130139
selection: &Arc<MaskValues>,
131-
) -> VortexResult<PVector<U>> {
140+
) -> VortexResult<PVectorMut<U>> {
132141
let values = filter_with_indices(array, selection.indices());
133142
let validity = array
134143
.validity_mask()
@@ -141,19 +150,7 @@ fn filter_primitive<U: UnsignedPType + BitPacking>(
141150
"`filter_with_indices` was somehow incorrect"
142151
);
143152

144-
let mut pvector = unsafe { PVectorMut::new_unchecked(values, validity) };
145-
146-
// TODO(connor): We want a `PatchesArray` or patching compute functions instead of this.
147-
let patches = array
148-
.patches()
149-
.map(|patches| patches.filter(&Mask::Values(selection.clone())))
150-
.transpose()?
151-
.flatten();
152-
if let Some(patches) = patches {
153-
pvector = patches.apply_to_pvector(pvector);
154-
}
155-
156-
Ok(pvector.freeze())
153+
Ok(unsafe { PVectorMut::new_unchecked(values, validity) })
157154
}
158155

159156
fn filter_with_indices<T: NativePType + BitPacking>(

vortex-array/src/arrays/dict/vtable/rules.rs

Lines changed: 32 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@ use crate::arrays::ConstantVTable;
1414
use crate::arrays::DictArray;
1515
use crate::arrays::DictVTable;
1616
use crate::arrays::ScalarFnArray;
17+
use crate::builtins::ArrayBuiltins;
1718
use crate::optimizer::ArrayOptimizer;
1819
use crate::optimizer::rules::ArrayParentReduceRule;
1920
use crate::optimizer::rules::ParentRuleSet;
@@ -94,10 +95,19 @@ impl ArrayParentReduceRule<DictVTable> for DictionaryScalarFnValuesPushDownRule
9495
.into_array()
9596
.optimize()?;
9697

97-
let new_dict =
98-
unsafe { DictArray::new_unchecked(array.codes().clone(), new_values) }.into_array();
98+
// We can only push down null-sensitive functions when we have all-valid codes.
99+
// In these cases, we cannot have the codes influence the nullability of the output DType.
100+
// Therefore, we cast the codes to be non-nullable and then cast the dictionary output
101+
// back to nullable if needed.
102+
if sig.is_null_sensitive() && array.codes().dtype().is_nullable() {
103+
let new_codes = array.codes().cast(array.codes().dtype().as_nonnullable())?;
104+
let new_dict = unsafe { DictArray::new_unchecked(new_codes, new_values) }.into_array();
105+
return Ok(Some(new_dict.cast(parent.dtype().clone())?));
106+
}
99107

100-
Ok(Some(new_dict))
108+
Ok(Some(
109+
unsafe { DictArray::new_unchecked(array.codes().clone(), new_values) }.into_array(),
110+
))
101111
}
102112
}
103113

@@ -117,13 +127,27 @@ impl ArrayParentReduceRule<DictVTable> for DictionaryScalarFnCodesPullUpRule {
117127
parent: &ScalarFnArray,
118128
child_idx: usize,
119129
) -> VortexResult<Option<ArrayRef>> {
120-
// Check that all siblings are dictionaries with the same codes as us.
130+
// Don't attempt to pull up if there are less than 2 siblings.
131+
if parent.children().len() < 2 {
132+
return Ok(None);
133+
}
134+
135+
// Check that all siblings are dictionaries, and have the same number of values as us.
136+
// This is a cheap first loop.
137+
if !parent.children().iter().enumerate().all(|(idx, c)| {
138+
idx == child_idx
139+
|| c.as_opt::<DictVTable>()
140+
.is_some_and(|c| c.values().len() == array.values().len())
141+
}) {
142+
return Ok(None);
143+
}
144+
145+
// Now run the slightly more expensive check that all siblings have the same codes as us.
146+
// We use the cheaper Precision::Ptr to avoid doing data comparisons.
121147
if !parent.children().iter().enumerate().all(|(idx, c)| {
122148
idx == child_idx
123-
|| c.as_opt::<DictVTable>().is_some_and(|c| {
124-
c.values().len() == array.values().len()
125-
&& c.codes().array_eq(array.codes(), Precision::Value)
126-
})
149+
|| c.as_opt::<DictVTable>()
150+
.is_some_and(|c| c.codes().array_eq(array.codes(), Precision::Value))
127151
}) {
128152
return Ok(None);
129153
}

vortex-array/src/optimizer/rules.rs

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -153,6 +153,25 @@ impl<V: VTable> ParentRuleSet<V> {
153153
continue;
154154
}
155155
if let Some(reduced) = rule.reduce_parent(child, parent, child_idx)? {
156+
// Debug assertions because these checks are already run elsewhere.
157+
#[cfg(debug_assertions)]
158+
{
159+
vortex_error::vortex_ensure!(
160+
reduced.len() == parent.len(),
161+
"Reduced array length mismatch from {:?}\nFrom:\n{}\nTo:\n{}",
162+
rule,
163+
parent.display_tree(),
164+
reduced.display_tree()
165+
);
166+
vortex_error::vortex_ensure!(
167+
reduced.dtype() == parent.dtype(),
168+
"Reduced array dtype mismatch from {:?}\nFrom:\n{}\nTo:\n{}",
169+
rule,
170+
parent.display_tree(),
171+
reduced.display_tree()
172+
);
173+
}
174+
156175
return Ok(Some(reduced));
157176
}
158177
}

vortex-array/src/patches.rs

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,8 +15,10 @@ use vortex_dtype::IntegerPType;
1515
use vortex_dtype::NativePType;
1616
use vortex_dtype::Nullability::NonNullable;
1717
use vortex_dtype::PType;
18+
use vortex_dtype::PTypeDowncastExt;
1819
use vortex_dtype::UnsignedPType;
1920
use vortex_dtype::match_each_integer_ptype;
21+
use vortex_dtype::match_each_native_ptype;
2022
use vortex_dtype::match_each_unsigned_integer_ptype;
2123
use vortex_error::VortexError;
2224
use vortex_error::VortexExpect;
@@ -31,6 +33,7 @@ use vortex_scalar::PValue;
3133
use vortex_scalar::Scalar;
3234
use vortex_utils::aliases::hash_map::HashMap;
3335
use vortex_vector::primitive::PVectorMut;
36+
use vortex_vector::primitive::PrimitiveVectorMut;
3437

3538
use crate::Array;
3639
use crate::ArrayRef;
@@ -823,6 +826,13 @@ impl Patches {
823826
}))
824827
}
825828

829+
/// Applies patches to a primitive vector, returning the patched vector.
830+
pub fn apply_to_primitive_vector(&self, vector: PrimitiveVectorMut) -> PrimitiveVectorMut {
831+
match_each_native_ptype!(vector.ptype(), |T| {
832+
self.apply_to_pvector(vector.downcast::<T>()).into()
833+
})
834+
}
835+
826836
/// Applies patches to a [`PVectorMut<T>`], returning the patched vector.
827837
///
828838
/// This function modifies the elements buffer in-place at the positions specified by the patch

vortex-array/src/vtable/dyn_.rs

Lines changed: 12 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -160,11 +160,15 @@ impl<V: VTable> DynVTable for ArrayVTableAdapter<V> {
160160
};
161161
vortex_ensure!(
162162
reduced.len() == array.len(),
163-
"Reduced array length mismatch"
163+
"Reduced array length mismatch from {} to {}",
164+
array.display_tree(),
165+
reduced.display_tree()
164166
);
165167
vortex_ensure!(
166168
reduced.dtype() == array.dtype(),
167-
"Reduced array dtype mismatch"
169+
"Reduced array dtype mismatch from {} to {}",
170+
array.display_tree(),
171+
reduced.display_tree()
168172
);
169173
Ok(Some(reduced))
170174
}
@@ -181,11 +185,15 @@ impl<V: VTable> DynVTable for ArrayVTableAdapter<V> {
181185

182186
vortex_ensure!(
183187
reduced.len() == parent.len(),
184-
"Reduced array length mismatch"
188+
"Reduced array length mismatch from {} to {}",
189+
parent.display_tree(),
190+
reduced.display_tree()
185191
);
186192
vortex_ensure!(
187193
reduced.dtype() == parent.dtype(),
188-
"Reduced array dtype mismatch"
194+
"Reduced array dtype mismatch from {} to {}",
195+
parent.display_tree(),
196+
reduced.display_tree()
189197
);
190198

191199
Ok(Some(reduced))

vortex-buffer/src/buffer_mut.rs

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -444,6 +444,33 @@ impl<T> BufferMut<T> {
444444
Self::copy_from_aligned(self, alignment)
445445
}
446446
}
447+
448+
/// Transmute a `Buffer<T>` into a `Buffer<U>`.
449+
///
450+
/// # Safety
451+
///
452+
/// The caller must ensure that all possible bit representations of type `T` are valid when
453+
/// interpreted as type `U`.
454+
/// See [`std::mem::transmute`] for more details.
455+
///
456+
/// # Panics
457+
///
458+
/// Panics if the type `U` does not have the same size and alignment as `T`.
459+
pub unsafe fn transmute<U>(self) -> BufferMut<U> {
460+
assert_eq!(size_of::<T>(), size_of::<U>(), "Buffer type size mismatch");
461+
assert_eq!(
462+
align_of::<T>(),
463+
align_of::<U>(),
464+
"Buffer type alignment mismatch"
465+
);
466+
467+
BufferMut {
468+
bytes: self.bytes,
469+
length: self.length,
470+
alignment: self.alignment,
471+
_marker: std::marker::PhantomData,
472+
}
473+
}
447474
}
448475

449476
impl<T> Clone for BufferMut<T> {

vortex-vector/src/primitive/generic_mut.rs

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -125,6 +125,25 @@ impl<T> PVectorMut<T> {
125125
self.elements.push_n(value, n);
126126
self.validity.append_n(true, n);
127127
}
128+
129+
/// Transmute a `PVectorMut<T>` into a `PVectorMut<U>`.
130+
///
131+
/// # Safety
132+
///
133+
/// The caller must ensure that all values of type `T` in this vector are valid as type `U`.
134+
/// See [`std::mem::transmute`] for more details.
135+
///
136+
/// # Panics
137+
///
138+
/// Panics if the type `U` does not have the same size and alignment as `T`.
139+
pub unsafe fn transmute<U: NativePType>(self) -> PVectorMut<U> {
140+
let (buffer, mask) = self.into_parts();
141+
142+
// SAFETY: same guarantees as this function.
143+
let buffer = unsafe { buffer.transmute::<U>() };
144+
145+
PVectorMut::new(buffer, mask)
146+
}
128147
}
129148

130149
impl<T: NativePType> VectorMutOps for PVectorMut<T> {

0 commit comments

Comments
 (0)