Skip to content
Closed
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 16 additions & 0 deletions test/sql/measures.test
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,22 @@ SEMANTIC SELECT year, region, AGGREGATE(revenue) FROM sales_v;
2023 EU 75.0
2023 US 150.0

# =============================================================================
# Test: CTE with AGGREGATE
# =============================================================================

query IIR rowsort
SEMANTIC WITH a AS (
SELECT year, region, AGGREGATE(revenue) AS revenue
FROM sales_v
)
SELECT * FROM a;
----
2022 EU 50.0
2022 US 100.0
2023 EU 75.0
2023 US 150.0

# =============================================================================
# Test: AT (ALL dimension) - remove dimension from context
# =============================================================================
Expand Down
251 changes: 212 additions & 39 deletions yardstick-rs/src/sql/measures.rs
Original file line number Diff line number Diff line change
Expand Up @@ -588,13 +588,11 @@ pub fn extract_table_and_alias_from_sql(sql: &str) -> Option<(String, Option<Str
.chars()
.map(|c| if c.is_whitespace() { ' ' } else { c })
.collect();
let normalized_upper = normalized.to_uppercase();
let from_pos = normalized_upper.find(" FROM ")?;
let from_pos = find_top_level_keyword(&normalized, "FROM", 0)?;
let after_from = &normalized[from_pos..];

// Parse: FROM table_name [AS] [alias]
let (rest, _) = multispace1::<_, nom::error::Error<&str>>(after_from).ok()?;
let (rest, _) = tag_no_case::<_, _, nom::error::Error<&str>>("FROM")(rest).ok()?;
let (rest, _) = tag_no_case::<_, _, nom::error::Error<&str>>("FROM")(after_from).ok()?;
let (rest, _) = multispace1::<_, nom::error::Error<&str>>(rest).ok()?;
let (rest, table) = identifier(rest).ok()?;

Expand Down Expand Up @@ -841,6 +839,165 @@ fn find_first_top_level_keyword(sql: &str, start: usize, keywords: &[&str]) -> O
.min()
}

fn skip_whitespace(sql: &str, mut idx: usize) -> usize {
while idx < sql.len() && sql.as_bytes()[idx].is_ascii_whitespace() {
idx += 1;
}
idx
}

fn matches_keyword_at(upper: &str, idx: usize, keyword: &str) -> bool {
if idx + keyword.len() > upper.len() {
return false;
}
if &upper[idx..idx + keyword.len()] != keyword {
return false;
}

let prev = if idx == 0 {
None
} else {
upper[..idx].chars().rev().next()
};
let next = if idx + keyword.len() >= upper.len() {
None
} else {
upper[idx + keyword.len()..].chars().next()
};

is_boundary_char(prev) && is_boundary_char(next)
}

struct CteExpansion {
sql: String,
had_aggregate: bool,
}

fn expand_cte_queries(sql: &str) -> CteExpansion {
let with_pos = match find_top_level_keyword(sql, "WITH", 0) {
Some(pos) => pos,
None => {
return CteExpansion {
sql: sql.to_string(),
had_aggregate: false,
};
}
};

let upper = sql.to_uppercase();
let mut idx = with_pos + "WITH".len();
let mut replacements: Vec<(usize, usize, String)> = Vec::new();
let mut had_aggregate = false;

idx = skip_whitespace(sql, idx);
if matches_keyword_at(&upper, idx, "RECURSIVE") {
idx += "RECURSIVE".len();
}

loop {
idx = skip_whitespace(sql, idx);
if idx >= sql.len() {
return CteExpansion {
sql: sql.to_string(),
had_aggregate: false,
};
}

let rest = &sql[idx..];
let (after_ident, _) = match identifier(rest) {
Ok(value) => value,
Err(_) => {
return CteExpansion {
sql: sql.to_string(),
had_aggregate: false,
};
}
};
idx += rest.len() - after_ident.len();

idx = skip_whitespace(sql, idx);
if idx < sql.len() && sql.as_bytes()[idx] == b'(' {
let sub = &sql[idx + 1..];
let (rest_after_cols, _) = match balanced_parens(sub) {
Ok(value) => value,
Err(_) => {
return CteExpansion {
sql: sql.to_string(),
had_aggregate: false,
};
}
};
let cols_len = sub.len() - rest_after_cols.len();
idx = idx + 1 + cols_len + 1;
idx = skip_whitespace(sql, idx);
}

if !matches_keyword_at(&upper, idx, "AS") {
return CteExpansion {
sql: sql.to_string(),
had_aggregate: false,
};
}
idx += "AS".len();
idx = skip_whitespace(sql, idx);

if idx >= sql.len() || sql.as_bytes()[idx] != b'(' {
return CteExpansion {
sql: sql.to_string(),
had_aggregate: false,
};
}

let sub = &sql[idx + 1..];
let (rest_after_query, _) = match balanced_parens(sub) {
Ok(value) => value,
Err(_) => {
return CteExpansion {
sql: sql.to_string(),
had_aggregate: false,
};
}
};
let query_len = sub.len() - rest_after_query.len();
let query_start = idx + 1;
let query_end = query_start + query_len;
let query_sql = &sql[query_start..query_end];
let expanded = expand_aggregate_with_at(query_sql);
if expanded.had_aggregate {
had_aggregate = true;
}
if expanded.expanded_sql != query_sql {
replacements.push((query_start, query_end, expanded.expanded_sql));
}

idx = query_end + 1;
idx = skip_whitespace(sql, idx);
if idx < sql.len() && sql.as_bytes()[idx] == b',' {
idx += 1;
continue;
}
break;
}

if replacements.is_empty() {
return CteExpansion {
sql: sql.to_string(),
had_aggregate,
};
}

let mut result = sql.to_string();
replacements.sort_by(|a, b| b.0.cmp(&a.0));
for (start, end, replacement) in replacements {
result.replace_range(start..end, &replacement);
}

CteExpansion {
sql: result,
had_aggregate,
}
}

