Skip to content

Commit 437392b

Browse files
authored
chore: introduce ExprRef, teach expressions new_ref (#1258)
I find this makes the pruner a bit easier to read.
1 parent ff00dec commit 437392b

File tree

18 files changed

+316
-344
lines changed

18 files changed

+316
-344
lines changed

pyvortex/src/expr.rs

Lines changed: 6 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@ use pyo3::types::*;
66
use vortex::dtype::field::Field;
77
use vortex::dtype::half::f16;
88
use vortex::dtype::{DType, Nullability, PType};
9-
use vortex::expr::{BinaryExpr, Column, Literal, Operator, VortexExpr};
9+
use vortex::expr::{BinaryExpr, Column, ExprRef, Literal, Operator};
1010
use vortex::scalar::{PValue, Scalar, ScalarValue};
1111

1212
use crate::dtype::PyDType;
@@ -119,11 +119,11 @@ use crate::dtype::PyDType;
119119
/// ]
120120
#[pyclass(name = "Expr", module = "vortex")]
121121
pub struct PyExpr {
122-
inner: Arc<dyn VortexExpr>,
122+
inner: ExprRef,
123123
}
124124

125125
impl PyExpr {
126-
pub fn unwrap(&self) -> &Arc<dyn VortexExpr> {
126+
pub fn unwrap(&self) -> &ExprRef {
127127
&self.inner
128128
}
129129
}
@@ -136,11 +136,7 @@ fn py_binary_opeartor<'py>(
136136
Bound::new(
137137
left.py(),
138138
PyExpr {
139-
inner: Arc::new(BinaryExpr::new(
140-
left.inner.clone(),
141-
operator,
142-
right.borrow().inner.clone(),
143-
)),
139+
inner: BinaryExpr::new_expr(left.inner.clone(), operator, right.borrow().inner.clone()),
144140
},
145141
)
146142
}
@@ -252,7 +248,7 @@ pub fn column<'py>(name: &Bound<'py, PyString>) -> PyResult<Bound<'py, PyExpr>>
252248
Bound::new(
253249
py,
254250
PyExpr {
255-
inner: Arc::new(Column::new(Field::Name(name))),
251+
inner: Column::new_expr(Field::Name(name)),
256252
},
257253
)
258254
}
@@ -270,10 +266,7 @@ pub fn scalar<'py>(dtype: DType, value: &Bound<'py, PyAny>) -> PyResult<Bound<'p
270266
Bound::new(
271267
py,
272268
PyExpr {
273-
inner: Arc::new(Literal::new(Scalar::new(
274-
dtype.clone(),
275-
scalar_value(dtype, value)?,
276-
))),
269+
inner: Literal::new_expr(Scalar::new(dtype.clone(), scalar_value(dtype, value)?)),
277270
},
278271
)
279272
}

vortex-datafusion/src/memory.rs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@ use vortex_array::arrow::infer_schema;
1717
use vortex_array::{Array, ArrayDType as _};
1818
use vortex_error::{VortexError, VortexExpect as _};
1919
use vortex_expr::datafusion::convert_expr_to_vortex;
20-
use vortex_expr::VortexExpr;
20+
use vortex_expr::ExprRef;
2121

