@@ -35,28 +35,81 @@ use vortex::expr::root;
3535use vortex:: scalar:: Scalar ;
3636
3737use crate :: convert:: FromDataFusion ;
38- use crate :: convert:: TryFromDataFusion ;
3938
4039/// Tries to convert the expressions into a vortex conjunction. Will return Ok(None) iff the input conjunction is empty.
4140pub ( crate ) fn make_vortex_predicate (
41+ expr_convertor : & dyn ExpressionConvertor ,
4242 predicate : & [ Arc < dyn PhysicalExpr > ] ,
4343) -> VortexResult < Option < Expression > > {
4444 let exprs = predicate
4545 . iter ( )
46- . map ( |e| Expression :: try_from_df ( e. as_ref ( ) ) )
46+ . map ( |e| expr_convertor . convert ( e. as_ref ( ) ) )
4747 . collect :: < VortexResult < Vec < _ > > > ( ) ?;
4848
4949 Ok ( exprs. into_iter ( ) . reduce ( and) )
5050}
5151
52- // TODO(joe): Don't return an error when we have an unsupported node, bubble up "TRUE" as in keep
53- // for that node, up to any `and` or `or` node.
54- impl TryFromDataFusion < dyn PhysicalExpr > for Expression {
55- fn try_from_df ( df : & dyn PhysicalExpr ) -> VortexResult < Self > {
52+ /// Trait for converting DataFusion expressions to Vortex ones.
53+ pub trait ExpressionConvertor : Send + Sync {
54+ /// Can an expression be pushed down given a specific schema
55+ fn can_be_pushed_down ( & self , expr : & Arc < dyn PhysicalExpr > , schema : & Schema ) -> bool ;
56+
57+ /// Try and convert a DataFusion [`PhysicalExpr`] into a Vortex [`Expression`].
58+ fn convert ( & self , expr : & dyn PhysicalExpr ) -> VortexResult < Expression > ;
59+ }
60+
61+ /// The default [`ExpressionConvertor`].
62+ #[ derive( Default ) ]
63+ pub struct DefaultExpressionConvertor { }
64+
65+ impl DefaultExpressionConvertor {
66+ /// Attempts to convert a DataFusion ScalarFunctionExpr to a Vortex expression.
67+ fn try_convert_scalar_function (
68+ & self ,
69+ scalar_fn : & ScalarFunctionExpr ,
70+ ) -> VortexResult < Expression > {
71+ if let Some ( get_field_fn) = ScalarFunctionExpr :: try_downcast_func :: < GetFieldFunc > ( scalar_fn)
72+ {
73+ let source_expr = get_field_fn
74+ . args ( )
75+ . first ( )
76+ . ok_or_else ( || vortex_err ! ( "get_field missing source expression" ) ) ?
77+ . as_ref ( ) ;
78+ let field_name_expr = get_field_fn
79+ . args ( )
80+ . get ( 1 )
81+ . ok_or_else ( || vortex_err ! ( "get_field missing field name argument" ) ) ?;
82+ let field_name = field_name_expr
83+ . as_any ( )
84+ . downcast_ref :: < df_expr:: Literal > ( )
85+ . ok_or_else ( || vortex_err ! ( "get_field field name must be a literal" ) ) ?
86+ . value ( )
87+ . try_as_str ( )
88+ . flatten ( )
89+ . ok_or_else ( || vortex_err ! ( "get_field field name must be a UTF-8 string" ) ) ?;
90+ return Ok ( get_item ( field_name. to_string ( ) , self . convert ( source_expr) ?) ) ;
91+ }
92+
93+ tracing:: debug!(
94+ function_name = scalar_fn. name( ) ,
95+ "Unsupported ScalarFunctionExpr"
96+ ) ;
97+ vortex_bail ! ( "Unsupported ScalarFunctionExpr: {}" , scalar_fn. name( ) )
98+ }
99+ }
100+
101+ impl ExpressionConvertor for DefaultExpressionConvertor {
102+ fn can_be_pushed_down ( & self , expr : & Arc < dyn PhysicalExpr > , schema : & Schema ) -> bool {
103+ can_be_pushed_down ( expr, schema)
104+ }
105+
106+ fn convert ( & self , df : & dyn PhysicalExpr ) -> VortexResult < Expression > {
107+ // TODO(joe): Don't return an error when we have an unsupported node, bubble up "TRUE" as in keep
108+ // for that node, up to any `and` or `or` node.
56109 if let Some ( binary_expr) = df. as_any ( ) . downcast_ref :: < df_expr:: BinaryExpr > ( ) {
57- let left = Expression :: try_from_df ( binary_expr. left ( ) . as_ref ( ) ) ?;
58- let right = Expression :: try_from_df ( binary_expr. right ( ) . as_ref ( ) ) ?;
59- let operator = Operator :: try_from_df ( binary_expr. op ( ) ) ?;
110+ let left = self . convert ( binary_expr. left ( ) . as_ref ( ) ) ?;
111+ let right = self . convert ( binary_expr. right ( ) . as_ref ( ) ) ?;
112+ let operator = try_operator_from_df ( binary_expr. op ( ) ) ?;
60113
61114 return Ok ( Binary . new_expr ( operator, [ left, right] ) ) ;
62115 }
@@ -66,8 +119,8 @@ impl TryFromDataFusion<dyn PhysicalExpr> for Expression {
66119 }
67120
68121 if let Some ( like) = df. as_any ( ) . downcast_ref :: < df_expr:: LikeExpr > ( ) {
69- let child = Expression :: try_from_df ( like. expr ( ) . as_ref ( ) ) ?;
70- let pattern = Expression :: try_from_df ( like. pattern ( ) . as_ref ( ) ) ?;
122+ let child = self . convert ( like. expr ( ) . as_ref ( ) ) ?;
123+ let pattern = self . convert ( like. pattern ( ) . as_ref ( ) ) ?;
71124 return Ok ( Like . new_expr (
72125 LikeOptions {
73126 negated : like. negated ( ) ,
@@ -84,30 +137,30 @@ impl TryFromDataFusion<dyn PhysicalExpr> for Expression {
84137
85138 if let Some ( cast_expr) = df. as_any ( ) . downcast_ref :: < df_expr:: CastExpr > ( ) {
86139 let cast_dtype = DType :: from_arrow ( ( cast_expr. cast_type ( ) , Nullability :: Nullable ) ) ;
87- let child = Expression :: try_from_df ( cast_expr. expr ( ) . as_ref ( ) ) ?;
140+ let child = self . convert ( cast_expr. expr ( ) . as_ref ( ) ) ?;
88141 return Ok ( cast ( child, cast_dtype) ) ;
89142 }
90143
91144 if let Some ( cast_col_expr) = df. as_any ( ) . downcast_ref :: < df_expr:: CastColumnExpr > ( ) {
92145 let target = cast_col_expr. target_field ( ) ;
93146
94147 let target_dtype = DType :: from_arrow ( ( target. data_type ( ) , target. is_nullable ( ) . into ( ) ) ) ;
95- let child = Expression :: try_from_df ( cast_col_expr. expr ( ) . as_ref ( ) ) ?;
148+ let child = self . convert ( cast_col_expr. expr ( ) . as_ref ( ) ) ?;
96149 return Ok ( cast ( child, target_dtype) ) ;
97150 }
98151
99152 if let Some ( is_null_expr) = df. as_any ( ) . downcast_ref :: < df_expr:: IsNullExpr > ( ) {
100- let arg = Expression :: try_from_df ( is_null_expr. arg ( ) . as_ref ( ) ) ?;
153+ let arg = self . convert ( is_null_expr. arg ( ) . as_ref ( ) ) ?;
101154 return Ok ( is_null ( arg) ) ;
102155 }
103156
104157 if let Some ( is_not_null_expr) = df. as_any ( ) . downcast_ref :: < df_expr:: IsNotNullExpr > ( ) {
105- let arg = Expression :: try_from_df ( is_not_null_expr. arg ( ) . as_ref ( ) ) ?;
158+ let arg = self . convert ( is_not_null_expr. arg ( ) . as_ref ( ) ) ?;
106159 return Ok ( not ( is_null ( arg) ) ) ;
107160 }
108161
109162 if let Some ( in_list) = df. as_any ( ) . downcast_ref :: < df_expr:: InListExpr > ( ) {
110- let value = Expression :: try_from_df ( in_list. expr ( ) . as_ref ( ) ) ?;
163+ let value = self . convert ( in_list. expr ( ) . as_ref ( ) ) ?;
111164 let list_elements: Vec < _ > = in_list
112165 . list ( )
113166 . iter ( )
@@ -131,94 +184,59 @@ impl TryFromDataFusion<dyn PhysicalExpr> for Expression {
131184 }
132185
133186 if let Some ( scalar_fn) = df. as_any ( ) . downcast_ref :: < ScalarFunctionExpr > ( ) {
134- return try_convert_scalar_function ( scalar_fn) ;
187+ return self . try_convert_scalar_function ( scalar_fn) ;
135188 }
136189
137190 vortex_bail ! ( "Couldn't convert DataFusion physical {df} expression to a vortex expression" )
138191 }
139192}
140193
141- /// Attempts to convert a DataFusion ScalarFunctionExpr to a Vortex expression.
142- fn try_convert_scalar_function ( scalar_fn : & ScalarFunctionExpr ) -> VortexResult < Expression > {
143- if let Some ( get_field_fn) = ScalarFunctionExpr :: try_downcast_func :: < GetFieldFunc > ( scalar_fn) {
144- let source_expr = get_field_fn
145- . args ( )
146- . first ( )
147- . ok_or_else ( || vortex_err ! ( "get_field missing source expression" ) ) ?
148- . as_ref ( ) ;
149- let field_name_expr = get_field_fn
150- . args ( )
151- . get ( 1 )
152- . ok_or_else ( || vortex_err ! ( "get_field missing field name argument" ) ) ?;
153- let field_name = field_name_expr
154- . as_any ( )
155- . downcast_ref :: < df_expr:: Literal > ( )
156- . ok_or_else ( || vortex_err ! ( "get_field field name must be a literal" ) ) ?
157- . value ( )
158- . try_as_str ( )
159- . flatten ( )
160- . ok_or_else ( || vortex_err ! ( "get_field field name must be a UTF-8 string" ) ) ?;
161- return Ok ( get_item (
162- field_name. to_string ( ) ,
163- Expression :: try_from_df ( source_expr) ?,
164- ) ) ;
165- }
166-
167- tracing:: debug!(
168- function_name = scalar_fn. name( ) ,
169- "Unsupported ScalarFunctionExpr"
170- ) ;
171- vortex_bail ! ( "Unsupported ScalarFunctionExpr: {}" , scalar_fn. name( ) )
172- }
173-
174- impl TryFromDataFusion < DFOperator > for Operator {
175- fn try_from_df ( value : & DFOperator ) -> VortexResult < Self > {
176- match value {
177- DFOperator :: Eq => Ok ( Operator :: Eq ) ,
178- DFOperator :: NotEq => Ok ( Operator :: NotEq ) ,
179- DFOperator :: Lt => Ok ( Operator :: Lt ) ,
180- DFOperator :: LtEq => Ok ( Operator :: Lte ) ,
181- DFOperator :: Gt => Ok ( Operator :: Gt ) ,
182- DFOperator :: GtEq => Ok ( Operator :: Gte ) ,
183- DFOperator :: And => Ok ( Operator :: And ) ,
184- DFOperator :: Or => Ok ( Operator :: Or ) ,
185- DFOperator :: Plus => Ok ( Operator :: Add ) ,
186- DFOperator :: Minus => Ok ( Operator :: Sub ) ,
187- DFOperator :: Multiply => Ok ( Operator :: Mul ) ,
188- DFOperator :: Divide => Ok ( Operator :: Div ) ,
189- DFOperator :: IsDistinctFrom
190- | DFOperator :: IsNotDistinctFrom
191- | DFOperator :: RegexMatch
192- | DFOperator :: RegexIMatch
193- | DFOperator :: RegexNotMatch
194- | DFOperator :: RegexNotIMatch
195- | DFOperator :: LikeMatch
196- | DFOperator :: ILikeMatch
197- | DFOperator :: NotLikeMatch
198- | DFOperator :: NotILikeMatch
199- | DFOperator :: BitwiseAnd
200- | DFOperator :: BitwiseOr
201- | DFOperator :: BitwiseXor
202- | DFOperator :: BitwiseShiftRight
203- | DFOperator :: BitwiseShiftLeft
204- | DFOperator :: StringConcat
205- | DFOperator :: AtArrow
206- | DFOperator :: ArrowAt
207- | DFOperator :: Modulo
208- | DFOperator :: Arrow
209- | DFOperator :: LongArrow
210- | DFOperator :: HashArrow
211- | DFOperator :: HashLongArrow
212- | DFOperator :: AtAt
213- | DFOperator :: IntegerDivide
214- | DFOperator :: HashMinus
215- | DFOperator :: AtQuestion
216- | DFOperator :: Question
217- | DFOperator :: QuestionAnd
218- | DFOperator :: QuestionPipe => {
219- tracing:: debug!( operator = %value, "Can't pushdown binary_operator operator" ) ;
220- Err ( vortex_err ! ( "Unsupported datafusion operator {value}" ) )
221- }
194+ fn try_operator_from_df ( value : & DFOperator ) -> VortexResult < Operator > {
195+ match value {
196+ DFOperator :: Eq => Ok ( Operator :: Eq ) ,
197+ DFOperator :: NotEq => Ok ( Operator :: NotEq ) ,
198+ DFOperator :: Lt => Ok ( Operator :: Lt ) ,
199+ DFOperator :: LtEq => Ok ( Operator :: Lte ) ,
200+ DFOperator :: Gt => Ok ( Operator :: Gt ) ,
201+ DFOperator :: GtEq => Ok ( Operator :: Gte ) ,
202+ DFOperator :: And => Ok ( Operator :: And ) ,
203+ DFOperator :: Or => Ok ( Operator :: Or ) ,
204+ DFOperator :: Plus => Ok ( Operator :: Add ) ,
205+ DFOperator :: Minus => Ok ( Operator :: Sub ) ,
206+ DFOperator :: Multiply => Ok ( Operator :: Mul ) ,
207+ DFOperator :: Divide => Ok ( Operator :: Div ) ,
208+ DFOperator :: IsDistinctFrom
209+ | DFOperator :: IsNotDistinctFrom
210+ | DFOperator :: RegexMatch
211+ | DFOperator :: RegexIMatch
212+ | DFOperator :: RegexNotMatch
213+ | DFOperator :: RegexNotIMatch
214+ | DFOperator :: LikeMatch
215+ | DFOperator :: ILikeMatch
216+ | DFOperator :: NotLikeMatch
217+ | DFOperator :: NotILikeMatch
218+ | DFOperator :: BitwiseAnd
219+ | DFOperator :: BitwiseOr
220+ | DFOperator :: BitwiseXor
221+ | DFOperator :: BitwiseShiftRight
222+ | DFOperator :: BitwiseShiftLeft
223+ | DFOperator :: StringConcat
224+ | DFOperator :: AtArrow
225+ | DFOperator :: ArrowAt
226+ | DFOperator :: Modulo
227+ | DFOperator :: Arrow
228+ | DFOperator :: LongArrow
229+ | DFOperator :: HashArrow
230+ | DFOperator :: HashLongArrow
231+ | DFOperator :: AtAt
232+ | DFOperator :: IntegerDivide
233+ | DFOperator :: HashMinus
234+ | DFOperator :: AtQuestion
235+ | DFOperator :: Question
236+ | DFOperator :: QuestionAnd
237+ | DFOperator :: QuestionPipe => {
238+ tracing:: debug!( operator = %value, "Can't pushdown binary_operator operator" ) ;
239+ Err ( vortex_err ! ( "Unsupported datafusion operator {value}" ) )
222240 }
223241 }
224242}
@@ -262,7 +280,7 @@ pub(crate) fn can_be_pushed_down(df_expr: &Arc<dyn PhysicalExpr>, schema: &Schem
262280}
263281
264282fn can_binary_be_pushed_down ( binary : & df_expr:: BinaryExpr , schema : & Schema ) -> bool {
265- let is_op_supported = Operator :: try_from_df ( binary. op ( ) ) . is_ok ( ) ;
283+ let is_op_supported = try_operator_from_df ( binary. op ( ) ) . is_ok ( ) ;
266284 is_op_supported
267285 && can_be_pushed_down ( binary. left ( ) , schema)
268286 && can_be_pushed_down ( binary. right ( ) , schema)
@@ -320,8 +338,6 @@ mod tests {
320338 use datafusion_physical_plan:: expressions as df_expr;
321339 use insta:: assert_snapshot;
322340 use rstest:: rstest;
323- use vortex:: expr:: Expression ;
324- use vortex:: expr:: Operator ;
325341
326342 use super :: * ;
327343
@@ -347,22 +363,25 @@ mod tests {
347363
348364 #[ test]
349365 fn test_make_vortex_predicate_empty ( ) {
350- let result = make_vortex_predicate ( & [ ] ) . unwrap ( ) ;
366+ let expr_convertor = DefaultExpressionConvertor :: default ( ) ;
367+ let result = make_vortex_predicate ( & expr_convertor, & [ ] ) . unwrap ( ) ;
351368 assert ! ( result. is_none( ) ) ;
352369 }
353370
354371 #[ test]
355372 fn test_make_vortex_predicate_single ( ) {
373+ let expr_convertor = DefaultExpressionConvertor :: default ( ) ;
356374 let col_expr = Arc :: new ( df_expr:: Column :: new ( "test" , 0 ) ) as Arc < dyn PhysicalExpr > ;
357- let result = make_vortex_predicate ( & [ col_expr] ) . unwrap ( ) ;
375+ let result = make_vortex_predicate ( & expr_convertor , & [ col_expr] ) . unwrap ( ) ;
358376 assert ! ( result. is_some( ) ) ;
359377 }
360378
361379 #[ test]
362380 fn test_make_vortex_predicate_multiple ( ) {
381+ let expr_convertor = DefaultExpressionConvertor :: default ( ) ;
363382 let col1 = Arc :: new ( df_expr:: Column :: new ( "col1" , 0 ) ) as Arc < dyn PhysicalExpr > ;
364383 let col2 = Arc :: new ( df_expr:: Column :: new ( "col2" , 1 ) ) as Arc < dyn PhysicalExpr > ;
365- let result = make_vortex_predicate ( & [ col1, col2] ) . unwrap ( ) ;
384+ let result = make_vortex_predicate ( & expr_convertor , & [ col1, col2] ) . unwrap ( ) ;
366385 assert ! ( result. is_some( ) ) ;
367386 // Result should be an AND expression combining the two columns
368387 }
@@ -384,7 +403,7 @@ mod tests {
384403 #[ case] df_op : DFOperator ,
385404 #[ case] expected_vortex_op : Operator ,
386405 ) {
387- let result = Operator :: try_from_df ( & df_op) . unwrap ( ) ;
406+ let result = try_operator_from_df ( & df_op) . unwrap ( ) ;
388407 assert_eq ! ( result, expected_vortex_op) ;
389408 }
390409
@@ -394,7 +413,7 @@ mod tests {
394413 #[ case:: regex_match( DFOperator :: RegexMatch ) ]
395414 #[ case:: like_match( DFOperator :: LikeMatch ) ]
396415 fn test_operator_conversion_unsupported ( #[ case] df_op : DFOperator ) {
397- let result = Operator :: try_from_df ( & df_op) ;
416+ let result = try_operator_from_df ( & df_op) ;
398417 assert ! ( result. is_err( ) ) ;
399418 assert ! (
400419 result
@@ -407,7 +426,9 @@ mod tests {
407426 #[ test]
408427 fn test_expr_from_df_column ( ) {
409428 let col_expr = df_expr:: Column :: new ( "test_column" , 0 ) ;
410- let result = Expression :: try_from_df ( & col_expr) . unwrap ( ) ;
429+ let result = DefaultExpressionConvertor :: default ( )
430+ . convert ( & col_expr)
431+ . unwrap ( ) ;
411432
412433 assert_snapshot ! ( result. display_tree( ) . to_string( ) , @r"
413434 vortex.get_item(test_column)
@@ -418,7 +439,9 @@ mod tests {
418439 #[ test]
419440 fn test_expr_from_df_literal ( ) {
420441 let literal_expr = df_expr:: Literal :: new ( ScalarValue :: Int32 ( Some ( 42 ) ) ) ;
421- let result = Expression :: try_from_df ( & literal_expr) . unwrap ( ) ;
442+ let result = DefaultExpressionConvertor :: default ( )
443+ . convert ( & literal_expr)
444+ . unwrap ( ) ;
422445
423446 assert_snapshot ! ( result. display_tree( ) . to_string( ) , @"vortex.literal(42i32)" ) ;
424447 }
@@ -430,7 +453,9 @@ mod tests {
430453 Arc :: new ( df_expr:: Literal :: new ( ScalarValue :: Int32 ( Some ( 42 ) ) ) ) as Arc < dyn PhysicalExpr > ;
431454 let binary_expr = df_expr:: BinaryExpr :: new ( left, DFOperator :: Eq , right) ;
432455
433- let result = Expression :: try_from_df ( & binary_expr) . unwrap ( ) ;
456+ let result = DefaultExpressionConvertor :: default ( )
457+ . convert ( & binary_expr)
458+ . unwrap ( ) ;
434459
435460 assert_snapshot ! ( result. display_tree( ) . to_string( ) , @r"
436461 vortex.binary(=)
@@ -452,7 +477,9 @@ mod tests {
452477 ) ) ) ) as Arc < dyn PhysicalExpr > ;
453478 let like_expr = df_expr:: LikeExpr :: new ( negated, case_insensitive, expr, pattern) ;
454479
455- let result = Expression :: try_from_df ( & like_expr) . unwrap ( ) ;
480+ let result = DefaultExpressionConvertor :: default ( )
481+ . convert ( & like_expr)
482+ . unwrap ( ) ;
456483 let like_opts = result. as_ :: < Like > ( ) ;
457484 assert_eq ! (
458485 like_opts,
0 commit comments