Skip to content

Commit 157579e

Browse files
committed
Move rules onto array vtable so we can optimize during construction without a session
Signed-off-by: Nicholas Gates <[email protected]>
1 parent 6ce6189 commit 157579e

File tree

9 files changed

+331
-80
lines changed

9 files changed

+331
-80
lines changed

vortex-array/src/arrays/scalar_fn/mod.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
mod array;
55
mod kernel;
66
mod metadata;
7+
mod rules;
78
mod vtable;
89

910
pub use array::*;
Lines changed: 144 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,144 @@
1+
// SPDX-License-Identifier: Apache-2.0
2+
// SPDX-FileCopyrightText: Copyright the Vortex contributors
3+
4+
use std::any::Any;
5+
6+
use vortex_dtype::DType;
7+
use vortex_error::VortexExpect;
8+
use vortex_error::VortexResult;
9+
use vortex_error::vortex_ensure;
10+
use vortex_scalar::Scalar;
11+
use vortex_vector::Datum;
12+
use vortex_vector::VectorOps;
13+
use vortex_vector::datum_matches_dtype;
14+
15+
use crate::ArrayRef;
16+
use crate::ArrayVisitor;
17+
use crate::IntoArray;
18+
use crate::arrays::ConstantArray;
19+
use crate::arrays::ConstantVTable;
20+
use crate::arrays::ScalarFnArray;
21+
use crate::arrays::ScalarFnVTable;
22+
use crate::expr::ExecutionArgs;
23+
use crate::expr::ReduceCtx;
24+
use crate::expr::ReduceNode;
25+
use crate::expr::ScalarFn;
26+
use crate::optimizer::rules::ArrayReduceRule;
27+
use crate::optimizer::rules::ReduceRuleSet;
28+
29+
pub(super) const RULES: ReduceRuleSet<ScalarFnVTable> =
30+
ReduceRuleSet::new(&[&ScalarFnConstantRule, &ScalarFnAbstractReduceRule]);
31+
32+
#[derive(Debug)]
33+
struct ScalarFnConstantRule;
34+
impl ArrayReduceRule<ScalarFnVTable> for ScalarFnConstantRule {
35+
fn reduce(&self, array: &ScalarFnArray) -> VortexResult<Option<ArrayRef>> {
36+
if !array.children.iter().all(|c| c.is::<ConstantVTable>()) {
37+
return Ok(None);
38+
}
39+
40+
let input_datums: Vec<_> = array
41+
.children
42+
.iter()
43+
.map(|c| c.as_::<ConstantVTable>().scalar().to_vector_scalar())
44+
.map(Datum::Scalar)
45+
.collect();
46+
let input_dtypes = array.children.iter().map(|c| c.dtype().clone()).collect();
47+
48+
let result = array.scalar_fn.execute(ExecutionArgs {
49+
datums: input_datums,
50+
dtypes: input_dtypes,
51+
row_count: array.len,
52+
return_dtype: array.dtype.clone(),
53+
})?;
54+
vortex_ensure!(
55+
datum_matches_dtype(&result, &array.dtype),
56+
"Scalar function {} result does not match expected dtype",
57+
array.scalar_fn
58+
);
59+
60+
let result = match result {
61+
Datum::Scalar(s) => s,
62+
Datum::Vector(v) => {
63+
tracing::warn!(
64+
"Scalar function {} returned vector from execution over all scalar inputs",
65+
array.scalar_fn
66+
);
67+
v.scalar_at(0)
68+
}
69+
};
70+
71+
Ok(Some(
72+
ConstantArray::new(Scalar::from_vector_scalar(result, &array.dtype)?, array.len)
73+
.into_array(),
74+
))
75+
}
76+
}
77+
78+
#[derive(Debug)]
79+
struct ScalarFnAbstractReduceRule;
80+
impl ArrayReduceRule<ScalarFnVTable> for ScalarFnAbstractReduceRule {
81+
fn reduce(&self, array: &ScalarFnArray) -> VortexResult<Option<ArrayRef>> {
82+
if let Some(reduced) = array.scalar_fn.reduce(
83+
// Blergh, re-boxing
84+
&array.to_array(),
85+
&ArrayReduceCtx { len: array.len },
86+
)? {
87+
return Ok(Some(
88+
reduced
89+
.as_any()
90+
.downcast_ref::<ArrayRef>()
91+
.vortex_expect("ReduceNode is not an ArrayRef")
92+
.clone(),
93+
));
94+
}
95+
Ok(None)
96+
}
97+
}
98+
99+
impl ReduceNode for ArrayRef {
100+
fn as_any(&self) -> &dyn Any {
101+
self
102+
}
103+
104+
fn node_dtype(&self) -> VortexResult<DType> {
105+
Ok(self.as_ref().dtype().clone())
106+
}
107+
108+
fn scalar_fn(&self) -> Option<&ScalarFn> {
109+
self.as_opt::<ScalarFnVTable>().map(|a| a.scalar_fn())
110+
}
111+
112+
fn child(&self, idx: usize) -> Box<dyn ReduceNode> {
113+
Box::new(self.children()[idx].clone())
114+
}
115+
}
116+
117+
struct ArrayReduceCtx {
118+
// The length of the array being reduced
119+
len: usize,
120+
}
121+
impl ReduceCtx for ArrayReduceCtx {
122+
fn create_node(
123+
&self,
124+
scalar_fn: ScalarFn,
125+
children: &[Box<dyn ReduceNode>],
126+
) -> VortexResult<Box<dyn ReduceNode>> {
127+
Ok(Box::new(
128+
ScalarFnArray::try_new(
129+
scalar_fn,
130+
children
131+
.iter()
132+
.map(|c| {
133+
c.as_any()
134+
.downcast_ref::<ArrayRef>()
135+
.vortex_expect("ReduceNode is not an ArrayRef")
136+
.clone()
137+
})
138+
.collect(),
139+
self.len,
140+
)?
141+
.into_array(),
142+
))
143+
}
144+
}

