Skip to content

Commit d3021c7

Browse files
perf[array]: extension pushdown filter + constant (#5794)
Signed-off-by: Joe Isaacs <[email protected]>
1 parent 064b5b0 commit d3021c7

File tree

2 files changed

+348
-0
lines changed

2 files changed

+348
-0
lines changed

vortex-array/src/arrays/extension/vtable/mod.rs

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
mod array;
55
mod canonical;
66
mod operations;
7+
mod rules;
78
mod validity;
89
mod visitor;
910

@@ -19,6 +20,7 @@ use crate::ArrayRef;
1920
use crate::EmptyMetadata;
2021
use crate::VectorExecutor;
2122
use crate::arrays::extension::ExtensionArray;
23+
use crate::arrays::extension::vtable::rules::PARENT_RULES;
2224
use crate::executor::ExecutionCtx;
2325
use crate::serde::ArrayChildren;
2426
use crate::vtable;
@@ -98,6 +100,14 @@ impl VTable for ExtensionVTable {
98100
fn execute(array: &Self::Array, ctx: &mut ExecutionCtx) -> VortexResult<Vector> {
99101
array.storage().execute(ctx)
100102
}
103+
104+
fn reduce_parent(
105+
array: &Self::Array,
106+
parent: &ArrayRef,
107+
child_idx: usize,
108+
) -> VortexResult<Option<ArrayRef>> {
109+
PARENT_RULES.evaluate(array, parent, child_idx)
110+
}
101111
}
102112

103113
#[derive(Debug)]
Lines changed: 338 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,338 @@
1+
// SPDX-License-Identifier: Apache-2.0
2+
// SPDX-FileCopyrightText: Copyright the Vortex contributors
3+
4+
use vortex_error::VortexResult;
5+
6+
use crate::Array;
7+
use crate::ArrayRef;
8+
use crate::IntoArray;
9+
use crate::arrays::AnyScalarFn;
10+
use crate::arrays::ConstantArray;
11+
use crate::arrays::ConstantVTable;
12+
use crate::arrays::ExtensionArray;
13+
use crate::arrays::ExtensionVTable;
14+
use crate::arrays::FilterArray;
15+
use crate::arrays::FilterVTable;
16+
use crate::arrays::ScalarFnArray;
17+
use crate::matchers::Exact;
18+
use crate::optimizer::rules::ArrayParentReduceRule;
19+
use crate::optimizer::rules::ParentRuleSet;
20+
21+
pub(super) const PARENT_RULES: ParentRuleSet<ExtensionVTable> = ParentRuleSet::new(&[
22+
ParentRuleSet::lift(&ExtensionFilterPushDownRule),
23+
ParentRuleSet::lift(&ExtensionScalarFnConstantPushDownRule),
24+
]);
25+
26+
/// Push filter operations into the storage array of an extension array.
27+
#[derive(Debug)]
28+
struct ExtensionFilterPushDownRule;
29+
30+
impl ArrayParentReduceRule<ExtensionVTable> for ExtensionFilterPushDownRule {
31+
type Parent = Exact<FilterVTable>;
32+
33+
fn parent(&self) -> Self::Parent {
34+
Exact::from(&FilterVTable)
35+
}
36+
37+
fn reduce_parent(
38+
&self,
39+
child: &ExtensionArray,
40+
parent: &FilterArray,
41+
child_idx: usize,
42+
) -> VortexResult<Option<ArrayRef>> {
43+
debug_assert_eq!(child_idx, 0);
44+
let filtered_storage = child
45+
.storage()
46+
.clone()
47+
.filter(parent.filter_mask().clone())?;
48+
Ok(Some(
49+
ExtensionArray::new(child.ext_dtype().clone(), filtered_storage).into_array(),
50+
))
51+
}
52+
}
53+
54+
/// Push scalar function operations into the storage array when the other operand is a constant
55+
/// with the same extension type.
56+
#[derive(Debug)]
57+
struct ExtensionScalarFnConstantPushDownRule;
58+
59+
impl ArrayParentReduceRule<ExtensionVTable> for ExtensionScalarFnConstantPushDownRule {
60+
type Parent = AnyScalarFn;
61+
62+
fn parent(&self) -> Self::Parent {
63+
AnyScalarFn
64+
}
65+
66+
fn reduce_parent(
67+
&self,
68+
child: &ExtensionArray,
69+
parent: &ScalarFnArray,
70+
child_idx: usize,
71+
) -> VortexResult<Option<ArrayRef>> {
72+
// Check that all other children are constants with matching extension types.
73+
for (idx, sibling) in parent.children().iter().enumerate() {
74+
if idx == child_idx {
75+
continue;
76+
}
77+
78+
// Sibling must be a constant.
79+
let Some(const_array) = sibling.as_opt::<ConstantVTable>() else {
80+
return Ok(None);
81+
};
82+
83+
// Sibling must be an extension scalar with the same extension type.
84+
let Some(ext_scalar) = const_array.scalar().as_extension_opt() else {
85+
return Ok(None);
86+
};
87+
88+
// ExtDType::eq_ignore_nullability checks id, metadata, and storage dtype
89+
if !ext_scalar
90+
.ext_dtype()
91+
.eq_ignore_nullability(child.ext_dtype())
92+
{
93+
return Ok(None);
94+
}
95+
}
96+
97+
// Build new children with storage arrays/scalars.
98+
let mut new_children = Vec::with_capacity(parent.children().len());
99+
for (idx, sibling) in parent.children().iter().enumerate() {
100+
if idx == child_idx {
101+
new_children.push(child.storage().clone());
102+
} else {
103+
let const_array = sibling.as_::<ConstantVTable>();
104+
let storage_scalar = const_array.scalar().as_extension().storage();
105+
new_children.push(ConstantArray::new(storage_scalar, child.len()).into_array());
106+
}
107+
}
108+
109+
Ok(Some(
110+
ScalarFnArray::try_new(parent.scalar_fn().clone(), new_children, child.len())?
111+
.into_array(),
112+
))
113+
}
114+
}
115+
116+
#[cfg(test)]
117+
mod tests {
118+
use std::sync::Arc;
119+
120+
use vortex_buffer::buffer;
121+
use vortex_dtype::DType;
122+
use vortex_dtype::ExtDType;
123+
use vortex_dtype::ExtID;
124+
use vortex_dtype::Nullability;
125+
use vortex_dtype::PType;
126+
use vortex_mask::Mask;
127+
use vortex_scalar::Scalar;
128+
129+
use crate::Array;
130+
use crate::IntoArray;
131+
use crate::ToCanonical;
132+
use crate::arrays::ConstantArray;
133+
use crate::arrays::ExtensionArray;
134+
use crate::arrays::ExtensionVTable;
135+
use crate::arrays::FilterArray;
136+
use crate::arrays::PrimitiveArray;
137+
use crate::arrays::PrimitiveVTable;
138+
use crate::arrays::ScalarFnArrayExt;
139+
use crate::expr::Binary;
140+
use crate::expr::Operator;
141+
use crate::optimizer::ArrayOptimizer;
142+
143+
fn test_ext_dtype() -> Arc<ExtDType> {
144+
Arc::new(ExtDType::new(
145+
ExtID::new("test_ext".into()),
146+
Arc::new(DType::Primitive(PType::I64, Nullability::NonNullable)),
147+
None,
148+
))
149+
}
150+
151+
#[test]
152+
fn test_filter_pushdown() {
153+
let ext_dtype = test_ext_dtype();
154+
let storage = buffer![1i64, 2, 3, 4, 5].into_array();
155+
let ext_array = ExtensionArray::new(ext_dtype.clone(), storage).into_array();
156+
157+
// Create a filter that selects elements at indices 0, 2, 4
158+
let mask = Mask::from_iter([true, false, true, false, true]);
159+
let filter_array = FilterArray::new(ext_array, mask).into_array();
160+
161+
// Optimize should push the filter into the storage
162+
let optimized = filter_array.optimize().unwrap();
163+
164+
// The result should be an ExtensionArray, not a FilterArray
165+
assert!(
166+
optimized.as_opt::<ExtensionVTable>().is_some(),
167+
"Expected ExtensionArray after optimization, got {}",
168+
optimized.encoding_id()
169+
);
170+
171+
let ext_result = optimized.as_::<ExtensionVTable>();
172+
assert_eq!(ext_result.len(), 3);
173+
assert_eq!(ext_result.ext_dtype().as_ref(), ext_dtype.as_ref());
174+
175+
// Check the storage values
176+
let storage_result: &[i64] = &ext_result.storage().to_primitive().buffer::<i64>();
177+
assert_eq!(storage_result, &[1, 3, 5]);
178+
}
179+
180+
#[test]
181+
fn test_filter_pushdown_nullable() {
182+
let ext_dtype = Arc::new(ExtDType::new(
183+
ExtID::new("test_ext".into()),
184+
Arc::new(DType::Primitive(PType::I64, Nullability::Nullable)),
185+
None,
186+
));
187+
let storage = PrimitiveArray::from_option_iter([Some(1i64), None, Some(3), Some(4), None])
188+
.into_array();
189+
let ext_array = ExtensionArray::new(ext_dtype, storage).into_array();
190+
191+
let mask = Mask::from_iter([true, true, false, false, true]);
192+
let filter_array = FilterArray::new(ext_array, mask).into_array();
193+
194+
let optimized = filter_array.optimize().unwrap();
195+
196+
assert!(optimized.as_opt::<ExtensionVTable>().is_some());
197+
let ext_result = optimized.as_::<ExtensionVTable>();
198+
assert_eq!(ext_result.len(), 3);
199+
200+
// Check values: should be [Some(1), None, None]
201+
let canonical = ext_result.storage().to_primitive();
202+
assert_eq!(canonical.len(), 3);
203+
}
204+
205+
#[test]
206+
fn test_scalar_fn_constant_pushdown_comparison() {
207+
let ext_dtype = test_ext_dtype();
208+
let storage = buffer![10i64, 20, 30, 40, 50].into_array();
209+
let ext_array = ExtensionArray::new(ext_dtype.clone(), storage).into_array();
210+
211+
// Create a constant extension scalar with value 25
212+
let const_scalar = Scalar::extension(ext_dtype, Scalar::from(25i64));
213+
let const_array = ConstantArray::new(const_scalar, 5).into_array();
214+
215+
// Create a binary comparison: ext_array < const_array
216+
let scalar_fn_array = Binary
217+
.try_new_array(5, Operator::Lt, [ext_array, const_array])
218+
.unwrap();
219+
220+
// Optimize should push down the comparison to storage
221+
let optimized = scalar_fn_array.optimize().unwrap();
222+
223+
// The result should still be a ScalarFnArray but operating on primitive storage
224+
let scalar_fn = optimized.as_opt::<crate::arrays::ScalarFnVTable>();
225+
assert!(
226+
scalar_fn.is_some(),
227+
"Expected ScalarFnArray after optimization"
228+
);
229+
230+
// The children should now be primitives, not extensions
231+
let children = scalar_fn.unwrap().children();
232+
assert_eq!(children.len(), 2);
233+
234+
// First child should be the primitive storage
235+
assert!(
236+
children[0].as_opt::<PrimitiveVTable>().is_some(),
237+
"Expected first child to be PrimitiveArray, got {}",
238+
children[0].encoding_id()
239+
);
240+
241+
// Second child should be a constant with primitive value
242+
assert!(
243+
children[1]
244+
.as_opt::<crate::arrays::ConstantVTable>()
245+
.is_some(),
246+
"Expected second child to be ConstantArray, got {}",
247+
children[1].encoding_id()
248+
);
249+
}
250+
251+
#[test]
252+
fn test_scalar_fn_no_pushdown_different_ext_types() {
253+
let ext_dtype1 = Arc::new(ExtDType::new(
254+
ExtID::new("type1".into()),
255+
Arc::new(DType::Primitive(PType::I64, Nullability::NonNullable)),
256+
None,
257+
));
258+
let ext_dtype2 = Arc::new(ExtDType::new(
259+
ExtID::new("type2".into()),
260+
Arc::new(DType::Primitive(PType::I64, Nullability::NonNullable)),
261+
None,
262+
));
263+
264+
let storage = buffer![10i64, 20, 30].into_array();
265+
let ext_array = ExtensionArray::new(ext_dtype1, storage).into_array();
266+
267+
// Create constant with different extension type
268+
let const_scalar = Scalar::extension(ext_dtype2, Scalar::from(25i64));
269+
let const_array = ConstantArray::new(const_scalar, 3).into_array();
270+
271+
let scalar_fn_array = Binary
272+
.try_new_array(3, Operator::Lt, [ext_array.clone(), const_array])
273+
.unwrap();
274+
275+
let optimized = scalar_fn_array.optimize().unwrap();
276+
277+
// The first child should still be an ExtensionArray (no pushdown happened)
278+
let scalar_fn = optimized.as_opt::<crate::arrays::ScalarFnVTable>().unwrap();
279+
assert!(
280+
scalar_fn.children()[0]
281+
.as_opt::<ExtensionVTable>()
282+
.is_some(),
283+
"Expected first child to remain ExtensionArray when ext types differ"
284+
);
285+
}
286+
287+
#[test]
288+
fn test_scalar_fn_no_pushdown_non_constant_sibling() {
289+
let ext_dtype = test_ext_dtype();
290+
291+
let storage1 = buffer![10i64, 20, 30].into_array();
292+
let ext_array1 = ExtensionArray::new(ext_dtype.clone(), storage1).into_array();
293+
294+
let storage2 = buffer![15i64, 25, 35].into_array();
295+
let ext_array2 = ExtensionArray::new(ext_dtype, storage2).into_array();
296+
297+
// Both children are extension arrays (not constants)
298+
let scalar_fn_array = Binary
299+
.try_new_array(3, Operator::Lt, [ext_array1.clone(), ext_array2])
300+
.unwrap();
301+
302+
let optimized = scalar_fn_array.optimize().unwrap();
303+
304+
// No pushdown should happen because sibling is not a constant
305+
let scalar_fn = optimized.as_opt::<crate::arrays::ScalarFnVTable>().unwrap();
306+
assert!(
307+
scalar_fn.children()[0]
308+
.as_opt::<ExtensionVTable>()
309+
.is_some(),
310+
"Expected first child to remain ExtensionArray when sibling is not constant"
311+
);
312+
}
313+
314+
#[test]
315+
fn test_scalar_fn_no_pushdown_non_extension_constant() {
316+
let ext_dtype = test_ext_dtype();
317+
let storage = buffer![10i64, 20, 30].into_array();
318+
let ext_array = ExtensionArray::new(ext_dtype, storage).into_array();
319+
320+
// Create a non-extension constant (plain primitive)
321+
let const_array = ConstantArray::new(Scalar::from(25i64), 3).into_array();
322+
323+
let scalar_fn_array = Binary
324+
.try_new_array(3, Operator::Lt, [ext_array.clone(), const_array])
325+
.unwrap();
326+
327+
let optimized = scalar_fn_array.optimize().unwrap();
328+
329+
// No pushdown should happen because constant is not an extension scalar
330+
let scalar_fn = optimized.as_opt::<crate::arrays::ScalarFnVTable>().unwrap();
331+
assert!(
332+
scalar_fn.children()[0]
333+
.as_opt::<ExtensionVTable>()
334+
.is_some(),
335+
"Expected first child to remain ExtensionArray when constant is not extension"
336+
);
337+
}
338+
}

0 commit comments

Comments
 (0)