@@ -5,20 +5,24 @@ use std::sync::Arc;
55
66use arrow_schema:: DataType ;
77use arrow_schema:: Schema ;
8+ use datafusion_common:: Result as DFResult ;
9+ use datafusion_common:: exec_datafusion_err;
10+ use datafusion_common:: tree_node:: TreeNode ;
11+ use datafusion_common:: tree_node:: TreeNodeRecursion ;
812use datafusion_expr:: Operator as DFOperator ;
913use datafusion_functions:: core:: getfield:: GetFieldFunc ;
1014use datafusion_physical_expr:: PhysicalExpr ;
1115use datafusion_physical_expr:: ScalarFunctionExpr ;
16+ use datafusion_physical_expr:: projection:: ProjectionExpr ;
17+ use datafusion_physical_expr:: projection:: ProjectionExprs ;
18+ use datafusion_physical_expr:: utils:: collect_columns;
1219use datafusion_physical_expr_common:: physical_expr:: is_dynamic_physical_expr;
1320use datafusion_physical_plan:: expressions as df_expr;
1421use itertools:: Itertools ;
1522use vortex:: compute:: LikeOptions ;
1623use vortex:: dtype:: DType ;
1724use vortex:: dtype:: Nullability ;
1825use vortex:: dtype:: arrow:: FromArrowType ;
19- use vortex:: error:: VortexResult ;
20- use vortex:: error:: vortex_bail;
21- use vortex:: error:: vortex_err;
2226use vortex:: expr:: Binary ;
2327use vortex:: expr:: Expression ;
2428use vortex:: expr:: Like ;
@@ -31,20 +35,27 @@ use vortex::expr::is_null;
3135use vortex:: expr:: list_contains;
3236use vortex:: expr:: lit;
3337use vortex:: expr:: not;
38+ use vortex:: expr:: pack;
3439use vortex:: expr:: root;
3540use vortex:: scalar:: Scalar ;
3641
3742use crate :: convert:: FromDataFusion ;
3843
44+ /// Result of splitting a projection into Vortex expressions and leftover DataFusion projections.
45+ pub struct ProcessedProjection {
46+ pub scan_projection : Expression ,
47+ pub leftover_projection : ProjectionExprs ,
48+ }
49+
3950/// Tries to convert the expressions into a vortex conjunction. Will return Ok(None) iff the input conjunction is empty.
4051pub ( crate ) fn make_vortex_predicate (
4152 expr_convertor : & dyn ExpressionConvertor ,
4253 predicate : & [ Arc < dyn PhysicalExpr > ] ,
43- ) -> VortexResult < Option < Expression > > {
54+ ) -> DFResult < Option < Expression > > {
4455 let exprs = predicate
4556 . iter ( )
4657 . map ( |e| expr_convertor. convert ( e. as_ref ( ) ) )
47- . collect :: < VortexResult < Vec < _ > > > ( ) ?;
58+ . collect :: < DFResult < Vec < _ > > > ( ) ?;
4859
4960 Ok ( exprs. into_iter ( ) . reduce ( and) )
5061}
@@ -55,7 +66,16 @@ pub trait ExpressionConvertor: Send + Sync {
5566 fn can_be_pushed_down ( & self , expr : & Arc < dyn PhysicalExpr > , schema : & Schema ) -> bool ;
5667
5768 /// Try and convert a DataFusion [`PhysicalExpr`] into a Vortex [`Expression`].
58- fn convert ( & self , expr : & dyn PhysicalExpr ) -> VortexResult < Expression > ;
69+ fn convert ( & self , expr : & dyn PhysicalExpr ) -> DFResult < Expression > ;
70+
71+ /// Split a projection into Vortex expressions that can be pushed down and leftover
72+ /// DataFusion projections that need to be evaluated after the scan.
73+ fn split_projection (
74+ & self ,
75+ source_projection : ProjectionExprs ,
76+ input_schema : & Schema ,
77+ output_schema : & Schema ,
78+ ) -> DFResult < ProcessedProjection > ;
5979}
6080
6181/// The default [`ExpressionConvertor`].
@@ -64,37 +84,35 @@ pub struct DefaultExpressionConvertor {}
6484
6585impl DefaultExpressionConvertor {
6686 /// Attempts to convert a DataFusion ScalarFunctionExpr to a Vortex expression.
67- fn try_convert_scalar_function (
68- & self ,
69- scalar_fn : & ScalarFunctionExpr ,
70- ) -> VortexResult < Expression > {
87+ fn try_convert_scalar_function ( & self , scalar_fn : & ScalarFunctionExpr ) -> DFResult < Expression > {
7188 if let Some ( get_field_fn) = ScalarFunctionExpr :: try_downcast_func :: < GetFieldFunc > ( scalar_fn)
7289 {
7390 let source_expr = get_field_fn
7491 . args ( )
7592 . first ( )
76- . ok_or_else ( || vortex_err ! ( "get_field missing source expression" ) ) ?
93+ . ok_or_else ( || exec_datafusion_err ! ( "get_field missing source expression" ) ) ?
7794 . as_ref ( ) ;
7895 let field_name_expr = get_field_fn
7996 . args ( )
8097 . get ( 1 )
81- . ok_or_else ( || vortex_err ! ( "get_field missing field name argument" ) ) ?;
98+ . ok_or_else ( || exec_datafusion_err ! ( "get_field missing field name argument" ) ) ?;
8299 let field_name = field_name_expr
83100 . as_any ( )
84101 . downcast_ref :: < df_expr:: Literal > ( )
85- . ok_or_else ( || vortex_err ! ( "get_field field name must be a literal" ) ) ?
102+ . ok_or_else ( || exec_datafusion_err ! ( "get_field field name must be a literal" ) ) ?
86103 . value ( )
87104 . try_as_str ( )
88105 . flatten ( )
89- . ok_or_else ( || vortex_err ! ( "get_field field name must be a UTF-8 string" ) ) ?;
106+ . ok_or_else ( || {
107+ exec_datafusion_err ! ( "get_field field name must be a UTF-8 string" )
108+ } ) ?;
90109 return Ok ( get_item ( field_name. to_string ( ) , self . convert ( source_expr) ?) ) ;
91110 }
92111
93- tracing:: debug!(
94- function_name = scalar_fn. name( ) ,
95- "Unsupported ScalarFunctionExpr"
96- ) ;
97- vortex_bail ! ( "Unsupported ScalarFunctionExpr: {}" , scalar_fn. name( ) )
112+ Err ( exec_datafusion_err ! (
113+ "Unsupported ScalarFunctionExpr: {}" ,
114+ scalar_fn. name( )
115+ ) )
98116 }
99117}
100118
@@ -103,7 +121,7 @@ impl ExpressionConvertor for DefaultExpressionConvertor {
103121 can_be_pushed_down_impl ( expr, schema)
104122 }
105123
106- fn convert ( & self , df : & dyn PhysicalExpr ) -> VortexResult < Expression > {
124+ fn convert ( & self , df : & dyn PhysicalExpr ) -> DFResult < Expression > {
107125 // TODO(joe): Don't return an error when we have an unsupported node, bubble up "TRUE" as in keep
108126 // for that node, up to any `and` or `or` node.
109127 if let Some ( binary_expr) = df. as_any ( ) . downcast_ref :: < df_expr:: BinaryExpr > ( ) {
@@ -168,7 +186,7 @@ impl ExpressionConvertor for DefaultExpressionConvertor {
168186 if let Some ( lit) = e. as_any ( ) . downcast_ref :: < df_expr:: Literal > ( ) {
169187 Ok ( Scalar :: from_df ( lit. value ( ) ) )
170188 } else {
171- Err ( vortex_err ! ( "Failed to cast sub-expression" ) )
189+ Err ( exec_datafusion_err ! ( "Failed to cast sub-expression" ) )
172190 }
173191 } )
174192 . try_collect ( ) ?;
@@ -187,11 +205,93 @@ impl ExpressionConvertor for DefaultExpressionConvertor {
187205 return self . try_convert_scalar_function ( scalar_fn) ;
188206 }
189207
190- vortex_bail ! ( "Couldn't convert DataFusion physical {df} expression to a vortex expression" )
208+ Err ( exec_datafusion_err ! (
209+ "Couldn't convert DataFusion physical {df} expression to a vortex expression"
210+ ) )
211+ }
212+
213+ fn split_projection (
214+ & self ,
215+ source_projection : ProjectionExprs ,
216+ input_schema : & Schema ,
217+ output_schema : & Schema ,
218+ ) -> DFResult < ProcessedProjection > {
219+ let mut scan_projection = vec ! [ ] ;
220+ let mut leftover_projection: Vec < ProjectionExpr > = vec ! [ ] ;
221+
222+ for projection_expr in source_projection. iter ( ) {
223+ let r = projection_expr. expr . apply ( |node| {
224+ // We only pull column children of scalar functions that we can't push into the scan.
225+ if let Some ( scalar_fn_expr) = node. as_any ( ) . downcast_ref :: < ScalarFunctionExpr > ( )
226+ && !can_scalar_fn_be_pushed_down ( scalar_fn_expr)
227+ {
228+ scan_projection. extend (
229+ collect_columns ( node)
230+ . into_iter ( )
231+ . map ( |c| ( c. name ( ) . to_string ( ) , get_item ( c. name ( ) , root ( ) ) ) ) ,
232+ ) ;
233+
234+ leftover_projection. push ( projection_expr. clone ( ) ) ;
235+ return Ok ( TreeNodeRecursion :: Stop ) ;
236+ }
237+
238+ // If the projection contains a `CastColumnExpr` that casts to physical types that don't have a 1:1 mapping
239+ // with Vortex's types system, we make sure to pull the input from the file and keep the projection
240+ if let Some ( cast_expr) = node. as_any ( ) . downcast_ref :: < df_expr:: CastColumnExpr > ( )
241+ && is_dtype_incompatible ( cast_expr. target_field ( ) . data_type ( ) )
242+ {
243+ scan_projection. push ( (
244+ cast_expr. input_field ( ) . name ( ) . clone ( ) ,
245+ self . convert ( cast_expr. expr ( ) . as_ref ( ) ) ?,
246+ ) ) ;
247+ leftover_projection. push ( projection_expr. clone ( ) ) ;
248+ return Ok ( TreeNodeRecursion :: Stop ) ;
249+ }
250+
251+ // DataFusion assumes different decimal types can be coerced.
252+ // Vortex expects a perfect match so we don't push it down.
253+ if let Some ( binary_expr) = node. as_any ( ) . downcast_ref :: < df_expr:: BinaryExpr > ( )
254+ && binary_expr. op ( ) . is_numerical_operators ( )
255+ && ( is_decimal ( & binary_expr. left ( ) . data_type ( input_schema) ?)
256+ && is_decimal ( & binary_expr. right ( ) . data_type ( input_schema) ?) )
257+ {
258+ scan_projection. extend (
259+ collect_columns ( node)
260+ . into_iter ( )
261+ . map ( |c| ( c. name ( ) . to_string ( ) , get_item ( c. name ( ) , root ( ) ) ) ) ,
262+ ) ;
263+
264+ leftover_projection. push ( projection_expr. clone ( ) ) ;
265+ return Ok ( TreeNodeRecursion :: Stop ) ;
266+ }
267+
268+ Ok ( TreeNodeRecursion :: Continue )
269+ } ) ?;
270+
271+ // if we didn't stop early
272+ if matches ! ( r, TreeNodeRecursion :: Continue ) {
273+ scan_projection. push ( (
274+ projection_expr. alias . clone ( ) ,
275+ self . convert ( projection_expr. expr . as_ref ( ) ) ?,
276+ ) ) ;
277+ leftover_projection. push ( ProjectionExpr {
278+ expr : Arc :: new ( df_expr:: Column :: new_with_schema (
279+ projection_expr. alias . as_str ( ) ,
280+ output_schema,
281+ ) ?) ,
282+ alias : projection_expr. alias . clone ( ) ,
283+ } ) ;
284+ }
285+ }
286+
287+ Ok ( ProcessedProjection {
288+ scan_projection : pack ( scan_projection, Nullability :: NonNullable ) ,
289+ leftover_projection : leftover_projection. into ( ) ,
290+ } )
191291 }
192292}
193293
194- fn try_operator_from_df ( value : & DFOperator ) -> VortexResult < Operator > {
294+ fn try_operator_from_df ( value : & DFOperator ) -> DFResult < Operator > {
195295 match value {
196296 DFOperator :: Eq => Ok ( Operator :: Eq ) ,
197297 DFOperator :: NotEq => Ok ( Operator :: NotEq ) ,
@@ -236,7 +336,9 @@ fn try_operator_from_df(value: &DFOperator) -> VortexResult<Operator> {
236336 | DFOperator :: QuestionAnd
237337 | DFOperator :: QuestionPipe => {
238338 tracing:: debug!( operator = %value, "Can't pushdown binary_operator operator" ) ;
239- Err ( vortex_err ! ( "Unsupported datafusion operator {value}" ) )
339+ Err ( exec_datafusion_err ! (
340+ "Unsupported datafusion operator {value}"
341+ ) )
240342 }
241343 }
242344}
@@ -328,6 +430,35 @@ fn can_scalar_fn_be_pushed_down(scalar_fn: &ScalarFunctionExpr) -> bool {
328430 ScalarFunctionExpr :: try_downcast_func :: < GetFieldFunc > ( scalar_fn) . is_some ( )
329431}
330432
433+ fn is_dtype_incompatible ( dt : & DataType ) -> bool {
434+ use DataType :: * ;
435+
436+ dt. is_run_ends_type ( )
437+ || is_decimal ( dt)
438+ || matches ! (
439+ dt,
440+ Dictionary ( ..)
441+ | Utf8
442+ | LargeUtf8
443+ | Binary
444+ | LargeBinary
445+ | FixedSizeBinary ( _)
446+ | FixedSizeList ( ..)
447+ | ListView ( ..)
448+ | LargeListView ( ..)
449+ )
450+ }
451+
452+ fn is_decimal ( dt : & DataType ) -> bool {
453+ matches ! (
454+ dt,
455+ DataType :: Decimal32 ( _, _)
456+ | DataType :: Decimal64 ( _, _)
457+ | DataType :: Decimal128 ( _, _)
458+ | DataType :: Decimal256 ( _, _)
459+ )
460+ }
461+
331462#[ cfg( test) ]
332463mod tests {
333464 use std:: sync:: Arc ;
0 commit comments