vortex-array/src/arrays/scalar_fn/vtable/mod.rs

Lines changed: 2 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -16,22 +16,17 @@ use vortex_dtype::DType;
1616
use vortex_error::VortexResult;
1717
use vortex_error::vortex_bail;
1818
use vortex_error::vortex_ensure;
19-
use vortex_scalar::Scalar;
20-
use vortex_vector::Datum;
21-
use vortex_vector::VectorOps;
22-
use vortex_vector::datum_matches_dtype;
2319

2420
use crate::Array;
2521
use crate::ArrayRef;
2622
use crate::IntoArray;
27-
use crate::arrays::ConstantArray;
2823
use crate::arrays::ConstantVTable;
2924
use crate::arrays::scalar_fn::array::ScalarFnArray;
3025
use crate::arrays::scalar_fn::kernel::KernelInput;
3126
use crate::arrays::scalar_fn::kernel::ScalarFnKernel;
3227
use crate::arrays::scalar_fn::metadata::ScalarFnMetadata;
28+
use crate::arrays::scalar_fn::rules::RULES;
3329
use crate::expr;
34-
use crate::expr::ExecutionArgs;
3530
use crate::expr::ExprVTable;
3631
use crate::expr::ScalarFn;
3732
use crate::kernel::BindCtx;
@@ -170,45 +165,7 @@ impl VTable for ScalarFnVTable {
170165
}
171166

