Skip to content

Commit 75e1599

Browse files
committed
remove, add tests
Signed-off-by: Andrew Duffy <[email protected]>
1 parent 7046d3a commit 75e1599

File tree

3 files changed

+162
-75
lines changed

3 files changed

+162
-75
lines changed

vortex-datafusion/Cargo.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,7 @@ vortex-utils = { workspace = true, features = ["dashmap"] }
4343
[dev-dependencies]
4444
anyhow = { workspace = true }
4545
datafusion = { workspace = true }
46+
datafusion-common = { workspace = true }
4647
insta = { workspace = true }
4748
rstest = { workspace = true }
4849
tempfile = { workspace = true }

vortex-datafusion/src/convert/exprs.rs

Lines changed: 160 additions & 74 deletions
Original file line numberDiff line numberDiff line change
@@ -228,8 +228,11 @@ pub(crate) fn can_be_pushed_down(df_expr: &PhysicalExprRef, schema: &Schema) ->
228228
} else if let Some(in_list) = expr.downcast_ref::<df_expr::InListExpr>() {
229229
can_be_pushed_down(in_list.expr(), schema)
230230
&& in_list.list().iter().all(|e| can_be_pushed_down(e, schema))
231-
} else if expr.downcast_ref::<ScalarFunctionExpr>().is_some() {
232-
get_source_data_type(df_expr, schema).is_some()
231+
} else if let Some(scalar_fn) = expr.downcast_ref::<ScalarFunctionExpr>() {
232+
// Only get_field expressions should be pushed down. Note, we know that
233+
// the GetFieldFunc call should be well-formed, because the DataFusion planner
234+
// checks that for us before we even get to the DataSource.
235+
ScalarFunctionExpr::try_downcast_func::<GetFieldFunc>(scalar_fn).is_some()
233236
} else {
234237
tracing::debug!(%df_expr, "DataFusion expression can't be pushed down");
235238
false
@@ -276,64 +279,37 @@ fn supported_data_types(dt: &DataType) -> bool {
276279
is_supported
277280
}
278281

279-
/// Evaluate the source `expr` within the scope of `schema` and return its data type. If the source
280-
/// expression is not composed of valid field accesses that we can pushdown to Vortex, fail.
281-
fn get_source_data_type(expr: &Arc<dyn PhysicalExpr>, schema: &Schema) -> Option<DataType> {
282-
if let Some(col) = expr.as_any().downcast_ref::<df_expr::Column>() {
283-
// Column expression handler
284-
let Ok(field) = schema.field_with_name(col.name()) else {
285-
return None;
286-
};
287-
288-
// Get back the data type here instead.
289-
Some(field.data_type().clone())
290-
} else if let Some(scalar_fn) = expr.as_any().downcast_ref::<ScalarFunctionExpr>() {
291-
// Struct field access handler
292-
let get_field_fn = ScalarFunctionExpr::try_downcast_func::<GetFieldFunc>(scalar_fn)?;
293-
294-
let args = get_field_fn.args();
295-
if args.len() != 2 {
296-
return None;
297-
}
298-
299-
let source = &args[0];
300-
let field_name_expr = &args[1];
301-
302-
let DataType::Struct(fields) = get_source_data_type(source, schema)? else {
303-
return None;
304-
};
305-
306-
let field_name = field_name_expr
307-
.as_any()
308-
.downcast_ref::<df_expr::Literal>()
309-
.and_then(|l| l.value().try_as_str())
310-
.flatten()?;
311-
312-
// Extract the named field from the struct type
313-
fields
314-
.find(field_name)
315-
.map(|(_, dt)| dt.data_type().clone())
316-
} else {
317-
None
318-
}
319-
}
320-
321282
#[cfg(test)]
322283
mod tests {
284+
use std::any::Any;
323285
use std::sync::Arc;
324286

325-
use arrow_schema::{DataType, Field, Fields, Schema, TimeUnit as ArrowTimeUnit};
287+
use arrow_schema::{
288+
DataType, Field, Schema, SchemaBuilder, SchemaRef, TimeUnit as ArrowTimeUnit,
289+
};
326290
use datafusion::functions::core::getfield::GetFieldFunc;
327-
use datafusion_common::ScalarValue;
291+
use datafusion::logical_expr::{ColumnarValue, Signature};
328292
use datafusion_common::config::ConfigOptions;
329-
use datafusion_expr::{Operator as DFOperator, ScalarUDF};
330-
use datafusion_physical_expr::PhysicalExpr;
293+
use datafusion_common::{ScalarValue, ToDFSchema};
294+
use datafusion_datasource::file::FileSource;
295+
use datafusion_expr::execution_props::ExecutionProps;
296+
use datafusion_expr::expr::ScalarFunction;
297+
use datafusion_expr::{
298+
Expr, Operator as DFOperator, ScalarFunctionArgs, ScalarUDF, ScalarUDFImpl, Volatility, col,
299+
};
300+
use datafusion_functions::expr_fn::get_field;
301+
use datafusion_physical_expr::{PhysicalExpr, create_physical_expr};
331302
use datafusion_physical_plan::expressions as df_expr;
303+
use datafusion_physical_plan::filter_pushdown::PushedDown;
332304
use insta::assert_snapshot;
333305
use rstest::rstest;
306+
use vortex::VortexSessionDefault;
334307
use vortex::expr::{Expression, Operator};
308+
use vortex::session::VortexSession;
335309

336310
use super::*;
311+
use crate::VortexSource;
312+
use crate::persistent::cache::VortexFileCache;
337313

338314
#[rstest::fixture]
339315
fn test_schema() -> Schema {
@@ -505,7 +481,8 @@ mod tests {
505481
DataType::List(Arc::new(Field::new("item", DataType::Int32, true))),
506482
false
507483
)]
508-
#[case::struct_type(DataType::Struct(vec![Field::new("field", DataType::Int32, true)].into()), false)]
484+
#[case::struct_type(DataType::Struct(vec![Field::new("field", DataType::Int32, true)].into()
485+
), false)]
509486
// Dictionary types - should be supported if value type is supported
510487
#[case::dict_utf8(
511488
DataType::Dictionary(Box::new(DataType::UInt32), Box::new(DataType::Utf8)),
@@ -647,33 +624,142 @@ mod tests {
647624
"#);
648625
}
649626

