diff --git a/vortex-datafusion/src/convert/exprs.rs b/vortex-datafusion/src/convert/exprs.rs index b1726e23068..080a44dfc71 100644 --- a/vortex-datafusion/src/convert/exprs.rs +++ b/vortex-datafusion/src/convert/exprs.rs @@ -35,28 +35,81 @@ use vortex::expr::root; use vortex::scalar::Scalar; use crate::convert::FromDataFusion; -use crate::convert::TryFromDataFusion; /// Tries to convert the expressions into a vortex conjunction. Will return Ok(None) iff the input conjunction is empty. pub(crate) fn make_vortex_predicate( + expr_convertor: &dyn ExpressionConvertor, predicate: &[Arc], ) -> VortexResult> { let exprs = predicate .iter() - .map(|e| Expression::try_from_df(e.as_ref())) + .map(|e| expr_convertor.convert(e.as_ref())) .collect::>>()?; Ok(exprs.into_iter().reduce(and)) } -// TODO(joe): Don't return an error when we have an unsupported node, bubble up "TRUE" as in keep -// for that node, up to any `and` or `or` node. -impl TryFromDataFusion for Expression { - fn try_from_df(df: &dyn PhysicalExpr) -> VortexResult { +/// Trait for converting DataFusion expressions to Vortex ones. +pub trait ExpressionConvertor: Send + Sync { + /// Can an expression be pushed down given a specific schema + fn can_be_pushed_down(&self, expr: &Arc, schema: &Schema) -> bool; + + /// Try and convert a DataFusion [`PhysicalExpr`] into a Vortex [`Expression`]. + fn convert(&self, expr: &dyn PhysicalExpr) -> VortexResult; +} + +/// The default [`ExpressionConvertor`]. +#[derive(Default)] +pub struct DefaultExpressionConvertor {} + +impl DefaultExpressionConvertor { + /// Attempts to convert a DataFusion ScalarFunctionExpr to a Vortex expression. + fn try_convert_scalar_function( + &self, + scalar_fn: &ScalarFunctionExpr, + ) -> VortexResult { + if let Some(get_field_fn) = ScalarFunctionExpr::try_downcast_func::(scalar_fn) + { + let source_expr = get_field_fn + .args() + .first() + .ok_or_else(|| vortex_err!("get_field missing source expression"))? + .as_ref(); + let field_name_expr = get_field_fn + .args() + .get(1) + .ok_or_else(|| vortex_err!("get_field missing field name argument"))?; + let field_name = field_name_expr + .as_any() + .downcast_ref::() + .ok_or_else(|| vortex_err!("get_field field name must be a literal"))? + .value() + .try_as_str() + .flatten() + .ok_or_else(|| vortex_err!("get_field field name must be a UTF-8 string"))?; + return Ok(get_item(field_name.to_string(), self.convert(source_expr)?)); + } + + tracing::debug!( + function_name = scalar_fn.name(), + "Unsupported ScalarFunctionExpr" + ); + vortex_bail!("Unsupported ScalarFunctionExpr: {}", scalar_fn.name()) + } +} + +impl ExpressionConvertor for DefaultExpressionConvertor { + fn can_be_pushed_down(&self, expr: &Arc, schema: &Schema) -> bool { + can_be_pushed_down(expr, schema) + } + + fn convert(&self, df: &dyn PhysicalExpr) -> VortexResult { + // TODO(joe): Don't return an error when we have an unsupported node, bubble up "TRUE" as in keep + // for that node, up to any `and` or `or` node. if let Some(binary_expr) = df.as_any().downcast_ref::() { - let left = Expression::try_from_df(binary_expr.left().as_ref())?; - let right = Expression::try_from_df(binary_expr.right().as_ref())?; - let operator = Operator::try_from_df(binary_expr.op())?; + let left = self.convert(binary_expr.left().as_ref())?; + let right = self.convert(binary_expr.right().as_ref())?; + let operator = try_operator_from_df(binary_expr.op())?; return Ok(Binary.new_expr(operator, [left, right])); } @@ -66,8 +119,8 @@ impl TryFromDataFusion for Expression { } if let Some(like) = df.as_any().downcast_ref::() { - let child = Expression::try_from_df(like.expr().as_ref())?; - let pattern = Expression::try_from_df(like.pattern().as_ref())?; + let child = self.convert(like.expr().as_ref())?; + let pattern = self.convert(like.pattern().as_ref())?; return Ok(Like.new_expr( LikeOptions { negated: like.negated(), @@ -84,7 +137,7 @@ impl TryFromDataFusion for Expression { if let Some(cast_expr) = df.as_any().downcast_ref::() { let cast_dtype = DType::from_arrow((cast_expr.cast_type(), Nullability::Nullable)); - let child = Expression::try_from_df(cast_expr.expr().as_ref())?; + let child = self.convert(cast_expr.expr().as_ref())?; return Ok(cast(child, cast_dtype)); } @@ -92,22 +145,22 @@ impl TryFromDataFusion for Expression { let target = cast_col_expr.target_field(); let target_dtype = DType::from_arrow((target.data_type(), target.is_nullable().into())); - let child = Expression::try_from_df(cast_col_expr.expr().as_ref())?; + let child = self.convert(cast_col_expr.expr().as_ref())?; return Ok(cast(child, target_dtype)); } if let Some(is_null_expr) = df.as_any().downcast_ref::() { - let arg = Expression::try_from_df(is_null_expr.arg().as_ref())?; + let arg = self.convert(is_null_expr.arg().as_ref())?; return Ok(is_null(arg)); } if let Some(is_not_null_expr) = df.as_any().downcast_ref::() { - let arg = Expression::try_from_df(is_not_null_expr.arg().as_ref())?; + let arg = self.convert(is_not_null_expr.arg().as_ref())?; return Ok(not(is_null(arg))); } if let Some(in_list) = df.as_any().downcast_ref::() { - let value = Expression::try_from_df(in_list.expr().as_ref())?; + let value = self.convert(in_list.expr().as_ref())?; let list_elements: Vec<_> = in_list .list() .iter() @@ -131,94 +184,59 @@ impl TryFromDataFusion for Expression { } if let Some(scalar_fn) = df.as_any().downcast_ref::() { - return try_convert_scalar_function(scalar_fn); + return self.try_convert_scalar_function(scalar_fn); } vortex_bail!("Couldn't convert DataFusion physical {df} expression to a vortex expression") } } -/// Attempts to convert a DataFusion ScalarFunctionExpr to a Vortex expression. -fn try_convert_scalar_function(scalar_fn: &ScalarFunctionExpr) -> VortexResult { - if let Some(get_field_fn) = ScalarFunctionExpr::try_downcast_func::(scalar_fn) { - let source_expr = get_field_fn - .args() - .first() - .ok_or_else(|| vortex_err!("get_field missing source expression"))? - .as_ref(); - let field_name_expr = get_field_fn - .args() - .get(1) - .ok_or_else(|| vortex_err!("get_field missing field name argument"))?; - let field_name = field_name_expr - .as_any() - .downcast_ref::() - .ok_or_else(|| vortex_err!("get_field field name must be a literal"))? - .value() - .try_as_str() - .flatten() - .ok_or_else(|| vortex_err!("get_field field name must be a UTF-8 string"))?; - return Ok(get_item( - field_name.to_string(), - Expression::try_from_df(source_expr)?, - )); - } - - tracing::debug!( - function_name = scalar_fn.name(), - "Unsupported ScalarFunctionExpr" - ); - vortex_bail!("Unsupported ScalarFunctionExpr: {}", scalar_fn.name()) -} - -impl TryFromDataFusion for Operator { - fn try_from_df(value: &DFOperator) -> VortexResult { - match value { - DFOperator::Eq => Ok(Operator::Eq), - DFOperator::NotEq => Ok(Operator::NotEq), - DFOperator::Lt => Ok(Operator::Lt), - DFOperator::LtEq => Ok(Operator::Lte), - DFOperator::Gt => Ok(Operator::Gt), - DFOperator::GtEq => Ok(Operator::Gte), - DFOperator::And => Ok(Operator::And), - DFOperator::Or => Ok(Operator::Or), - DFOperator::Plus => Ok(Operator::Add), - DFOperator::Minus => Ok(Operator::Sub), - DFOperator::Multiply => Ok(Operator::Mul), - DFOperator::Divide => Ok(Operator::Div), - DFOperator::IsDistinctFrom - | DFOperator::IsNotDistinctFrom - | DFOperator::RegexMatch - | DFOperator::RegexIMatch - | DFOperator::RegexNotMatch - | DFOperator::RegexNotIMatch - | DFOperator::LikeMatch - | DFOperator::ILikeMatch - | DFOperator::NotLikeMatch - | DFOperator::NotILikeMatch - | DFOperator::BitwiseAnd - | DFOperator::BitwiseOr - | DFOperator::BitwiseXor - | DFOperator::BitwiseShiftRight - | DFOperator::BitwiseShiftLeft - | DFOperator::StringConcat - | DFOperator::AtArrow - | DFOperator::ArrowAt - | DFOperator::Modulo - | DFOperator::Arrow - | DFOperator::LongArrow - | DFOperator::HashArrow - | DFOperator::HashLongArrow - | DFOperator::AtAt - | DFOperator::IntegerDivide - | DFOperator::HashMinus - | DFOperator::AtQuestion - | DFOperator::Question - | DFOperator::QuestionAnd - | DFOperator::QuestionPipe => { - tracing::debug!(operator = %value, "Can't pushdown binary_operator operator"); - Err(vortex_err!("Unsupported datafusion operator {value}")) - } +fn try_operator_from_df(value: &DFOperator) -> VortexResult { + match value { + DFOperator::Eq => Ok(Operator::Eq), + DFOperator::NotEq => Ok(Operator::NotEq), + DFOperator::Lt => Ok(Operator::Lt), + DFOperator::LtEq => Ok(Operator::Lte), + DFOperator::Gt => Ok(Operator::Gt), + DFOperator::GtEq => Ok(Operator::Gte), + DFOperator::And => Ok(Operator::And), + DFOperator::Or => Ok(Operator::Or), + DFOperator::Plus => Ok(Operator::Add), + DFOperator::Minus => Ok(Operator::Sub), + DFOperator::Multiply => Ok(Operator::Mul), + DFOperator::Divide => Ok(Operator::Div), + DFOperator::IsDistinctFrom + | DFOperator::IsNotDistinctFrom + | DFOperator::RegexMatch + | DFOperator::RegexIMatch + | DFOperator::RegexNotMatch + | DFOperator::RegexNotIMatch + | DFOperator::LikeMatch + | DFOperator::ILikeMatch + | DFOperator::NotLikeMatch + | DFOperator::NotILikeMatch + | DFOperator::BitwiseAnd + | DFOperator::BitwiseOr + | DFOperator::BitwiseXor + | DFOperator::BitwiseShiftRight + | DFOperator::BitwiseShiftLeft + | DFOperator::StringConcat + | DFOperator::AtArrow + | DFOperator::ArrowAt + | DFOperator::Modulo + | DFOperator::Arrow + | DFOperator::LongArrow + | DFOperator::HashArrow + | DFOperator::HashLongArrow + | DFOperator::AtAt + | DFOperator::IntegerDivide + | DFOperator::HashMinus + | DFOperator::AtQuestion + | DFOperator::Question + | DFOperator::QuestionAnd + | DFOperator::QuestionPipe => { + tracing::debug!(operator = %value, "Can't pushdown binary_operator operator"); + Err(vortex_err!("Unsupported datafusion operator {value}")) } } } @@ -262,7 +280,7 @@ pub(crate) fn can_be_pushed_down(df_expr: &Arc, schema: &Schem } fn can_binary_be_pushed_down(binary: &df_expr::BinaryExpr, schema: &Schema) -> bool { - let is_op_supported = Operator::try_from_df(binary.op()).is_ok(); + let is_op_supported = try_operator_from_df(binary.op()).is_ok(); is_op_supported && can_be_pushed_down(binary.left(), schema) && can_be_pushed_down(binary.right(), schema) @@ -320,8 +338,6 @@ mod tests { use datafusion_physical_plan::expressions as df_expr; use insta::assert_snapshot; use rstest::rstest; - use vortex::expr::Expression; - use vortex::expr::Operator; use super::*; @@ -347,22 +363,25 @@ mod tests { #[test] fn test_make_vortex_predicate_empty() { - let result = make_vortex_predicate(&[]).unwrap(); + let expr_convertor = DefaultExpressionConvertor::default(); + let result = make_vortex_predicate(&expr_convertor, &[]).unwrap(); assert!(result.is_none()); } #[test] fn test_make_vortex_predicate_single() { + let expr_convertor = DefaultExpressionConvertor::default(); let col_expr = Arc::new(df_expr::Column::new("test", 0)) as Arc; - let result = make_vortex_predicate(&[col_expr]).unwrap(); + let result = make_vortex_predicate(&expr_convertor, &[col_expr]).unwrap(); assert!(result.is_some()); } #[test] fn test_make_vortex_predicate_multiple() { + let expr_convertor = DefaultExpressionConvertor::default(); let col1 = Arc::new(df_expr::Column::new("col1", 0)) as Arc; let col2 = Arc::new(df_expr::Column::new("col2", 1)) as Arc; - let result = make_vortex_predicate(&[col1, col2]).unwrap(); + let result = make_vortex_predicate(&expr_convertor, &[col1, col2]).unwrap(); assert!(result.is_some()); // Result should be an AND expression combining the two columns } @@ -384,7 +403,7 @@ mod tests { #[case] df_op: DFOperator, #[case] expected_vortex_op: Operator, ) { - let result = Operator::try_from_df(&df_op).unwrap(); + let result = try_operator_from_df(&df_op).unwrap(); assert_eq!(result, expected_vortex_op); } @@ -394,7 +413,7 @@ mod tests { #[case::regex_match(DFOperator::RegexMatch)] #[case::like_match(DFOperator::LikeMatch)] fn test_operator_conversion_unsupported(#[case] df_op: DFOperator) { - let result = Operator::try_from_df(&df_op); + let result = try_operator_from_df(&df_op); assert!(result.is_err()); assert!( result @@ -407,7 +426,9 @@ mod tests { #[test] fn test_expr_from_df_column() { let col_expr = df_expr::Column::new("test_column", 0); - let result = Expression::try_from_df(&col_expr).unwrap(); + let result = DefaultExpressionConvertor::default() + .convert(&col_expr) + .unwrap(); assert_snapshot!(result.display_tree().to_string(), @r" vortex.get_item(test_column) @@ -418,7 +439,9 @@ mod tests { #[test] fn test_expr_from_df_literal() { let literal_expr = df_expr::Literal::new(ScalarValue::Int32(Some(42))); - let result = Expression::try_from_df(&literal_expr).unwrap(); + let result = DefaultExpressionConvertor::default() + .convert(&literal_expr) + .unwrap(); assert_snapshot!(result.display_tree().to_string(), @"vortex.literal(42i32)"); } @@ -430,7 +453,9 @@ mod tests { Arc::new(df_expr::Literal::new(ScalarValue::Int32(Some(42)))) as Arc; let binary_expr = df_expr::BinaryExpr::new(left, DFOperator::Eq, right); - let result = Expression::try_from_df(&binary_expr).unwrap(); + let result = DefaultExpressionConvertor::default() + .convert(&binary_expr) + .unwrap(); assert_snapshot!(result.display_tree().to_string(), @r" vortex.binary(=) @@ -452,7 +477,9 @@ mod tests { )))) as Arc; let like_expr = df_expr::LikeExpr::new(negated, case_insensitive, expr, pattern); - let result = Expression::try_from_df(&like_expr).unwrap(); + let result = DefaultExpressionConvertor::default() + .convert(&like_expr) + .unwrap(); let like_opts = result.as_::(); assert_eq!( like_opts, diff --git a/vortex-datafusion/src/convert/mod.rs b/vortex-datafusion/src/convert/mod.rs index d40be1d38fc..51779914b28 100644 --- a/vortex-datafusion/src/convert/mod.rs +++ b/vortex-datafusion/src/convert/mod.rs @@ -6,11 +6,6 @@ use vortex::error::VortexResult; pub(crate) mod exprs; mod scalars; -/// First-party trait for implementing conversion from DataFusion types to Vortex types. -pub(crate) trait TryFromDataFusion: Sized { - fn try_from_df(df: &D) -> VortexResult; -} - /// First-party trait for implementing conversion from DataFusion types to Vortex types. pub(crate) trait FromDataFusion: Sized { fn from_df(df: &D) -> Self; diff --git a/vortex-datafusion/src/lib.rs b/vortex-datafusion/src/lib.rs index fceff2d5462..245f5547ac3 100644 --- a/vortex-datafusion/src/lib.rs +++ b/vortex-datafusion/src/lib.rs @@ -12,6 +12,7 @@ mod convert; mod persistent; pub mod vendor; +pub use convert::exprs::ExpressionConvertor; pub use persistent::*; /// Extension trait to convert our [`Precision`](vortex::stats::Precision) to Datafusion's [`Precision`](datafusion_common::stats::Precision) diff --git a/vortex-datafusion/src/persistent/opener.rs b/vortex-datafusion/src/persistent/opener.rs index b6f43338788..91755e7a92f 100644 --- a/vortex-datafusion/src/persistent/opener.rs +++ b/vortex-datafusion/src/persistent/opener.rs @@ -47,6 +47,7 @@ use vortex_utils::aliases::dash_map::Entry; use super::cache::VortexFileCache; use crate::VortexAccessPlan; +use crate::convert::exprs::ExpressionConvertor; use crate::convert::exprs::can_be_pushed_down; use crate::convert::exprs::make_vortex_predicate; @@ -84,6 +85,8 @@ pub(crate) struct VortexOpener { pub layout_readers: Arc>>, /// Whether the query has output ordering specified pub has_output_ordering: bool, + + pub expression_convertor: Arc, } impl FileOpener for VortexOpener { @@ -117,6 +120,8 @@ impl FileOpener for VortexOpener { let partition_fields = self.table_schema.table_partition_cols().clone(); let table_schema = self.table_schema.clone(); + let expr_convertor = self.expression_convertor.clone(); + Ok(async move { // Create FilePruner when we have a predicate and either dynamic expressions // or file statistics available. The pruner can eliminate files without @@ -271,7 +276,7 @@ impl FileOpener for VortexOpener { ))); } - make_vortex_predicate(&pushed).transpose() + make_vortex_predicate(expr_convertor.as_ref(), &pushed).transpose() }) .transpose() .map_err(|e| DataFusionError::External(e.into()))?; @@ -398,6 +403,7 @@ mod tests { use super::*; use crate::VortexAccessPlan; + use crate::convert::exprs::DefaultExpressionConvertor; use crate::vendor::schema_rewriter::DF52PhysicalExprAdapterFactory; static SESSION: LazyLock = LazyLock::new(VortexSession::default); @@ -484,6 +490,7 @@ mod tests { metrics: Default::default(), layout_readers: Default::default(), has_output_ordering: false, + expression_convertor: Arc::new(DefaultExpressionConvertor::default()), } } @@ -629,6 +636,7 @@ mod tests { metrics: Default::default(), layout_readers: Default::default(), has_output_ordering: false, + expression_convertor: Arc::new(DefaultExpressionConvertor::default()), }; let filter = col("a").lt(lit(100_i32)); @@ -712,6 +720,7 @@ mod tests { metrics: Default::default(), layout_readers: Default::default(), has_output_ordering: false, + expression_convertor: Arc::new(DefaultExpressionConvertor::default()), }; // The opener should successfully open the file and reorder columns @@ -864,6 +873,7 @@ mod tests { metrics: Default::default(), layout_readers: Default::default(), has_output_ordering: false, + expression_convertor: Arc::new(DefaultExpressionConvertor::default()), }; // This should succeed and return the correctly projected and cast data @@ -920,6 +930,7 @@ mod tests { metrics: Default::default(), layout_readers: Default::default(), has_output_ordering: false, + expression_convertor: Arc::new(DefaultExpressionConvertor::default()), } } diff --git a/vortex-datafusion/src/persistent/source.rs b/vortex-datafusion/src/persistent/source.rs index ac1ad37f03c..584396279fd 100644 --- a/vortex-datafusion/src/persistent/source.rs +++ b/vortex-datafusion/src/persistent/source.rs @@ -37,6 +37,8 @@ use vortex_utils::aliases::dash_map::DashMap; use super::cache::VortexFileCache; use super::metrics::PARTITION_LABEL; use super::opener::VortexOpener; +use crate::convert::exprs::DefaultExpressionConvertor; +use crate::convert::exprs::ExpressionConvertor; use crate::convert::exprs::can_be_pushed_down; use crate::vendor::schema_rewriter::DF52PhysicalExprAdapterFactory; @@ -63,6 +65,7 @@ pub struct VortexSource { /// /// Sharing the readers allows us to only read every layout once from the file, even across partitions. layout_readers: Arc>>, + expression_convertor: Arc, } impl VortexSource { @@ -79,8 +82,18 @@ impl VortexSource { expr_adapter_factory: None, _unused_df_metrics: Default::default(), layout_readers: Arc::new(DashMap::default()), + expression_convertor: Arc::new(DefaultExpressionConvertor::default()), } } + + /// Set a [`ExpressionConvertor`] to control how Datafusion expression should be converted and pushed down. + pub fn with_expression_convertor( + mut self, + expr_convertor: Arc, + ) -> Self { + self.expression_convertor = expr_convertor; + self + } } impl FileSource for VortexSource { @@ -150,6 +163,7 @@ impl FileSource for VortexSource { metrics: partition_metrics, layout_readers: self.layout_readers.clone(), has_output_ordering: !base_config.output_ordering.is_empty(), + expression_convertor: Arc::new(DefaultExpressionConvertor::default()), }; Arc::new(opener)