172167
fn reduce(array: &Self::Array) -> VortexResult<Option<ArrayRef>> {
173-
if !array.children.iter().all(|c| c.is::<ConstantVTable>()) {
174-
return Ok(None);
175-
}
176-
177-
let input_datums: Vec<_> = array
178-
.children
179-
.iter()
180-
.map(|c| c.as_::<ConstantVTable>().scalar().to_vector_scalar())
181-
.map(Datum::Scalar)
182-
.collect();
183-
let input_dtypes = array.children.iter().map(|c| c.dtype().clone()).collect();
184-
185-
let result = array.scalar_fn.execute(ExecutionArgs {
186-
datums: input_datums,
187-
dtypes: input_dtypes,
188-
row_count: array.len,
189-
return_dtype: array.dtype.clone(),
190-
})?;
191-
vortex_ensure!(
192-
datum_matches_dtype(&result, &array.dtype),
193-
"Scalar function {} result does not match expected dtype",
194-
array.scalar_fn
195-
);
196-
197-
let result = match result {
198-
Datum::Scalar(s) => s,
199-
Datum::Vector(v) => {
200-
tracing::warn!(
201-
"Scalar function {} returned vector from execution over all scalar inputs",
202-
array.scalar_fn
203-
);
204-
v.scalar_at(0)
205-
}
206-
};
207-
208-
Ok(Some(
209-
ConstantArray::new(Scalar::from_vector_scalar(result, &array.dtype)?, array.len)
210-
.into_array(),
211-
))
168+
RULES.evaluate(array)
212169
}
213170
}
214171

vortex-array/src/expr/exprs/get_item.rs

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,10 +23,15 @@ use crate::builtins::ExprBuiltins;
2323
use crate::compute::mask;
2424
use crate::expr::Arity;
2525
use crate::expr::ChildName;
26+
use crate::expr::EmptyOptions;
2627
use crate::expr::ExecutionArgs;
2728
use crate::expr::ExprId;
2829
use crate::expr::Expression;
30+
use crate::expr::Literal;
31+
use crate::expr::Mask;
2932
use crate::expr::Pack;
33+
use crate::expr::ReduceCtx;
34+
use crate::expr::ReduceNode;
3035
use crate::expr::StatsCatalog;
3136
use crate::expr::VTable;
3237
use crate::expr::VTableExt;
@@ -135,6 +140,33 @@ impl VTable for GetItem {
135140
}
136141
}
137142

