diff --git a/crates/squawk_ide/src/goto_definition.rs b/crates/squawk_ide/src/goto_definition.rs index 85c6172c..0c6c4b8e 100644 --- a/crates/squawk_ide/src/goto_definition.rs +++ b/crates/squawk_ide/src/goto_definition.rs @@ -2851,4 +2851,155 @@ select f.b$0 from t as f(x); ╰╴ ─ 1. source "); } + + #[test] + fn goto_join_table() { + assert_snapshot!(goto(" +create table users(id int, email text); +create table messages(id int, user_id int, message text); +select * from users join messages$0 on users.id = messages.user_id; +"), @r" + ╭▸ + 3 │ create table messages(id int, user_id int, message text); + │ ──────── 2. destination + 4 │ select * from users join messages on users.id = messages.user_id; + ╰╴ ─ 1. source + "); + } + + #[test] + fn goto_join_qualified_column_from_joined_table() { + assert_snapshot!(goto(" +create table users(id int, email text); +create table messages(id int, user_id int, message text); +select messages.user_id$0 from users join messages on users.id = messages.user_id; +"), @r" + ╭▸ + 3 │ create table messages(id int, user_id int, message text); + │ ─────── 2. destination + 4 │ select messages.user_id from users join messages on users.id = messages.user_id; + ╰╴ ─ 1. source + "); + } + + #[test] + fn goto_join_qualified_column_from_base_table() { + assert_snapshot!(goto(" +create table users(id int, email text); +create table messages(id int, user_id int, message text); +select users.id$0 from users join messages on users.id = messages.user_id; +"), @r" + ╭▸ + 2 │ create table users(id int, email text); + │ ── 2. destination + 3 │ create table messages(id int, user_id int, message text); + 4 │ select users.id from users join messages on users.id = messages.user_id; + ╰╴ ─ 1. source + "); + } + + #[test] + fn goto_join_multiple_joins() { + assert_snapshot!(goto(" +create table users(id int, name text); +create table messages(id int, user_id int, message text); +create table comments(id int, message_id int, text text); +select comments.text$0 from users + join messages on users.id = messages.user_id + join comments on messages.id = comments.message_id; +"), @r" + ╭▸ + 4 │ create table comments(id int, message_id int, text text); + │ ──── 2. destination + 5 │ select comments.text from users + ╰╴ ─ 1. source + "); + } + + #[test] + fn goto_join_with_aliases() { + assert_snapshot!(goto(" +create table users(id int, name text); +create table messages(id int, user_id int, message text); +select m.message$0 from users as u join messages as m on u.id = m.user_id; +"), @r" + ╭▸ + 3 │ create table messages(id int, user_id int, message text); + │ ─────── 2. destination + 4 │ select m.message from users as u join messages as m on u.id = m.user_id; + ╰╴ ─ 1. source + "); + } + + #[test] + fn goto_join_unqualified_column() { + assert_snapshot!(goto(" +create table users(id int, email text); +create table messages(id int, user_id int, message text); +select message$0 from users join messages on users.id = messages.user_id; +"), @r" + ╭▸ + 3 │ create table messages(id int, user_id int, message text); + │ ─────── 2. destination + 4 │ select message from users join messages on users.id = messages.user_id; + ╰╴ ─ 1. source + "); + } + + #[test] + fn goto_join_with_many_tables() { + assert_snapshot!(goto(" +create table users(id int, email text); +create table messages(id int, user_id int, message text); +create table logins(id int, user_id int, at timestamptz); +create table posts(id int, user_id int, post text); + +select post$0 + from users + join messages + on users.id = messages.user_id + join logins + on users.id = logins.user_id + join posts + on users.id = posts.user_id +"), @r" + ╭▸ + 5 │ create table posts(id int, user_id int, post text); + │ ──── 2. destination + 6 │ + 7 │ select post + ╰╴ ─ 1. source + "); + } + + #[test] + fn goto_join_with_schema() { + assert_snapshot!(goto(" +create schema foo; +create table foo.users(id int, email text); +create table foo.messages(id int, user_id int, message text); +select foo.messages.message$0 from foo.users join foo.messages on foo.users.id = foo.messages.user_id; +"), @r" + ╭▸ + 4 │ create table foo.messages(id int, user_id int, message text); + │ ─────── 2. destination + 5 │ select foo.messages.message from foo.users join foo.messages on foo.users.id = foo.messages.user_id; + ╰╴ ─ 1. source + "); + } + + #[test] + fn goto_join_left_join() { + assert_snapshot!(goto(" +create table users(id int, email text); +create table messages(id int, user_id int, message text); +select messages.message$0 from users left join messages on users.id = messages.user_id; +"), @r" + ╭▸ + 3 │ create table messages(id int, user_id int, message text); + │ ─────── 2. destination + 4 │ select messages.message from users left join messages on users.id = messages.user_id; + ╰╴ ─ 1. source + "); + } } diff --git a/crates/squawk_ide/src/resolve.rs b/crates/squawk_ide/src/resolve.rs index 654aba3a..01a906a5 100644 --- a/crates/squawk_ide/src/resolve.rs +++ b/crates/squawk_ide/src/resolve.rs @@ -186,7 +186,7 @@ pub(crate) fn resolve_name_ref(binder: &Binder, name_ref: &ast::NameRef) -> Opti // select a(t) from t; // ``` if schema.is_none() - && let Some(ptr) = resolve_function_call_style_column(binder, name_ref) + && let Some(ptr) = resolve_fn_call_column(binder, name_ref) { return Some(ptr); } @@ -630,7 +630,7 @@ fn resolve_select_qualified_column_table( let select = name_ref.syntax().ancestors().find_map(ast::Select::cast)?; let from_clause = select.from_clause()?; - let from_item = from_clause.from_items().next()?; + let from_item = find_from_item_in_from_clause(&from_clause, &table_name)?; if let Some(alias_name) = from_item.alias().and_then(|a| a.name()) && Name::from_node(&alias_name) == table_name @@ -702,7 +702,7 @@ fn resolve_select_qualified_column( } else { let select = name_ref.syntax().ancestors().find_map(ast::Select::cast)?; let from_clause = select.from_clause()?; - let from_item = from_clause.from_items().next()?; + let from_item = find_from_item_in_from_clause(&from_clause, &column_table_name)?; // `from t as u` // `from t as u(a, b, c)` @@ -791,13 +791,12 @@ fn resolve_select_qualified_column( resolve_function(binder, &column_name, &schema, None, position) } -fn resolve_select_column(binder: &Binder, name_ref: &ast::NameRef) -> Option { +fn resolve_from_item_for_column( + binder: &Binder, + from_item: &ast::FromItem, + name_ref: &ast::NameRef, +) -> Option { let column_name = Name::from_node(name_ref); - - let select = name_ref.syntax().ancestors().find_map(ast::Select::cast)?; - let from_clause = select.from_clause()?; - let from_item = from_clause.from_items().next()?; - if let Some(paren_select) = from_item.paren_select() { return resolve_subquery_column(&paren_select, &column_name); } @@ -855,6 +854,50 @@ fn resolve_select_column(binder: &Binder, name_ref: &ast::NameRef) -> Option(join_expr: &ast::JoinExpr, try_resolve: &F) -> Option +where + F: Fn(&ast::FromItem) -> Option, +{ + if let Some(nested_join) = join_expr.join_expr() + && let Some(result) = resolve_from_join_expr(&nested_join, try_resolve) + { + return Some(result); + } + if let Some(from_item) = join_expr.from_item() + && let Some(result) = try_resolve(&from_item) + { + return Some(result); + } + if let Some(join) = join_expr.join() + && let Some(from_item) = join.from_item() + && let Some(result) = try_resolve(&from_item) + { + return Some(result); + } + None +} + +fn resolve_select_column(binder: &Binder, name_ref: &ast::NameRef) -> Option { + let select = name_ref.syntax().ancestors().find_map(ast::Select::cast)?; + let from_clause = select.from_clause()?; + + for from_item in from_clause.from_items() { + if let Some(result) = resolve_from_item_for_column(binder, &from_item, name_ref) { + return Some(result); + } + } + + for join_expr in from_clause.join_exprs() { + if let Some(result) = resolve_from_join_expr(&join_expr, &|from_item: &ast::FromItem| { + resolve_from_item_for_column(binder, from_item, name_ref) + }) { + return Some(result); + } + } + + None +} + fn resolve_delete_where_column(binder: &Binder, name_ref: &ast::NameRef) -> Option { let column_name = Name::from_node(name_ref); @@ -887,10 +930,7 @@ fn resolve_delete_where_column(binder: &Binder, name_ref: &ast::NameRef) -> Opti None } -fn resolve_function_call_style_column( - binder: &Binder, - name_ref: &ast::NameRef, -) -> Option { +fn resolve_fn_call_column(binder: &Binder, name_ref: &ast::NameRef) -> Option { let column_name = Name::from_node(name_ref); // function call syntax for columns is only valid if there is one argument @@ -905,9 +945,32 @@ fn resolve_function_call_style_column( let select = name_ref.syntax().ancestors().find_map(ast::Select::cast)?; let from_clause = select.from_clause()?; - let from_item = from_clause.from_items().next()?; - // get the table name and schema from the FROM clause + for from_item in from_clause.from_items() { + if let Some(result) = + resolve_from_item_for_fn_call_column(binder, &from_item, &column_name, name_ref) + { + return Some(result); + } + } + + for join_expr in from_clause.join_exprs() { + if let Some(result) = resolve_from_join_expr(&join_expr, &|from_item: &ast::FromItem| { + resolve_from_item_for_fn_call_column(binder, from_item, &column_name, name_ref) + }) { + return Some(result); + } + } + + None +} + +fn resolve_from_item_for_fn_call_column( + binder: &Binder, + from_item: &ast::FromItem, + column_name: &Name, + name_ref: &ast::NameRef, +) -> Option { let (table_name, schema) = if let Some(name_ref_node) = from_item.name_ref() { (Name::from_node(&name_ref_node), None) } else { @@ -931,7 +994,7 @@ fn resolve_function_call_style_column( for arg in create_table.table_arg_list()?.args() { if let ast::TableArg::Column(column) = arg && let Some(col_name) = column.name() - && Name::from_node(&col_name) == column_name + && Name::from_node(&col_name) == *column_name { return Some(SyntaxNodePtr::new(col_name.syntax())); } @@ -940,6 +1003,74 @@ fn resolve_function_call_style_column( None } +fn is_from_item_match(from_item: &ast::FromItem, qualifier: &Name) -> bool { + if let Some(alias_name) = from_item.alias().and_then(|a| a.name()) + && Name::from_node(&alias_name) == *qualifier + { + return true; + } + + if let Some(name_ref) = from_item.name_ref() + && Name::from_node(&name_ref) == *qualifier + { + return true; + } + + if let Some(field_expr) = from_item.field_expr() + && let Some(field) = field_expr.field() + && Name::from_node(&field) == *qualifier + { + return true; + } + + false +} + +fn find_from_item_in_join_expr( + join_expr: &ast::JoinExpr, + qualifier: &Name, +) -> Option { + if let Some(nested_join_expr) = join_expr.join_expr() + && let Some(found) = find_from_item_in_join_expr(&nested_join_expr, qualifier) + { + return Some(found); + } + + if let Some(from_item) = join_expr.from_item() + && is_from_item_match(&from_item, qualifier) + { + return Some(from_item); + } + + if let Some(join) = join_expr.join() + && let Some(from_item) = join.from_item() + && is_from_item_match(&from_item, qualifier) + { + return Some(from_item); + } + + None +} + +fn find_from_item_in_from_clause( + from_clause: &ast::FromClause, + qualifier: &Name, +) -> Option { + for from_item in from_clause.from_items() { + if is_from_item_match(&from_item, qualifier) { + return Some(from_item); + } + } + + for join_expr in from_clause.join_exprs() { + if let Some(found) = find_from_item_in_join_expr(&join_expr, qualifier) { + return Some(found); + } + } + + None +} + fn find_containing_path(name_ref: &ast::NameRef) -> Option { for ancestor in name_ref.syntax().ancestors() { if let Some(path) = ast::Path::cast(ancestor) {