Skip to content

Commit f5ee1a0

Browse files
authored
Add more operator rules (#5746)
Signed-off-by: Nicholas Gates <[email protected]>
1 parent 6afde02 commit f5ee1a0

File tree

12 files changed

+320
-49
lines changed

12 files changed

+320
-49
lines changed

encodings/fastlanes/src/bitpacking/vtable/kernels/filter.rs

Lines changed: 24 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -22,9 +22,8 @@ use vortex_mask::Mask;
2222
use vortex_mask::MaskValues;
2323
use vortex_vector::Vector;
2424
use vortex_vector::VectorMutOps;
25-
use vortex_vector::primitive::PVector;
2625
use vortex_vector::primitive::PVectorMut;
27-
use vortex_vector::primitive::PrimitiveVector;
26+
use vortex_vector::primitive::PrimitiveVectorMut;
2827

2928
use crate::BitPackedArray;
3029
use crate::BitPackedVTable;
@@ -83,36 +82,46 @@ impl ExecuteParentKernel<BitPackedVTable> for BitPackingFilterKernel {
8382
}
8483
});
8584

86-
let primitive_vector: PrimitiveVector = match array.ptype() {
87-
PType::U8 => filter_primitive::<u8>(array, values)?.into(),
88-
PType::U16 => filter_primitive::<u16>(array, values)?.into(),
89-
PType::U32 => filter_primitive::<u32>(array, values)?.into(),
90-
PType::U64 => filter_primitive::<u64>(array, values)?.into(),
85+
let mut primitive_vector: PrimitiveVectorMut = match array.ptype() {
86+
PType::U8 => filter_primitive_without_patches::<u8>(array, values)?.into(),
87+
PType::U16 => filter_primitive_without_patches::<u16>(array, values)?.into(),
88+
PType::U32 => filter_primitive_without_patches::<u32>(array, values)?.into(),
89+
PType::U64 => filter_primitive_without_patches::<u64>(array, values)?.into(),
9190

9291
// Since the fastlanes crate only supports unsigned integers, and since we know that all
9392
// numbers are going to be non-negative, we can safely "cast" to unsigned and back.
9493
PType::I8 => {
95-
let pvector = filter_primitive::<u8>(array, values)?;
94+
let pvector = filter_primitive_without_patches::<u8>(array, values)?;
9695
unsafe { pvector.transmute::<i8>() }.into()
9796
}
9897
PType::I16 => {
99-
let pvector = filter_primitive::<u16>(array, values)?;
98+
let pvector = filter_primitive_without_patches::<u16>(array, values)?;
10099
unsafe { pvector.transmute::<i16>() }.into()
101100
}
102101
PType::I32 => {
103-
let pvector = filter_primitive::<u32>(array, values)?;
102+
let pvector = filter_primitive_without_patches::<u32>(array, values)?;
104103
unsafe { pvector.transmute::<i32>() }.into()
105104
}
106105
PType::I64 => {
107-
let pvector = filter_primitive::<u64>(array, values)?;
106+
let pvector = filter_primitive_without_patches::<u64>(array, values)?;
108107
unsafe { pvector.transmute::<i64>() }.into()
109108
}
110109
other => {
111110
unreachable!("Unsupported ptype {other} for bitpacking, we also checked this above")
112111
}
113112
};
114113

115-
Ok(Some(primitive_vector.into()))
114+
// TODO(connor): We want a `PatchesArray` or patching compute functions instead of this.
115+
let patches = array
116+
.patches()
117+
.map(|patches| patches.filter(&Mask::Values(values.clone())))
118+
.transpose()?
119+
.flatten();
120+
if let Some(patches) = patches {
121+
primitive_vector = patches.apply_to_primitive_vector(primitive_vector);
122+
}
123+
124+
Ok(Some(primitive_vector.freeze().into()))
116125
}
117126
}
118127

