@@ -5,20 +5,23 @@ use std::sync::Arc;
55
66use arrow_schema:: DataType ;
77use arrow_schema:: Schema ;
8+ use datafusion_common:: Result as DFResult ;
9+ use datafusion_common:: exec_datafusion_err;
10+ use datafusion_common:: tree_node:: TreeNode ;
11+ use datafusion_common:: tree_node:: TreeNodeRecursion ;
812use datafusion_expr:: Operator as DFOperator ;
913use datafusion_functions:: core:: getfield:: GetFieldFunc ;
1014use datafusion_physical_expr:: PhysicalExpr ;
1115use datafusion_physical_expr:: ScalarFunctionExpr ;
16+ use datafusion_physical_expr:: projection:: ProjectionExpr ;
17+ use datafusion_physical_expr:: projection:: ProjectionExprs ;
1218use datafusion_physical_expr_common:: physical_expr:: is_dynamic_physical_expr;
1319use datafusion_physical_plan:: expressions as df_expr;
1420use itertools:: Itertools ;
1521use vortex:: compute:: LikeOptions ;
1622use vortex:: dtype:: DType ;
1723use vortex:: dtype:: Nullability ;
1824use vortex:: dtype:: arrow:: FromArrowType ;
19- use vortex:: error:: VortexResult ;
20- use vortex:: error:: vortex_bail;
21- use vortex:: error:: vortex_err;
2225use vortex:: expr:: Binary ;
2326use vortex:: expr:: Expression ;
2427use vortex:: expr:: Like ;
@@ -31,20 +34,27 @@ use vortex::expr::is_null;
3134use vortex:: expr:: list_contains;
3235use vortex:: expr:: lit;
3336use vortex:: expr:: not;
37+ use vortex:: expr:: pack;
3438use vortex:: expr:: root;
3539use vortex:: scalar:: Scalar ;
3640
3741use crate :: convert:: FromDataFusion ;
3842
43+ /// Result of splitting a projection into Vortex expressions and leftover DataFusion projections.
44+ pub struct ProcessedProjection {
45+ pub scan_projection : Expression ,
46+ pub leftover_projection : ProjectionExprs ,
47+ }
48+
3949/// Tries to convert the expressions into a vortex conjunction. Will return Ok(None) iff the input conjunction is empty.
4050pub ( crate ) fn make_vortex_predicate (
4151 expr_convertor : & dyn ExpressionConvertor ,
4252 predicate : & [ Arc < dyn PhysicalExpr > ] ,
43- ) -> VortexResult < Option < Expression > > {
53+ ) -> DFResult < Option < Expression > > {
4454 let exprs = predicate
4555 . iter ( )
4656 . map ( |e| expr_convertor. convert ( e. as_ref ( ) ) )
47- . collect :: < VortexResult < Vec < _ > > > ( ) ?;
57+ . collect :: < DFResult < Vec < _ > > > ( ) ?;
4858
4959 Ok ( exprs. into_iter ( ) . reduce ( and) )
5060}
@@ -55,7 +65,16 @@ pub trait ExpressionConvertor: Send + Sync {
5565 fn can_be_pushed_down ( & self , expr : & Arc < dyn PhysicalExpr > , schema : & Schema ) -> bool ;
5666
5767 /// Try and convert a DataFusion [`PhysicalExpr`] into a Vortex [`Expression`].
58- fn convert ( & self , expr : & dyn PhysicalExpr ) -> VortexResult < Expression > ;
68+ fn convert ( & self , expr : & dyn PhysicalExpr ) -> DFResult < Expression > ;
69+
70+ /// Split a projection into Vortex expressions that can be pushed down and leftover
71+ /// DataFusion projections that need to be evaluated after the scan.
72+ fn split_projection (
73+ & self ,
74+ source_projection : ProjectionExprs ,
75+ input_schema : & Schema ,
76+ output_schema : & Schema ,
77+ ) -> DFResult < ProcessedProjection > ;
5978}
6079
6180/// The default [`ExpressionConvertor`].
@@ -64,37 +83,35 @@ pub struct DefaultExpressionConvertor {}
6483
6584impl DefaultExpressionConvertor {
6685 /// Attempts to convert a DataFusion ScalarFunctionExpr to a Vortex expression.
67- fn try_convert_scalar_function (
68- & self ,
69- scalar_fn : & ScalarFunctionExpr ,
70- ) -> VortexResult < Expression > {
86+ fn try_convert_scalar_function ( & self , scalar_fn : & ScalarFunctionExpr ) -> DFResult < Expression > {
7187 if let Some ( get_field_fn) = ScalarFunctionExpr :: try_downcast_func :: < GetFieldFunc > ( scalar_fn)
7288 {
7389 let source_expr = get_field_fn
7490 . args ( )
7591 . first ( )
76- . ok_or_else ( || vortex_err ! ( "get_field missing source expression" ) ) ?
92+ . ok_or_else ( || exec_datafusion_err ! ( "get_field missing source expression" ) ) ?
7793 . as_ref ( ) ;
7894 let field_name_expr = get_field_fn
7995 . args ( )
8096 . get ( 1 )
81- . ok_or_else ( || vortex_err ! ( "get_field missing field name argument" ) ) ?;
97+ . ok_or_else ( || exec_datafusion_err ! ( "get_field missing field name argument" ) ) ?;
8298 let field_name = field_name_expr
8399 . as_any ( )
84100 . downcast_ref :: < df_expr:: Literal > ( )
85- . ok_or_else ( || vortex_err ! ( "get_field field name must be a literal" ) ) ?
101+ . ok_or_else ( || exec_datafusion_err ! ( "get_field field name must be a literal" ) ) ?
86102 . value ( )
87103 . try_as_str ( )
88104 . flatten ( )
89- . ok_or_else ( || vortex_err ! ( "get_field field name must be a UTF-8 string" ) ) ?;
105+ . ok_or_else ( || {
106+ exec_datafusion_err ! ( "get_field field name must be a UTF-8 string" )
107+ } ) ?;
90108 return Ok ( get_item ( field_name. to_string ( ) , self . convert ( source_expr) ?) ) ;
91109 }
92110
93- tracing:: debug!(
94- function_name = scalar_fn. name( ) ,
95- "Unsupported ScalarFunctionExpr"
96- ) ;
97- vortex_bail ! ( "Unsupported ScalarFunctionExpr: {}" , scalar_fn. name( ) )
111+ Err ( exec_datafusion_err ! (
112+ "Unsupported ScalarFunctionExpr: {}" ,
113+ scalar_fn. name( )
114+ ) )
98115 }
99116}
100117
@@ -103,7 +120,7 @@ impl ExpressionConvertor for DefaultExpressionConvertor {
103120 can_be_pushed_down_impl ( expr, schema)
104121 }
105122
106- fn convert ( & self , df : & dyn PhysicalExpr ) -> VortexResult < Expression > {
123+ fn convert ( & self , df : & dyn PhysicalExpr ) -> DFResult < Expression > {
107124 // TODO(joe): Don't return an error when we have an unsupported node, bubble up "TRUE" as in keep
108125 // for that node, up to any `and` or `or` node.
109126 if let Some ( binary_expr) = df. as_any ( ) . downcast_ref :: < df_expr:: BinaryExpr > ( ) {
@@ -168,7 +185,7 @@ impl ExpressionConvertor for DefaultExpressionConvertor {
168185 if let Some ( lit) = e. as_any ( ) . downcast_ref :: < df_expr:: Literal > ( ) {
169186 Ok ( Scalar :: from_df ( lit. value ( ) ) )
170187 } else {
171- Err ( vortex_err ! ( "Failed to cast sub-expression" ) )
188+ Err ( exec_datafusion_err ! ( "Failed to cast sub-expression" ) )
172189 }
173190 } )
174191 . try_collect ( ) ?;
@@ -187,11 +204,114 @@ impl ExpressionConvertor for DefaultExpressionConvertor {
187204 return self . try_convert_scalar_function ( scalar_fn) ;
188205 }
189206
190- vortex_bail ! ( "Couldn't convert DataFusion physical {df} expression to a vortex expression" )
207+ Err ( exec_datafusion_err ! (
208+ "Couldn't convert DataFusion physical {df} expression to a vortex expression"
209+ ) )
210+ }
211+
212+ fn split_projection (
213+ & self ,
214+ source_projection : ProjectionExprs ,
215+ input_schema : & Schema ,
216+ output_schema : & Schema ,
217+ ) -> DFResult < ProcessedProjection > {
218+ let mut scan_projection = vec ! [ ] ;
219+ let mut leftover_projection: Vec < ProjectionExpr > = vec ! [ ] ;
220+
221+ for projection_expr in source_projection. iter ( ) {
222+ let r = projection_expr. expr . apply ( |node| {
223+ // We only pull column children of scalar functions that we can't push into the scan.
224+ if let Some ( scalar_fn_expr) = node. as_any ( ) . downcast_ref :: < ScalarFunctionExpr > ( )
225+ && !can_scalar_fn_be_pushed_down ( scalar_fn_expr)
226+ {
227+ for col_expr in scalar_fn_expr
228+ . children ( )
229+ . iter ( )
230+ . filter_map ( |c| c. as_any ( ) . downcast_ref :: < df_expr:: Column > ( ) )
231+ {
232+ let name = col_expr. name ( ) . to_string ( ) ;
233+ let child = get_item ( col_expr. name ( ) , root ( ) ) ;
234+
235+ scan_projection. push ( ( name, child) ) ;
236+ }
237+
238+ leftover_projection. push ( projection_expr. clone ( ) ) ;
239+ return Ok ( TreeNodeRecursion :: Stop ) ;
240+ }
241+
242+ // If the projection contains a `CastColumnExpr` that casts to physical types that don't have a 1:1 mapping
243+ // with Vortex's types system, we make sure to pull the input from the file and keep the projection
244+ if let Some ( cast_expr) = node. as_any ( ) . downcast_ref :: < df_expr:: CastColumnExpr > ( )
245+ && is_dtype_incompatible ( cast_expr. target_field ( ) . data_type ( ) )
246+ {
247+ scan_projection. push ( (
248+ cast_expr. input_field ( ) . name ( ) . clone ( ) ,
249+ self . convert ( cast_expr. expr ( ) . as_ref ( ) ) ?,
250+ ) ) ;
251+ leftover_projection. push ( projection_expr. clone ( ) ) ;
252+ return Ok ( TreeNodeRecursion :: Stop ) ;
253+ }
254+
255+ // DataFusion assumes different decimal types can be coerced.
256+ // Vortex expects a perfect match so we don't push it down.
257+ if let Some ( binary_expr) = node. as_any ( ) . downcast_ref :: < df_expr:: BinaryExpr > ( )
258+ && binary_expr. op ( ) . is_numerical_operators ( )
259+ && ( is_decimal ( & binary_expr. left ( ) . data_type ( input_schema) ?)
260+ && is_decimal ( & binary_expr. right ( ) . data_type ( input_schema) ?) )
261+ {
262+ if let Some ( col) = binary_expr
263+ . left ( )
264+ . as_any ( )
265+ . downcast_ref :: < df_expr:: Column > ( )
266+ {
267+ let name = col. name ( ) . to_string ( ) ;
268+ let col_expr = self . convert ( col) ?;
269+ scan_projection. push ( ( name, col_expr) ) ;
270+ }
271+
272+ if let Some ( col) = binary_expr
273+ . right ( )
274+ . as_any ( )
275+ . downcast_ref :: < df_expr:: Column > ( )
276+ {
277+ let name = col. name ( ) . to_string ( ) ;
278+ let col_expr = self . convert ( col) ?;
279+ scan_projection. push ( ( name, col_expr) ) ;
280+ }
281+
282+ leftover_projection. push ( projection_expr. clone ( ) ) ;
283+ return Ok ( TreeNodeRecursion :: Stop ) ;
284+ }
285+
286+ Ok ( TreeNodeRecursion :: Continue )
287+ } ) ?;
288+
289+ match r {
290+ TreeNodeRecursion :: Continue => {
291+ scan_projection. push ( (
292+ projection_expr. alias . clone ( ) ,
293+ self . convert ( projection_expr. expr . as_ref ( ) ) ?,
294+ ) ) ;
295+ leftover_projection. push ( ProjectionExpr {
296+ expr : Arc :: new ( df_expr:: Column :: new_with_schema (
297+ projection_expr. alias . as_str ( ) ,
298+ output_schema,
299+ ) ?) ,
300+ alias : projection_expr. alias . clone ( ) ,
301+ } ) ;
302+ }
303+ TreeNodeRecursion :: Jump | TreeNodeRecursion :: Stop => { }
304+ }
305+ }
306+
307+ Ok ( ProcessedProjection {
308+ scan_projection : pack ( scan_projection, Nullability :: NonNullable ) ,
309+ leftover_projection : leftover_projection. into ( ) ,
310+ } )
191311 }
192312}
193313
194- fn try_operator_from_df ( value : & DFOperator ) -> VortexResult < Operator > {
314+ fn try_operator_from_df ( value : & DFOperator ) -> DFResult < Operator > {
195315 match value {
196316 DFOperator :: Eq => Ok ( Operator :: Eq ) ,
197317 DFOperator :: NotEq => Ok ( Operator :: NotEq ) ,
@@ -236,7 +356,9 @@ fn try_operator_from_df(value: &DFOperator) -> VortexResult<Operator> {
236356 | DFOperator :: QuestionAnd
237357 | DFOperator :: QuestionPipe => {
238358 tracing:: debug!( operator = %value, "Can't pushdown binary_operator operator" ) ;
239- Err ( vortex_err ! ( "Unsupported datafusion operator {value}" ) )
359+ Err ( exec_datafusion_err ! (
360+ "Unsupported datafusion operator {value}"
361+ ) )
240362 }
241363 }
242364}
@@ -328,6 +450,35 @@ fn can_scalar_fn_be_pushed_down(scalar_fn: &ScalarFunctionExpr) -> bool {
328450 ScalarFunctionExpr :: try_downcast_func :: < GetFieldFunc > ( scalar_fn) . is_some ( )
329451}
330452
453+ fn is_dtype_incompatible ( dt : & DataType ) -> bool {
454+ use DataType :: * ;
455+
456+ dt. is_run_ends_type ( )
457+ || is_decimal ( dt)
458+ || matches ! (
459+ dt,
460+ Dictionary ( ..)
461+ | Utf8
462+ | LargeUtf8
463+ | Binary
464+ | LargeBinary
465+ | FixedSizeBinary ( _)
466+ | FixedSizeList ( ..)
467+ | ListView ( ..)
468+ | LargeListView ( ..)
469+ )
470+ }
471+
472+ fn is_decimal ( dt : & DataType ) -> bool {
473+ matches ! (
474+ dt,
475+ DataType :: Decimal32 ( _, _)
476+ | DataType :: Decimal64 ( _, _)
477+ | DataType :: Decimal128 ( _, _)
478+ | DataType :: Decimal256 ( _, _)
479+ )
480+ }
481+
331482#[ cfg( test) ]
332483mod tests {
333484 use std:: sync:: Arc ;
0 commit comments