Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
127 changes: 127 additions & 0 deletions datafusion-uwheel/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -275,6 +275,10 @@ impl UWheelOptimizer {
agg.group_expr.is_empty() && agg.aggr_expr.len() > 1
}

fn single_aggregate_without_filter(agg: &Aggregate) -> bool {
!Self::has_filter(agg) && Self::single_agg(agg)
}

// Attemps to rewrite a top-level Projection plan
fn try_rewrite_projection(
&self,
Expand Down Expand Up @@ -433,6 +437,31 @@ impl UWheelOptimizer {

uwheel_multiple_aggregations_to_table_scan(agg_results, schema).ok()
}

LogicalPlan::Aggregate(agg) if Self::single_aggregate_without_filter(agg) => {
let agg_expr = agg.aggr_expr.first().unwrap();
match agg_expr {
// Single Aggregate Function (e.g., SUM(col))
Expr::AggregateFunction(agg) if agg.args.len() == 1 => {
if let Expr::Column(col) = &agg.args[0] {
// build the key for the wheel
let wheel_key =
format!("{}.{}.{}", self.name, col.name, STAR_AGGREGATION_ALIAS);

let agg_type = func_def_to_aggregate_type(&agg.func)?;
let schema = Arc::new(plan.schema().clone().as_arrow().clone());

let result =
self.get_aggregate_landmark_result(agg_type, &wheel_key)?;

uwheel_agg_to_table_scan(result, schema).ok()
} else {
None
}
}
_ => None,
}
}
// Check whether it follows the pattern: SELECT * FROM X WHERE TIME >= X AND TIME <= Y
LogicalPlan::Filter(filter) => self.try_rewrite_filter(filter, plan),
_ => None,
Expand Down Expand Up @@ -544,6 +573,32 @@ impl UWheelOptimizer {
_ => unimplemented!(),
}
}

fn get_aggregate_landmark_result(
&self,
agg_type: UWheelAggregate,
wheel_key: &str,
) -> Option<f64> {
match agg_type {
UWheelAggregate::Sum => {
let wheel = self.wheels.sum.lock().unwrap().get(wheel_key)?.clone();
wheel.landmark()
}
UWheelAggregate::Avg => {
let wheel = self.wheels.avg.lock().unwrap().get(wheel_key)?.clone();
wheel.landmark().map(|(sum, count)| sum / count)
}
UWheelAggregate::Min => {
let wheel = self.wheels.min.lock().unwrap().get(wheel_key)?.clone();
wheel.landmark()
}
UWheelAggregate::Max => {
let wheel = self.wheels.max.lock().unwrap().get(wheel_key)?.clone();
wheel.landmark()
}
_ => unimplemented!(),
}
}
}

fn count_scan(count: u32, schema: SchemaRef) -> Result<LogicalPlan> {
Expand Down Expand Up @@ -1771,4 +1826,76 @@ mod tests {

Ok(())
}

#[tokio::test]
async fn full_scan_rewrite() -> Result<()> {
let optimizer = test_optimizer().await?;

optimizer
.build_index(IndexBuilder::with_col_and_aggregate(
"agg_col",
UWheelAggregate::Sum,
))
.await?;

let plan =
LogicalPlanBuilder::scan("test", provider_as_source(optimizer.provider()), None)?
.aggregate(Vec::<Expr>::new(), vec![sum(col("agg_col"))])?
.project(vec![sum(col("agg_col"))])?
.build()?;

// Assert that the original plan is a Projection
assert!(matches!(plan, LogicalPlan::Projection(_)));

let rewritten = optimizer.try_rewrite(&plan).unwrap();
// assert it was rewritten to a TableScan
assert!(matches!(rewritten, LogicalPlan::TableScan(_)));

Ok(())
}

#[tokio::test]
async fn full_scan_exec() -> Result<()> {
let optimizer = test_optimizer().await?;

optimizer
.build_index(IndexBuilder::with_col_and_aggregate(
"agg_col",
UWheelAggregate::Sum,
))
.await?;

let plan =
LogicalPlanBuilder::scan("test", provider_as_source(optimizer.provider()), None)?
.aggregate(Vec::<Expr>::new(), vec![sum(col("agg_col"))])?
.project(vec![sum(col("agg_col"))])?
.build()?;

let ctx = SessionContext::new();
ctx.register_table("test", optimizer.provider().clone())?;

// Set UWheelOptimizer as optimizer rule
let session_state = SessionStateBuilder::new()
.with_optimizer_rules(vec![optimizer.clone()])
.build();
let uwheel_ctx = SessionContext::new_with_state(session_state);

// Run the query through the ctx that has our OptimizerRule
let df = uwheel_ctx.execute_logical_plan(plan).await?;
let results = df.collect().await?;

assert_eq!(results.len(), 1);

assert_eq!(
results[0]
.column(0)
.as_any()
.downcast_ref::<Float64Array>()
.unwrap()
.value(0),
55.0
);

Ok(())
}
}