@@ -125,10 +134,10 @@ impl ExecuteParentKernel<BitPackedVTable> for BitPackingFilterKernel {
125134
/// This function fully decompresses the array for all but the most selective masks because the
126135
/// FastLanes decompression is so fast and the bookkeepping necessary to decompress individual
127136
/// elements is relatively slow.
128-
fn filter_primitive<U: UnsignedPType + BitPacking>(
137+
fn filter_primitive_without_patches<U: UnsignedPType + BitPacking>(
129138
array: &BitPackedArray,
130139
selection: &Arc<MaskValues>,
131-
) -> VortexResult<PVector<U>> {
140+
) -> VortexResult<PVectorMut<U>> {
132141
let values = filter_with_indices(array, selection.indices());
133142
let validity = array
134143
.validity_mask()
@@ -141,19 +150,7 @@ fn filter_primitive<U: UnsignedPType + BitPacking>(
141150
"`filter_with_indices` was somehow incorrect"
142151
);
143152

144-
let mut pvector = unsafe { PVectorMut::new_unchecked(values, validity) };
145-
146-
// TODO(connor): We want a `PatchesArray` or patching compute functions instead of this.
147-
let patches = array
148-
.patches()
149-
.map(|patches| patches.filter(&Mask::Values(selection.clone())))
150-
.transpose()?
151-
.flatten();
152-
if let Some(patches) = patches {
153-
pvector = patches.apply_to_pvector(pvector);
154-
}
155-
156-
Ok(pvector.freeze())
153+
Ok(unsafe { PVectorMut::new_unchecked(values, validity) })
157154
}
158155

159156
fn filter_with_indices<T: NativePType + BitPacking>(

vortex-array/src/arrays/dict/vtable/mod.rs

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@ use crate::DeserializeMetadata;
1919
use crate::ProstMetadata;
2020
use crate::SerializeMetadata;
2121
use crate::VectorExecutor;
22+
use crate::arrays::vtable::rules::PARENT_RULES;
2223
use crate::executor::ExecutionCtx;
2324
use crate::serde::ArrayChildren;
2425
use crate::vtable;
@@ -32,6 +33,7 @@ mod array;
3233
mod canonical;
3334
mod encode;
3435
mod operations;
36+
mod rules;
3537
mod validity;
3638
mod visitor;
3739

@@ -134,4 +136,12 @@ impl VTable for DictVTable {
134136
let codes = array.codes().execute(ctx)?.into_primitive();
135137
Ok(values.take(&codes))
136138
}
139+
140+
fn reduce_parent(
141+
array: &Self::Array,
142+
parent: &ArrayRef,
143+
child_idx: usize,
144+
) -> VortexResult<Option<ArrayRef>> {
145+
PARENT_RULES.evaluate(array, parent, child_idx)
146+
}
137147
}
Lines changed: 174 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,174 @@
1+
// SPDX-License-Identifier: Apache-2.0
2+
// SPDX-FileCopyrightText: Copyright the Vortex contributors
3+
4+
use vortex_error::VortexResult;
5+
6+
use crate::Array;
7+
use crate::ArrayEq;
8+
use crate::ArrayRef;
9+
use crate::IntoArray;
10+
use crate::Precision;
11+
use crate::arrays::AnyScalarFn;
12+
use crate::arrays::ConstantArray;
13+
use crate::arrays::ConstantVTable;
14+
use crate::arrays::DictArray;
15+
use crate::arrays::DictVTable;
16+
use crate::arrays::ScalarFnArray;
17+
use crate::builtins::ArrayBuiltins;
18+
use crate::optimizer::ArrayOptimizer;
19+
use crate::optimizer::rules::ArrayParentReduceRule;
20+
use crate::optimizer::rules::ParentRuleSet;
21+
22+
pub(super) const PARENT_RULES: ParentRuleSet<DictVTable> = ParentRuleSet::new(&[
23+
ParentRuleSet::lift(&DictionaryScalarFnValuesPushDownRule),
24+
ParentRuleSet::lift(&DictionaryScalarFnCodesPullUpRule),
25+
]);
26+
27+
/// Push down a scalar function to run only over the values of a dictionary array.
28+
#[derive(Debug)]
29+
struct DictionaryScalarFnValuesPushDownRule;
30+
31+
impl ArrayParentReduceRule<DictVTable> for DictionaryScalarFnValuesPushDownRule {
32+
type Parent = AnyScalarFn;
33+
34+
fn parent(&self) -> Self::Parent {
35+
AnyScalarFn
36+
}
37+
38+
fn reduce_parent(
39+
&self,
40+
array: &DictArray,
41+
parent: &ScalarFnArray,
42+
child_idx: usize,
43+
) -> VortexResult<Option<ArrayRef>> {
44+
// Check that the scalar function can actually be pushed down.
45+
let sig = parent.scalar_fn().signature();
46+
47+
// If the scalar function is fallible, we cannot push it down since it may fail over a
48+
// value that isn't referenced by any code.
49+
if !array.all_values_referenced && sig.is_fallible() {
50+
tracing::trace!(
51+
"Not pushing down fallible scalar function {} over dictionary with sparse codes {}",
52+
parent.scalar_fn(),
53+
array.display_tree(),
54+
);
55+
return Ok(None);
56+
}
57+
58+
// Check that all siblings are constant
59+
// TODO(ngates): we can also support other dictionaries if the values are the same!
60+
if !parent
61+
.children()
62+
.iter()
63+
.enumerate()
64+
.all(|(idx, c)| idx == child_idx || c.is::<ConstantVTable>())
65+
{
66+
return Ok(None);
67+
}
68+
69+
// If the scalar function is null-sensitive, then we cannot push it down to values if
70+
// we have any nulls in the codes.
71+
if array.codes.dtype().is_nullable() && !array.codes.all_valid() && sig.is_null_sensitive()
72+
{
73+
tracing::trace!(
74+
"Not pushing down null-sensitive scalar function {} over dictionary with null codes {}",
75+
parent.scalar_fn(),
76+
array.display_tree(),
77+
);
78+
return Ok(None);
79+
}
80+
81+
// Now we push the parent scalar function into the dictionary values.
82+
let values_len = array.values().len();
83+
let mut new_children = Vec::with_capacity(parent.children().len());
84+
for (idx, child) in parent.children().iter().enumerate() {
85+
if idx == child_idx {
86+
new_children.push(array.values().clone());
87+
} else {
88+
let scalar = child.as_::<ConstantVTable>().scalar().clone();
89+
new_children.push(ConstantArray::new(scalar, values_len).into_array());
90+
}
91+
}
92+
93+
let new_values =
94+
ScalarFnArray::try_new(parent.scalar_fn().clone(), new_children, values_len)?
95+
.into_array()
96+
.optimize()?;
97+
98+
// We can only push down null-sensitive functions when we have all-valid codes.
99+
// In these cases, we cannot have the codes influence the nullability of the output DType.
100+
// Therefore, we cast the codes to be non-nullable and then cast the dictionary output
101+
// back to nullable if needed.
102+
if sig.is_null_sensitive() && array.codes().dtype().is_nullable() {
103+
let new_codes = array.codes().cast(array.codes().dtype().as_nonnullable())?;
104+
let new_dict = unsafe { DictArray::new_unchecked(new_codes, new_values) }.into_array();
105+
return Ok(Some(new_dict.cast(parent.dtype().clone())?));
106+
}
107+
108+
Ok(Some(
109+
unsafe { DictArray::new_unchecked(array.codes().clone(), new_values) }.into_array(),
110+
))
111+
}
112+
}
113+
114+
#[derive(Debug)]
115+
struct DictionaryScalarFnCodesPullUpRule;
116+
117+
impl ArrayParentReduceRule<DictVTable> for DictionaryScalarFnCodesPullUpRule {
118+
type Parent = AnyScalarFn;
119+
120+
fn parent(&self) -> Self::Parent {
121+
AnyScalarFn
122+
}
123+
124+
fn reduce_parent(
125+
&self,
126+
array: &DictArray,
127+
parent: &ScalarFnArray,
128+
child_idx: usize,
129+
) -> VortexResult<Option<ArrayRef>> {
130+
// Don't attempt to pull up if there are less than 2 siblings.
131+
if parent.children().len() < 2 {
132+
return Ok(None);
133+
}
134+
135+
// Check that all siblings are dictionaries, and have the same number of values as us.
136+
// This is a cheap first loop.
137+
if !parent.children().iter().enumerate().all(|(idx, c)| {
138+
idx == child_idx
139+
|| c.as_opt::<DictVTable>()
140+
.is_some_and(|c| c.values().len() == array.values().len())
141+
}) {
142+
return Ok(None);
143+
}
144+
145+
// Now run the slightly more expensive check that all siblings have the same codes as us.
146+
// We use the cheaper Precision::Ptr to avoid doing data comparisons.
147+
if !parent.children().iter().enumerate().all(|(idx, c)| {
148+
idx == child_idx
149+
|| c.as_opt::<DictVTable>()
150+
.is_some_and(|c| c.codes().array_eq(array.codes(), Precision::Value))
151+
}) {
152+
return Ok(None);
153+
}
154+
155+
let mut new_children = Vec::with_capacity(parent.children().len());
156+
for (idx, child) in parent.children().iter().enumerate() {
157+
if idx == child_idx {
158+
new_children.push(array.values().clone());
159+
} else {
160+
new_children.push(child.as_::<DictVTable>().values().clone());
161+
}
162+
}
163+
164+
let new_values =
165+
ScalarFnArray::try_new(parent.scalar_fn().clone(), new_children, array.values.len())?
166+
.into_array()
167+
.optimize()?;
168+
169+
let new_dict =
170+
unsafe { DictArray::new_unchecked(array.codes().clone(), new_values) }.into_array();
171+
172+
Ok(Some(new_dict))
173+
}
174+
}

vortex-array/src/arrays/scalar_fn/vtable/mod.rs

Lines changed: 10 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -140,20 +140,16 @@ impl VTable for ScalarFnVTable {
140140
}
141141

142142
fn execute(array: &Self::Array, ctx: &mut ExecutionCtx) -> VortexResult<Vector> {
143-
let datums: Vec<_> = array
144-
.children()
145-
.iter()
146-
.map(|child| match child.as_opt::<ConstantVTable>() {
147-
None => child.execute(ctx).map(Datum::Vector),
148-
Some(constant) => Ok(Datum::Scalar(constant.scalar().to_vector_scalar())),
149-
})
150-
.try_collect()?;
151-
152-
let input_dtypes: Vec<_> = array
153-
.children()
154-
.iter()
155-
.map(|child| child.dtype().clone())
156-
.collect();
143+
// NOTE: we don't use iterators here to make the profiles easier to read!
144+
let mut datums = Vec::with_capacity(array.children.len());
145+
let mut input_dtypes = Vec::with_capacity(array.children.len());
146+
for child in array.children.iter() {
147+
match child.as_opt::<ConstantVTable>() {
148+
None => datums.push(child.execute(ctx).map(Datum::Vector)?),
149+
Some(constant) => datums.push(Datum::Scalar(constant.scalar().to_vector_scalar())),
150+
}
151+
input_dtypes.push(child.dtype().clone());
152+
}
157153

158154
let args = ExecutionArgs {
159155
datums,

vortex-array/src/executor.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -67,7 +67,7 @@ impl VectorExecutor for ArrayRef {
6767
}
6868

6969
let mut ctx = ExecutionCtx::new(session.clone());
70-
tracing::debug!("Executing array:\n{}", self.display_tree());
70+
tracing::debug!("Executing array {}:\n{}", self, self.display_tree());
7171
Ok(Datum::Vector(self.execute(&mut ctx)?))
7272
}
7373

vortex-array/src/expr/exprs/list_contains.rs

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -199,6 +199,10 @@ impl VTable for ListContains {
199199
fn is_null_sensitive(&self, _instance: &Self::Options) -> bool {
200200
true
201201
}
202+
203+
fn is_fallible(&self, _options: &Self::Options) -> bool {
204+
false
205+
}
202206
}
203207

204208
/// Creates an expression that checks if a value is contained in a list.

vortex-array/src/optimizer/rules.rs

Lines changed: 20 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,7 @@ pub trait ArrayParentReduceRule<V: VTable>: Debug + Send + Sync + 'static {
3838
/// - `Err(e)` if an error occurred
3939
fn reduce_parent(
4040
&self,
41-
child: &V::Array,
41+
array: &V::Array,
4242
parent: <Self::Parent as Matcher>::View<'_>,
4343
child_idx: usize,
4444
) -> VortexResult<Option<ArrayRef>>;
@@ -153,6 +153,25 @@ impl<V: VTable> ParentRuleSet<V> {
153153
continue;
154154
}
155155
if let Some(reduced) = rule.reduce_parent(child, parent, child_idx)? {
156+
// Debug assertions because these checks are already run elsewhere.
157+
#[cfg(debug_assertions)]
158+
{
159+
vortex_error::vortex_ensure!(
160+
reduced.len() == parent.len(),
161+
"Reduced array length mismatch from {:?}\nFrom:\n{}\nTo:\n{}",
162+
rule,
163+
parent.display_tree(),
164+
reduced.display_tree()
165+
);
166+
vortex_error::vortex_ensure!(
167+
reduced.dtype() == parent.dtype(),
168+
"Reduced array dtype mismatch from {:?}\nFrom:\n{}\nTo:\n{}",
169+
rule,
170+
parent.display_tree(),
171+
reduced.display_tree()
172+
);
173+
}
174+
156175
return Ok(Some(reduced));
157176
}
158177
}

0 commit comments

Comments
 (0)