Skip to content

Commit 3333268

Browse files
committed
Move rules onto array vtable so we can optimize during construction without a session
Signed-off-by: Nicholas Gates <[email protected]>
1 parent f65667c commit 3333268

File tree

12 files changed

+387
-220
lines changed

12 files changed

+387
-220
lines changed

vortex-array/src/arrays/scalar_fn/rules.rs

Lines changed: 13 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
// SPDX-FileCopyrightText: Copyright the Vortex contributors
33

44
use std::any::Any;
5+
use std::sync::Arc;
56

67
use vortex_dtype::DType;
78
use vortex_error::VortexExpect;
@@ -12,6 +13,7 @@ use vortex_vector::Datum;
1213
use vortex_vector::VectorOps;
1314
use vortex_vector::datum_matches_dtype;
1415

16+
use crate::Array;
1517
use crate::ArrayRef;
1618
use crate::ArrayVisitor;
1719
use crate::IntoArray;
@@ -22,6 +24,7 @@ use crate::arrays::ScalarFnVTable;
2224
use crate::expr::ExecutionArgs;
2325
use crate::expr::ReduceCtx;
2426
use crate::expr::ReduceNode;
27+
use crate::expr::ReduceNodeRef;
2528
use crate::expr::ScalarFn;
2629
use crate::optimizer::rules::ArrayReduceRule;
2730
use crate::optimizer::rules::ReduceRuleSet;
@@ -109,8 +112,12 @@ impl ReduceNode for ArrayRef {
109112
self.as_opt::<ScalarFnVTable>().map(|a| a.scalar_fn())
110113
}
111114

112-
fn child(&self, idx: usize) -> Box<dyn ReduceNode> {
113-
Box::new(self.children()[idx].clone())
115+
fn child(&self, idx: usize) -> ReduceNodeRef {
116+
Arc::new(<dyn Array>::children(self)[idx].clone())
117+
}
118+
119+
fn child_count(&self) -> usize {
120+
self.nchildren()
114121
}
115122
}
116123