fn extract_base_relation_sql(view_query: &str) -> Option<String> {
let query = view_query.trim().trim_end_matches(';').trim();
if query.is_empty() {
Expand Down Expand Up @@ -1911,21 +2068,26 @@ fn find_expression_start(sql: &str, end: usize) -> usize {
/// Expand AGGREGATE() function calls in a SELECT statement
/// Uses position-based replacement with the C++ FFI parser
pub fn expand_aggregate(sql: &str) -> AggregateExpandResult {
if !has_aggregate_function(sql) {
let cte_expansion = expand_cte_queries(sql);
let mut sql = cte_expansion.sql;
let mut had_aggregate = cte_expansion.had_aggregate;

if !has_aggregate_function(&sql) {
return AggregateExpandResult {
had_aggregate: false,
expanded_sql: sql.to_string(),
had_aggregate,
expanded_sql: sql,
error: None,
};
}
had_aggregate = true;

// Parse the SQL to get table info and GROUP BY columns
let select_info = match parser_ffi::parse_select(sql) {
let select_info = match parser_ffi::parse_select(&sql) {
Ok(info) => info,
Err(e) => {
return AggregateExpandResult {
had_aggregate: false,
expanded_sql: sql.to_string(),
had_aggregate,
expanded_sql: sql,
error: Some(format!("SQL parse error: {e}")),
};
}
Expand All @@ -1939,7 +2101,7 @@ pub fn expand_aggregate(sql: &str) -> AggregateExpandResult {
let measure_view = views.get(&table_name);

// Extract all AGGREGATE() calls (without AT modifiers)
let mut aggregate_calls = extract_all_aggregate_calls(sql);
let mut aggregate_calls = extract_all_aggregate_calls(&sql);

if aggregate_calls.is_empty() {
return AggregateExpandResult {
Expand Down Expand Up @@ -1967,7 +2129,7 @@ pub fn expand_aggregate(sql: &str) -> AggregateExpandResult {
}

// Build replacements
let mut result_sql = sql.to_string();
let mut result_sql = sql;

for (measure_name, start, end) in aggregate_calls {
// Look up measure definition
Expand Down Expand Up @@ -2028,7 +2190,7 @@ pub fn expand_aggregate(sql: &str) -> AggregateExpandResult {
}

AggregateExpandResult {
had_aggregate: true,
had_aggregate,
expanded_sql: result_sql,
error: None,
}
Expand Down Expand Up @@ -3331,25 +3493,30 @@ fn expand_modifiers_to_sql_derived(

/// Expand AGGREGATE() with AT modifiers in SQL
pub fn expand_aggregate_with_at(sql: &str) -> AggregateExpandResult {
let cte_expansion = expand_cte_queries(sql);
let mut sql = cte_expansion.sql;
let mut had_aggregate = cte_expansion.had_aggregate;

// Check if we need the full expansion path (AT modifiers or non-decomposable measures)
let has_aggregate = has_aggregate_function(sql);
let has_aggregate = has_aggregate_function(&sql);

// If no AGGREGATE function at all, nothing to do
if !has_aggregate {
return AggregateExpandResult {
had_aggregate: false,
expanded_sql: sql.to_string(),
had_aggregate,
expanded_sql: sql,
error: None,
};
}
had_aggregate = true;

let at_patterns = extract_aggregate_with_at_full(sql);
let at_patterns = extract_aggregate_with_at_full(&sql);
// Keep full expansion path even without AT to handle non-decomposable measures safely

// Extract table info using string-based approach (works with AGGREGATE syntax)
// Note: DuckDB's parser can't parse AGGREGATE() since it's our custom syntax
let (primary_table_name, existing_alias) =
extract_table_and_alias_from_sql(sql).unwrap_or_else(|| ("t".to_string(), None));
extract_table_and_alias_from_sql(&sql).unwrap_or_else(|| ("t".to_string(), None));

// Build FromClauseInfo from string-based extraction for now
// TODO: For proper JOIN support, we'd need to extract all tables from the FROM clause
Expand All @@ -3367,15 +3534,15 @@ pub fn expand_aggregate_with_at(sql: &str) -> AggregateExpandResult {
from_info.primary_table = Some(primary_table);

// Extract outer WHERE clause for VISIBLE semantics
let outer_where = extract_where_clause(sql);
let outer_where = extract_where_clause(&sql);
let outer_where_ref = outer_where.as_deref();

// Extract GROUP BY columns for AT (ALL dim) correlation
let group_by_cols = extract_group_by_columns(sql);
let group_by_cols = extract_group_by_columns(&sql);

// Extract dimension columns from original SQL for implicit GROUP BY
// (must be done before expansion since expanded SQL has SUM() etc)
let original_dim_cols = extract_dimension_columns_from_select(sql);
let original_dim_cols = extract_dimension_columns_from_select(&sql);

// Check if any AT modifier needs correlation (for alias handling)
let needs_outer_alias = at_patterns.iter().any(|(_, modifiers, _, _)| {
Expand All @@ -3386,7 +3553,7 @@ pub fn expand_aggregate_with_at(sql: &str) -> AggregateExpandResult {
})
});

let mut result_sql = sql.to_string();
let mut result_sql = sql;

// Handle alias for the primary table if needed for correlation
let primary_alias: Option<String> = if needs_outer_alias {
Expand Down Expand Up @@ -3634,7 +3801,7 @@ pub fn expand_aggregate_with_at(sql: &str) -> AggregateExpandResult {
}

AggregateExpandResult {
had_aggregate: true,
had_aggregate,
expanded_sql: result_sql,
error: None,
}
Expand Down Expand Up @@ -3694,19 +3861,19 @@ pub fn get_measure_aggregation(column_name: &str) -> Option<(String, String)> {
}

fn extract_group_by_columns(sql: &str) -> Vec<String> {
let sql_upper = sql.to_uppercase();
let mut columns = Vec::new();

if let Some(group_by_pos) = sql_upper.find("GROUP BY") {
let after_group_by = &sql[group_by_pos + 8..];

let end_pos = ["ORDER BY", "LIMIT", "HAVING", ";"]
.iter()
.filter_map(|kw| after_group_by.to_uppercase().find(kw))
.min()
.unwrap_or(after_group_by.len());
let query = sql.trim().trim_end_matches(';').trim();
if let Some(group_by_pos) = find_top_level_keyword(query, "GROUP BY", 0) {
let start = group_by_pos + "GROUP BY".len();
let end = find_first_top_level_keyword(
Comment on lines 3882 to 3886

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P2 Badge Advance past actual GROUP BY spacing

The new find_top_level_keyword will match valid GROUP BY (multiple spaces), but start is calculated as group_by_pos + "GROUP BY".len(), which assumes exactly one space. For inputs like GROUP BY col, start points at the B in BY, so group_by_content becomes BY col and the first extracted column is BY col, which then feeds incorrect group_by_cols into AT correlation and expansion. Consider advancing past the real keyword end (e.g., skip whitespace after GROUP and BY) so arbitrary spacing doesn’t corrupt the parsed columns.

Useful? React with 👍 / 👎.

query,
start,
&["ORDER BY", "LIMIT", "HAVING", "QUALIFY", "WINDOW", "UNION", "INTERSECT", "EXCEPT"],
)
.unwrap_or(query.len());

let group_by_content = after_group_by[..end_pos].trim();
let group_by_content = query[start..end].trim();

for part in group_by_content.split(',') {
let col = part.trim();
Expand All @@ -3726,18 +3893,24 @@ fn extract_group_by_columns(sql: &str) -> Vec<String> {

/// Extract non-AGGREGATE columns from SELECT clause to use as implicit GROUP BY columns
fn extract_dimension_columns_from_select(sql: &str) -> Vec<String> {
let sql_upper = sql.to_uppercase();
let mut columns = Vec::new();

// Find SELECT ... FROM
let select_pos = sql_upper.find("SELECT").unwrap_or(0) + 6;
let from_pos = sql_upper.find("FROM").unwrap_or(sql.len());
let query = sql.trim().trim_end_matches(';').trim();
let select_pos = match find_top_level_keyword(query, "SELECT", 0) {
Some(pos) => pos,
None => return columns,
};
let from_pos = match find_top_level_keyword(query, "FROM", select_pos) {
Some(pos) => pos,
None => return columns,
};

if select_pos >= from_pos {
let select_start = select_pos + "SELECT".len();
if select_start >= from_pos {
return columns;
}

let select_content = &sql[select_pos..from_pos];
let select_content = &query[select_start..from_pos];

// Split by comma, but be careful about nested parens
let mut depth = 0;
Expand Down
Loading