Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
190 changes: 180 additions & 10 deletions yardstick-rs/src/sql/measures.rs
Original file line number Diff line number Diff line change
Expand Up @@ -203,16 +203,135 @@ pub fn has_as_measure(sql: &str) -> bool {

/// Check if SQL contains AGGREGATE( function
pub fn has_aggregate_function(sql: &str) -> bool {
let sql_upper = sql.to_uppercase();
let mut search_pos = 0;
let chars: Vec<char> = sql.chars().collect();
let len = chars.len();
let mut i = 0;

while let Some(offset) = sql_upper[search_pos..].find("AGGREGATE") {
let start = search_pos + offset;
let after = &sql_upper[start + "AGGREGATE".len()..];
if after.trim_start().starts_with('(') {
return true;
let is_ident_start = |c: char| c.is_alphabetic() || c == '_';
let is_ident_char = |c: char| c.is_alphanumeric() || c == '_';

let skip_whitespace = |mut idx: usize| -> usize {
while idx < len && chars[idx].is_whitespace() {
idx += 1;
}
idx
};

let parse_identifier = |start: usize| -> (String, usize) {
let mut idx = start + 1;
while idx < len && is_ident_char(chars[idx]) {
idx += 1;
}
let token: String = chars[start..idx].iter().collect();
(token, idx)
};

let parse_quoted_identifier = |start: usize| -> (String, usize) {
let mut token = String::new();
let mut idx = start;
while idx < len {
match chars[idx] {
'"' => {
if idx + 1 < len && chars[idx + 1] == '"' {
token.push('"');
idx += 2;
} else {
idx += 1;
break;
}
}
c => {
token.push(c);
idx += 1;
}
}
}
(token, idx)
};

let parse_qualified_chain = |first: String, mut idx: usize| -> (String, usize) {
let mut last = first;
loop {
idx = skip_whitespace(idx);
if idx >= len || chars[idx] != '.' {
break;
}
idx += 1;
idx = skip_whitespace(idx);
if idx >= len {
break;
}
if chars[idx] == '"' {
let (token, next) = parse_quoted_identifier(idx + 1);
last = token;
idx = next;
} else if is_ident_start(chars[idx]) {
let (token, next) = parse_identifier(idx);
last = token;
idx = next;
} else {
break;
}
}
(last, idx)
};

let is_aggregate_token = |token: &str| token.eq_ignore_ascii_case("AGGREGATE");

while i < len {
match chars[i] {
'\'' => {
i += 1;
while i < len {
if chars[i] == '\'' {
if i + 1 < len && chars[i + 1] == '\'' {
i += 2;
continue;
}
i += 1;
break;
}
i += 1;
}
}
'-' if i + 1 < len && chars[i + 1] == '-' => {
i += 2;
while i < len && chars[i] != '\n' {
i += 1;
}
}
'/' if i + 1 < len && chars[i + 1] == '*' => {
i += 2;
while i + 1 < len {
if chars[i] == '*' && chars[i + 1] == '/' {
i += 2;
break;
}
i += 1;
}
}
'"' => {
let (token, next) = parse_quoted_identifier(i + 1);
let (last, after_chain) = parse_qualified_chain(token, next);
let after_ws = skip_whitespace(after_chain);
if after_ws < len && chars[after_ws] == '(' && is_aggregate_token(&last) {
return true;
}
i = after_chain;
}
c if is_ident_start(c) => {
let (token, next) = parse_identifier(i);
let (last, after_chain) = parse_qualified_chain(token, next);
let after_ws = skip_whitespace(after_chain);
if after_ws < len && chars[after_ws] == '(' && is_aggregate_token(&last) {
return true;
}
i = after_chain;
}
_ => {
i += 1;
}
}
search_pos = start + 1;
}

false
Expand Down Expand Up @@ -3767,10 +3886,10 @@ fn extract_dimension_columns_from_select(sql: &str) -> Vec<String> {

// Filter out AGGREGATE() calls and extract column names
for item in items {
let item_upper = item.to_uppercase();
if item_upper.contains("AGGREGATE(") {
if has_aggregate_function(&item) {

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P2 Badge Avoid skipping non-AGGREGATE functions that end with AGGREGATE

This now calls has_aggregate_function to decide whether to drop a SELECT item from implicit GROUP BY. That helper scans for any occurrence of AGGREGATE and returns true if the following non-whitespace char is (, so an expression like TOTAL_AGGREGATE (revenue) or a UDF named myaggregate will be treated as the special AGGREGATE function and excluded from grouping. In those cases, the dimension column is silently removed, changing query semantics. Consider matching AGGREGATE as a standalone function name (e.g., word boundary or full token) before skipping.

Useful? React with 👍 / 👎.

continue;
}
let item_upper = item.to_uppercase();
// Handle "col AS alias" - use the column name, not alias
let col = if let Some(as_pos) = item_upper.find(" AS ") {
item[..as_pos].trim()
Expand Down Expand Up @@ -3807,9 +3926,60 @@ mod tests {
#[test]
fn test_has_aggregate_function() {
assert!(has_aggregate_function("SELECT AGGREGATE(revenue) FROM foo"));
assert!(has_aggregate_function("SELECT AGGREGATE (revenue) FROM foo"));
assert!(has_aggregate_function("SELECT \"AGGREGATE\"(revenue) FROM foo"));
assert!(has_aggregate_function("SELECT schema.AGGREGATE(revenue) FROM foo"));
assert!(has_aggregate_function(
"SELECT \"schema\".\"AGGREGATE\" (revenue) FROM foo"
));
assert!(!has_aggregate_function("SELECT TOTAL_AGGREGATE(revenue) FROM foo"));
assert!(!has_aggregate_function("SELECT \"TOTAL_AGGREGATE\"(revenue) FROM foo"));
assert!(!has_aggregate_function("SELECT myaggregate(revenue) FROM foo"));
assert!(!has_aggregate_function("SELECT SUM(amount) FROM foo"));
}

#[test]
fn test_extract_dimension_columns_ignores_aggregate_with_space() {
let cols = extract_dimension_columns_from_select(
"SELECT region, AGGREGATE (revenue) FROM sales_v",
);
assert_eq!(cols, vec!["region".to_string()]);

let cols = extract_dimension_columns_from_select(
"SELECT region, AGGREGATE (revenue) AT (ALL region) FROM sales_v",
);
assert_eq!(cols, vec!["region".to_string()]);

let cols = extract_dimension_columns_from_select(
"SELECT AGGREGATE (revenue) FROM sales_v",
);
assert!(cols.is_empty());
}

#[test]
fn test_extract_dimension_columns_keeps_non_aggregate_suffix() {
let cols = extract_dimension_columns_from_select(
"SELECT region, TOTAL_AGGREGATE(revenue) FROM sales_v",
);
assert_eq!(
cols,
vec!["region".to_string(), "TOTAL_AGGREGATE(revenue)".to_string()]
);
}

#[test]
fn test_extract_dimension_columns_ignores_quoted_and_qualified_aggregate() {
let cols = extract_dimension_columns_from_select(
"SELECT region, \"AGGREGATE\"(revenue) FROM sales_v",
);
assert_eq!(cols, vec!["region".to_string()]);

let cols = extract_dimension_columns_from_select(
"SELECT region, schema.AGGREGATE(revenue) FROM sales_v",
);
assert_eq!(cols, vec!["region".to_string()]);
}

#[test]
#[serial]
fn test_process_create_view_basic() {
Expand Down
Loading