Skip to content

Commit 1b73409

Browse files
authored
Pluggable expression conversion for DF (#5853)
This PR introduces an extendable interface for converting DataFusion physical expressions into Vortex ones, so that users can define their own behavior for UDFs or any other home-brewed physical expressions. I think this is an OK way to expose the defaults for now, but would appreciate thoughts. Signed-off-by: Adam Gutglick <[email protected]>
1 parent 59bcf21 commit 1b73409

File tree

5 files changed

+164
-116
lines changed

5 files changed

+164
-116
lines changed

vortex-datafusion/src/convert/exprs.rs

Lines changed: 137 additions & 110 deletions
Original file line numberDiff line numberDiff line change
@@ -35,28 +35,81 @@ use vortex::expr::root;
3535
use vortex::scalar::Scalar;
3636

3737
use crate::convert::FromDataFusion;
38-
use crate::convert::TryFromDataFusion;
3938

4039
/// Tries to convert the expressions into a vortex conjunction. Will return Ok(None) iff the input conjunction is empty.
4140
pub(crate) fn make_vortex_predicate(
41+
expr_convertor: &dyn ExpressionConvertor,
4242
predicate: &[Arc<dyn PhysicalExpr>],
4343
) -> VortexResult<Option<Expression>> {
4444
let exprs = predicate
4545
.iter()
46-
.map(|e| Expression::try_from_df(e.as_ref()))
46+
.map(|e| expr_convertor.convert(e.as_ref()))
4747
.collect::<VortexResult<Vec<_>>>()?;
4848

4949
Ok(exprs.into_iter().reduce(and))
5050
}
5151

52-
// TODO(joe): Don't return an error when we have an unsupported node, bubble up "TRUE" as in keep
53-
// for that node, up to any `and` or `or` node.
54-
impl TryFromDataFusion<dyn PhysicalExpr> for Expression {
55-
fn try_from_df(df: &dyn PhysicalExpr) -> VortexResult<Self> {
52+
/// Trait for converting DataFusion expressions to Vortex ones.
53+
pub trait ExpressionConvertor: Send + Sync {
54+
/// Can an expression be pushed down given a specific schema
55+
fn can_be_pushed_down(&self, expr: &Arc<dyn PhysicalExpr>, schema: &Schema) -> bool;
56+
57+
/// Try and convert a DataFusion [`PhysicalExpr`] into a Vortex [`Expression`].
58+
fn convert(&self, expr: &dyn PhysicalExpr) -> VortexResult<Expression>;
59+
}
60+
61+
/// The default [`ExpressionConvertor`].
62+
#[derive(Default)]
63+
pub struct DefaultExpressionConvertor {}
64+
65+
impl DefaultExpressionConvertor {
66+
/// Attempts to convert a DataFusion ScalarFunctionExpr to a Vortex expression.
67+
fn try_convert_scalar_function(
68+
&self,
69+
scalar_fn: &ScalarFunctionExpr,
70+
) -> VortexResult<Expression> {
71+
if let Some(get_field_fn) = ScalarFunctionExpr::try_downcast_func::<GetFieldFunc>(scalar_fn)
72+
{
73+
let source_expr = get_field_fn
74+
.args()
75+
.first()
76+
.ok_or_else(|| vortex_err!("get_field missing source expression"))?
77+
.as_ref();
78+
let field_name_expr = get_field_fn
79+
.args()
80+
.get(1)
81+
.ok_or_else(|| vortex_err!("get_field missing field name argument"))?;
82+
let field_name = field_name_expr
83+
.as_any()
84+
.downcast_ref::<df_expr::Literal>()
85+
.ok_or_else(|| vortex_err!("get_field field name must be a literal"))?
86+
.value()
87+
.try_as_str()
88+
.flatten()
89+
.ok_or_else(|| vortex_err!("get_field field name must be a UTF-8 string"))?;
90+
return Ok(get_item(field_name.to_string(), self.convert(source_expr)?));
91+
}
92+
93+
tracing::debug!(
94+
function_name = scalar_fn.name(),
95+
"Unsupported ScalarFunctionExpr"
96+
);
97+
vortex_bail!("Unsupported ScalarFunctionExpr: {}", scalar_fn.name())
98+
}
99+
}
100+
101+
impl ExpressionConvertor for DefaultExpressionConvertor {
102+
fn can_be_pushed_down(&self, expr: &Arc<dyn PhysicalExpr>, schema: &Schema) -> bool {
103+
can_be_pushed_down(expr, schema)
104+
}
105+
106+
fn convert(&self, df: &dyn PhysicalExpr) -> VortexResult<Expression> {
107+
// TODO(joe): Don't return an error when we have an unsupported node, bubble up "TRUE" as in keep
108+
// for that node, up to any `and` or `or` node.
56109
if let Some(binary_expr) = df.as_any().downcast_ref::<df_expr::BinaryExpr>() {
57-
let left = Expression::try_from_df(binary_expr.left().as_ref())?;
58-
let right = Expression::try_from_df(binary_expr.right().as_ref())?;
59-
let operator = Operator::try_from_df(binary_expr.op())?;
110+
let left = self.convert(binary_expr.left().as_ref())?;
111+
let right = self.convert(binary_expr.right().as_ref())?;
112+
let operator = try_operator_from_df(binary_expr.op())?;
60113

61114
return Ok(Binary.new_expr(operator, [left, right]));
62115
}
@@ -66,8 +119,8 @@ impl TryFromDataFusion<dyn PhysicalExpr> for Expression {
66119
}
67120

68121
if let Some(like) = df.as_any().downcast_ref::<df_expr::LikeExpr>() {
69-
let child = Expression::try_from_df(like.expr().as_ref())?;
70-
let pattern = Expression::try_from_df(like.pattern().as_ref())?;
122+
let child = self.convert(like.expr().as_ref())?;
123+
let pattern = self.convert(like.pattern().as_ref())?;
71124
return Ok(Like.new_expr(
72125
LikeOptions {
73126
negated: like.negated(),
@@ -84,30 +137,30 @@ impl TryFromDataFusion<dyn PhysicalExpr> for Expression {
84137

85138
if let Some(cast_expr) = df.as_any().downcast_ref::<df_expr::CastExpr>() {
86139
let cast_dtype = DType::from_arrow((cast_expr.cast_type(), Nullability::Nullable));
87-
let child = Expression::try_from_df(cast_expr.expr().as_ref())?;
140+
let child = self.convert(cast_expr.expr().as_ref())?;
88141
return Ok(cast(child, cast_dtype));
89142
}
90143

91144
if let Some(cast_col_expr) = df.as_any().downcast_ref::<df_expr::CastColumnExpr>() {
92145
let target = cast_col_expr.target_field();
93146

94147
let target_dtype = DType::from_arrow((target.data_type(), target.is_nullable().into()));
95-
let child = Expression::try_from_df(cast_col_expr.expr().as_ref())?;
148+
let child = self.convert(cast_col_expr.expr().as_ref())?;
96149
return Ok(cast(child, target_dtype));
97150
}
98151

99152
if let Some(is_null_expr) = df.as_any().downcast_ref::<df_expr::IsNullExpr>() {
100-
let arg = Expression::try_from_df(is_null_expr.arg().as_ref())?;
153+
let arg = self.convert(is_null_expr.arg().as_ref())?;
101154
return Ok(is_null(arg));
102155
}
103156

104157
if let Some(is_not_null_expr) = df.as_any().downcast_ref::<df_expr::IsNotNullExpr>() {
105-
let arg = Expression::try_from_df(is_not_null_expr.arg().as_ref())?;
158+
let arg = self.convert(is_not_null_expr.arg().as_ref())?;
106159
return Ok(not(is_null(arg)));
107160
}
108161

109162
if let Some(in_list) = df.as_any().downcast_ref::<df_expr::InListExpr>() {
110-
let value = Expression::try_from_df(in_list.expr().as_ref())?;
163+
let value = self.convert(in_list.expr().as_ref())?;
111164
let list_elements: Vec<_> = in_list
112165
.list()
113166
.iter()
@@ -131,94 +184,59 @@ impl TryFromDataFusion<dyn PhysicalExpr> for Expression {
131184
}
132185

133186
if let Some(scalar_fn) = df.as_any().downcast_ref::<ScalarFunctionExpr>() {
134-
return try_convert_scalar_function(scalar_fn);
187+
return self.try_convert_scalar_function(scalar_fn);
135188
}
136189

137190
vortex_bail!("Couldn't convert DataFusion physical {df} expression to a vortex expression")
138191
}
139192
}
140193

141-
/// Attempts to convert a DataFusion ScalarFunctionExpr to a Vortex expression.
142-
fn try_convert_scalar_function(scalar_fn: &ScalarFunctionExpr) -> VortexResult<Expression> {
143-
if let Some(get_field_fn) = ScalarFunctionExpr::try_downcast_func::<GetFieldFunc>(scalar_fn) {
144-
let source_expr = get_field_fn
145-
.args()
146-
.first()
147-
.ok_or_else(|| vortex_err!("get_field missing source expression"))?
148-
.as_ref();
149-
let field_name_expr = get_field_fn
150-
.args()
151-
.get(1)
152-
.ok_or_else(|| vortex_err!("get_field missing field name argument"))?;
153-
let field_name = field_name_expr
154-
.as_any()
155-
.downcast_ref::<df_expr::Literal>()
156-
.ok_or_else(|| vortex_err!("get_field field name must be a literal"))?
157-
.value()
158-
.try_as_str()
159-
.flatten()
160-
.ok_or_else(|| vortex_err!("get_field field name must be a UTF-8 string"))?;
161-
return Ok(get_item(
162-
field_name.to_string(),
163-
Expression::try_from_df(source_expr)?,
164-
));
165-
}
166-
167-
tracing::debug!(
168-
function_name = scalar_fn.name(),
169-
"Unsupported ScalarFunctionExpr"
170-
);
171-
vortex_bail!("Unsupported ScalarFunctionExpr: {}", scalar_fn.name())
172-
}
173-
174-
impl TryFromDataFusion<DFOperator> for Operator {
175-
fn try_from_df(value: &DFOperator) -> VortexResult<Self> {
176-
match value {
177-
DFOperator::Eq => Ok(Operator::Eq),
178-
DFOperator::NotEq => Ok(Operator::NotEq),
179-
DFOperator::Lt => Ok(Operator::Lt),
180-
DFOperator::LtEq => Ok(Operator::Lte),
181-
DFOperator::Gt => Ok(Operator::Gt),
182-
DFOperator::GtEq => Ok(Operator::Gte),
183-
DFOperator::And => Ok(Operator::And),
184-
DFOperator::Or => Ok(Operator::Or),
185-
DFOperator::Plus => Ok(Operator::Add),
186-
DFOperator::Minus => Ok(Operator::Sub),
187-
DFOperator::Multiply => Ok(Operator::Mul),
188-
DFOperator::Divide => Ok(Operator::Div),
189-
DFOperator::IsDistinctFrom
190-
| DFOperator::IsNotDistinctFrom
191-
| DFOperator::RegexMatch
192-
| DFOperator::RegexIMatch
193-
| DFOperator::RegexNotMatch
194-
| DFOperator::RegexNotIMatch
195-
| DFOperator::LikeMatch
196-
| DFOperator::ILikeMatch
197-
| DFOperator::NotLikeMatch
198-
| DFOperator::NotILikeMatch
199-
| DFOperator::BitwiseAnd
200-
| DFOperator::BitwiseOr
201-
| DFOperator::BitwiseXor
202-
| DFOperator::BitwiseShiftRight
203-
| DFOperator::BitwiseShiftLeft
204-
| DFOperator::StringConcat
205-
| DFOperator::AtArrow
206-
| DFOperator::ArrowAt
207-
| DFOperator::Modulo
208-
| DFOperator::Arrow
209-
| DFOperator::LongArrow
210-
| DFOperator::HashArrow
211-
| DFOperator::HashLongArrow
212-
| DFOperator::AtAt
213-
| DFOperator::IntegerDivide
214-
| DFOperator::HashMinus
215-
| DFOperator::AtQuestion
216-
| DFOperator::Question
217-
| DFOperator::QuestionAnd
218-
| DFOperator::QuestionPipe => {
219-
tracing::debug!(operator = %value, "Can't pushdown binary_operator operator");
220-
Err(vortex_err!("Unsupported datafusion operator {value}"))
221-
}
194+
fn try_operator_from_df(value: &DFOperator) -> VortexResult<Operator> {
195+
match value {
196+
DFOperator::Eq => Ok(Operator::Eq),
197+
DFOperator::NotEq => Ok(Operator::NotEq),
198+
DFOperator::Lt => Ok(Operator::Lt),
199+
DFOperator::LtEq => Ok(Operator::Lte),
200+
DFOperator::Gt => Ok(Operator::Gt),
201+
DFOperator::GtEq => Ok(Operator::Gte),
202+
DFOperator::And => Ok(Operator::And),
203+
DFOperator::Or => Ok(Operator::Or),
204+
DFOperator::Plus => Ok(Operator::Add),
205+
DFOperator::Minus => Ok(Operator::Sub),
206+
DFOperator::Multiply => Ok(Operator::Mul),
207+
DFOperator::Divide => Ok(Operator::Div),
208+
DFOperator::IsDistinctFrom
209+
| DFOperator::IsNotDistinctFrom
210+
| DFOperator::RegexMatch
211+
| DFOperator::RegexIMatch
212+
| DFOperator::RegexNotMatch
213+
| DFOperator::RegexNotIMatch
214+
| DFOperator::LikeMatch
215+
| DFOperator::ILikeMatch
216+
| DFOperator::NotLikeMatch
217+
| DFOperator::NotILikeMatch
218+
| DFOperator::BitwiseAnd
219+
| DFOperator::BitwiseOr
220+
| DFOperator::BitwiseXor
221+
| DFOperator::BitwiseShiftRight
222+
| DFOperator::BitwiseShiftLeft
223+
| DFOperator::StringConcat
224+
| DFOperator::AtArrow
225+
| DFOperator::ArrowAt
226+
| DFOperator::Modulo
227+
| DFOperator::Arrow
228+
| DFOperator::LongArrow
229+
| DFOperator::HashArrow
230+
| DFOperator::HashLongArrow
231+
| DFOperator::AtAt
232+
| DFOperator::IntegerDivide
233+
| DFOperator::HashMinus
234+
| DFOperator::AtQuestion
235+
| DFOperator::Question
236+
| DFOperator::QuestionAnd
237+
| DFOperator::QuestionPipe => {
238+
tracing::debug!(operator = %value, "Can't pushdown binary_operator operator");
239+
Err(vortex_err!("Unsupported datafusion operator {value}"))
222240
}
223241
}
224242
}
@@ -262,7 +280,7 @@ pub(crate) fn can_be_pushed_down(df_expr: &Arc<dyn PhysicalExpr>, schema: &Schem
262280
}
263281

264282
fn can_binary_be_pushed_down(binary: &df_expr::BinaryExpr, schema: &Schema) -> bool {
265-
let is_op_supported = Operator::try_from_df(binary.op()).is_ok();
283+
let is_op_supported = try_operator_from_df(binary.op()).is_ok();
266284
is_op_supported
267285
&& can_be_pushed_down(binary.left(), schema)
268286
&& can_be_pushed_down(binary.right(), schema)
@@ -320,8 +338,6 @@ mod tests {
320338
use datafusion_physical_plan::expressions as df_expr;
321339
use insta::assert_snapshot;
322340
use rstest::rstest;
323-
use vortex::expr::Expression;
324-
use vortex::expr::Operator;
325341

326342
use super::*;
327343

@@ -347,22 +363,25 @@ mod tests {
347363

348364
#[test]
349365
fn test_make_vortex_predicate_empty() {
350-
let result = make_vortex_predicate(&[]).unwrap();
366+
let expr_convertor = DefaultExpressionConvertor::default();
367+
let result = make_vortex_predicate(&expr_convertor, &[]).unwrap();
351368
assert!(result.is_none());
352369
}
353370

354371
#[test]
355372
fn test_make_vortex_predicate_single() {
373+
let expr_convertor = DefaultExpressionConvertor::default();
356374
let col_expr = Arc::new(df_expr::Column::new("test", 0)) as Arc<dyn PhysicalExpr>;
357-
let result = make_vortex_predicate(&[col_expr]).unwrap();
375+
let result = make_vortex_predicate(&expr_convertor, &[col_expr]).unwrap();
358376
assert!(result.is_some());
359377
}
360378

361379
#[test]
362380
fn test_make_vortex_predicate_multiple() {
381+
let expr_convertor = DefaultExpressionConvertor::default();
363382
let col1 = Arc::new(df_expr::Column::new("col1", 0)) as Arc<dyn PhysicalExpr>;
364383
let col2 = Arc::new(df_expr::Column::new("col2", 1)) as Arc<dyn PhysicalExpr>;
365-
let result = make_vortex_predicate(&[col1, col2]).unwrap();
384+
let result = make_vortex_predicate(&expr_convertor, &[col1, col2]).unwrap();
366385
assert!(result.is_some());
367386
// Result should be an AND expression combining the two columns
368387
}
@@ -384,7 +403,7 @@ mod tests {
384403
#[case] df_op: DFOperator,
385404
#[case] expected_vortex_op: Operator,
386405
) {
387-
let result = Operator::try_from_df(&df_op).unwrap();
406+
let result = try_operator_from_df(&df_op).unwrap();
388407
assert_eq!(result, expected_vortex_op);
389408
}
390409

@@ -394,7 +413,7 @@ mod tests {
394413
#[case::regex_match(DFOperator::RegexMatch)]
395414
#[case::like_match(DFOperator::LikeMatch)]
396415
fn test_operator_conversion_unsupported(#[case] df_op: DFOperator) {
397-
let result = Operator::try_from_df(&df_op);
416+
let result = try_operator_from_df(&df_op);
398417
assert!(result.is_err());
399418
assert!(
400419
result
@@ -407,7 +426,9 @@ mod tests {
407426
#[test]
408427
fn test_expr_from_df_column() {
409428
let col_expr = df_expr::Column::new("test_column", 0);
410-
let result = Expression::try_from_df(&col_expr).unwrap();
429+
let result = DefaultExpressionConvertor::default()
430+
.convert(&col_expr)
431+
.unwrap();
411432

412433
assert_snapshot!(result.display_tree().to_string(), @r"
413434
vortex.get_item(test_column)
@@ -418,7 +439,9 @@ mod tests {
418439
#[test]
419440
fn test_expr_from_df_literal() {
420441
let literal_expr = df_expr::Literal::new(ScalarValue::Int32(Some(42)));
421-
let result = Expression::try_from_df(&literal_expr).unwrap();
442+
let result = DefaultExpressionConvertor::default()
443+
.convert(&literal_expr)
444+
.unwrap();
422445

423446
assert_snapshot!(result.display_tree().to_string(), @"vortex.literal(42i32)");
424447
}
@@ -430,7 +453,9 @@ mod tests {
430453
Arc::new(df_expr::Literal::new(ScalarValue::Int32(Some(42)))) as Arc<dyn PhysicalExpr>;
431454
let binary_expr = df_expr::BinaryExpr::new(left, DFOperator::Eq, right);
432455

433-
let result = Expression::try_from_df(&binary_expr).unwrap();
456+
let result = DefaultExpressionConvertor::default()
457+
.convert(&binary_expr)
458+
.unwrap();
434459

435460
assert_snapshot!(result.display_tree().to_string(), @r"
436461
vortex.binary(=)
@@ -452,7 +477,9 @@ mod tests {
452477
)))) as Arc<dyn PhysicalExpr>;
453478
let like_expr = df_expr::LikeExpr::new(negated, case_insensitive, expr, pattern);
454479

455-
let result = Expression::try_from_df(&like_expr).unwrap();
480+
let result = DefaultExpressionConvertor::default()
481+
.convert(&like_expr)
482+
.unwrap();
456483
let like_opts = result.as_::<Like>();
457484
assert_eq!(
458485
like_opts,

0 commit comments

Comments
 (0)