Skip to content

Commit 680e25a

Browse files
committed
Fix AT correlation with joins
1 parent 6001370 commit 680e25a

File tree

2 files changed

+65
-7
lines changed

2 files changed

+65
-7
lines changed

test/sql/measures.test

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -818,6 +818,40 @@ FROM fact_orders_v o JOIN fact_returns_v r ON o.year = r.year AND o.region = r.r
818818
2023 EU 75.0 225.0
819819
2023 US 150.0 225.0
820820

821+
# =============================================================================
822+
# Test: JOIN with extra dimension from second table
823+
# =============================================================================
824+
825+
statement ok
826+
CREATE TABLE salesdetails (year INT, region TEXT, product TEXT, amount DOUBLE);
827+
828+
statement ok
829+
INSERT INTO salesdetails VALUES
830+
(2022, 'US', 'Shoes', 2), (2022, 'US', 'Cars', 1),
831+
(2022, 'EU', 'Shoes', 3),
832+
(2023, 'US', 'Shoes', 4), (2023, 'US', 'Cars', 2),
833+
(2023, 'EU', 'Cars', 5);
834+
835+
statement ok
836+
CREATE VIEW salesdetails_v AS
837+
SELECT year, region, product, SUM(amount) AS MEASURE quantity
838+
FROM salesdetails;
839+
840+
query IIIRRR rowsort
841+
SEMANTIC SELECT s.year, s.region, sd.product,
842+
AGGREGATE(revenue) AS year_sales_revenue,
843+
AGGREGATE(revenue) AT (ALL year) AS region_total,
844+
AGGREGATE(quantity) AS product_qty
845+
FROM sales_v s JOIN salesdetails_v sd ON s.year = sd.year AND s.region = sd.region
846+
;
847+
----
848+
2022 EU Shoes 50.0 125.0 3.0
849+
2022 US Cars 100.0 250.0 1.0
850+
2022 US Shoes 100.0 250.0 2.0
851+
2023 EU Cars 75.0 125.0 5.0
852+
2023 US Cars 150.0 250.0 2.0
853+
2023 US Shoes 150.0 250.0 4.0
854+
821855
# =============================================================================
822856
# Test: SET reaches beyond WHERE clause (paper semantics)
823857
# Per paper: SET should evaluate over data removed by outer WHERE clause

yardstick-rs/src/sql/measures.rs

Lines changed: 31 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
//!
88
//! Reference: https://arxiv.org/abs/2406.00251
99
10-
use std::collections::HashMap;
10+
use std::collections::{HashMap, HashSet};
1111
use std::sync::Mutex;
1212

1313
use nom::{
@@ -1051,6 +1051,26 @@ fn group_by_matches_view(outer_cols: &[String], view_cols: &[String]) -> bool {
10511051
!outer_set.is_empty() && outer_set == view_set
10521052
}
10531053

1054+
fn filter_group_by_cols_for_measure(
1055+
outer_cols: &[String],
1056+
view_cols: &[String],
1057+
) -> Vec<String> {
1058+
if view_cols.is_empty() {
1059+
return outer_cols.to_vec();
1060+
}
1061+
1062+
let view_set: HashSet<String> = view_cols
1063+
.iter()
1064+
.map(|col| normalize_group_by_col(col))
1065+
.collect();
1066+
1067+
outer_cols
1068+
.iter()
1069+
.filter(|col| view_set.contains(&normalize_group_by_col(col)))
1070+
.cloned()
1071+
.collect()
1072+
}
1073+
10541074
fn can_use_view_measure_directly(resolved: &ResolvedMeasure, outer_group_by: &[String]) -> bool {
10551075
group_by_matches_view(outer_group_by, &resolved.view_group_by_cols)
10561076
}
@@ -3416,6 +3436,8 @@ pub fn expand_aggregate_with_at(sql: &str) -> AggregateExpandResult {
34163436
for (measure_name, modifiers, start, end) in patterns {
34173437
// Look up which view contains this measure (for JOIN support)
34183438
let resolved = resolve_measure_source(&measure_name, &primary_table_name);
3439+
let measure_group_by_cols =
3440+
filter_group_by_cols_for_measure(&group_by_cols, &resolved.view_group_by_cols);
34193441

34203442
// Non-decomposable measures are recomputed from base rows (including AT modifiers)
34213443

@@ -3452,7 +3474,7 @@ pub fn expand_aggregate_with_at(sql: &str) -> AggregateExpandResult {
34523474
&resolved.source_view,
34533475
outer_alias_ref,
34543476
outer_where_ref,
3455-
&group_by_cols,
3477+
&measure_group_by_cols,
34563478
)
34573479
} else if !resolved.is_decomposable {
34583480
let outer_ref_for_non_decomp =
@@ -3480,7 +3502,7 @@ pub fn expand_aggregate_with_at(sql: &str) -> AggregateExpandResult {
34803502
&base_relation_sql,
34813503
outer_ref_for_non_decomp,
34823504
outer_where_ref,
3483-
&group_by_cols,
3505+
&measure_group_by_cols,
34843506
&modifiers,
34853507
&resolved.dimension_exprs,
34863508
&format!("_nd_{join_counter}"),
@@ -3496,7 +3518,7 @@ pub fn expand_aggregate_with_at(sql: &str) -> AggregateExpandResult {
34963518
&base_relation_sql,
34973519
outer_ref_for_non_decomp,
34983520
outer_where_ref,
3499-
&group_by_cols,
3521+
&measure_group_by_cols,
35003522
&modifiers,
35013523
&resolved.dimension_exprs,
35023524
),
@@ -3517,7 +3539,7 @@ pub fn expand_aggregate_with_at(sql: &str) -> AggregateExpandResult {
35173539
&resolved.source_view,
35183540
outer_alias_ref,
35193541
outer_where_ref,
3520-
&group_by_cols,
3542+
&measure_group_by_cols,
35213543
)
35223544
};
35233545
result_sql = format!("{}{}{}", &result_sql[..start], expanded, &result_sql[end..]);
@@ -3529,6 +3551,8 @@ pub fn expand_aggregate_with_at(sql: &str) -> AggregateExpandResult {
35293551

35303552
for (measure_name, start, end) in plain_calls {
35313553
let resolved = resolve_measure_source(&measure_name, &primary_table_name);
3554+
let measure_group_by_cols =
3555+
filter_group_by_cols_for_measure(&group_by_cols, &resolved.view_group_by_cols);
35323556

35333557

35343558
// For derived measures, use the expanded expression; otherwise use AGG(measure_name)
@@ -3558,7 +3582,7 @@ pub fn expand_aggregate_with_at(sql: &str) -> AggregateExpandResult {
35583582
&base_relation_sql,
35593583
outer_ref_for_non_decomp,
35603584
outer_where_ref,
3561-
&group_by_cols,
3585+
&measure_group_by_cols,
35623586
&[], // No modifiers for plain AGGREGATE()
35633587
&resolved.dimension_exprs,
35643588
&format!("_nd_{join_counter}"),
@@ -3574,7 +3598,7 @@ pub fn expand_aggregate_with_at(sql: &str) -> AggregateExpandResult {
35743598
&base_relation_sql,
35753599
outer_ref_for_non_decomp,
35763600
outer_where_ref,
3577-
&group_by_cols,
3601+
&measure_group_by_cols,
35783602
&[], // No modifiers for plain AGGREGATE()
35793603
&resolved.dimension_exprs,
35803604
),

0 commit comments

Comments
 (0)