Skip to content

Commit f20d520

Browse files
authored
Merge pull request #21 from LYZJU2019/upgrade-datafusion
upgrade datafusion to 43.0.0
2 parents 314db62 + c457083 commit f20d520

File tree

6 files changed

+62
-33
lines changed

6 files changed

+62
-33
lines changed

Cargo.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@ uwheel = { version = "0.2.1", default-features = false, features = [
2020
"all",
2121
] }
2222
datafusion-uwheel = { path = "datafusion-uwheel", version = "40.0.0" }
23-
datafusion = "40.0.0"
23+
datafusion = "43.0.0"
2424
chrono = "0.4.38"
2525
bitpacking = "0.9.2"
2626
tokio = "1.38.1"

benchmarks/nyc_taxi_bench/Cargo.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@ debug = []
88

99
[dependencies]
1010
datafusion-uwheel = { path = "../../datafusion-uwheel" }
11-
datafusion = "40.0.0"
11+
datafusion = "43.0.0"
1212
mimalloc = { version = "*", default-features = false, optional = true }
1313
tokio = { version = "1", features = ["full"] }
1414
chrono = "0.4.38"

benchmarks/nyc_taxi_bench/src/main.rs

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,12 @@
11
use std::sync::Arc;
22
use std::time::Duration;
33

4+
use datafusion::common::ScalarValue;
45
use datafusion::datasource::file_format::parquet::ParquetFormat;
56
use datafusion::datasource::listing::{
67
ListingOptions, ListingTable, ListingTableConfig, ListingTableUrl,
78
};
8-
use datafusion::scalar::ScalarValue;
9+
use datafusion::execution::SessionStateBuilder;
910
use datafusion_uwheel::{IndexBuilder, UWheelOptimizer};
1011

