Skip to content

Commit 3daa075

Browse files
committed
Expressions
Signed-off-by: Nicholas Gates <[email protected]>
1 parent ef98cb1 commit 3daa075

File tree

10 files changed

+65
-46
lines changed

10 files changed

+65
-46
lines changed

vortex-array/src/expr/bound.rs

Lines changed: 20 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -7,15 +7,18 @@ use std::fmt::Display;
77
use std::fmt::Formatter;
88
use std::hash::Hash;
99
use std::hash::Hasher;
10+
use std::ops::Deref;
1011

1112
use vortex_dtype::DType;
1213
use vortex_error::VortexResult;
1314
use vortex_utils::debug_with::DebugWith;
1415
use vortex_vector::Datum;
1516

17+
use crate::ArrayRef;
1618
use crate::expr::ExecutionArgs;
1719
use crate::expr::ExprId;
1820
use crate::expr::ExprVTable;
21+
use crate::expr::Expression;
1922
use crate::expr::VTable;
2023
use crate::expr::options::ExpressionOptions;
2124
use crate::expr::signature::ExpressionSignature;
@@ -67,36 +70,44 @@ impl BoundExpression {
6770
pub fn options(&self) -> ExpressionOptions<'_> {
6871
ExpressionOptions {
6972
vtable: &self.vtable,
70-
options: self.options.as_ref(),
73+
options: self.options.deref(),
7174
}
7275
}
7376

7477
/// Signature information for this expression.
7578
pub fn signature(&self) -> ExpressionSignature<'_> {
7679
ExpressionSignature {
7780
vtable: &self.vtable,
78-
options: self.options.as_ref(),
81+
options: self.options.deref(),
7982
}
8083
}
8184

8285
/// Compute the return [`DType`] of this expression given the input argument types.
8386
pub fn return_dtype(&self, arg_types: &[DType]) -> VortexResult<DType> {
8487
self.vtable
8588
.as_dyn()
86-
.return_dtype(self.options.as_ref(), arg_types)
89+
.return_dtype(self.options.deref(), arg_types)
90+
}
91+
92+
/// Evaluate the expression, returning an ArrayRef.
93+
///
94+
/// NOTE: this function will soon be deprecated as all expressions will evaluate trivially
95+
/// into an ExprArray.
96+
pub fn evaluate(&self, expr: &Expression, scope: &ArrayRef) -> VortexResult<ArrayRef> {
97+
self.vtable.as_dyn().evaluate(expr, scope)
8798
}
8899

89100
/// Execute the expression given the input arguments.
90101
pub fn execute(&self, ctx: ExecutionArgs) -> VortexResult<Datum> {
91-
self.vtable.as_dyn().execute(self.options.as_ref(), ctx)
102+
self.vtable.as_dyn().execute(self.options.deref(), ctx)
92103
}
93104
}
94105

95106
impl Clone for BoundExpression {
96107
fn clone(&self) -> Self {
97108
BoundExpression {
98109
vtable: self.vtable.clone(),
99-
options: self.vtable.as_dyn().options_clone(self.options.as_ref()),
110+
options: self.vtable.as_dyn().options_clone(self.options.deref()),
100111
}
101112
}
102113
}
@@ -110,7 +121,7 @@ impl Debug for BoundExpression {
110121
&DebugWith(|fmt| {
111122
self.vtable
112123
.as_dyn()
113-
.options_debug(self.options.as_ref(), fmt)
124+
.options_debug(self.options.deref(), fmt)
114125
}),
115126
)
116127
.finish()
@@ -122,7 +133,7 @@ impl Display for BoundExpression {
122133
write!(f, "{}(", self.vtable.id())?;
123134
self.vtable
124135
.as_dyn()
125-
.options_display(self.options.as_ref(), f)?;
136+
.options_display(self.options.deref(), f)?;
126137
write!(f, ")")
127138
}
128139
}
@@ -133,7 +144,7 @@ impl PartialEq for BoundExpression {
133144
&& self
134145
.vtable
135146
.as_dyn()
136-
.options_eq(self.options.as_ref(), other.options.as_ref())
147+
.options_eq(self.options.deref(), other.options.deref())
137148
}
138149
}
139150
impl Eq for BoundExpression {}
@@ -143,6 +154,6 @@ impl Hash for BoundExpression {
143154
self.vtable.hash(state);
144155
self.vtable
145156
.as_dyn()
146-
.options_hash(self.options.as_ref(), state);
157+
.options_hash(self.options.deref(), state);
147158
}
148159
}

