Skip to content

Commit 0fb522b

Browse files
committed
perf[ext]: pushdown filter + constant
Signed-off-by: Joe Isaacs <[email protected]>
1 parent aa7a891 commit 0fb522b

File tree

2 files changed

+355
-0
lines changed

2 files changed

+355
-0
lines changed

vortex-array/src/arrays/extension/vtable/mod.rs

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
mod array;
55
mod canonical;
66
mod operations;
7+
mod rules;
78
mod validity;
89
mod visitor;
910

@@ -19,6 +20,7 @@ use crate::ArrayRef;
1920
use crate::EmptyMetadata;
2021
use crate::VectorExecutor;
2122
use crate::arrays::extension::ExtensionArray;
23+
use crate::arrays::extension::vtable::rules::PARENT_RULES;
2224
use crate::executor::ExecutionCtx;
2325
use crate::serde::ArrayChildren;
2426
use crate::vtable;
@@ -98,6 +100,14 @@ impl VTable for ExtensionVTable {
98100
fn execute(array: &Self::Array, ctx: &mut ExecutionCtx) -> VortexResult<Vector> {
99101
array.storage().execute(ctx)
100102
}
103+
104+
fn reduce_parent(
105+
array: &Self::Array,
106+
parent: &ArrayRef,
107+
child_idx: usize,
108+
) -> VortexResult<Option<ArrayRef>> {
109+
PARENT_RULES.evaluate(array, parent, child_idx)
110+
}
101111
}
102112