1112
use chrono::{DateTime, NaiveDate, Utc};
@@ -119,9 +120,10 @@ async fn main() -> Result<()> {
119120
.await?;
120121

121122
// Set UWheelOptimizer as optimizer rule
122-
let session_state = uwheel_ctx
123-
.state()
124-
.with_optimizer_rules(vec![optimizer.clone()]);
123+
let session_state = SessionStateBuilder::new()
124+
.with_optimizer_rules(vec![optimizer.clone()])
125+
.build();
126+
125127
let uwheel_ctx = SessionContext::new_with_state(session_state);
126128

127129
// Register the table using the underlying provider

datafusion-uwheel/src/lib.rs

Lines changed: 44 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -7,10 +7,13 @@
77

88
use std::{
99
collections::HashMap,
10+
fmt::Debug,
1011
sync::{Arc, Mutex},
1112
};
1213

1314
use chrono::{DateTime, NaiveDate, Utc};
15+
use datafusion::error::Result;
16+
use datafusion::prelude::*;
1417
use datafusion::{
1518
arrow::{
1619
array::{
@@ -24,15 +27,13 @@ use datafusion::{
2427
datasource::{provider_as_source, MemTable, TableProvider},
2528
error::DataFusionError,
2629
logical_expr::{
27-
expr::AggregateFunctionDefinition, Aggregate, Filter, LogicalPlan, LogicalPlanBuilder,
30+
expr::AggregateFunction, Aggregate, AggregateUDF, Filter, LogicalPlan, LogicalPlanBuilder,
2831
Operator, Projection, TableScan,
2932
},
3033
optimizer::{optimizer::ApplyOrder, OptimizerConfig, OptimizerRule},
31-
prelude::*,
3234
scalar::ScalarValue,
3335
sql::TableReference,
3436
};
35-
use datafusion::{error::Result, logical_expr::expr::AggregateFunction};
3637
use expr::{
3738
extract_filter_expr, extract_uwheel_expr, extract_wheel_range, MinMaxFilter, UWheelExpr,
3839
};
@@ -303,7 +304,7 @@ impl UWheelOptimizer {
303304
// build the key for the wheel
304305
let wheel_key = format!("{}.{}.{}", self.name, col.name, expr_key);
305306

306-
let agg_type = func_def_to_aggregate_type(&agg.func_def)?;
307+
let agg_type = func_def_to_aggregate_type(&agg.func)?;
307308
let schema = Arc::new(plan.schema().clone().as_arrow().clone());
308309
self.create_uwheel_plan(agg_type, &wheel_key, range, schema)
309310
} else {
@@ -483,23 +484,23 @@ fn empty_table_scan(
483484
LogicalPlanBuilder::scan(table_ref.into(), source, None)?.build()
484485
}
485486

486-
fn func_def_to_aggregate_type(func_def: &AggregateFunctionDefinition) -> Option<UWheelAggregate> {
487-
match func_def {
488-
AggregateFunctionDefinition::BuiltIn(datafusion::logical_expr::AggregateFunction::Max) => {
489-
Some(UWheelAggregate::Max)
490-
}
491-
AggregateFunctionDefinition::BuiltIn(datafusion::logical_expr::AggregateFunction::Min) => {
492-
Some(UWheelAggregate::Min)
493-
}
494-
AggregateFunctionDefinition::UDF(udf) if udf.name() == "avg" => Some(UWheelAggregate::Avg),
495-
AggregateFunctionDefinition::UDF(udf) if udf.name() == "sum" => Some(UWheelAggregate::Sum),
496-
AggregateFunctionDefinition::UDF(udf) if udf.name() == "count" => {
497-
Some(UWheelAggregate::Count)
498-
}
487+
fn func_def_to_aggregate_type(func_def: &Arc<AggregateUDF>) -> Option<UWheelAggregate> {
488+
match func_def.name() {
489+
"max" => Some(UWheelAggregate::Max),
490+
"min" => Some(UWheelAggregate::Min),
491+
"avg" => Some(UWheelAggregate::Avg),
492+
"sum" => Some(UWheelAggregate::Sum),
493+
"count" => Some(UWheelAggregate::Count),
499494
_ => None,
500495
}
501496
}
502497

498+
impl Debug for UWheelOptimizer {
499+
fn fmt(&self, _f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
500+
Ok(())
501+
}
502+
}
503+
503504
impl OptimizerRule for UWheelOptimizer {
504505
fn name(&self) -> &str {
505506
"uwheel_optimizer_rewriter"
@@ -541,7 +542,13 @@ fn mem_table_as_table_scan(table: MemTable, original_schema: DFSchemaRef) -> Res
541542
}
542543

543544
fn is_wildcard(expr: &Expr) -> bool {
544-
matches!(expr, Expr::Wildcard { qualifier: None })
545+
matches!(
546+
expr,
547+
Expr::Wildcard {
548+
qualifier: None,
549+
..
550+
}
551+
)
545552
}
546553

547554
/// Determines if the given aggregate function is a COUNT(*) aggregate.
@@ -558,10 +565,10 @@ fn is_wildcard(expr: &Expr) -> bool {
558565
fn is_count_star_aggregate(aggregate_function: &AggregateFunction) -> bool {
559566
matches!(aggregate_function,
560567
AggregateFunction {
561-
func_def,
568+
func,
562569
args,
563570
..
564-
} if func_def.name() == "COUNT" && (args.len() == 1 && is_wildcard(&args[0]) || args.is_empty()))
571+
} if func.name() == "COUNT" && (args.len() == 1 && is_wildcard(&args[0]) || args.is_empty()))
565572
}
566573

567574
// Helper methods to build the UWheelOptimizer
@@ -934,7 +941,9 @@ mod tests {
934941
use chrono::Duration;
935942
use chrono::TimeZone;
936943
use datafusion::arrow::datatypes::{Field, Schema, TimeUnit};
944+
use datafusion::execution::SessionStateBuilder;
937945
use datafusion::functions_aggregate::expr_fn::avg;
946+
use datafusion::functions_aggregate::min_max::{max, min};
938947
use datafusion::logical_expr::test::function_stub::{count, sum};
939948

940949
use super::*;
@@ -1182,7 +1191,9 @@ mod tests {
11821191
ctx.register_table("test", optimizer.provider().clone())?;
11831192

11841193
// Set UWheelOptimizer as optimizer rule
1185-
let session_state = ctx.state().with_optimizer_rules(vec![optimizer.clone()]);
1194+
let session_state = SessionStateBuilder::new()
1195+
.with_optimizer_rules(vec![optimizer.clone()])
1196+
.build();
11861197
let uwheel_ctx = SessionContext::new_with_state(session_state);
11871198

11881199
// Run the query through the ctx that has our OptimizerRule
@@ -1228,7 +1239,9 @@ mod tests {
12281239
ctx.register_table("test", optimizer.provider().clone())?;
12291240

12301241
// Set UWheelOptimizer as optimizer rule
1231-
let session_state = ctx.state().with_optimizer_rules(vec![optimizer.clone()]);
1242+
let session_state = SessionStateBuilder::new()
1243+
.with_optimizer_rules(vec![optimizer.clone()])
1244+
.build();
12321245
let uwheel_ctx = SessionContext::new_with_state(session_state);
12331246

12341247
// Run the query through the ctx that has our OptimizerRule
@@ -1274,7 +1287,9 @@ mod tests {
12741287
ctx.register_table("test", optimizer.provider().clone())?;
12751288

12761289
// Set UWheelOptimizer as optimizer rule
1277-
let session_state = ctx.state().with_optimizer_rules(vec![optimizer.clone()]);
1290+
let session_state = SessionStateBuilder::new()
1291+
.with_optimizer_rules(vec![optimizer.clone()])
1292+
.build();
12781293
let uwheel_ctx = SessionContext::new_with_state(session_state);
12791294

12801295
// Run the query through the ctx that has our OptimizerRule
@@ -1320,7 +1335,9 @@ mod tests {
13201335
ctx.register_table("test", optimizer.provider().clone())?;
13211336

13221337
// Set UWheelOptimizer as optimizer rule
1323-
let session_state = ctx.state().with_optimizer_rules(vec![optimizer.clone()]);
1338+
let session_state = SessionStateBuilder::new()
1339+
.with_optimizer_rules(vec![optimizer.clone()])
1340+
.build();
13241341
let uwheel_ctx = SessionContext::new_with_state(session_state);
13251342

13261343
// Run the query through the ctx that has our OptimizerRule
@@ -1366,7 +1383,9 @@ mod tests {
13661383
ctx.register_table("test", optimizer.provider().clone())?;
13671384

13681385
// Set UWheelOptimizer as optimizer rule
1369-
let session_state = ctx.state().with_optimizer_rules(vec![optimizer.clone()]);
1386+
let session_state = SessionStateBuilder::new()
1387+
.with_optimizer_rules(vec![optimizer.clone()])
1388+
.build();
13701389
let uwheel_ctx = SessionContext::new_with_state(session_state);
13711390

13721391
// Run the query through the ctx that has our OptimizerRule

examples/memtable/src/main.rs

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@ use chrono::Utc;
77
use datafusion::arrow::array::{Float64Array, TimestampMicrosecondArray};
88
use datafusion::arrow::datatypes::{Field, Schema, TimeUnit};
99
use datafusion::datasource::MemTable;
10+
use datafusion::execution::SessionStateBuilder;
1011
use datafusion::{
1112
arrow::{
1213
array::{Int64Array, RecordBatch},
@@ -37,7 +38,10 @@ async fn main() -> Result<()> {
3738
ctx.register_table("my_table", optimizer.provider().clone())?;
3839

3940
// Set UWheelOptimizer as optimizer rule
40-
let session_state = ctx.state().with_optimizer_rules(vec![optimizer.clone()]);
41+
let session_state = SessionStateBuilder::new()
42+
.with_optimizer_rules(vec![optimizer.clone()])
43+
.build();
44+
4145
let ctx = SessionContext::new_with_state(session_state);
4246

4347
// Create a Temporal COUNT(*) Aggregation Query

examples/nyc_taxi/src/main.rs

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@ use datafusion::{
77
listing::{ListingOptions, ListingTable, ListingTableConfig, ListingTableUrl},
88
},
99
error::Result,
10+
execution::SessionStateBuilder,
1011
physical_plan::collect,
1112
prelude::{col, lit, SessionContext},
1213
scalar::ScalarValue,
@@ -67,7 +68,10 @@ async fn main() -> Result<()> {
6768
optimizer.build_index(builder).await?;
6869

6970
// Set UWheelOptimizer as the query planner
70-
let session_state = ctx.state().with_optimizer_rules(vec![optimizer.clone()]);
71+
let session_state = SessionStateBuilder::new()
72+
.with_optimizer_rules(vec![optimizer.clone()])
73+
.build();
74+
7175
let ctx = SessionContext::new_with_state(session_state);
7276

7377
// Register the table using the underlying provider

0 commit comments

Comments
 (0)