Skip to content

Commit 31802fc

Browse files
committed
Expressions
Signed-off-by: Nicholas Gates <[email protected]>
1 parent e6b27de commit 31802fc

File tree

19 files changed

+115
-744
lines changed

19 files changed

+115
-744
lines changed

Cargo.lock

Lines changed: 1 addition & 0 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

vortex-array/src/expr/mod.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,7 @@ pub mod session;
3939
mod signature;
4040
mod simplify;
4141
pub mod stats;
42-
// pub mod transform;
42+
pub mod transform;
4343
pub mod traversal;
4444
mod vtable;
4545

vortex-array/src/expr/proto.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -93,7 +93,7 @@ mod tests {
9393
eq(lit(1), root()),
9494
);
9595

96-
let s_expr = (&expr).serialize_proto().unwrap();
96+
let s_expr = expr.serialize_proto().unwrap();
9797
let buf = s_expr.encode_to_vec();
9898
let s_expr = pb::Expr::decode(buf.as_slice()).unwrap();
9999
let deser_expr = deserialize_expr_proto(&s_expr, &registry).unwrap();

vortex-array/src/expr/simplify.rs

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@ use vortex_utils::aliases::hash_map::HashMap;
1212
use crate::expr::Expression;
1313
use crate::expr::Root;
1414
use crate::expr::SimplifyCtx;
15+
use crate::expr::transform::match_between::find_between;
1516

1617
impl Expression {
1718
/// Simplify the expression, returning a potentially new expression.
@@ -55,7 +56,15 @@ impl Expression {
5556
dtype_cache: RefCell::new(HashMap::new()),
5657
};
5758

58-
inner(self, &cache)?.map_or_else(|| Ok(self.clone()), Ok)
59+
let simplified = inner(self, &cache)?.unwrap_or_else(|| self.clone());
60+
61+
// TODO(ngates): remove the "between" optimization, or rewrite it to not always convert
62+
// to CNF?
63+
let simplified = find_between(simplified);
64+
65+
// TODO(ngates): perform constant folding by executing expressions with all-literal
66+
// children here
67+
Ok(simplified)
5968
}
6069
}
6170

vortex-array/src/expr/transform/match_between.rs

Lines changed: 28 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -49,65 +49,69 @@ pub fn find_between(expr: Expression) -> Expression {
4949
}
5050

