Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
216 changes: 206 additions & 10 deletions datafusion-uwheel/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -265,6 +265,16 @@ impl UWheelOptimizer {
agg.group_expr.is_empty() && agg.aggr_expr.len() == 1
}

/// checks whether the Aggregate has a single group_by expression
fn single_group_by(agg: &Aggregate) -> bool {
agg.group_expr.len() == 1
}

/// check whether the Aggregate has no group_expr and aggr_expr has a length greater than 1
fn multiple_aggregates(agg: &Aggregate) -> bool {
agg.group_expr.is_empty() && agg.aggr_expr.len() > 1
}

// Attemps to rewrite a top-level Projection plan
fn try_rewrite_projection(
&self,
Expand Down Expand Up @@ -315,7 +325,7 @@ impl UWheelOptimizer {
}
}

LogicalPlan::Aggregate(agg) => {
LogicalPlan::Aggregate(agg) if Self::single_group_by(agg) => {
let group_expr = agg.group_expr.first()?;

// Only continue if the aggregation has a filter
Expand Down Expand Up @@ -372,6 +382,57 @@ impl UWheelOptimizer {
}
None
}

LogicalPlan::Aggregate(agg) if Self::multiple_aggregates(agg) => {
// Only continue if the aggregation has a filter
let LogicalPlan::Filter(filter) = agg.input.as_ref() else {
return None;
};

let agg_exprs = &agg.aggr_expr;

let mut agg_results = Vec::new();

for agg_expr in agg_exprs {
match agg_expr {
// Single Aggregate Function (e.g., SUM(col))
Expr::AggregateFunction(agg) if agg.args.len() == 1 => {
if let Expr::Column(col) = &agg.args[0] {
// Fetch temporal filter range and expr key which is used to identify a wheel
let (range, expr_key) = match extract_filter_expr(
&filter.predicate,
&self.time_column,
)? {
(range, Some(expr)) => {
(range, maybe_replace_table_name(&expr, &self.name))
}
(range, None) => (range, STAR_AGGREGATION_ALIAS.to_string()),
};

// build the key for the wheel
let wheel_key = format!("{}.{}.{}", self.name, col.name, expr_key);

let agg_type = func_def_to_aggregate_type(&agg.func)?;

// get aggregation result
let result =
self.get_aggregate_result(agg_type, &wheel_key, range)?;

agg_results.push(result);
} else {
return None;
}
}
_ => {
return None;
}
}
}

let schema = Arc::new(plan.schema().clone().as_arrow().clone());

uwheel_multiple_aggregations_to_table_scan(agg_results, schema).ok()
}
// Check whether it follows the pattern: SELECT * FROM X WHERE TIME >= X AND TIME <= Y
LogicalPlan::Filter(filter) => self.try_rewrite_filter(filter, plan),
_ => None,
Expand Down Expand Up @@ -453,27 +514,32 @@ impl UWheelOptimizer {
range: WheelRange,
schema: SchemaRef,
) -> Option<LogicalPlan> {
let result = self.get_aggregate_result(agg_type, wheel_key, range)?;
uwheel_agg_to_table_scan(result, schema).ok()
}

fn get_aggregate_result(
&self,
agg_type: UWheelAggregate,
wheel_key: &str,
range: WheelRange,
) -> Option<f64> {
match agg_type {
UWheelAggregate::Sum => {
let wheel = self.wheels.sum.lock().unwrap().get(wheel_key)?.clone();
let result = wheel.combine_range_and_lower(range)?;
uwheel_agg_to_table_scan(result, schema).ok()
wheel.combine_range_and_lower(range)
}
UWheelAggregate::Avg => {
let wheel = self.wheels.avg.lock().unwrap().get(wheel_key)?.clone();
let result = wheel.combine_range_and_lower(range)?;

uwheel_agg_to_table_scan(result, schema).ok()
wheel.combine_range_and_lower(range)
}
UWheelAggregate::Min => {
let wheel = self.wheels.min.lock().unwrap().get(wheel_key)?.clone();
let result = wheel.combine_range_and_lower(range)?;
uwheel_agg_to_table_scan(result, schema).ok()
wheel.combine_range_and_lower(range)
}
UWheelAggregate::Max => {
let wheel = self.wheels.max.lock().unwrap().get(wheel_key)?.clone();
let result = wheel.combine_range_and_lower(range)?;
uwheel_agg_to_table_scan(result, schema).ok()
wheel.combine_range_and_lower(range)
}
_ => unimplemented!(),
}
Expand Down Expand Up @@ -517,6 +583,24 @@ fn uwheel_group_by_to_table_scan(
mem_table_as_table_scan(mem_table, df_schema)
}

fn uwheel_multiple_aggregations_to_table_scan(
agg_results: Vec<f64>,
schema: SchemaRef,
) -> Result<LogicalPlan> {
let mut columns = Vec::new();

for result in agg_results {
let data = Float64Array::from(vec![result]);
columns.push(Arc::new(data) as Arc<dyn Array>);
}

let record_batch = RecordBatch::try_new(schema.clone(), columns)?;

let df_schema = Arc::new(DFSchema::try_from(schema.clone())?);
let mem_table = MemTable::try_new(schema, vec![vec![record_batch]])?;
mem_table_as_table_scan(mem_table, df_schema)
}

// helper for possibly removing the table name from the expression key
fn maybe_replace_table_name(expr: &Expr, table_name: &str) -> String {
let expr_str = expr.to_string();
Expand Down Expand Up @@ -1575,4 +1659,116 @@ mod tests {

Ok(())
}

#[tokio::test]
async fn multiple_aggregation_rewrite() -> Result<()> {
let optimizer = test_optimizer().await?;

optimizer
.build_index(IndexBuilder::with_col_and_aggregate(
"agg_col",
UWheelAggregate::Avg,
))
.await?;

optimizer
.build_index(IndexBuilder::with_col_and_aggregate(
"agg_col",
UWheelAggregate::Sum,
))
.await?;

let temporal_filter = col("timestamp")
.gt_eq(lit("2024-05-10T00:00:00Z"))
.and(col("timestamp").lt(lit("2024-05-10T00:00:10Z")));

let plan =
LogicalPlanBuilder::scan("test", provider_as_source(optimizer.provider()), None)?
.filter(temporal_filter)?
.aggregate(
Vec::<Expr>::new(),
vec![avg(col("agg_col")), sum(col("agg_col"))],
)?
.project(vec![avg(col("agg_col")), sum(col("agg_col"))])?
.build()?;

// Assert that the original plan is a Projection
assert!(matches!(plan, LogicalPlan::Projection(_)));

let rewritten = optimizer.try_rewrite(&plan).unwrap();
// assert it was rewritten to a TableScan
assert!(matches!(rewritten, LogicalPlan::TableScan(_)));

Ok(())
}

#[tokio::test]
async fn multiple_aggregation_exec() -> Result<()> {
let optimizer = test_optimizer().await?;

optimizer
.build_index(IndexBuilder::with_col_and_aggregate(
"agg_col",
UWheelAggregate::Avg,
))
.await?;

optimizer
.build_index(IndexBuilder::with_col_and_aggregate(
"agg_col",
UWheelAggregate::Sum,
))
.await?;

let temporal_filter = col("timestamp")
.gt_eq(lit("2024-05-10T00:00:00Z"))
.and(col("timestamp").lt(lit("2024-05-10T00:00:10Z")));

let plan =
LogicalPlanBuilder::scan("test", provider_as_source(optimizer.provider()), None)?
.filter(temporal_filter)?
.aggregate(
Vec::<Expr>::new(),
vec![avg(col("agg_col")), sum(col("agg_col"))],
)?
.project(vec![avg(col("agg_col")), sum(col("agg_col"))])?
.build()?;

let ctx = SessionContext::new();
ctx.register_table("test", optimizer.provider().clone())?;

// Set UWheelOptimizer as optimizer rule
let session_state = SessionStateBuilder::new()
.with_optimizer_rules(vec![optimizer.clone()])
.build();
let uwheel_ctx = SessionContext::new_with_state(session_state);

// Run the query through the ctx that has our OptimizerRule
let df = uwheel_ctx.execute_logical_plan(plan).await?;
let results = df.collect().await?;

assert_eq!(results.len(), 1);

assert_eq!(
results[0]
.column(0)
.as_any()
.downcast_ref::<Float64Array>()
.unwrap()
.value(0),
5.5
);

assert_eq!(
results[0]
.column(1)
.as_any()
.downcast_ref::<Float64Array>()
.unwrap()
.value(0),
55.0
);

Ok(())
}
}