143+
fn reduce(
144+
&self,
145+
field_name: &FieldName,
146+
node: &dyn ReduceNode,
147+
ctx: &dyn ReduceCtx,
148+
) -> VortexResult<Option<Box<dyn ReduceNode>>> {
149+
let child = node.child(0);
150+
if let Some(child_fn) = child.scalar_fn()
151+
&& let Some(pack) = child_fn.as_opt::<Pack>()
152+
&& let Some(idx) = pack.names.find(field_name)
153+
{
154+
let mut field = child.child(idx);
155+
156+
// Possibly mask the field if the pack is nullable
157+
if pack.nullability.is_nullable() {
158+
field = ctx.create_node(
159+
Mask.bind(EmptyOptions),
160+
&[field, ctx.create_node(Literal.bind(true.into()), &[])?],
161+
)?;
162+
}
163+
164+
return Ok(Some(field));
165+
}
166+
167+
Ok(None)
168+
}
169+
138170
fn simplify_untyped(
139171
&self,
140172
field_name: &FieldName,

vortex-array/src/expr/exprs/select.rs

Lines changed: 35 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -136,41 +136,6 @@ impl VTable for Select {
136136
Ok(DType::Struct(projected, child_dtype.nullability()))
137137
}
138138

139-
fn simplify(
140-
&self,
141-
options: &Self::Options,
142-
expr: &Expression,
143-
ctx: &dyn SimplifyCtx,
144-
) -> VortexResult<Option<Expression>> {
145-
let child = expr.child(0);
146-
let child_dtype = ctx.return_dtype(child)?;
147-
let child_nullability = child_dtype.nullability();
148-
149-
let child_dtype = child_dtype.as_struct_fields_opt().ok_or_else(|| {
150-
vortex_err!(
151-
"Select child must return a struct dtype, however it was a {}",
152-
child_dtype
153-
)
154-
})?;
155-
156-
let expr = pack(
157-
options
158-
.as_include_names(child_dtype.names())
159-
.map_err(|e| {
160-
e.with_context(format!(
161-
"Select fields {:?} must be a subset of child fields {:?}",
162-
options,
163-
child_dtype.names()
164-
))
165-
})?
166-
.iter()
167-
.map(|name| (name.clone(), get_item(name.clone(), child.clone()))),
168-
child_nullability,
169-
);
170-
171-
Ok(Some(expr))
172-
}
173-
174139
fn evaluate(
175140
&self,
176141
selection: &FieldSelection,
@@ -238,6 +203,41 @@ impl VTable for Select {
238203
.into())
239204
}
240205

206+
fn simplify(
207+
&self,
208+
options: &Self::Options,
209+
expr: &Expression,
210+
ctx: &dyn SimplifyCtx,
211+
) -> VortexResult<Option<Expression>> {
212+
let child = expr.child(0);
213+
let child_dtype = ctx.return_dtype(child)?;
214+
let child_nullability = child_dtype.nullability();
215+
216+
let child_dtype = child_dtype.as_struct_fields_opt().ok_or_else(|| {
217+
vortex_err!(
218+
"Select child must return a struct dtype, however it was a {}",
219+
child_dtype
220+
)
221+
})?;
222+
223+
let expr = pack(
224+
options
225+
.as_include_names(child_dtype.names())
226+
.map_err(|e| {
227+
e.with_context(format!(
228+
"Select fields {:?} must be a subset of child fields {:?}",
229+
options,
230+
child_dtype.names()
231+
))
232+
})?
233+
.iter()
234+
.map(|name| (name.clone(), get_item(name.clone(), child.clone()))),
235+
child_nullability,
236+
);
237+
238+
Ok(Some(expr))
239+
}
240+
241241
fn is_null_sensitive(&self, _instance: &Self::Options) -> bool {
242242
true
243243
}

vortex-array/src/expr/scalar_fn.rs

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,8 @@ use crate::expr::ExecutionArgs;
1919
use crate::expr::ExprId;
2020
use crate::expr::ExprVTable;
2121
use crate::expr::Expression;
22+
use crate::expr::ReduceCtx;
23+
use crate::expr::ReduceNode;
2224
use crate::expr::VTable;
2325
use crate::expr::options::ExpressionOptions;
2426
use crate::expr::signature::ExpressionSignature;
@@ -74,6 +76,16 @@ impl ScalarFn {
7476
}
7577
}
7678

79+
/// Returns whether the scalar function is of the given vtable type.
80+
pub fn is<V: VTable>(&self) -> bool {
81+
self.vtable.is::<V>()
82+
}
83+
84+
/// Returns the typed options for this expression if it matches the given vtable type.
85+
pub fn as_opt<V: VTable>(&self) -> Option<&V::Options> {
86+
self.options().as_any().downcast_ref::<V::Options>()
87+
}
88+
7789
/// Signature information for this expression.
7890
pub fn signature(&self) -> ExpressionSignature<'_> {
7991
ExpressionSignature {
@@ -101,6 +113,15 @@ impl ScalarFn {
101113
pub fn execute(&self, ctx: ExecutionArgs) -> VortexResult<Datum> {
102114
self.vtable.as_dyn().execute(self.options.deref(), ctx)
103115
}
116+
117+
/// Perform abstract reduction on this scalar function node.
118+
pub fn reduce(
119+
&self,
120+
node: &dyn ReduceNode,
121+
ctx: &dyn ReduceCtx,
122+
) -> VortexResult<Option<Box<dyn ReduceNode>>> {
123+
self.vtable.as_dyn().reduce(self.options.deref(), node, ctx)
124+
}
104125
}
105126

106127
impl Clone for ScalarFn {

0 commit comments

Comments
 (0)