vortex-array/src/expr/display.rs

Lines changed: 11 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,9 @@
33

44
use std::fmt::Display;
55
use std::fmt::Formatter;
6+
use std::ops::Deref;
67

8+
use crate::expr::BoundExpression;
79
use crate::expr::Expression;
810

911
pub enum DisplayFormat {
@@ -17,7 +19,8 @@ impl Display for DisplayTreeExpr<'_> {
1719
fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
1820
pub use termtree::Tree;
1921
fn make_tree(expr: &Expression) -> Result<Tree<String>, std::fmt::Error> {
20-
let node_name = format!("{}", expr);
22+
let bound: &BoundExpression = expr.deref();
23+
let node_name = format!("{}", bound);
2124

2225
// Get child names for display purposes
2326
let child_names = (0..expr.children().len()).map(|i| expr.signature().child_name(i));
@@ -96,10 +99,10 @@ mod tests {
9699
use insta::assert_snapshot;
97100

98101
let root_expr = root();
99-
assert_snapshot!(root_expr.display_tree().to_string(), @"vortex.root");
102+
assert_snapshot!(root_expr.display_tree().to_string(), @"vortex.root()");
100103

101104
let lit_expr = lit(42);
102-
assert_snapshot!(lit_expr.display_tree().to_string(), @"vortex.literal 42i32");
105+
assert_snapshot!(lit_expr.display_tree().to_string(), @"vortex.literal(42i32)");
103106

104107
let get_item_expr = get_item("my_field", root());
105108
assert_snapshot!(get_item_expr.display_tree().to_string(), @r#"
@@ -123,24 +126,24 @@ mod tests {
123126
vortex.binary and
124127
├── lhs: vortex.binary =
125128
│ ├── lhs: vortex.get_item "name"
126-
│ │ └── input: vortex.root
129+
│ │ └── input: $
127130
│ └── rhs: vortex.literal "alice"
128131
└── rhs: vortex.binary >
129132
├── lhs: vortex.get_item "age"
130-
│ └── input: vortex.root
133+
│ └── input: $
131134
└── rhs: vortex.literal 18i32
132135
"#);
133136

134137
let select_expr = select(["name", "age"], root());
135138
assert_snapshot!(select_expr.display_tree().to_string(), @r"
136139
vortex.select include={name, age}
137-
└── child: vortex.root
140+
└── child: $
138141
");
139142

140143
let select_exclude_expr = select_exclude(["internal_id", "metadata"], root());
141144
assert_snapshot!(select_exclude_expr.display_tree().to_string(), @r"
142145
vortex.select exclude={internal_id, metadata}
143-
└── child: vortex.root
146+
└── child: $
144147
");
145148

146149
let cast_expr = cast(
@@ -150,7 +153,7 @@ mod tests {
150153
assert_snapshot!(cast_expr.display_tree().to_string(), @r#"
151154
vortex.cast i64
152155
└── input: vortex.get_item "value"
153-
└── input: vortex.root
156+
└── input: $
154157
"#);
155158

156159
let not_expr = not(eq(get_item("active", root()), lit(true)));

vortex-array/src/expr/expression.rs

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@ use vortex_error::VortexResult;
1616
use vortex_error::vortex_ensure;
1717

1818
use crate::ArrayRef;
19+
use crate::expr::Root;
1920
use crate::expr::StatsCatalog;
2021
use crate::expr::VTable;
2122
use crate::expr::bound::BoundExpression;
@@ -101,6 +102,10 @@ impl Expression {
101102

102103
/// Computes the return dtype of this expression given the input dtype.
103104
pub fn return_dtype(&self, scope: &DType) -> VortexResult<DType> {
105+
if self.is::<Root>() {
106+
return Ok(scope.clone());
107+
}
108+
104109
let dtypes: Vec<_> = self
105110
.children
106111
.iter()
@@ -110,8 +115,11 @@ impl Expression {
110115
}
111116

112117
/// Evaluates the expression in the given scope, returning an array.
113-
pub fn evaluate(&self, _scope: &ArrayRef) -> VortexResult<ArrayRef> {
114-
todo!("Return an ExprArray")
118+
pub fn evaluate(&self, scope: &ArrayRef) -> VortexResult<ArrayRef> {
119+
if self.is::<Root>() {
120+
return Ok(scope.clone());
121+
}
122+
self.bound.evaluate(self, scope)
115123
}
116124

117125
/// An expression over zone-statistics which implies all records in the zone evaluate to false.

vortex-array/src/expr/exprs/between.rs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -117,14 +117,14 @@ impl VTable for Between {
117117
let lower_dt = &arg_dtypes[1];
118118
let upper_dt = &arg_dtypes[2];
119119

120-
if !arr_dt.eq_ignore_nullability(&lower_dt) {
120+
if !arr_dt.eq_ignore_nullability(lower_dt) {
121121
vortex_bail!(
122122
"Array dtype {} does not match lower dtype {}",
123123
arr_dt,
124124
lower_dt
125125
);
126126
}
127-
if !arr_dt.eq_ignore_nullability(&upper_dt) {
127+
if !arr_dt.eq_ignore_nullability(upper_dt) {
128128
vortex_bail!(
129129
"Array dtype {} does not match upper dtype {}",
130130
arr_dt,

vortex-array/src/expr/exprs/binary.rs

Lines changed: 2 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -52,7 +52,7 @@ impl VTable for Binary {
5252

5353
fn deserialize(&self, metadata: &[u8]) -> VortexResult<Self::Options> {
5454
let opts = pb::BinaryOpts::decode(metadata)?;
55-
Ok(Operator::try_from(opts.op)?)
55+
Operator::try_from(opts.op)
5656
}
5757

5858
fn arity(&self, _options: &Self::Options) -> Arity {
@@ -85,7 +85,7 @@ impl VTable for Binary {
8585
let rhs = &arg_dtypes[1];
8686

8787
if operator.is_arithmetic() {
88-
if lhs.is_primitive() && lhs.eq_ignore_nullability(&rhs) {
88+
if lhs.is_primitive() && lhs.eq_ignore_nullability(rhs) {
8989
return Ok(lhs.with_nullability(lhs.nullability() | rhs.nullability()));
9090
}
9191
vortex_bail!(
@@ -601,15 +601,6 @@ mod tests {
601601
);
602602
}
603603

604-
#[test]
605-
fn test_debug_print() {
606-
let expr = gt(lit(1), lit(2));
607-
assert_eq!(
608-
format!("{expr:?}"),
609-
"Expression { vtable: vortex.binary, data: >, children: [Expression { vtable: vortex.literal, data: 1i32, children: [] }, Expression { vtable: vortex.literal, data: 2i32, children: [] }] }"
610-
);
611-
}
612-
613604
#[test]
614605
fn test_display_print() {
615606
let expr = gt(lit(1), lit(2));

vortex-array/src/expr/exprs/dynamic.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -79,7 +79,7 @@ impl VTable for DynamicComparison {
7979
arg_dtypes: &[DType],
8080
) -> VortexResult<DType> {
8181
let lhs = &arg_dtypes[0];
82-
if !dynamic.rhs.dtype.eq_ignore_nullability(&lhs) {
82+
if !dynamic.rhs.dtype.eq_ignore_nullability(lhs) {
8383
vortex_bail!(
8484
"Incompatible dtypes for dynamic comparison: expected {} (ignore nullability) but got {}",
8585
&dynamic.rhs.dtype,

vortex-array/src/expr/exprs/literal.rs

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -47,11 +47,10 @@ impl VTable for Literal {
4747

4848
fn deserialize(&self, metadata: &[u8]) -> VortexResult<Self::Options> {
4949
let ops = pb::LiteralOpts::decode(metadata)?;
50-
Ok(ops
51-
.value
50+
ops.value
5251
.as_ref()
5352
.ok_or_else(|| vortex_err!("Literal metadata missing value"))?
54-
.try_into()?)
53+
.try_into()
5554
}
5655

5756
fn arity(&self, _options: &Self::Options) -> Arity {
@@ -154,8 +153,8 @@ impl VTable for Literal {
154153
///
155154
/// let number = lit(34i32);
156155
///
157-
/// let literal = number.as_::<Literal>();
158-
/// assert_eq!(literal.data(), &Scalar::primitive(34i32, Nullability::NonNullable));
156+
/// let scalar = number.as_::<Literal>();
157+
/// assert_eq!(scalar, &Scalar::primitive(34i32, Nullability::NonNullable));
159158
/// ```
160159
pub fn lit(value: impl Into<Scalar>) -> Expression {
161160
Literal.new_expr(value.into(), [])

vortex-array/src/expr/mod.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -92,7 +92,7 @@ pub struct ExactExpr(pub Expression);
9292
impl PartialEq for ExactExpr {
9393
fn eq(&self, other: &Self) -> bool {
9494
self.0.id() == other.0.id()
95-
&& ptr::addr_eq(&*self.0.options().as_any(), &*other.0.options().as_any())
95+
&& ptr::addr_eq(self.0.options().as_any(), other.0.options().as_any())
9696
}
9797
}
9898
impl Eq for ExactExpr {}

vortex-array/src/expr/simplify.rs

Lines changed: 11 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@ impl Expression {
2222
let children: Vec<_> = expr
2323
.children()
2424
.iter()
25-
.map(|c| inner(c, &cache))
25+
.map(|c| inner(c, cache))
2626
.try_collect()?;
2727

2828
if children.iter().any(|c| c.is_some()) {
@@ -35,11 +35,18 @@ impl Expression {
3535

3636
let new_expr = expr.clone().with_children(new_children)?;
3737

38-
// Then we simplify the new expression
39-
new_expr.vtable().as_dyn().simplify(&new_expr, cache)
38+
// Then we simplify the new expression, and since we rewrote the expression we must
39+
// always return a new expression (even if simplification returns None)
40+
Ok(Some(
41+
new_expr
42+
.vtable()
43+
.as_dyn()
44+
.simplify(&new_expr, cache)?
45+
.unwrap_or(new_expr),
46+
))
4047
} else {
4148
// Otherwise, we attempt to simplify the current expression
42-
expr.vtable().as_dyn().simplify(&expr, cache)
49+
expr.vtable().as_dyn().simplify(expr, cache)
4350
}
4451
}
4552

vortex-array/src/expr/vtable.rs

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -338,7 +338,7 @@ impl<V: VTable> DynExprVTable for VTableAdapter<V> {
338338
}
339339

340340
fn options_deserialize(&self, bytes: &[u8]) -> VortexResult<Box<dyn Any + Send + Sync>> {
341-
Ok(Box::new(V::deserialize(&self.0, bytes)))
341+
Ok(Box::new(V::deserialize(&self.0, bytes)?))
342342
}
343343

344344
fn options_clone(&self, options: &dyn Any) -> Box<dyn Any + Send + Sync> {
@@ -357,11 +357,11 @@ impl<V: VTable> DynExprVTable for VTableAdapter<V> {
357357
}
358358

359359
fn options_display(&self, options: &dyn Any, fmt: &mut Formatter<'_>) -> fmt::Result {
360-
write!(fmt, "{}", downcast::<V>(options))
360+
Display::fmt(downcast::<V>(options), fmt)
361361
}
362362

363363
fn options_debug(&self, options: &dyn Any, fmt: &mut Formatter<'_>) -> fmt::Result {
364-
write!(fmt, "{:?}", downcast::<V>(options))
364+
Debug::fmt(downcast::<V>(options), fmt)
365365
}
366366

367367
fn return_dtype(&self, options: &dyn Any, arg_dtypes: &[DType]) -> VortexResult<DType> {

0 commit comments

Comments
 (0)