@@ -119,12 +126,12 @@ struct ArrayReduceCtx {
119126
len: usize,
120127
}
121128
impl ReduceCtx for ArrayReduceCtx {
122-
fn create_node(
129+
fn new_node(
123130
&self,
124131
scalar_fn: ScalarFn,
125-
children: &[Box<dyn ReduceNode>],
126-
) -> VortexResult<Box<dyn ReduceNode>> {
127-
Ok(Box::new(
132+
children: &[ReduceNodeRef],
133+
) -> VortexResult<ReduceNodeRef> {
134+
Ok(Arc::new(
128135
ScalarFnArray::try_new(
129136
scalar_fn,
130137
children

vortex-array/src/expr/exprs/get_item.rs

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@ use crate::expr::Mask;
3232
use crate::expr::Pack;
3333
use crate::expr::ReduceCtx;
3434
use crate::expr::ReduceNode;
35+
use crate::expr::ReduceNodeRef;
3536
use crate::expr::StatsCatalog;
3637
use crate::expr::VTable;
3738
use crate::expr::VTableExt;
@@ -145,7 +146,7 @@ impl VTable for GetItem {
145146
field_name: &FieldName,
146147
node: &dyn ReduceNode,
147148
ctx: &dyn ReduceCtx,
148-
) -> VortexResult<Option<Box<dyn ReduceNode>>> {
149+
) -> VortexResult<Option<ReduceNodeRef>> {
149150
let child = node.child(0);
150151
if let Some(child_fn) = child.scalar_fn()
151152
&& let Some(pack) = child_fn.as_opt::<Pack>()
@@ -155,9 +156,9 @@ impl VTable for GetItem {
155156

156157
// Possibly mask the field if the pack is nullable
157158
if pack.nullability.is_nullable() {
158-
field = ctx.create_node(
159+
field = ctx.new_node(
159160
Mask.bind(EmptyOptions),
160-
&[field, ctx.create_node(Literal.bind(true.into()), &[])?],
161+
&[field, ctx.new_node(Literal.bind(true.into()), &[])?],
161162
)?;
162163
}
163164

vortex-array/src/expr/exprs/merge.rs

Lines changed: 31 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -27,11 +27,14 @@ use crate::expr::ChildName;
2727
use crate::expr::ExecutionArgs;
2828
use crate::expr::ExprId;
2929
use crate::expr::Expression;
30-
use crate::expr::SimplifyCtx;
30+
use crate::expr::GetItem;
31+
use crate::expr::Pack;
32+
use crate::expr::PackOptions;
33+
use crate::expr::ReduceCtx;
34+
use crate::expr::ReduceNode;
35+
use crate::expr::ReduceNodeRef;
3136
use crate::expr::VTable;
3237
use crate::expr::VTableExt;
33-
use crate::expr::get_item;
34-
use crate::expr::pack;
3538
use crate::validity::Validity;
3639

3740
/// Merge zero or more expressions that ALL return structs.
@@ -185,19 +188,19 @@ impl VTable for Merge {
185188
todo!()
186189
}
187190

188-
fn simplify(
191+
fn reduce(
189192
&self,
190193
options: &Self::Options,
191-
expr: &Expression,
192-
ctx: &dyn SimplifyCtx,
193-
) -> VortexResult<Option<Expression>> {
194-
let merge_dtype = ctx.return_dtype(expr)?;
195-
let mut names = Vec::with_capacity(expr.children().len() * 2);
196-
let mut children = Vec::with_capacity(expr.children().len() * 2);
194+
node: &dyn ReduceNode,
195+
ctx: &dyn ReduceCtx,
196+
) -> VortexResult<Option<ReduceNodeRef>> {
197+
let merge_dtype = node.node_dtype()?;
198+
let mut names = Vec::with_capacity(node.child_count() * 2);
199+
let mut children = Vec::with_capacity(node.child_count() * 2);
197200
let mut duplicate_names = HashSet::<_>::new();
198201

199-
for child in expr.children().iter() {
200-
let child_dtype = ctx.return_dtype(child)?;
202+
for child in node.children() {
203+
let child_dtype = child.node_dtype()?;
201204
if !child_dtype.is_struct() {
202205
vortex_bail!(
203206
"Merge child must return a non-nullable struct dtype, got {}",
@@ -227,15 +230,21 @@ impl VTable for Merge {
227230
}
228231
}
229232

230-
let expr = pack(
231-
names
232-
.into_iter()
233-
.zip(children)
234-
.map(|(name, child)| (name.clone(), get_item(name, child))),
235-
merge_dtype.nullability(),
236-
);
237-
238-
Ok(Some(expr))
233+
let pack_children: Vec<_> = names
234+
.iter()
235+
.zip(children)
236+
.map(|(name, child)| ctx.new_node(GetItem.bind(name.clone()), &[child]))
237+
.try_collect()?;
238+
239+
let pack_expr = ctx.new_node(
240+
Pack.bind(PackOptions {
241+
names: FieldNames::from(names),
242+
nullability: merge_dtype.nullability(),
243+
}),
244+
&pack_children,
245+
)?;
246+
247+
Ok(Some(pack_expr))
239248
}
240249

241250
fn is_null_sensitive(&self, _instance: &Self::Options) -> bool {
@@ -575,7 +584,7 @@ mod tests {
575584
DuplicateHandling::RightMost,
576585
);
577586

578-
let result = e.simplify(&dtype).unwrap();
587+
let result = e.optimize_root(&dtype).unwrap();
579588

580589
assert!(result.is::<Pack>());
581590
assert_eq!(

vortex-array/src/expr/mod.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -31,13 +31,13 @@ mod expression;
3131
mod exprs;
3232
mod field;
3333
pub mod forms;
34+
mod optimize;
3435
mod options;
3536
pub mod proto;
3637
pub mod pruning;
3738
mod scalar_fn;
3839
pub mod session;
3940
mod signature;
40-
mod simplify;
4141
pub mod stats;
4242
pub mod transform;
4343
pub mod traversal;

0 commit comments

Comments
 (0)