From 7c39936e16de96f953e0905e097ffe46a352542f Mon Sep 17 00:00:00 2001 From: Nico Ritschel Date: Thu, 8 Jan 2026 18:14:15 -0800 Subject: [PATCH 1/3] Fix AGGREGATE expansion in CTEs --- test/sql/measures.test | 16 ++ yardstick-rs/src/sql/measures.rs | 251 ++++++++++++++++++++++++++----- 2 files changed, 228 insertions(+), 39 deletions(-) diff --git a/test/sql/measures.test b/test/sql/measures.test index 8e503cb..753ef55 100644 --- a/test/sql/measures.test +++ b/test/sql/measures.test @@ -34,6 +34,22 @@ SEMANTIC SELECT year, region, AGGREGATE(revenue) FROM sales_v; 2023 EU 75.0 2023 US 150.0 +# ============================================================================= +# Test: CTE with AGGREGATE +# ============================================================================= + +query IIR rowsort +SEMANTIC WITH a AS ( + SELECT year, region, AGGREGATE(revenue) AS revenue + FROM sales_v +) +SELECT * FROM a; +---- +2022 EU 50.0 +2022 US 100.0 +2023 EU 75.0 +2023 US 150.0 + # ============================================================================= # Test: AT (ALL dimension) - remove dimension from context # ============================================================================= diff --git a/yardstick-rs/src/sql/measures.rs b/yardstick-rs/src/sql/measures.rs index 2e01a52..6dc275b 100644 --- a/yardstick-rs/src/sql/measures.rs +++ b/yardstick-rs/src/sql/measures.rs @@ -588,13 +588,11 @@ pub fn extract_table_and_alias_from_sql(sql: &str) -> Option<(String, Option>(after_from).ok()?; - let (rest, _) = tag_no_case::<_, _, nom::error::Error<&str>>("FROM")(rest).ok()?; + let (rest, _) = tag_no_case::<_, _, nom::error::Error<&str>>("FROM")(after_from).ok()?; let (rest, _) = multispace1::<_, nom::error::Error<&str>>(rest).ok()?; let (rest, table) = identifier(rest).ok()?; @@ -841,6 +839,165 @@ fn find_first_top_level_keyword(sql: &str, start: usize, keywords: &[&str]) -> O .min() } +fn skip_whitespace(sql: &str, mut idx: usize) -> usize { + while idx < sql.len() && sql.as_bytes()[idx].is_ascii_whitespace() { + idx += 1; + } + idx +} + +fn matches_keyword_at(upper: &str, idx: usize, keyword: &str) -> bool { + if idx + keyword.len() > upper.len() { + return false; + } + if &upper[idx..idx + keyword.len()] != keyword { + return false; + } + + let prev = if idx == 0 { + None + } else { + upper[..idx].chars().rev().next() + }; + let next = if idx + keyword.len() >= upper.len() { + None + } else { + upper[idx + keyword.len()..].chars().next() + }; + + is_boundary_char(prev) && is_boundary_char(next) +} + +struct CteExpansion { + sql: String, + had_aggregate: bool, +} + +fn expand_cte_queries(sql: &str) -> CteExpansion { + let with_pos = match find_top_level_keyword(sql, "WITH", 0) { + Some(pos) => pos, + None => { + return CteExpansion { + sql: sql.to_string(), + had_aggregate: false, + }; + } + }; + + let upper = sql.to_uppercase(); + let mut idx = with_pos + "WITH".len(); + let mut replacements: Vec<(usize, usize, String)> = Vec::new(); + let mut had_aggregate = false; + + idx = skip_whitespace(sql, idx); + if matches_keyword_at(&upper, idx, "RECURSIVE") { + idx += "RECURSIVE".len(); + } + + loop { + idx = skip_whitespace(sql, idx); + if idx >= sql.len() { + return CteExpansion { + sql: sql.to_string(), + had_aggregate: false, + }; + } + + let rest = &sql[idx..]; + let (after_ident, _) = match identifier(rest) { + Ok(value) => value, + Err(_) => { + return CteExpansion { + sql: sql.to_string(), + had_aggregate: false, + }; + } + }; + idx += rest.len() - after_ident.len(); + + idx = skip_whitespace(sql, idx); + if idx < sql.len() && sql.as_bytes()[idx] == b'(' { + let sub = &sql[idx + 1..]; + let (rest_after_cols, _) = match balanced_parens(sub) { + Ok(value) => value, + Err(_) => { + return CteExpansion { + sql: sql.to_string(), + had_aggregate: false, + }; + } + }; + let cols_len = sub.len() - rest_after_cols.len(); + idx = idx + 1 + cols_len + 1; + idx = skip_whitespace(sql, idx); + } + + if !matches_keyword_at(&upper, idx, "AS") { + return CteExpansion { + sql: sql.to_string(), + had_aggregate: false, + }; + } + idx += "AS".len(); + idx = skip_whitespace(sql, idx); + + if idx >= sql.len() || sql.as_bytes()[idx] != b'(' { + return CteExpansion { + sql: sql.to_string(), + had_aggregate: false, + }; + } + + let sub = &sql[idx + 1..]; + let (rest_after_query, _) = match balanced_parens(sub) { + Ok(value) => value, + Err(_) => { + return CteExpansion { + sql: sql.to_string(), + had_aggregate: false, + }; + } + }; + let query_len = sub.len() - rest_after_query.len(); + let query_start = idx + 1; + let query_end = query_start + query_len; + let query_sql = &sql[query_start..query_end]; + let expanded = expand_aggregate_with_at(query_sql); + if expanded.had_aggregate { + had_aggregate = true; + } + if expanded.expanded_sql != query_sql { + replacements.push((query_start, query_end, expanded.expanded_sql)); + } + + idx = query_end + 1; + idx = skip_whitespace(sql, idx); + if idx < sql.len() && sql.as_bytes()[idx] == b',' { + idx += 1; + continue; + } + break; + } + + if replacements.is_empty() { + return CteExpansion { + sql: sql.to_string(), + had_aggregate, + }; + } + + let mut result = sql.to_string(); + replacements.sort_by(|a, b| b.0.cmp(&a.0)); + for (start, end, replacement) in replacements { + result.replace_range(start..end, &replacement); + } + + CteExpansion { + sql: result, + had_aggregate, + } +} + fn extract_base_relation_sql(view_query: &str) -> Option { let query = view_query.trim().trim_end_matches(';').trim(); if query.is_empty() { @@ -1911,21 +2068,26 @@ fn find_expression_start(sql: &str, end: usize) -> usize { /// Expand AGGREGATE() function calls in a SELECT statement /// Uses position-based replacement with the C++ FFI parser pub fn expand_aggregate(sql: &str) -> AggregateExpandResult { - if !has_aggregate_function(sql) { + let cte_expansion = expand_cte_queries(sql); + let mut sql = cte_expansion.sql; + let mut had_aggregate = cte_expansion.had_aggregate; + + if !has_aggregate_function(&sql) { return AggregateExpandResult { - had_aggregate: false, - expanded_sql: sql.to_string(), + had_aggregate, + expanded_sql: sql, error: None, }; } + had_aggregate = true; // Parse the SQL to get table info and GROUP BY columns - let select_info = match parser_ffi::parse_select(sql) { + let select_info = match parser_ffi::parse_select(&sql) { Ok(info) => info, Err(e) => { return AggregateExpandResult { - had_aggregate: false, - expanded_sql: sql.to_string(), + had_aggregate, + expanded_sql: sql, error: Some(format!("SQL parse error: {e}")), }; } @@ -1939,7 +2101,7 @@ pub fn expand_aggregate(sql: &str) -> AggregateExpandResult { let measure_view = views.get(&table_name); // Extract all AGGREGATE() calls (without AT modifiers) - let mut aggregate_calls = extract_all_aggregate_calls(sql); + let mut aggregate_calls = extract_all_aggregate_calls(&sql); if aggregate_calls.is_empty() { return AggregateExpandResult { @@ -1967,7 +2129,7 @@ pub fn expand_aggregate(sql: &str) -> AggregateExpandResult { } // Build replacements - let mut result_sql = sql.to_string(); + let mut result_sql = sql; for (measure_name, start, end) in aggregate_calls { // Look up measure definition @@ -2028,7 +2190,7 @@ pub fn expand_aggregate(sql: &str) -> AggregateExpandResult { } AggregateExpandResult { - had_aggregate: true, + had_aggregate, expanded_sql: result_sql, error: None, } @@ -3331,25 +3493,30 @@ fn expand_modifiers_to_sql_derived( /// Expand AGGREGATE() with AT modifiers in SQL pub fn expand_aggregate_with_at(sql: &str) -> AggregateExpandResult { + let cte_expansion = expand_cte_queries(sql); + let mut sql = cte_expansion.sql; + let mut had_aggregate = cte_expansion.had_aggregate; + // Check if we need the full expansion path (AT modifiers or non-decomposable measures) - let has_aggregate = has_aggregate_function(sql); + let has_aggregate = has_aggregate_function(&sql); // If no AGGREGATE function at all, nothing to do if !has_aggregate { return AggregateExpandResult { - had_aggregate: false, - expanded_sql: sql.to_string(), + had_aggregate, + expanded_sql: sql, error: None, }; } + had_aggregate = true; - let at_patterns = extract_aggregate_with_at_full(sql); + let at_patterns = extract_aggregate_with_at_full(&sql); // Keep full expansion path even without AT to handle non-decomposable measures safely // Extract table info using string-based approach (works with AGGREGATE syntax) // Note: DuckDB's parser can't parse AGGREGATE() since it's our custom syntax let (primary_table_name, existing_alias) = - extract_table_and_alias_from_sql(sql).unwrap_or_else(|| ("t".to_string(), None)); + extract_table_and_alias_from_sql(&sql).unwrap_or_else(|| ("t".to_string(), None)); // Build FromClauseInfo from string-based extraction for now // TODO: For proper JOIN support, we'd need to extract all tables from the FROM clause @@ -3367,15 +3534,15 @@ pub fn expand_aggregate_with_at(sql: &str) -> AggregateExpandResult { from_info.primary_table = Some(primary_table); // Extract outer WHERE clause for VISIBLE semantics - let outer_where = extract_where_clause(sql); + let outer_where = extract_where_clause(&sql); let outer_where_ref = outer_where.as_deref(); // Extract GROUP BY columns for AT (ALL dim) correlation - let group_by_cols = extract_group_by_columns(sql); + let group_by_cols = extract_group_by_columns(&sql); // Extract dimension columns from original SQL for implicit GROUP BY // (must be done before expansion since expanded SQL has SUM() etc) - let original_dim_cols = extract_dimension_columns_from_select(sql); + let original_dim_cols = extract_dimension_columns_from_select(&sql); // Check if any AT modifier needs correlation (for alias handling) let needs_outer_alias = at_patterns.iter().any(|(_, modifiers, _, _)| { @@ -3386,7 +3553,7 @@ pub fn expand_aggregate_with_at(sql: &str) -> AggregateExpandResult { }) }); - let mut result_sql = sql.to_string(); + let mut result_sql = sql; // Handle alias for the primary table if needed for correlation let primary_alias: Option = if needs_outer_alias { @@ -3634,7 +3801,7 @@ pub fn expand_aggregate_with_at(sql: &str) -> AggregateExpandResult { } AggregateExpandResult { - had_aggregate: true, + had_aggregate, expanded_sql: result_sql, error: None, } @@ -3694,19 +3861,19 @@ pub fn get_measure_aggregation(column_name: &str) -> Option<(String, String)> { } fn extract_group_by_columns(sql: &str) -> Vec { - let sql_upper = sql.to_uppercase(); let mut columns = Vec::new(); - if let Some(group_by_pos) = sql_upper.find("GROUP BY") { - let after_group_by = &sql[group_by_pos + 8..]; - - let end_pos = ["ORDER BY", "LIMIT", "HAVING", ";"] - .iter() - .filter_map(|kw| after_group_by.to_uppercase().find(kw)) - .min() - .unwrap_or(after_group_by.len()); + let query = sql.trim().trim_end_matches(';').trim(); + if let Some(group_by_pos) = find_top_level_keyword(query, "GROUP BY", 0) { + let start = group_by_pos + "GROUP BY".len(); + let end = find_first_top_level_keyword( + query, + start, + &["ORDER BY", "LIMIT", "HAVING", "QUALIFY", "WINDOW", "UNION", "INTERSECT", "EXCEPT"], + ) + .unwrap_or(query.len()); - let group_by_content = after_group_by[..end_pos].trim(); + let group_by_content = query[start..end].trim(); for part in group_by_content.split(',') { let col = part.trim(); @@ -3726,18 +3893,24 @@ fn extract_group_by_columns(sql: &str) -> Vec { /// Extract non-AGGREGATE columns from SELECT clause to use as implicit GROUP BY columns fn extract_dimension_columns_from_select(sql: &str) -> Vec { - let sql_upper = sql.to_uppercase(); let mut columns = Vec::new(); - // Find SELECT ... FROM - let select_pos = sql_upper.find("SELECT").unwrap_or(0) + 6; - let from_pos = sql_upper.find("FROM").unwrap_or(sql.len()); + let query = sql.trim().trim_end_matches(';').trim(); + let select_pos = match find_top_level_keyword(query, "SELECT", 0) { + Some(pos) => pos, + None => return columns, + }; + let from_pos = match find_top_level_keyword(query, "FROM", select_pos) { + Some(pos) => pos, + None => return columns, + }; - if select_pos >= from_pos { + let select_start = select_pos + "SELECT".len(); + if select_start >= from_pos { return columns; } - let select_content = &sql[select_pos..from_pos]; + let select_content = &query[select_start..from_pos]; // Split by comma, but be careful about nested parens let mut depth = 0; From 1b45909dbadba7d724dd80aac35289bdbd419d12 Mon Sep 17 00:00:00 2001 From: Nico Ritschel Date: Thu, 8 Jan 2026 18:19:04 -0800 Subject: [PATCH 2/3] Fix CI compile error --- yardstick-rs/src/sql/measures.rs | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/yardstick-rs/src/sql/measures.rs b/yardstick-rs/src/sql/measures.rs index 6dc275b..334a4ed 100644 --- a/yardstick-rs/src/sql/measures.rs +++ b/yardstick-rs/src/sql/measures.rs @@ -2069,7 +2069,7 @@ fn find_expression_start(sql: &str, end: usize) -> usize { /// Uses position-based replacement with the C++ FFI parser pub fn expand_aggregate(sql: &str) -> AggregateExpandResult { let cte_expansion = expand_cte_queries(sql); - let mut sql = cte_expansion.sql; + let sql = cte_expansion.sql; let mut had_aggregate = cte_expansion.had_aggregate; if !has_aggregate_function(&sql) { @@ -2124,7 +2124,7 @@ pub fn expand_aggregate(sql: &str) -> AggregateExpandResult { .unwrap_or(false) }); if uses_non_decomposable { - return expand_aggregate_with_at(sql); + return expand_aggregate_with_at(&sql); } } @@ -3494,7 +3494,7 @@ fn expand_modifiers_to_sql_derived( /// Expand AGGREGATE() with AT modifiers in SQL pub fn expand_aggregate_with_at(sql: &str) -> AggregateExpandResult { let cte_expansion = expand_cte_queries(sql); - let mut sql = cte_expansion.sql; + let sql = cte_expansion.sql; let mut had_aggregate = cte_expansion.had_aggregate; // Check if we need the full expansion path (AT modifiers or non-decomposable measures) From 0f1f4cad2b782c14fcab9f0044b7b2ddc4192019 Mon Sep 17 00:00:00 2001 From: Nico Ritschel Date: Thu, 8 Jan 2026 18:40:18 -0800 Subject: [PATCH 3/3] Handle spacing in GROUP BY parsing --- yardstick-rs/src/sql/measures.rs | 21 +++++++++++++++++++-- 1 file changed, 19 insertions(+), 2 deletions(-) diff --git a/yardstick-rs/src/sql/measures.rs b/yardstick-rs/src/sql/measures.rs index 334a4ed..8e01ff0 100644 --- a/yardstick-rs/src/sql/measures.rs +++ b/yardstick-rs/src/sql/measures.rs @@ -868,6 +868,21 @@ fn matches_keyword_at(upper: &str, idx: usize, keyword: &str) -> bool { is_boundary_char(prev) && is_boundary_char(next) } +fn advance_after_group_by(query: &str, group_pos: usize) -> Option { + let upper = query.to_uppercase(); + let mut idx = group_pos; + if !matches_keyword_at(&upper, idx, "GROUP") { + return None; + } + idx += "GROUP".len(); + idx = skip_whitespace(query, idx); + if !matches_keyword_at(&upper, idx, "BY") { + return None; + } + idx += "BY".len(); + Some(skip_whitespace(query, idx)) +} + struct CteExpansion { sql: String, had_aggregate: bool, @@ -1110,7 +1125,8 @@ fn extract_view_group_by_cols(view_query: &str) -> Vec { None => return Vec::new(), }; - let start = group_pos + "GROUP BY".len(); + let start = advance_after_group_by(query, group_pos) + .unwrap_or_else(|| group_pos + "GROUP BY".len()); let end = find_first_top_level_keyword( query, start, @@ -3865,7 +3881,8 @@ fn extract_group_by_columns(sql: &str) -> Vec { let query = sql.trim().trim_end_matches(';').trim(); if let Some(group_by_pos) = find_top_level_keyword(query, "GROUP BY", 0) { - let start = group_by_pos + "GROUP BY".len(); + let start = advance_after_group_by(query, group_by_pos) + .unwrap_or_else(|| group_by_pos + "GROUP BY".len()); let end = find_first_top_level_keyword( query, start,