Skip to content

Commit bedf610

Browse files
committed
Expressions
Signed-off-by: Nicholas Gates <[email protected]>
1 parent 71d8418 commit bedf610

File tree

25 files changed

+228
-117
lines changed

25 files changed

+228
-117
lines changed

vortex-array/src/arrays/scalar_fn/array.rs

Lines changed: 16 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,8 @@ use vortex_error::vortex_ensure;
88
use crate::Array;
99
use crate::ArrayRef;
1010
use crate::arrays::ScalarFnVTable;
11-
use crate::expr::BoundExpression;
11+
use crate::expr::ReturnDTypeCtx;
12+
use crate::expr::ScalarFn;
1213
use crate::stats::ArrayStats;
1314
use crate::vtable::ArrayVTable;
1415
use crate::vtable::ArrayVTableExt;
@@ -17,7 +18,7 @@ use crate::vtable::ArrayVTableExt;
1718
pub struct ScalarFnArray {
1819
// NOTE(ngates): we should fix vtables so we don't have to hold this
1920
pub(super) vtable: ArrayVTable,
20-
pub(super) bound: BoundExpression,
21+
pub(super) bound: ScalarFn,
2122
pub(super) dtype: DType,
2223
pub(super) len: usize,
2324
pub(super) children: Vec<ArrayRef>,
@@ -26,13 +27,8 @@ pub struct ScalarFnArray {
2627

2728
impl ScalarFnArray {
2829
/// Create a new ScalarFnArray from a scalar function and its children.
29-
pub fn try_new(
30-
bound: BoundExpression,
31-
children: Vec<ArrayRef>,
32-
len: usize,
33-
) -> VortexResult<Self> {
34-
let arg_dtypes: Vec<_> = children.iter().map(|c| c.dtype().clone()).collect();
35-
let dtype = bound.return_dtype(&arg_dtypes)?;
30+
pub fn try_new(bound: ScalarFn, children: Vec<ArrayRef>, len: usize) -> VortexResult<Self> {
31+
let dtype = bound.return_dtype(&ChildArraysReturnDTypeCtx(&children))?;
3632

3733
vortex_ensure!(
3834
children.iter().all(|c| c.len() == len),
@@ -49,3 +45,14 @@ impl ScalarFnArray {
4945
})
5046
}
5147
}
48+
49+
pub(super) struct ChildArraysReturnDTypeCtx<'a>(pub(super) &'a [ArrayRef]);
50+
impl ReturnDTypeCtx for ChildArraysReturnDTypeCtx<'_> {
51+
fn child_count(&self) -> usize {
52+
self.0.len()
53+
}
54+
55+
fn return_dtype(&self, child_idx: usize) -> VortexResult<DType> {
56+
Ok(self.0[child_idx].dtype().clone())
57+
}
58+
}

vortex-array/src/arrays/scalar_fn/metadata.rs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,10 +3,10 @@
33

44
use vortex_dtype::DType;
55

6-
use crate::expr::BoundExpression;
6+
use crate::expr::ScalarFn;
77

88
#[derive(Clone, Debug)]
99
pub struct ScalarFnMetadata {
10-
pub(super) bound: BoundExpression,
10+
pub(super) bound: ScalarFn,
1111
pub(super) child_dtypes: Vec<DType>,
1212
}

vortex-array/src/arrays/scalar_fn/vtable/mod.rs

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -25,13 +25,14 @@ use vortex_vector::Vector;
2525
use crate::Array;
2626
use crate::ArrayRef;
2727
use crate::IntoArray;
28+
use crate::arrays::scalar_fn::array::ChildArraysReturnDTypeCtx;
2829
use crate::arrays::scalar_fn::array::ScalarFnArray;
2930
use crate::arrays::scalar_fn::metadata::ScalarFnMetadata;
3031
use crate::execution::ExecutionCtx;
3132
use crate::expr;
32-
use crate::expr::BoundExpression;
3333
use crate::expr::ExecutionArgs;
3434
use crate::expr::ExprVTable;
35+
use crate::expr::ScalarFn;
3536
use crate::optimizer::rules::MatchKey;
3637
use crate::optimizer::rules::Matcher;
3738
use crate::serde::ArrayChildren;
@@ -115,9 +116,9 @@ impl VTable for ScalarFnVTable {
115116

116117
#[cfg(debug_assertions)]
117118
{
118-
let child_dtypes: Vec<_> = children.iter().map(|c| c.dtype().clone()).collect();
119+
let ctx = ChildArraysReturnDTypeCtx(&children);
119120
vortex_error::vortex_ensure!(
120-
&metadata.bound.return_dtype(&child_dtypes)? == dtype,
121+
&metadata.bound.return_dtype(&ctx)? == dtype,
121122
"Return dtype mismatch when building ScalarFnArray"
122123
);
123124
}
@@ -162,16 +163,15 @@ pub trait ScalarFnArrayExt: expr::VTable {
162163
options: Self::Options,
163164
children: impl Into<Vec<ArrayRef>>,
164165
) -> VortexResult<ArrayRef> {
165-
let bound = BoundExpression::new_static(self, options);
166+
let bound = ScalarFn::new_static(self, options);
166167

167168
let children = children.into();
168169
vortex_ensure!(
169170
children.iter().all(|c| c.len() == len),
170171
"All child arrays must have the same length as the scalar function array"
171172
);
172173

173-
let child_dtypes = children.iter().map(|c| c.dtype().clone()).collect_vec();
174-
let dtype = bound.return_dtype(&child_dtypes)?;
174+
let dtype = bound.return_dtype(&ChildArraysReturnDTypeCtx(&children))?;
175175

176176
let array_vtable: ArrayVTable = ScalarFnVTable {
177177
vtable: bound.vtable().clone(),

vortex-array/src/expr/display.rs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,8 +5,8 @@ use std::fmt::Display;
55
use std::fmt::Formatter;
66
use std::ops::Deref;
77

8-
use crate::expr::BoundExpression;
98
use crate::expr::Expression;
9+
use crate::expr::ScalarFn;
1010

1111
pub enum DisplayFormat {
1212
Compact,
@@ -19,7 +19,7 @@ impl Display for DisplayTreeExpr<'_> {
1919
fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
2020
pub use termtree::Tree;
2121
fn make_tree(expr: &Expression) -> Result<Tree<String>, std::fmt::Error> {
22-
let bound: &BoundExpression = expr.deref();
22+
let bound: &ScalarFn = expr.deref();
2323
let node_name = format!("{}", bound);
2424

2525
// Get child names for display purposes

vortex-array/src/expr/expression.rs

Lines changed: 45 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -9,18 +9,18 @@ use std::hash::Hash;
99
use std::ops::Deref;
1010
use std::sync::Arc;
1111

12-
use itertools::Itertools;
1312
use vortex_dtype::DType;
1413
use vortex_error::VortexExpect;
1514
use vortex_error::VortexResult;
1615
use vortex_error::vortex_ensure;
1716

1817
use crate::ArrayRef;
18+
use crate::expr::ReturnDTypeCtx;
1919
use crate::expr::Root;
2020
use crate::expr::StatsCatalog;
2121
use crate::expr::VTable;
22-
use crate::expr::bound::BoundExpression;
2322
use crate::expr::display::DisplayTreeExpr;
23+
use crate::expr::scalar_fn::ScalarFn;
2424
use crate::expr::stats::Stat;
2525

2626
/// A node in a Vortex expression tree.
@@ -29,36 +29,39 @@ use crate::expr::stats::Stat;
2929
/// expression consists of an encoding (vtable), heap-allocated metadata, and child expressions.
3030
#[derive(Clone, Debug, PartialEq, Eq, Hash)]
3131
pub struct Expression {
32-
/// The bound expression for this node.
33-
bound: BoundExpression,
32+
/// The scalar function for this node.
33+
scalar_fn: ScalarFn,
3434
/// Any children of this expression.
3535
children: Arc<[Expression]>,
3636
}
3737

3838
impl Deref for Expression {
39-
type Target = BoundExpression;
39+
type Target = ScalarFn;
4040

4141
fn deref(&self) -> &Self::Target {
42-
&self.bound
42+
&self.scalar_fn
4343
}
4444
}
4545

4646
impl Expression {
47-
/// Create a new expression node from a bound expression and its children.
47+
/// Create a new expression node from a scalar function and children.
4848
pub fn try_new(
49-
bound: BoundExpression,
49+
scalar_fn: ScalarFn,
5050
children: impl Into<Arc<[Expression]>>,
5151
) -> VortexResult<Self> {
5252
let children: Arc<[Expression]> = children.into();
5353

5454
vortex_ensure!(
55-
bound.signature().arity().matches(children.len()),
55+
scalar_fn.signature().arity().matches(children.len()),
5656
"Expression arity mismatch: expected {} children but got {}",
57-
bound.signature().arity(),
57+
scalar_fn.signature().arity(),
5858
children.len()
5959
);
6060

61-
Ok(Self { bound, children })
61+
Ok(Self {
62+
scalar_fn,
63+
children,
64+
})
6265
}
6366

6467
/// Returns true if this expression is of the given vtable type.
@@ -102,24 +105,18 @@ impl Expression {
102105

103106
/// Computes the return dtype of this expression given the input dtype.
104107
pub fn return_dtype(&self, scope: &DType) -> VortexResult<DType> {
105-
if self.is::<Root>() {
106-
return Ok(scope.clone());
107-
}
108-
109-
let dtypes: Vec<_> = self
110-
.children
111-
.iter()
112-
.map(|c| c.return_dtype(scope))
113-
.try_collect()?;
114-
self.bound.return_dtype(&dtypes)
108+
self.scalar_fn.return_dtype(&ExpressionReturnDTypeCtx {
109+
expr: self,
110+
scope_dtype: scope,
111+
})
115112
}
116113

117114
/// Evaluates the expression in the given scope, returning an array.
118115
pub fn evaluate(&self, scope: &ArrayRef) -> VortexResult<ArrayRef> {
119116
if self.is::<Root>() {
120117
return Ok(scope.clone());
121118
}
122-
self.bound.evaluate(self, scope)
119+
self.scalar_fn.evaluate(self, scope)
123120
}
124121

125122
/// An expression over zone-statistics which implies all records in the zone evaluate to false.
@@ -237,3 +234,29 @@ impl Display for Expression {
237234
self.fmt_sql(f)
238235
}
239236
}
237+
238+
pub(super) struct ExpressionReturnDTypeCtx<'a> {
239+
pub(super) expr: &'a Expression,
240+
pub(super) scope_dtype: &'a DType,
241+
}
242+
243+
impl ReturnDTypeCtx for ExpressionReturnDTypeCtx<'_> {
244+
fn child_count(&self) -> usize {
245+
self.expr.children().len()
246+
}
247+
248+
fn return_dtype(&self, child_idx: usize) -> VortexResult<DType> {
249+
let child = &self.expr.children()[child_idx];
250+
251+
if child.is::<Root>() {
252+
return Ok(self.scope_dtype.clone());
253+
}
254+
255+
let ctx = ExpressionReturnDTypeCtx {
256+
expr: child,
257+
scope_dtype: self.scope_dtype,
258+
};
259+
let child_fn: &ScalarFn = child.deref();
260+
child_fn.return_dtype(&ctx)
261+
}
262+
}

vortex-array/src/expr/exprs/between.rs

Lines changed: 12 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@ use crate::expr::Arity;
1919
use crate::expr::ChildName;
2020
use crate::expr::ExecutionArgs;
2121
use crate::expr::ExprId;
22+
use crate::expr::ReturnDTypeCtx;
2223
use crate::expr::StatsCatalog;
2324
use crate::expr::VTable;
2425
use crate::expr::VTableExt;
@@ -112,19 +113,23 @@ impl VTable for Between {
112113
)
113114
}
114115

115-
fn return_dtype(&self, _options: &Self::Options, arg_dtypes: &[DType]) -> VortexResult<DType> {
116-
let arr_dt = &arg_dtypes[0];
117-
let lower_dt = &arg_dtypes[1];
118-
let upper_dt = &arg_dtypes[2];
119-
120-
if !arr_dt.eq_ignore_nullability(lower_dt) {
116+
fn return_dtype(
117+
&self,
118+
_options: &Self::Options,
119+
ctx: &dyn ReturnDTypeCtx,
120+
) -> VortexResult<DType> {
121+
let arr_dt = ctx.return_dtype(0)?;
122+
let lower_dt = ctx.return_dtype(1)?;
123+
let upper_dt = ctx.return_dtype(2)?;
124+
125+
if !arr_dt.eq_ignore_nullability(&lower_dt) {
121126
vortex_bail!(
122127
"Array dtype {} does not match lower dtype {}",
123128
arr_dt,
124129
lower_dt
125130
);
126131
}
127-
if !arr_dt.eq_ignore_nullability(upper_dt) {
132+
if !arr_dt.eq_ignore_nullability(&upper_dt) {
128133
vortex_bail!(
129134
"Array dtype {} does not match upper dtype {}",
130135
arr_dt,

vortex-array/src/expr/exprs/binary.rs

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@ use crate::expr::Arity;
2424
use crate::expr::ChildName;
2525
use crate::expr::ExecutionArgs;
2626
use crate::expr::ExprId;
27+
use crate::expr::ReturnDTypeCtx;
2728
use crate::expr::StatsCatalog;
2829
use crate::expr::VTable;
2930
use crate::expr::VTableExt;
@@ -80,12 +81,12 @@ impl VTable for Binary {
8081
write!(f, ")")
8182
}
8283

83-
fn return_dtype(&self, operator: &Operator, arg_dtypes: &[DType]) -> VortexResult<DType> {
84-
let lhs = &arg_dtypes[0];
85-
let rhs = &arg_dtypes[1];
84+
fn return_dtype(&self, operator: &Operator, ctx: &dyn ReturnDTypeCtx) -> VortexResult<DType> {
85+
let lhs = ctx.return_dtype(0)?;
86+
let rhs = ctx.return_dtype(1)?;
8687

8788
if operator.is_arithmetic() {
88-
if lhs.is_primitive() && lhs.eq_ignore_nullability(rhs) {
89+
if lhs.is_primitive() && lhs.eq_ignore_nullability(&rhs) {
8990
return Ok(lhs.with_nullability(lhs.nullability() | rhs.nullability()));
9091
}
9192
vortex_bail!(

vortex-array/src/expr/exprs/cast.rs

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@ use crate::expr::Arity;
1818
use crate::expr::ChildName;
1919
use crate::expr::ExecutionArgs;
2020
use crate::expr::ExprId;
21+
use crate::expr::ReturnDTypeCtx;
2122
use crate::expr::StatsCatalog;
2223
use crate::expr::VTable;
2324
use crate::expr::VTableExt;
@@ -69,7 +70,7 @@ impl VTable for Cast {
6970
write!(f, ")")
7071
}
7172

72-
fn return_dtype(&self, dtype: &DType, _arg_dtypes: &[DType]) -> VortexResult<DType> {
73+
fn return_dtype(&self, dtype: &DType, _ctx: &dyn ReturnDTypeCtx) -> VortexResult<DType> {
7374
Ok(dtype.clone())
7475
}
7576

vortex-array/src/expr/exprs/dynamic.rs

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@ use crate::expr::ChildName;
2828
use crate::expr::ExecutionArgs;
2929
use crate::expr::ExprId;
3030
use crate::expr::Expression;
31+
use crate::expr::ReturnDTypeCtx;
3132
use crate::expr::StatsCatalog;
3233
use crate::expr::VTable;
3334
use crate::expr::VTableExt;
@@ -76,10 +77,10 @@ impl VTable for DynamicComparison {
7677
fn return_dtype(
7778
&self,
7879
dynamic: &DynamicComparisonExpr,
79-
arg_dtypes: &[DType],
80+
ctx: &dyn ReturnDTypeCtx,
8081
) -> VortexResult<DType> {
81-
let lhs = &arg_dtypes[0];
82-
if !dynamic.rhs.dtype.eq_ignore_nullability(lhs) {
82+
let lhs = ctx.return_dtype(0)?;
83+
if !dynamic.rhs.dtype.eq_ignore_nullability(&lhs) {
8384
vortex_bail!(
8485
"Incompatible dtypes for dynamic comparison: expected {} (ignore nullability) but got {}",
8586
&dynamic.rhs.dtype,

vortex-array/src/expr/exprs/get_item.rs

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@ use crate::expr::ExecutionArgs;
2626
use crate::expr::ExprId;
2727
use crate::expr::Expression;
2828
use crate::expr::Pack;
29+
use crate::expr::ReturnDTypeCtx;
2930
use crate::expr::SimplifyCtx;
3031
use crate::expr::StatsCatalog;
3132
use crate::expr::VTable;
@@ -77,8 +78,12 @@ impl VTable for GetItem {
7778
write!(f, ".{}", field_name)
7879
}
7980

80-
fn return_dtype(&self, field_name: &FieldName, arg_dtypes: &[DType]) -> VortexResult<DType> {
81-
let struct_dtype = &arg_dtypes[0];
81+
fn return_dtype(
82+
&self,
83+
field_name: &FieldName,
84+
ctx: &dyn ReturnDTypeCtx,
85+
) -> VortexResult<DType> {
86+
let struct_dtype = ctx.return_dtype(0)?;
8287
let field_dtype = struct_dtype
8388
.as_struct_fields_opt()
8489
.and_then(|st| st.field(field_name))

0 commit comments

Comments
 (0)