5151
fn maybe_match(lhs: &Expression, rhs: &Expression) -> Option<Expression> {
52-
let (Some(lhs_e), Some(rhs_e)) = (lhs.as_opt::<Binary>(), rhs.as_opt::<Binary>()) else {
52+
let (Some(lhs_op), Some(rhs_op)) = (lhs.as_opt::<Binary>(), rhs.as_opt::<Binary>()) else {
5353
return None;
5454
};
5555

56+
// Extract the grandchildren
57+
let lhs_lhs = lhs.child(0);
58+
let lhs_rhs = lhs.child(1);
59+
let rhs_lhs = rhs.child(0);
60+
let rhs_rhs = rhs.child(1);
61+
5662
// Cannot compare to self
57-
if lhs_e.lhs().eq(lhs_e.rhs()) || rhs_e.lhs().eq(rhs_e.rhs()) {
63+
if lhs_lhs.eq(lhs_rhs) || rhs_lhs.eq(rhs_rhs) {
5864
return None;
5965
}
6066

6167
// First, get both halves to have GetItem on the left
62-
let lhs = match (lhs_e.lhs().is::<GetItem>(), lhs_e.rhs().is::<GetItem>()) {
68+
let lhs = match (lhs_lhs.is::<GetItem>(), lhs_rhs.is::<GetItem>()) {
6369
(true, false) => lhs.clone(),
64-
(false, true) => Binary.new_expr(
65-
lhs_e.operator().swap()?,
66-
[lhs_e.rhs().clone(), lhs_e.lhs().clone()],
67-
),
70+
(false, true) => Binary.new_expr(lhs_op.swap()?, [lhs_rhs.clone(), lhs_lhs.clone()]),
6871
_ => return None,
6972
};
70-
let lhs_e = lhs.as_::<Binary>();
73+
let lhs_op = lhs.as_::<Binary>();
74+
let lhs_lhs = lhs.child(0);
7175

72-
let rhs = match (rhs_e.lhs().is::<GetItem>(), rhs_e.rhs().is::<GetItem>()) {
76+
let rhs = match (rhs_lhs.is::<GetItem>(), rhs_rhs.is::<GetItem>()) {
7377
(true, false) => rhs.clone(),
74-
(false, true) => Binary.new_expr(
75-
rhs_e.operator().swap()?,
76-
[rhs_e.rhs().clone(), rhs_e.lhs().clone()],
77-
),
78+
(false, true) => Binary.new_expr(rhs_op.swap()?, [rhs_rhs.clone(), rhs_lhs.clone()]),
7879
_ => return None,
7980
};
80-
let rhs_e = rhs.as_::<Binary>();
81+
let rhs_op = rhs.as_::<Binary>();
82+
let rhs_lhs = rhs.child(0);
8183

8284
// Both conjuncts must reference the same GetItem column
83-
if !lhs_e.lhs().eq(rhs_e.lhs()) {
85+
if !lhs_lhs.eq(rhs_lhs) {
8486
return None;
8587
}
8688

87-
let target = lhs_e.lhs().clone();
89+
let target = lhs_lhs.clone();
8890

8991
// Find the lower bound
90-
let (lower, upper) = match (lhs_e.operator(), rhs_e.operator()) {
92+
let (lower, upper) = match (lhs_op, rhs_op) {
9193
(Operator::Lt | Operator::Lte, Operator::Gt | Operator::Gte) => (rhs, lhs),
9294
(Operator::Gt | Operator::Gte, Operator::Lt | Operator::Lte) => (lhs, rhs),
9395
_ => return None,
9496
};
95-
let lower_e = lower.as_::<Binary>();
96-
let upper_e = upper.as_::<Binary>();
97+
let lower_op = lower.as_::<Binary>();
98+
let lower_rhs = lower.child(1);
99+
let upper_op = upper.as_::<Binary>();
100+
let upper_rhs = upper.child(1);
97101

98102
// Ensure bounds are literals
99-
let _ = lower_e.rhs().as_opt::<Literal>()?;
100-
let _ = upper_e.rhs().as_opt::<Literal>()?;
103+
let _ = lower_rhs.as_opt::<Literal>()?;
104+
let _ = upper_rhs.as_opt::<Literal>()?;
101105

102-
let lower_strict = is_strict_comparison(lower_e.operator())?;
103-
let upper_strict = is_strict_comparison(upper_e.operator())?;
106+
let lower_strict = is_strict_comparison(*lower_op)?;
107+
let upper_strict = is_strict_comparison(*upper_op)?;
104108

105109
Some(Between.new_expr(
106110
BetweenOptions {
107111
lower_strict,
108112
upper_strict,
109113
},
110-
[target, lower_e.rhs().clone(), upper_e.rhs().clone()],
114+
[target, lower_rhs.clone(), upper_rhs.clone()],
111115
))
112116
}
113117

vortex-array/src/expr/transform/mod.rs

Lines changed: 0 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -3,13 +3,8 @@
33

44
//! A collection of transformations that can be applied to a [`crate::expr::Expression`].
55
pub(crate) mod match_between;
6-
mod optimizer;
76
mod partition;
87
mod replace;
9-
pub mod rules;
10-
mod simplify;
11-
mod simplify_typed;
128

13-
pub use optimizer::*;
149
pub use partition::*;
1510
pub use replace::*;

vortex-array/src/expr/transform/optimizer.rs

Lines changed: 0 additions & 36 deletions
This file was deleted.

vortex-array/src/expr/transform/partition.rs

Lines changed: 12 additions & 48 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,6 @@ use crate::expr::analysis::descendent_annotations;
2222
use crate::expr::exprs::get_item::get_item;
2323
use crate::expr::exprs::pack::pack;
2424
use crate::expr::exprs::root::root;
25-
use crate::expr::transform::ExprOptimizer;
2625
use crate::expr::traversal::NodeExt;
2726
use crate::expr::traversal::NodeRewriter;
2827
use crate::expr::traversal::Transformed;
@@ -43,7 +42,6 @@ pub fn partition<A: AnnotationFn>(
4342
expr: Expression,
4443
scope: &DType,
4544
annotate_fn: A,
46-
optimizer: &ExprOptimizer,
4745
) -> VortexResult<PartitionedExpr<A::Annotation>>
4846
where
4947
A::Annotation: Display,
@@ -73,7 +71,7 @@ where
7371
Nullability::NonNullable,
7472
);
7573

76-
let expr = optimizer.optimize_typed(expr.clone(), scope)?;
74+
let expr = expr.simplify(scope)?;
7775
let expr_dtype = expr.return_dtype(scope)?;
7876

7977
partitions.push(expr);
@@ -91,7 +89,7 @@ where
9189
);
9290

9391
Ok(PartitionedExpr {
94-
root: optimizer.optimize_typed(root, &root_scope)?,
92+
root: root.simplify(&root_scope)?,
9593
partitions: partitions.into_boxed_slice(),
9694
partition_names,
9795
partition_dtypes: partition_dtypes.into_boxed_slice(),
@@ -222,9 +220,7 @@ mod tests {
222220
use crate::expr::exprs::pack::pack;
223221
use crate::expr::exprs::root::root;
224222
use crate::expr::exprs::select::select;
225-
use crate::expr::session::ExprSession;
226223
use crate::expr::transform::replace::replace_root_fields;
227-
use crate::expr::transform::simplify_typed::simplify_typed;
228224

229225
#[fixture]
230226
fn dtype() -> DType {
@@ -247,48 +243,34 @@ mod tests {
247243
#[rstest]
248244
fn test_expr_top_level_ref(dtype: DType) {
249245
let fields = dtype.as_struct_fields_opt().unwrap();
250-
let session = ExprSession::default();
251-
let optimizer = ExprOptimizer::new(&session);
252246

253247
let expr = root();
254-
let partitioned = partition(
255-
expr.clone(),
256-
&dtype,
257-
annotate_scope_access(fields),
258-
&optimizer,
259-
)
260-
.unwrap();
248+
let partitioned = partition(expr.clone(), &dtype, annotate_scope_access(fields)).unwrap();
261249

262250
// An un-expanded root expression is annotated by all fields, but since it is a single node
263251
assert_eq!(partitioned.partitions.len(), 0);
264252
assert_eq!(&partitioned.root, &root());
265253

266254
// Instead, callers must expand the root expression themselves.
267255
let expr = replace_root_fields(expr, fields);
268-
let partitioned =
269-
partition(expr, &dtype, annotate_scope_access(fields), &optimizer).unwrap();
256+
let partitioned = partition(expr, &dtype, annotate_scope_access(fields)).unwrap();
270257

271258
assert_eq!(partitioned.partitions.len(), fields.names().len());
272259
}
273260

274261
#[rstest]
275262
fn test_expr_top_level_ref_get_item_and_split(dtype: DType) {
276263
let fields = dtype.as_struct_fields_opt().unwrap();
277-
let session = ExprSession::default();
278-
let optimizer = ExprOptimizer::new(&session);
279264

280265
let expr = get_item("y", get_item("a", root()));
281266

282-
let partitioned =
283-
partition(expr, &dtype, annotate_scope_access(fields), &optimizer).unwrap();
267+
let partitioned = partition(expr, &dtype, annotate_scope_access(fields)).unwrap();
284268
assert_eq!(&partitioned.root, &get_item("a_0", get_item("a", root())));
285269
}
286270

287271
#[rstest]
288272
fn test_expr_top_level_ref_get_item_and_split_pack(dtype: DType) {
289273
let fields = dtype.as_struct_fields_opt().unwrap();
290-
let session = ExprSession::default();
291-
let optimizer = ExprOptimizer::new(&session);
292274

293275
let expr = pack(
294276
[
@@ -298,17 +280,11 @@ mod tests {
298280
],
299281
NonNullable,
300282
);
301-
let partitioned =
302-
partition(expr, &dtype, annotate_scope_access(fields), &optimizer).unwrap();
283+
let partitioned = partition(expr, &dtype, annotate_scope_access(fields)).unwrap();
303284

304285
let split_a = partitioned.find_partition(&"a".into()).unwrap();
305286
assert_eq!(
306-
&simplify_typed(
307-
split_a.clone(),
308-
&dtype,
309-
ExprSession::default().rewrite_rules()
310-
)
311-
.unwrap(),
287+
&split_a.simplify(&dtype).unwrap(),
312288
&pack(
313289
[
314290
("a_0", get_item("x", get_item("a", root()))),
@@ -322,12 +298,9 @@ mod tests {
322298
#[rstest]
323299
fn test_expr_top_level_ref_get_item_add(dtype: DType) {
324300
let fields = dtype.as_struct_fields_opt().unwrap();
325-
let session = ExprSession::default();
326-
let optimizer = ExprOptimizer::new(&session);
327301

328302
let expr = and(get_item("y", get_item("a", root())), lit(1));
329-
let partitioned =
330-
partition(expr, &dtype, annotate_scope_access(fields), &optimizer).unwrap();
303+
let partitioned = partition(expr, &dtype, annotate_scope_access(fields)).unwrap();
331304

332305
// Whole expr is a single split
333306
assert_eq!(partitioned.partitions.len(), 1);
@@ -336,12 +309,9 @@ mod tests {
336309
#[rstest]
337310
fn test_expr_top_level_ref_get_item_add_cannot_split(dtype: DType) {
338311
let fields = dtype.as_struct_fields_opt().unwrap();
339-
let session = ExprSession::default();
340-
let optimizer = ExprOptimizer::new(&session);
341312

342313
let expr = and(get_item("y", get_item("a", root())), get_item("b", root()));
343-
let partitioned =
344-
partition(expr, &dtype, annotate_scope_access(fields), &optimizer).unwrap();
314+
let partitioned = partition(expr, &dtype, annotate_scope_access(fields)).unwrap();
345315

346316
// One for id.a and id.b
347317
assert_eq!(partitioned.partitions.len(), 2);
@@ -351,16 +321,13 @@ mod tests {
351321
#[rstest]
352322
fn test_expr_partition_many_occurrences_of_field(dtype: DType) {
353323
let fields = dtype.as_struct_fields_opt().unwrap();
354-
let session = ExprSession::default();
355-
let optimizer = ExprOptimizer::new(&session);
356324

357325
let expr = and(
358326
get_item("y", get_item("a", root())),
359327
select(["a", "b"], root()),
360328
);
361-
let expr = simplify_typed(expr, &dtype, ExprSession::default().rewrite_rules()).unwrap();
362-
let partitioned =
363-
partition(expr, &dtype, annotate_scope_access(fields), &optimizer).unwrap();
329+
let expr = expr.simplify(&dtype).unwrap();
330+
let partitioned = partition(expr, &dtype, annotate_scope_access(fields)).unwrap();
364331

365332
// One for id.a and id.b
366333
assert_eq!(partitioned.partitions.len(), 2);
@@ -394,13 +361,10 @@ mod tests {
394361
#[rstest]
395362
fn test_expr_merge(dtype: DType) {
396363
let fields = dtype.as_struct_fields_opt().unwrap();
397-
let session = ExprSession::default();
398-
let optimizer = ExprOptimizer::new(&session);
399364

400365
let expr = merge([col("a"), pack([("b", col("b"))], NonNullable)]);
401366

402-
let partitioned =
403-
partition(expr, &dtype, annotate_scope_access(fields), &optimizer).unwrap();
367+
let partitioned = partition(expr, &dtype, annotate_scope_access(fields)).unwrap();
404368
let expected = pack(
405369
[
406370
("x", get_item("x", get_item("a_0", col("a")))),

0 commit comments

Comments
 (0)