2222
use crate::plans::{RowSelectorExec, TakeRowsExec};
2323
use crate::{can_be_pushed_down, VortexScanExec};
@@ -190,7 +190,7 @@ impl VortexMemTableOptions {
190190
/// columns.
191191
fn make_filter_then_take_plan(
192192
schema: SchemaRef,
193-
filter_expr: Arc<dyn VortexExpr>,
193+
filter_expr: ExprRef,
194194
chunked_array: ChunkedArray,
195195
output_projection: Vec<usize>,
196196
_session_state: &dyn Session,

vortex-datafusion/src/plans.rs

Lines changed: 4 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -24,13 +24,13 @@ use vortex_array::compute::take;
2424
use vortex_array::{Array, IntoArrayVariant, IntoCanonical};
2525
use vortex_dtype::field::Field;
2626
use vortex_error::{vortex_err, vortex_panic, VortexError};
27-
use vortex_expr::VortexExpr;
27+
use vortex_expr::ExprRef;
2828

2929
/// Physical plan operator that applies a set of [filters][Expr] against the input, producing a
3030
/// row mask that can be used downstream to force a take against the corresponding struct array
3131
/// chunks but for different columns.
3232
pub(crate) struct RowSelectorExec {
33-
filter_expr: Arc<dyn VortexExpr>,
33+
filter_expr: ExprRef,
3434
/// cached PlanProperties object. We do not make use of this.
3535
cached_plan_props: PlanProperties,
3636
/// Full array. We only access partitions of this data.
@@ -46,10 +46,7 @@ static ROW_SELECTOR_SCHEMA_REF: LazyLock<SchemaRef> = LazyLock::new(|| {
4646
});
4747

4848
impl RowSelectorExec {
49-
pub(crate) fn try_new(
50-
filter_expr: Arc<dyn VortexExpr>,
51-
chunked_array: &ChunkedArray,
52-
) -> DFResult<Self> {
49+
pub(crate) fn try_new(filter_expr: ExprRef, chunked_array: &ChunkedArray) -> DFResult<Self> {
5350
let cached_plan_props = PlanProperties::new(
5451
EquivalenceProperties::new(ROW_SELECTOR_SCHEMA_REF.clone()),
5552
Partitioning::UnknownPartitioning(1),
@@ -134,7 +131,7 @@ impl ExecutionPlan for RowSelectorExec {
134131
pub(crate) struct RowIndicesStream {
135132
chunked_array: ChunkedArray,
136133
chunk_idx: usize,
137-
conjunction_expr: Arc<dyn VortexExpr>,
134+
conjunction_expr: ExprRef,
138135
filter_projection: Vec<Field>,
139136
}
140137

vortex-expr/src/binary.rs

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -8,25 +8,25 @@ use vortex_array::Array;
88
use vortex_dtype::field::Field;
99
use vortex_error::VortexResult;
1010

11-
use crate::{unbox_any, Operator, VortexExpr};
11+
use crate::{unbox_any, ExprRef, Operator, VortexExpr};
1212

1313
#[derive(Debug, Clone)]
1414
pub struct BinaryExpr {
15-
lhs: Arc<dyn VortexExpr>,
15+
lhs: ExprRef,
1616
operator: Operator,
17-
rhs: Arc<dyn VortexExpr>,
17+
rhs: ExprRef,
1818
}
1919

2020
impl BinaryExpr {
21-
pub fn new(lhs: Arc<dyn VortexExpr>, operator: Operator, rhs: Arc<dyn VortexExpr>) -> Self {
22-
Self { lhs, operator, rhs }
21+
pub fn new_expr(lhs: ExprRef, operator: Operator, rhs: ExprRef) -> ExprRef {
22+
Arc::new(Self { lhs, operator, rhs })
2323
}
2424

25-
pub fn lhs(&self) -> &Arc<dyn VortexExpr> {
25+
pub fn lhs(&self) -> &ExprRef {
2626
&self.lhs
2727
}
2828

29-
pub fn rhs(&self) -> &Arc<dyn VortexExpr> {
29+
pub fn rhs(&self) -> &ExprRef {
3030
&self.rhs
3131
}
3232

vortex-expr/src/column.rs

Lines changed: 10 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
use std::any::Any;
22
use std::fmt::Display;
3+
use std::sync::Arc;
34

45
use vortex_array::aliases::hash_set::HashSet;
56
use vortex_array::array::StructArray;
@@ -8,16 +9,16 @@ use vortex_array::Array;
89
use vortex_dtype::field::Field;
910
use vortex_error::{vortex_err, VortexResult};
1011

11-
use crate::{unbox_any, VortexExpr};
12+
use crate::{unbox_any, ExprRef, VortexExpr};
1213

1314
#[derive(Debug, PartialEq, Hash, Clone, Eq)]
1415
pub struct Column {
1516
field: Field,
1617
}
1718

1819
impl Column {
19-
pub fn new(field: Field) -> Self {
20-
Self { field }
20+
pub fn new_expr(field: Field) -> ExprRef {
21+
Arc::new(Self { field })
2122
}
2223

2324
pub fn field(&self) -> &Field {
@@ -27,13 +28,17 @@ impl Column {
2728

2829
impl From<String> for Column {
2930
fn from(value: String) -> Self {
30-
Column::new(value.into())
31+
Column {
32+
field: value.into(),
33+
}
3134
}
3235
}
3336

3437
impl From<usize> for Column {
3538
fn from(value: usize) -> Self {
36-
Column::new(value.into())
39+
Column {
40+
field: value.into(),
41+
}
3742
}
3843
}
3944

vortex-expr/src/datafusion.rs

Lines changed: 4 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -7,11 +7,8 @@ use datafusion_physical_expr::{expressions, PhysicalExpr};
77
use vortex_error::{vortex_bail, vortex_err, VortexError, VortexResult};
88
use vortex_scalar::Scalar;
99

10-
use crate::{BinaryExpr, Column, Literal, Operator, VortexExpr};
11-
12-
pub fn convert_expr_to_vortex(
13-
physical_expr: Arc<dyn PhysicalExpr>,
14-
) -> VortexResult<Arc<dyn VortexExpr>> {
10+
use crate::{BinaryExpr, Column, ExprRef, Literal, Operator};
11+
pub fn convert_expr_to_vortex(physical_expr: Arc<dyn PhysicalExpr>) -> VortexResult<ExprRef> {
1512
if let Some(binary_expr) = physical_expr
1613
.as_any()
1714
.downcast_ref::<expressions::BinaryExpr>()
@@ -20,7 +17,7 @@ pub fn convert_expr_to_vortex(
2017
let right = convert_expr_to_vortex(binary_expr.right().clone())?;
2118
let operator = *binary_expr.op();
2219

23-
return Ok(Arc::new(BinaryExpr::new(left, operator.try_into()?, right)) as _);
20+
return Ok(BinaryExpr::new_expr(left, operator.try_into()?, right));
2421
}
2522

2623
if let Some(col_expr) = physical_expr.as_any().downcast_ref::<expressions::Column>() {
@@ -34,7 +31,7 @@ pub fn convert_expr_to_vortex(
3431
.downcast_ref::<expressions::Literal>()
3532
{
3633
let value = Scalar::from(lit.value().clone());
37-
return Ok(Arc::new(Literal::new(value)) as _);
34+
return Ok(Literal::new_expr(value));
3835
}
3936

4037
vortex_bail!("Couldn't convert DataFusion physical expression to a vortex expression")

vortex-expr/src/lib.rs

Lines changed: 41 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,8 @@ use vortex_array::Array;
2424
use vortex_dtype::field::Field;
2525
use vortex_error::{VortexExpect, VortexResult};
2626

27+
pub type ExprRef = Arc<dyn VortexExpr>;
28+
2729
/// Represents logical operation on [`Array`]s
2830
pub trait VortexExpr: Debug + Send + Sync + PartialEq<dyn Any> + Display {
2931
/// Convert expression reference to reference of [`Any`] type
@@ -44,13 +46,13 @@ pub trait VortexExpr: Debug + Send + Sync + PartialEq<dyn Any> + Display {
4446
}
4547

4648
/// Splits top level and operations into separate expressions
47-
pub fn split_conjunction(expr: &Arc<dyn VortexExpr>) -> Vec<Arc<dyn VortexExpr>> {
49+
pub fn split_conjunction(expr: &ExprRef) -> Vec<ExprRef> {
4850
let mut conjunctions = vec![];
4951
split_inner(expr, &mut conjunctions);
5052
conjunctions
5153
}
5254

53-
fn split_inner(expr: &Arc<dyn VortexExpr>, exprs: &mut Vec<Arc<dyn VortexExpr>>) {
55+
fn split_inner(expr: &ExprRef, exprs: &mut Vec<ExprRef>) {
5456
match expr.as_any().downcast_ref::<BinaryExpr>() {
5557
Some(bexp) if bexp.op() == Operator::And => {
5658
split_inner(bexp.lhs(), exprs);
@@ -64,9 +66,9 @@ fn split_inner(expr: &Arc<dyn VortexExpr>, exprs: &mut Vec<Arc<dyn VortexExpr>>)
6466

6567
// Taken from apache-datafusion, necessary since you can't require VortexExpr implement PartialEq<dyn VortexExpr>
6668
pub fn unbox_any(any: &dyn Any) -> &dyn Any {
67-
if any.is::<Arc<dyn VortexExpr>>() {
68-
any.downcast_ref::<Arc<dyn VortexExpr>>()
69-
.vortex_expect("any.is::<Arc<dyn VortexExpr>> returned true but downcast_ref failed")
69+
if any.is::<ExprRef>() {
70+
any.downcast_ref::<ExprRef>()
71+
.vortex_expect("any.is::<ExprRef> returned true but downcast_ref failed")
7072
.as_any()
7173
} else if any.is::<Box<dyn VortexExpr>>() {
7274
any.downcast_ref::<Box<dyn VortexExpr>>()
@@ -87,75 +89,78 @@ mod tests {
8789

8890
#[test]
8991
fn basic_expr_split_test() {
90-
let lhs = Arc::new(Column::new(Field::Name("a".to_string()))) as _;
91-
let rhs = Arc::new(Literal::new(1.into())) as _;
92-
let expr = Arc::new(BinaryExpr::new(lhs, Operator::Eq, rhs)) as _;
92+
let lhs = Column::new_expr(Field::Name("a".to_string()));
93+
let rhs = Literal::new_expr(1.into());
94+
let expr = BinaryExpr::new_expr(lhs, Operator::Eq, rhs);
9395
let conjunction = split_conjunction(&expr);
9496
assert_eq!(conjunction.len(), 1);
9597
}
9698

9799
#[test]
98100
fn basic_conjunction_split_test() {
99-
let lhs = Arc::new(Column::new(Field::Name("a".to_string()))) as _;
100-
let rhs = Arc::new(Literal::new(1.into())) as _;
101-
let expr = Arc::new(BinaryExpr::new(lhs, Operator::And, rhs)) as _;
101+
let lhs = Column::new_expr(Field::Name("a".to_string()));
102+
let rhs = Literal::new_expr(1.into());
103+
let expr = BinaryExpr::new_expr(lhs, Operator::And, rhs);
102104
let conjunction = split_conjunction(&expr);
103105
assert_eq!(conjunction.len(), 2, "Conjunction is {conjunction:?}");
104106
}
105107

106108
#[test]
107109
fn expr_display() {
108-
assert_eq!(Column::new(Field::Name("a".to_string())).to_string(), "$a");
109-
assert_eq!(Column::new(Field::Index(1)).to_string(), "[1]");
110+
assert_eq!(
111+
Column::new_expr(Field::Name("a".to_string())).to_string(),
112+
"$a"
113+
);
114+
assert_eq!(Column::new_expr(Field::Index(1)).to_string(), "[1]");
110115
assert_eq!(Identity.to_string(), "[]");
111116
assert_eq!(Identity.to_string(), "[]");
112117

113-
let col1: Arc<dyn VortexExpr> = Arc::new(Column::new(Field::Name("col1".to_string())));
114-
let col2: Arc<dyn VortexExpr> = Arc::new(Column::new(Field::Name("col2".to_string())));
118+
let col1: Arc<dyn VortexExpr> = Column::new_expr(Field::Name("col1".to_string()));
119+
let col2: Arc<dyn VortexExpr> = Column::new_expr(Field::Name("col2".to_string()));
115120
assert_eq!(
116-
BinaryExpr::new(col1.clone(), Operator::And, col2.clone()).to_string(),
121+
BinaryExpr::new_expr(col1.clone(), Operator::And, col2.clone()).to_string(),
117122
"($col1 and $col2)"
118123
);
119124
assert_eq!(
120-
BinaryExpr::new(col1.clone(), Operator::Or, col2.clone()).to_string(),
125+
BinaryExpr::new_expr(col1.clone(), Operator::Or, col2.clone()).to_string(),
121126
"($col1 or $col2)"
122127
);
123128
assert_eq!(
124-
BinaryExpr::new(col1.clone(), Operator::Eq, col2.clone()).to_string(),
129+
BinaryExpr::new_expr(col1.clone(), Operator::Eq, col2.clone()).to_string(),
125130
"($col1 = $col2)"
126131
);
127132
assert_eq!(
128-
BinaryExpr::new(col1.clone(), Operator::NotEq, col2.clone()).to_string(),
133+
BinaryExpr::new_expr(col1.clone(), Operator::NotEq, col2.clone()).to_string(),
129134
"($col1 != $col2)"
130135
);
131136
assert_eq!(
132-
BinaryExpr::new(col1.clone(), Operator::Gt, col2.clone()).to_string(),
137+
BinaryExpr::new_expr(col1.clone(), Operator::Gt, col2.clone()).to_string(),
133138
"($col1 > $col2)"
134139
);
135140
assert_eq!(
136-
BinaryExpr::new(col1.clone(), Operator::Gte, col2.clone()).to_string(),
141+
BinaryExpr::new_expr(col1.clone(), Operator::Gte, col2.clone()).to_string(),
137142
"($col1 >= $col2)"
138143
);
139144
assert_eq!(
140-
BinaryExpr::new(col1.clone(), Operator::Lt, col2.clone()).to_string(),
145+
BinaryExpr::new_expr(col1.clone(), Operator::Lt, col2.clone()).to_string(),
141146
"($col1 < $col2)"
142147
);
143148
assert_eq!(
144-
BinaryExpr::new(col1.clone(), Operator::Lte, col2.clone()).to_string(),
149+
BinaryExpr::new_expr(col1.clone(), Operator::Lte, col2.clone()).to_string(),
145150
"($col1 <= $col2)"
146151
);
147152

148153
assert_eq!(
149-
BinaryExpr::new(
150-
Arc::new(BinaryExpr::new(col1.clone(), Operator::Lt, col2.clone())),
154+
BinaryExpr::new_expr(
155+
BinaryExpr::new_expr(col1.clone(), Operator::Lt, col2.clone()),
151156
Operator::Or,
152-
Arc::new(BinaryExpr::new(col1.clone(), Operator::NotEq, col2.clone()))
157+
BinaryExpr::new_expr(col1.clone(), Operator::NotEq, col2.clone())
153158
)
154159
.to_string(),
155160
"(($col1 < $col2) or ($col1 != $col2))"
156161
);
157162

158-
assert_eq!(Not::new(col1.clone()).to_string(), "!$col1");
163+
assert_eq!(Not::new_expr(col1.clone()).to_string(), "!$col1");
159164

160165
assert_eq!(
161166
Select::include(vec![Field::Name("col1".to_string())]).to_string(),
@@ -179,20 +184,23 @@ mod tests {
179184
"Exclude($col1,$col2,[1])"
180185
);
181186

182-
assert_eq!(Literal::new(Scalar::from(0_u8)).to_string(), "0_u8");
183-
assert_eq!(Literal::new(Scalar::from(0.0_f32)).to_string(), "0_f32");
187+
assert_eq!(Literal::new_expr(Scalar::from(0_u8)).to_string(), "0_u8");
188+
assert_eq!(
189+
Literal::new_expr(Scalar::from(0.0_f32)).to_string(),
190+
"0_f32"
191+
);
184192
assert_eq!(
185-
Literal::new(Scalar::from(i64::MAX)).to_string(),
193+
Literal::new_expr(Scalar::from(i64::MAX)).to_string(),
186194
"9223372036854775807_i64"
187195
);
188-
assert_eq!(Literal::new(Scalar::from(true)).to_string(), "true");
196+
assert_eq!(Literal::new_expr(Scalar::from(true)).to_string(), "true");
189197
assert_eq!(
190-
Literal::new(Scalar::null(DType::Bool(Nullability::Nullable))).to_string(),
198+
Literal::new_expr(Scalar::null(DType::Bool(Nullability::Nullable))).to_string(),
191199
"null"
192200
);
193201

194202
assert_eq!(
195-
Literal::new(Scalar::new(
203+
Literal::new_expr(Scalar::new(
196204
DType::Struct(
197205
StructDType::new(
198206
Arc::from([Arc::from("dog"), Arc::from("cat")]),

0 commit comments

Comments
 (0)