103113
#[derive(Debug)]
Lines changed: 345 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,345 @@
1+
// SPDX-License-Identifier: Apache-2.0
2+
// SPDX-FileCopyrightText: Copyright the Vortex contributors
3+
4+
use vortex_error::VortexResult;
5+
6+
use crate::Array;
7+
use crate::ArrayRef;
8+
use crate::IntoArray;
9+
use crate::arrays::AnyScalarFn;
10+
use crate::arrays::ConstantArray;
11+
use crate::arrays::ConstantVTable;
12+
use crate::arrays::ExtensionArray;
13+
use crate::arrays::ExtensionVTable;
14+
use crate::arrays::FilterArray;
15+
use crate::arrays::FilterVTable;
16+
use crate::arrays::ScalarFnArray;
17+
use crate::matchers::Exact;
18+
use crate::optimizer::rules::ArrayParentReduceRule;
19+
use crate::optimizer::rules::ParentRuleSet;
20+
21+
pub(super) const PARENT_RULES: ParentRuleSet<ExtensionVTable> = ParentRuleSet::new(&[
22+
ParentRuleSet::lift(&ExtensionFilterPushDownRule),
23+
ParentRuleSet::lift(&ExtensionScalarFnConstantPushDownRule),
24+
]);
25+
26+
/// Push filter operations into the storage array of an extension array.
27+
#[derive(Debug)]
28+
struct ExtensionFilterPushDownRule;
29+
30+
impl ArrayParentReduceRule<ExtensionVTable> for ExtensionFilterPushDownRule {
31+
type Parent = Exact<FilterVTable>;
32+
33+
fn parent(&self) -> Self::Parent {
34+
Exact::from(&FilterVTable)
35+
}
36+
37+
fn reduce_parent(
38+
&self,
39+
child: &ExtensionArray,
40+
parent: &FilterArray,
41+
_child_idx: usize,
42+
) -> VortexResult<Option<ArrayRef>> {
43+
let filtered_storage = child
44+
.storage()
45+
.clone()
46+
.filter(parent.filter_mask().clone())?;
47+
Ok(Some(
48+
ExtensionArray::new(child.ext_dtype().clone(), filtered_storage).into_array(),
49+
))
50+
}
51+
}
52+
53+
/// Push scalar function operations into the storage array when the other operand is a constant
54+
/// with the same extension type.
55+
#[derive(Debug)]
56+
struct ExtensionScalarFnConstantPushDownRule;
57+
58+
impl ArrayParentReduceRule<ExtensionVTable> for ExtensionScalarFnConstantPushDownRule {
59+
type Parent = AnyScalarFn;
60+
61+
fn parent(&self) -> Self::Parent {
62+
AnyScalarFn
63+
}
64+
65+
fn reduce_parent(
66+
&self,
67+
child: &ExtensionArray,
68+
parent: &ScalarFnArray,
69+
child_idx: usize,
70+
) -> VortexResult<Option<ArrayRef>> {
71+
// Check that all other children are constants with matching extension types.
72+
for (idx, sibling) in parent.children().iter().enumerate() {
73+
if idx == child_idx {
74+
continue;
75+
}
76+
77+
// Sibling must be a constant.
78+
let Some(const_array) = sibling.as_opt::<ConstantVTable>() else {
79+
return Ok(None);
80+
};
81+
82+
// Sibling must be an extension scalar with the same extension type.
83+
let Some(ext_scalar) = const_array.scalar().as_extension_opt() else {
84+
return Ok(None);
85+
};
86+
87+
if !ext_scalar
88+
.ext_dtype()
89+
.eq_ignore_nullability(child.ext_dtype())
90+
{
91+
return Ok(None);
92+
}
93+
94+
// The storage dtype must match.
95+
if !ext_scalar
96+
.ext_dtype()
97+
.storage_dtype()
98+
.eq_ignore_nullability(child.ext_dtype().storage_dtype())
99+
{
100+
return Ok(None);
101+
}
102+
}
103+
104+
// Build new children with storage arrays/scalars.
105+
let mut new_children = Vec::with_capacity(parent.children().len());
106+
for (idx, sibling) in parent.children().iter().enumerate() {
107+
if idx == child_idx {
108+
new_children.push(child.storage().clone());
109+
} else {
110+
let const_array = sibling.as_::<ConstantVTable>();
111+
let storage_scalar = const_array.scalar().as_extension().storage();
112+
new_children.push(ConstantArray::new(storage_scalar, child.len()).into_array());
113+
}
114+
}
115+
116+
Ok(Some(
117+
ScalarFnArray::try_new(parent.scalar_fn().clone(), new_children, child.len())?
118+
.into_array(),
119+
))
120+
}
121+
}
122+
123+
#[cfg(test)]
124+
mod tests {
125+
use std::sync::Arc;
126+
127+
use vortex_buffer::buffer;
128+
use vortex_dtype::DType;
129+
use vortex_dtype::ExtDType;
130+
use vortex_dtype::ExtID;
131+
use vortex_dtype::Nullability;
132+
use vortex_dtype::PType;
133+
use vortex_mask::Mask;
134+
use vortex_scalar::Scalar;
135+
136+
use crate::Array;
137+
use crate::IntoArray;
138+
use crate::ToCanonical;
139+
use crate::arrays::ConstantArray;
140+
use crate::arrays::ExtensionArray;
141+
use crate::arrays::ExtensionVTable;
142+
use crate::arrays::FilterArray;
143+
use crate::arrays::PrimitiveArray;
144+
use crate::arrays::PrimitiveVTable;
145+
use crate::arrays::ScalarFnArrayExt;
146+
use crate::expr::Binary;
147+
use crate::expr::Operator;
148+
use crate::optimizer::ArrayOptimizer;
149+
150+
fn test_ext_dtype() -> Arc<ExtDType> {
151+
Arc::new(ExtDType::new(
152+
ExtID::new("test_ext".into()),
153+
Arc::new(DType::Primitive(PType::I64, Nullability::NonNullable)),
154+
None,
155+
))
156+
}
157+
158+
#[test]
159+
fn test_filter_pushdown() {
160+
let ext_dtype = test_ext_dtype();
161+
let storage = buffer![1i64, 2, 3, 4, 5].into_array();
162+
let ext_array = ExtensionArray::new(ext_dtype.clone(), storage).into_array();
163+
164+
// Create a filter that selects elements at indices 0, 2, 4
165+
let mask = Mask::from_iter([true, false, true, false, true]);
166+
let filter_array = FilterArray::new(ext_array, mask).into_array();
167+
168+
// Optimize should push the filter into the storage
169+
let optimized = filter_array.optimize().unwrap();
170+
171+
// The result should be an ExtensionArray, not a FilterArray
172+
assert!(
173+
optimized.as_opt::<ExtensionVTable>().is_some(),
174+
"Expected ExtensionArray after optimization, got {}",
175+
optimized.encoding_id()
176+
);
177+
178+
let ext_result = optimized.as_::<ExtensionVTable>();
179+
assert_eq!(ext_result.len(), 3);
180+
assert_eq!(ext_result.ext_dtype().as_ref(), ext_dtype.as_ref());
181+
182+
// Check the storage values
183+
let storage_result: &[i64] = &ext_result.storage().to_primitive().buffer::<i64>();
184+
assert_eq!(storage_result, &[1, 3, 5]);
185+
}
186+
187+
#[test]
188+
fn test_filter_pushdown_nullable() {
189+
let ext_dtype = Arc::new(ExtDType::new(
190+
ExtID::new("test_ext".into()),
191+
Arc::new(DType::Primitive(PType::I64, Nullability::Nullable)),
192+
None,
193+
));
194+
let storage = PrimitiveArray::from_option_iter([Some(1i64), None, Some(3), Some(4), None])
195+
.into_array();
196+
let ext_array = ExtensionArray::new(ext_dtype, storage).into_array();
197+
198+
let mask = Mask::from_iter([true, true, false, false, true]);
199+
let filter_array = FilterArray::new(ext_array, mask).into_array();
200+
201+
let optimized = filter_array.optimize().unwrap();
202+
203+
assert!(optimized.as_opt::<ExtensionVTable>().is_some());
204+
let ext_result = optimized.as_::<ExtensionVTable>();
205+
assert_eq!(ext_result.len(), 3);
206+
207+
// Check values: should be [Some(1), None, None]
208+
let canonical = ext_result.storage().to_primitive();
209+
assert_eq!(canonical.len(), 3);
210+
}
211+
212+
#[test]
213+
fn test_scalar_fn_constant_pushdown_comparison() {
214+
let ext_dtype = test_ext_dtype();
215+
let storage = buffer![10i64, 20, 30, 40, 50].into_array();
216+
let ext_array = ExtensionArray::new(ext_dtype.clone(), storage).into_array();
217+
218+
// Create a constant extension scalar with value 25
219+
let const_scalar = Scalar::extension(ext_dtype, Scalar::from(25i64));
220+
let const_array = ConstantArray::new(const_scalar, 5).into_array();
221+
222+
// Create a binary comparison: ext_array < const_array
223+
let scalar_fn_array = Binary
224+
.try_new_array(5, Operator::Lt, [ext_array, const_array])
225+
.unwrap();
226+
227+
// Optimize should push down the comparison to storage
228+
let optimized = scalar_fn_array.optimize().unwrap();
229+
230+
// The result should still be a ScalarFnArray but operating on primitive storage
231+
let scalar_fn = optimized.as_opt::<crate::arrays::ScalarFnVTable>();
232+
assert!(
233+
scalar_fn.is_some(),
234+
"Expected ScalarFnArray after optimization"
235+
);
236+
237+
// The children should now be primitives, not extensions
238+
let children = scalar_fn.unwrap().children();
239+
assert_eq!(children.len(), 2);
240+
241+
// First child should be the primitive storage
242+
assert!(
243+
children[0].as_opt::<PrimitiveVTable>().is_some(),
244+
"Expected first child to be PrimitiveArray, got {}",
245+
children[0].encoding_id()
246+
);
247+
248+
// Second child should be a constant with primitive value
249+
assert!(
250+
children[1]
251+
.as_opt::<crate::arrays::ConstantVTable>()
252+
.is_some(),
253+
"Expected second child to be ConstantArray, got {}",
254+
children[1].encoding_id()
255+
);
256+
}
257+
258+
#[test]
259+
fn test_scalar_fn_no_pushdown_different_ext_types() {
260+
let ext_dtype1 = Arc::new(ExtDType::new(
261+
ExtID::new("type1".into()),
262+
Arc::new(DType::Primitive(PType::I64, Nullability::NonNullable)),
263+
None,
264+
));
265+
let ext_dtype2 = Arc::new(ExtDType::new(
266+
ExtID::new("type2".into()),
267+
Arc::new(DType::Primitive(PType::I64, Nullability::NonNullable)),
268+
None,
269+
));
270+
271+
let storage = buffer![10i64, 20, 30].into_array();
272+
let ext_array = ExtensionArray::new(ext_dtype1, storage).into_array();
273+
274+
// Create constant with different extension type
275+
let const_scalar = Scalar::extension(ext_dtype2, Scalar::from(25i64));
276+
let const_array = ConstantArray::new(const_scalar, 3).into_array();
277+
278+
let scalar_fn_array = Binary
279+
.try_new_array(3, Operator::Lt, [ext_array.clone(), const_array])
280+
.unwrap();
281+
282+
let optimized = scalar_fn_array.optimize().unwrap();
283+
284+
// The first child should still be an ExtensionArray (no pushdown happened)
285+
let scalar_fn = optimized.as_opt::<crate::arrays::ScalarFnVTable>().unwrap();
286+
assert!(
287+
scalar_fn.children()[0]
288+
.as_opt::<ExtensionVTable>()
289+
.is_some(),
290+
"Expected first child to remain ExtensionArray when ext types differ"
291+
);
292+
}
293+
294+
#[test]
295+
fn test_scalar_fn_no_pushdown_non_constant_sibling() {
296+
let ext_dtype = test_ext_dtype();
297+
298+
let storage1 = buffer![10i64, 20, 30].into_array();
299+
let ext_array1 = ExtensionArray::new(ext_dtype.clone(), storage1).into_array();
300+
301+
let storage2 = buffer![15i64, 25, 35].into_array();
302+
let ext_array2 = ExtensionArray::new(ext_dtype, storage2).into_array();
303+
304+
// Both children are extension arrays (not constants)
305+
let scalar_fn_array = Binary
306+
.try_new_array(3, Operator::Lt, [ext_array1.clone(), ext_array2])
307+
.unwrap();
308+
309+
let optimized = scalar_fn_array.optimize().unwrap();
310+
311+
// No pushdown should happen because sibling is not a constant
312+
let scalar_fn = optimized.as_opt::<crate::arrays::ScalarFnVTable>().unwrap();
313+
assert!(
314+
scalar_fn.children()[0]
315+
.as_opt::<ExtensionVTable>()
316+
.is_some(),
317+
"Expected first child to remain ExtensionArray when sibling is not constant"
318+
);
319+
}
320+
321+
#[test]
322+
fn test_scalar_fn_no_pushdown_non_extension_constant() {
323+
let ext_dtype = test_ext_dtype();
324+
let storage = buffer![10i64, 20, 30].into_array();
325+
let ext_array = ExtensionArray::new(ext_dtype, storage).into_array();
326+
327+
// Create a non-extension constant (plain primitive)
328+
let const_array = ConstantArray::new(Scalar::from(25i64), 3).into_array();
329+
330+
let scalar_fn_array = Binary
331+
.try_new_array(3, Operator::Lt, [ext_array.clone(), const_array])
332+
.unwrap();
333+
334+
let optimized = scalar_fn_array.optimize().unwrap();
335+
336+
// No pushdown should happen because constant is not an extension scalar
337+
let scalar_fn = optimized.as_opt::<crate::arrays::ScalarFnVTable>().unwrap();
338+
assert!(
339+
scalar_fn.children()[0]
340+
.as_opt::<ExtensionVTable>()
341+
.is_some(),
342+
"Expected first child to remain ExtensionArray when constant is not extension"
343+
);
344+
}
345+
}

0 commit comments

Comments
 (0)