650-
#[rstest]
651-
#[case::valid_field("field1", true)]
652-
#[case::missing_field("nonexistent_field", false)]
653-
fn test_can_be_pushed_down_get_field(#[case] field_name: &str, #[case] expected: bool) {
654-
let struct_fields = Fields::from(vec![
655-
Field::new("field1", DataType::Utf8, true),
656-
Field::new("field2", DataType::Int32, true),
657-
]);
658-
let schema = Schema::new(vec![Field::new(
659-
"my_struct",
660-
DataType::Struct(struct_fields),
661-
true,
662-
)]);
627+
#[test]
628+
fn test_pushdown_nested_filter() {
629+
// schema:
630+
// a: struct
631+
// |- one: i32
632+
// b:struct
633+
// |- two: i32
634+
let mut test_schema = SchemaBuilder::new();
635+
test_schema.push(Field::new_struct(
636+
"a",
637+
vec![Field::new("one", DataType::Int32, false)],
638+
false,
639+
));
640+
test_schema.push(Field::new_struct(
641+
"b",
642+
vec![Field::new("two", DataType::Int32, false)],
643+
false,
644+
));
663645

664-
let struct_col = Arc::new(df_expr::Column::new("my_struct", 0)) as Arc<dyn PhysicalExpr>;
665-
let field_name_lit = Arc::new(df_expr::Literal::new(ScalarValue::Utf8(Some(
666-
field_name.to_string(),
667-
)))) as Arc<dyn PhysicalExpr>;
646+
let test_schema = Arc::new(test_schema.finish());
647+
// Make sure filter is pushed down
648+
let filter = get_field(col("b"), "two").eq(datafusion_expr::lit(10i32));
668649

669-
let get_field_expr = Arc::new(ScalarFunctionExpr::new(
670-
"get_field",
671-
Arc::new(ScalarUDF::from(GetFieldFunc::new())),
672-
vec![struct_col, field_name_lit],
673-
Arc::new(Field::new(field_name, DataType::Utf8, true)),
674-
Arc::new(ConfigOptions::new()),
675-
)) as Arc<dyn PhysicalExpr>;
650+
let df_schema = test_schema.clone().to_dfschema().unwrap();
651+
652+
let physical_filter =
653+
create_physical_expr(&filter, &df_schema, &ExecutionProps::default()).unwrap();
654+
655+
let source = vortex_source(&test_schema);
656+
657+
let prop = source
658+
.try_pushdown_filters(vec![physical_filter], &ConfigOptions::default())
659+
.unwrap();
660+
assert!(matches!(prop.filters[0], PushedDown::Yes));
661+
}
662+
663+
#[test]
664+
fn test_pushdown_deeply_nested_filter() {
665+
// schema:
666+
// a: struct
667+
// |- b: struct
668+
// |- c: i32
669+
let mut schema = SchemaBuilder::new();
670+
671+
let c = Field::new("c", DataType::Int32, false);
672+
let b = Field::new_struct("b", vec![c], false);
673+
let a = Field::new_struct("a", vec![b], false);
674+
schema.push(a);
675+
676+
let schema = Arc::new(schema.finish());
677+
let df_schema = schema.clone().to_dfschema().unwrap();
678+
679+
let source = vortex_source(&schema);
680+
681+
let deep_filter = get_field(get_field(col("a"), "b"), "c").eq(datafusion_expr::lit(10i32));
682+
683+
let physical_filter =
684+
create_physical_expr(&deep_filter, &df_schema, &ExecutionProps::default()).unwrap();
685+
686+
let prop = source
687+
.try_pushdown_filters(vec![physical_filter], &ConfigOptions::default())
688+
.unwrap();
689+
assert!(matches!(prop.filters[0], PushedDown::Yes));
690+
}
691+
692+
#[test]
693+
fn test_unknown_scalar_function() {
694+
#[derive(Debug, PartialEq, Eq, Hash)]
695+
pub struct UnknownImpl {
696+
signature: Signature,
697+
}
698+
699+
impl ScalarUDFImpl for UnknownImpl {
700+
fn as_any(&self) -> &dyn Any {
701+
self
702+
}
703+
704+
fn name(&self) -> &str {
705+
"unknown"
706+
}
707+
708+
fn signature(&self) -> &Signature {
709+
&self.signature
710+
}
711+
712+
fn return_type(&self, _arg_types: &[DataType]) -> datafusion_common::Result<DataType> {
713+
Ok(DataType::Int32)
714+
}
715+
716+
fn invoke_with_args(
717+
&self,
718+
_args: ScalarFunctionArgs,
719+
) -> datafusion_common::Result<ColumnarValue> {
720+
Ok(ColumnarValue::Scalar(ScalarValue::Int32(Some(1))))
721+
}
722+
}
723+
724+
// schema:
725+
// a: struct
726+
// |- b: struct
727+
// |- c: i32
728+
let mut schema = SchemaBuilder::new();
729+
730+
let c = Field::new("c", DataType::Int32, false);
731+
let b = Field::new_struct("b", vec![c], false);
732+
let a = Field::new_struct("a", vec![b], false);
733+
schema.push(a);
734+
735+
let schema = Arc::new(schema.finish());
736+
let df_schema = schema.clone().to_dfschema().unwrap();
737+
738+
let source = vortex_source(&schema);
739+
740+
let unknown_func = Expr::ScalarFunction(ScalarFunction {
741+
func: Arc::new(ScalarUDF::new_from_impl(UnknownImpl {
742+
signature: Signature::nullary(Volatility::Immutable),
743+
})),
744+
args: vec![],
745+
});
746+
747+
// Another weird ScalarFunction that we can't push down
748+
let deep_filter = unknown_func.eq(datafusion_expr::lit(10i32));
749+
750+
let physical_filter =
751+
create_physical_expr(&deep_filter, &df_schema, &ExecutionProps::default()).unwrap();
752+
753+
let prop = source
754+
.try_pushdown_filters(vec![physical_filter], &ConfigOptions::default())
755+
.unwrap();
756+
assert!(matches!(prop.filters[0], PushedDown::No));
757+
}
758+
759+
fn vortex_source(schema: &SchemaRef) -> Arc<dyn FileSource> {
760+
let session = VortexSession::default();
761+
let cache = VortexFileCache::new(1024, 1024, session.clone());
676762

677-
assert_eq!(can_be_pushed_down(&get_field_expr, &schema), expected);
763+
Arc::new(VortexSource::new(session.clone(), cache)).with_schema(schema.clone())
678764
}
679765
}

vortex-datafusion/src/persistent/mod.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
// SPDX-FileCopyrightText: Copyright the Vortex contributors
33

44
//! Persistent implementation of a Vortex table provider.
5-
mod cache;
5+
pub(crate) mod cache;
66
mod format;
77
pub mod metrics;
88
mod opener;

0 commit comments

Comments
 (0)