Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
247 changes: 137 additions & 110 deletions vortex-datafusion/src/convert/exprs.rs
Original file line number Diff line number Diff line change
Expand Up @@ -35,28 +35,81 @@ use vortex::expr::root;
use vortex::scalar::Scalar;

use crate::convert::FromDataFusion;
use crate::convert::TryFromDataFusion;

/// Tries to convert the expressions into a vortex conjunction. Will return Ok(None) iff the input conjunction is empty.
pub(crate) fn make_vortex_predicate(
expr_convertor: &dyn ExpressionConvertor,
predicate: &[Arc<dyn PhysicalExpr>],
) -> VortexResult<Option<Expression>> {
let exprs = predicate
.iter()
.map(|e| Expression::try_from_df(e.as_ref()))
.map(|e| expr_convertor.convert(e.as_ref()))
.collect::<VortexResult<Vec<_>>>()?;

Ok(exprs.into_iter().reduce(and))
}

// TODO(joe): Don't return an error when we have an unsupported node, bubble up "TRUE" as in keep
// for that node, up to any `and` or `or` node.
impl TryFromDataFusion<dyn PhysicalExpr> for Expression {
fn try_from_df(df: &dyn PhysicalExpr) -> VortexResult<Self> {
/// Trait for converting DataFusion expressions to Vortex ones.
pub trait ExpressionConvertor: Send + Sync {
/// Can an expression be pushed down given a specific schema
fn can_be_pushed_down(&self, expr: &Arc<dyn PhysicalExpr>, schema: &Schema) -> bool;

/// Try and convert a DataFusion [`PhysicalExpr`] into a Vortex [`Expression`].
fn convert(&self, expr: &dyn PhysicalExpr) -> VortexResult<Expression>;
}

/// The default [`ExpressionConvertor`].
#[derive(Default)]
pub struct DefaultExpressionConvertor {}

impl DefaultExpressionConvertor {
/// Attempts to convert a DataFusion ScalarFunctionExpr to a Vortex expression.
fn try_convert_scalar_function(
&self,
scalar_fn: &ScalarFunctionExpr,
) -> VortexResult<Expression> {
if let Some(get_field_fn) = ScalarFunctionExpr::try_downcast_func::<GetFieldFunc>(scalar_fn)
{
let source_expr = get_field_fn
.args()
.first()
.ok_or_else(|| vortex_err!("get_field missing source expression"))?
.as_ref();
let field_name_expr = get_field_fn
.args()
.get(1)
.ok_or_else(|| vortex_err!("get_field missing field name argument"))?;
let field_name = field_name_expr
.as_any()
.downcast_ref::<df_expr::Literal>()
.ok_or_else(|| vortex_err!("get_field field name must be a literal"))?
.value()
.try_as_str()
.flatten()
.ok_or_else(|| vortex_err!("get_field field name must be a UTF-8 string"))?;
return Ok(get_item(field_name.to_string(), self.convert(source_expr)?));
}

tracing::debug!(
function_name = scalar_fn.name(),
"Unsupported ScalarFunctionExpr"
);
vortex_bail!("Unsupported ScalarFunctionExpr: {}", scalar_fn.name())
}
}

impl ExpressionConvertor for DefaultExpressionConvertor {
fn can_be_pushed_down(&self, expr: &Arc<dyn PhysicalExpr>, schema: &Schema) -> bool {
can_be_pushed_down(expr, schema)
}

fn convert(&self, df: &dyn PhysicalExpr) -> VortexResult<Expression> {
// TODO(joe): Don't return an error when we have an unsupported node, bubble up "TRUE" as in keep
// for that node, up to any `and` or `or` node.
if let Some(binary_expr) = df.as_any().downcast_ref::<df_expr::BinaryExpr>() {
let left = Expression::try_from_df(binary_expr.left().as_ref())?;
let right = Expression::try_from_df(binary_expr.right().as_ref())?;
let operator = Operator::try_from_df(binary_expr.op())?;
let left = self.convert(binary_expr.left().as_ref())?;
let right = self.convert(binary_expr.right().as_ref())?;
let operator = try_operator_from_df(binary_expr.op())?;

return Ok(Binary.new_expr(operator, [left, right]));
}
Expand All @@ -66,8 +119,8 @@ impl TryFromDataFusion<dyn PhysicalExpr> for Expression {
}

if let Some(like) = df.as_any().downcast_ref::<df_expr::LikeExpr>() {
let child = Expression::try_from_df(like.expr().as_ref())?;
let pattern = Expression::try_from_df(like.pattern().as_ref())?;
let child = self.convert(like.expr().as_ref())?;
let pattern = self.convert(like.pattern().as_ref())?;
return Ok(Like.new_expr(
LikeOptions {
negated: like.negated(),
Expand All @@ -84,30 +137,30 @@ impl TryFromDataFusion<dyn PhysicalExpr> for Expression {

if let Some(cast_expr) = df.as_any().downcast_ref::<df_expr::CastExpr>() {
let cast_dtype = DType::from_arrow((cast_expr.cast_type(), Nullability::Nullable));
let child = Expression::try_from_df(cast_expr.expr().as_ref())?;
let child = self.convert(cast_expr.expr().as_ref())?;
return Ok(cast(child, cast_dtype));
}

if let Some(cast_col_expr) = df.as_any().downcast_ref::<df_expr::CastColumnExpr>() {
let target = cast_col_expr.target_field();

let target_dtype = DType::from_arrow((target.data_type(), target.is_nullable().into()));
let child = Expression::try_from_df(cast_col_expr.expr().as_ref())?;
let child = self.convert(cast_col_expr.expr().as_ref())?;
return Ok(cast(child, target_dtype));
}

if let Some(is_null_expr) = df.as_any().downcast_ref::<df_expr::IsNullExpr>() {
let arg = Expression::try_from_df(is_null_expr.arg().as_ref())?;
let arg = self.convert(is_null_expr.arg().as_ref())?;
return Ok(is_null(arg));
}

if let Some(is_not_null_expr) = df.as_any().downcast_ref::<df_expr::IsNotNullExpr>() {
let arg = Expression::try_from_df(is_not_null_expr.arg().as_ref())?;
let arg = self.convert(is_not_null_expr.arg().as_ref())?;
return Ok(not(is_null(arg)));
}

if let Some(in_list) = df.as_any().downcast_ref::<df_expr::InListExpr>() {
let value = Expression::try_from_df(in_list.expr().as_ref())?;
let value = self.convert(in_list.expr().as_ref())?;
let list_elements: Vec<_> = in_list
.list()
.iter()
Expand All @@ -131,94 +184,59 @@ impl TryFromDataFusion<dyn PhysicalExpr> for Expression {
}

if let Some(scalar_fn) = df.as_any().downcast_ref::<ScalarFunctionExpr>() {
return try_convert_scalar_function(scalar_fn);
return self.try_convert_scalar_function(scalar_fn);
}

vortex_bail!("Couldn't convert DataFusion physical {df} expression to a vortex expression")
}
}

/// Attempts to convert a DataFusion ScalarFunctionExpr to a Vortex expression.
fn try_convert_scalar_function(scalar_fn: &ScalarFunctionExpr) -> VortexResult<Expression> {
if let Some(get_field_fn) = ScalarFunctionExpr::try_downcast_func::<GetFieldFunc>(scalar_fn) {
let source_expr = get_field_fn
.args()
.first()
.ok_or_else(|| vortex_err!("get_field missing source expression"))?
.as_ref();
let field_name_expr = get_field_fn
.args()
.get(1)
.ok_or_else(|| vortex_err!("get_field missing field name argument"))?;
let field_name = field_name_expr
.as_any()
.downcast_ref::<df_expr::Literal>()
.ok_or_else(|| vortex_err!("get_field field name must be a literal"))?
.value()
.try_as_str()
.flatten()
.ok_or_else(|| vortex_err!("get_field field name must be a UTF-8 string"))?;
return Ok(get_item(
field_name.to_string(),
Expression::try_from_df(source_expr)?,
));
}

tracing::debug!(
function_name = scalar_fn.name(),
"Unsupported ScalarFunctionExpr"
);
vortex_bail!("Unsupported ScalarFunctionExpr: {}", scalar_fn.name())
}

impl TryFromDataFusion<DFOperator> for Operator {
fn try_from_df(value: &DFOperator) -> VortexResult<Self> {
match value {
DFOperator::Eq => Ok(Operator::Eq),
DFOperator::NotEq => Ok(Operator::NotEq),
DFOperator::Lt => Ok(Operator::Lt),
DFOperator::LtEq => Ok(Operator::Lte),
DFOperator::Gt => Ok(Operator::Gt),
DFOperator::GtEq => Ok(Operator::Gte),
DFOperator::And => Ok(Operator::And),
DFOperator::Or => Ok(Operator::Or),
DFOperator::Plus => Ok(Operator::Add),
DFOperator::Minus => Ok(Operator::Sub),
DFOperator::Multiply => Ok(Operator::Mul),
DFOperator::Divide => Ok(Operator::Div),
DFOperator::IsDistinctFrom
| DFOperator::IsNotDistinctFrom
| DFOperator::RegexMatch
| DFOperator::RegexIMatch
| DFOperator::RegexNotMatch
| DFOperator::RegexNotIMatch
| DFOperator::LikeMatch
| DFOperator::ILikeMatch
| DFOperator::NotLikeMatch
| DFOperator::NotILikeMatch
| DFOperator::BitwiseAnd
| DFOperator::BitwiseOr
| DFOperator::BitwiseXor
| DFOperator::BitwiseShiftRight
| DFOperator::BitwiseShiftLeft
| DFOperator::StringConcat
| DFOperator::AtArrow
| DFOperator::ArrowAt
| DFOperator::Modulo
| DFOperator::Arrow
| DFOperator::LongArrow
| DFOperator::HashArrow
| DFOperator::HashLongArrow
| DFOperator::AtAt
| DFOperator::IntegerDivide
| DFOperator::HashMinus
| DFOperator::AtQuestion
| DFOperator::Question
| DFOperator::QuestionAnd
| DFOperator::QuestionPipe => {
tracing::debug!(operator = %value, "Can't pushdown binary_operator operator");
Err(vortex_err!("Unsupported datafusion operator {value}"))
}
fn try_operator_from_df(value: &DFOperator) -> VortexResult<Operator> {
match value {
DFOperator::Eq => Ok(Operator::Eq),
DFOperator::NotEq => Ok(Operator::NotEq),
DFOperator::Lt => Ok(Operator::Lt),
DFOperator::LtEq => Ok(Operator::Lte),
DFOperator::Gt => Ok(Operator::Gt),
DFOperator::GtEq => Ok(Operator::Gte),
DFOperator::And => Ok(Operator::And),
DFOperator::Or => Ok(Operator::Or),
DFOperator::Plus => Ok(Operator::Add),
DFOperator::Minus => Ok(Operator::Sub),
DFOperator::Multiply => Ok(Operator::Mul),
DFOperator::Divide => Ok(Operator::Div),
DFOperator::IsDistinctFrom
| DFOperator::IsNotDistinctFrom
| DFOperator::RegexMatch
| DFOperator::RegexIMatch
| DFOperator::RegexNotMatch
| DFOperator::RegexNotIMatch
| DFOperator::LikeMatch
| DFOperator::ILikeMatch
| DFOperator::NotLikeMatch
| DFOperator::NotILikeMatch
| DFOperator::BitwiseAnd
| DFOperator::BitwiseOr
| DFOperator::BitwiseXor
| DFOperator::BitwiseShiftRight
| DFOperator::BitwiseShiftLeft
| DFOperator::StringConcat
| DFOperator::AtArrow
| DFOperator::ArrowAt
| DFOperator::Modulo
| DFOperator::Arrow
| DFOperator::LongArrow
| DFOperator::HashArrow
| DFOperator::HashLongArrow
| DFOperator::AtAt
| DFOperator::IntegerDivide
| DFOperator::HashMinus
| DFOperator::AtQuestion
| DFOperator::Question
| DFOperator::QuestionAnd
| DFOperator::QuestionPipe => {
tracing::debug!(operator = %value, "Can't pushdown binary_operator operator");
Err(vortex_err!("Unsupported datafusion operator {value}"))
}
}
}
Expand Down Expand Up @@ -262,7 +280,7 @@ pub(crate) fn can_be_pushed_down(df_expr: &Arc<dyn PhysicalExpr>, schema: &Schem
}

fn can_binary_be_pushed_down(binary: &df_expr::BinaryExpr, schema: &Schema) -> bool {
let is_op_supported = Operator::try_from_df(binary.op()).is_ok();
let is_op_supported = try_operator_from_df(binary.op()).is_ok();
is_op_supported
&& can_be_pushed_down(binary.left(), schema)
&& can_be_pushed_down(binary.right(), schema)
Expand Down Expand Up @@ -320,8 +338,6 @@ mod tests {
use datafusion_physical_plan::expressions as df_expr;
use insta::assert_snapshot;
use rstest::rstest;
use vortex::expr::Expression;
use vortex::expr::Operator;

use super::*;

Expand All @@ -347,22 +363,25 @@ mod tests {

#[test]
fn test_make_vortex_predicate_empty() {
let result = make_vortex_predicate(&[]).unwrap();
let expr_convertor = DefaultExpressionConvertor::default();
let result = make_vortex_predicate(&expr_convertor, &[]).unwrap();
assert!(result.is_none());
}

#[test]
fn test_make_vortex_predicate_single() {
let expr_convertor = DefaultExpressionConvertor::default();
let col_expr = Arc::new(df_expr::Column::new("test", 0)) as Arc<dyn PhysicalExpr>;
let result = make_vortex_predicate(&[col_expr]).unwrap();
let result = make_vortex_predicate(&expr_convertor, &[col_expr]).unwrap();
assert!(result.is_some());
}

#[test]
fn test_make_vortex_predicate_multiple() {
let expr_convertor = DefaultExpressionConvertor::default();
let col1 = Arc::new(df_expr::Column::new("col1", 0)) as Arc<dyn PhysicalExpr>;
let col2 = Arc::new(df_expr::Column::new("col2", 1)) as Arc<dyn PhysicalExpr>;
let result = make_vortex_predicate(&[col1, col2]).unwrap();
let result = make_vortex_predicate(&expr_convertor, &[col1, col2]).unwrap();
assert!(result.is_some());
// Result should be an AND expression combining the two columns
}
Expand All @@ -384,7 +403,7 @@ mod tests {
#[case] df_op: DFOperator,
#[case] expected_vortex_op: Operator,
) {
let result = Operator::try_from_df(&df_op).unwrap();
let result = try_operator_from_df(&df_op).unwrap();
assert_eq!(result, expected_vortex_op);
}

Expand All @@ -394,7 +413,7 @@ mod tests {
#[case::regex_match(DFOperator::RegexMatch)]
#[case::like_match(DFOperator::LikeMatch)]
fn test_operator_conversion_unsupported(#[case] df_op: DFOperator) {
let result = Operator::try_from_df(&df_op);
let result = try_operator_from_df(&df_op);
assert!(result.is_err());
assert!(
result
Expand All @@ -407,7 +426,9 @@ mod tests {
#[test]
fn test_expr_from_df_column() {
let col_expr = df_expr::Column::new("test_column", 0);
let result = Expression::try_from_df(&col_expr).unwrap();
let result = DefaultExpressionConvertor::default()
.convert(&col_expr)
.unwrap();

assert_snapshot!(result.display_tree().to_string(), @r"
vortex.get_item(test_column)
Expand All @@ -418,7 +439,9 @@ mod tests {
#[test]
fn test_expr_from_df_literal() {
let literal_expr = df_expr::Literal::new(ScalarValue::Int32(Some(42)));
let result = Expression::try_from_df(&literal_expr).unwrap();
let result = DefaultExpressionConvertor::default()
.convert(&literal_expr)
.unwrap();

assert_snapshot!(result.display_tree().to_string(), @"vortex.literal(42i32)");
}
Expand All @@ -430,7 +453,9 @@ mod tests {
Arc::new(df_expr::Literal::new(ScalarValue::Int32(Some(42)))) as Arc<dyn PhysicalExpr>;
let binary_expr = df_expr::BinaryExpr::new(left, DFOperator::Eq, right);

let result = Expression::try_from_df(&binary_expr).unwrap();
let result = DefaultExpressionConvertor::default()
.convert(&binary_expr)
.unwrap();

assert_snapshot!(result.display_tree().to_string(), @r"
vortex.binary(=)
Expand All @@ -452,7 +477,9 @@ mod tests {
)))) as Arc<dyn PhysicalExpr>;
let like_expr = df_expr::LikeExpr::new(negated, case_insensitive, expr, pattern);

let result = Expression::try_from_df(&like_expr).unwrap();
let result = DefaultExpressionConvertor::default()
.convert(&like_expr)
.unwrap();
let like_opts = result.as_::<Like>();
assert_eq!(
like_opts,
Expand Down
Loading
Loading