Skip to content

Commit 8654db7

Browse files
authored
ide: goto def for joins (#793)
1 parent ac3fccf commit 8654db7

File tree

2 files changed

+298
-16
lines changed

2 files changed

+298
-16
lines changed

crates/squawk_ide/src/goto_definition.rs

Lines changed: 151 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2851,4 +2851,155 @@ select f.b$0 from t as f(x);
28512851
╰╴ ─ 1. source
28522852
");
28532853
}
2854+
2855+
#[test]
2856+
fn goto_join_table() {
2857+
assert_snapshot!(goto("
2858+
create table users(id int, email text);
2859+
create table messages(id int, user_id int, message text);
2860+
select * from users join messages$0 on users.id = messages.user_id;
2861+
"), @r"
2862+
╭▸
2863+
3 │ create table messages(id int, user_id int, message text);
2864+
│ ──────── 2. destination
2865+
4 │ select * from users join messages on users.id = messages.user_id;
2866+
╰╴ ─ 1. source
2867+
");
2868+
}
2869+
2870+
#[test]
2871+
fn goto_join_qualified_column_from_joined_table() {
2872+
assert_snapshot!(goto("
2873+
create table users(id int, email text);
2874+
create table messages(id int, user_id int, message text);
2875+
select messages.user_id$0 from users join messages on users.id = messages.user_id;
2876+
"), @r"
2877+
╭▸
2878+
3 │ create table messages(id int, user_id int, message text);
2879+
│ ─────── 2. destination
2880+
4 │ select messages.user_id from users join messages on users.id = messages.user_id;
2881+
╰╴ ─ 1. source
2882+
");
2883+
}
2884+
2885+
#[test]
2886+
fn goto_join_qualified_column_from_base_table() {
2887+
assert_snapshot!(goto("
2888+
create table users(id int, email text);
2889+
create table messages(id int, user_id int, message text);
2890+
select users.id$0 from users join messages on users.id = messages.user_id;
2891+
"), @r"
2892+
╭▸
2893+
2 │ create table users(id int, email text);
2894+
│ ── 2. destination
2895+
3 │ create table messages(id int, user_id int, message text);
2896+
4 │ select users.id from users join messages on users.id = messages.user_id;
2897+
╰╴ ─ 1. source
2898+
");
2899+
}
2900+
2901+
#[test]
2902+
fn goto_join_multiple_joins() {
2903+
assert_snapshot!(goto("
2904+
create table users(id int, name text);
2905+
create table messages(id int, user_id int, message text);
2906+
create table comments(id int, message_id int, text text);
2907+
select comments.text$0 from users
2908+
join messages on users.id = messages.user_id
2909+
join comments on messages.id = comments.message_id;
2910+
"), @r"
2911+
╭▸
2912+
4 │ create table comments(id int, message_id int, text text);
2913+
│ ──── 2. destination
2914+
5 │ select comments.text from users
2915+
╰╴ ─ 1. source
2916+
");
2917+
}
2918+
2919+
#[test]
2920+
fn goto_join_with_aliases() {
2921+
assert_snapshot!(goto("
2922+
create table users(id int, name text);
2923+
create table messages(id int, user_id int, message text);
2924+
select m.message$0 from users as u join messages as m on u.id = m.user_id;
2925+
"), @r"
2926+
╭▸
2927+
3 │ create table messages(id int, user_id int, message text);
2928+
│ ─────── 2. destination
2929+
4 │ select m.message from users as u join messages as m on u.id = m.user_id;
2930+
╰╴ ─ 1. source
2931+
");
2932+
}
2933+
2934+
#[test]
2935+
fn goto_join_unqualified_column() {
2936+
assert_snapshot!(goto("
2937+
create table users(id int, email text);
2938+
create table messages(id int, user_id int, message text);
2939+
select message$0 from users join messages on users.id = messages.user_id;
2940+
"), @r"
2941+
╭▸
2942+
3 │ create table messages(id int, user_id int, message text);
2943+
│ ─────── 2. destination
2944+
4 │ select message from users join messages on users.id = messages.user_id;
2945+
╰╴ ─ 1. source
2946+
");
2947+
}
2948+
2949+
#[test]
2950+
fn goto_join_with_many_tables() {
2951+
assert_snapshot!(goto("
2952+
create table users(id int, email text);
2953+
create table messages(id int, user_id int, message text);
2954+
create table logins(id int, user_id int, at timestamptz);
2955+
create table posts(id int, user_id int, post text);
2956+
2957+
select post$0
2958+
from users
2959+
join messages
2960+
on users.id = messages.user_id
2961+
join logins
2962+
on users.id = logins.user_id
2963+
join posts
2964+
on users.id = posts.user_id
2965+
"), @r"
2966+
╭▸
2967+
5 │ create table posts(id int, user_id int, post text);
2968+
│ ──── 2. destination
2969+
6 │
2970+
7 │ select post
2971+
╰╴ ─ 1. source
2972+
");
2973+
}
2974+
2975+
#[test]
2976+
fn goto_join_with_schema() {
2977+
assert_snapshot!(goto("
2978+
create schema foo;
2979+
create table foo.users(id int, email text);
2980+
create table foo.messages(id int, user_id int, message text);
2981+
select foo.messages.message$0 from foo.users join foo.messages on foo.users.id = foo.messages.user_id;
2982+
"), @r"
2983+
╭▸
2984+
4 │ create table foo.messages(id int, user_id int, message text);
2985+
│ ─────── 2. destination
2986+
5 │ select foo.messages.message from foo.users join foo.messages on foo.users.id = foo.messages.user_id;
2987+
╰╴ ─ 1. source
2988+
");
2989+
}
2990+
2991+
#[test]
2992+
fn goto_join_left_join() {
2993+
assert_snapshot!(goto("
2994+
create table users(id int, email text);
2995+
create table messages(id int, user_id int, message text);
2996+
select messages.message$0 from users left join messages on users.id = messages.user_id;
2997+
"), @r"
2998+
╭▸
2999+
3 │ create table messages(id int, user_id int, message text);
3000+
│ ─────── 2. destination
3001+
4 │ select messages.message from users left join messages on users.id = messages.user_id;
3002+
╰╴ ─ 1. source
3003+
");
3004+
}
28543005
}

crates/squawk_ide/src/resolve.rs

Lines changed: 147 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -186,7 +186,7 @@ pub(crate) fn resolve_name_ref(binder: &Binder, name_ref: &ast::NameRef) -> Opti
186186
// select a(t) from t;
187187
// ```
188188
if schema.is_none()
189-
&& let Some(ptr) = resolve_function_call_style_column(binder, name_ref)
189+
&& let Some(ptr) = resolve_fn_call_column(binder, name_ref)
190190
{
191191
return Some(ptr);
192192
}
@@ -630,7 +630,7 @@ fn resolve_select_qualified_column_table(
630630

631631
let select = name_ref.syntax().ancestors().find_map(ast::Select::cast)?;
632632
let from_clause = select.from_clause()?;
633-
let from_item = from_clause.from_items().next()?;
633+
let from_item = find_from_item_in_from_clause(&from_clause, &table_name)?;
634634

635635
if let Some(alias_name) = from_item.alias().and_then(|a| a.name())
636636
&& Name::from_node(&alias_name) == table_name
@@ -702,7 +702,7 @@ fn resolve_select_qualified_column(
702702
} else {
703703
let select = name_ref.syntax().ancestors().find_map(ast::Select::cast)?;
704704
let from_clause = select.from_clause()?;
705-
let from_item = from_clause.from_items().next()?;
705+
let from_item = find_from_item_in_from_clause(&from_clause, &column_table_name)?;
706706

707707
// `from t as u`
708708
// `from t as u(a, b, c)`
@@ -791,13 +791,12 @@ fn resolve_select_qualified_column(
791791
resolve_function(binder, &column_name, &schema, None, position)
792792
}
793793

794-
fn resolve_select_column(binder: &Binder, name_ref: &ast::NameRef) -> Option<SyntaxNodePtr> {
794+
fn resolve_from_item_for_column(
795+
binder: &Binder,
796+
from_item: &ast::FromItem,
797+
name_ref: &ast::NameRef,
798+
) -> Option<SyntaxNodePtr> {
795799
let column_name = Name::from_node(name_ref);
796-
797-
let select = name_ref.syntax().ancestors().find_map(ast::Select::cast)?;
798-
let from_clause = select.from_clause()?;
799-
let from_item = from_clause.from_items().next()?;
800-
801800
if let Some(paren_select) = from_item.paren_select() {
802801
return resolve_subquery_column(&paren_select, &column_name);
803802
}
@@ -855,6 +854,50 @@ fn resolve_select_column(binder: &Binder, name_ref: &ast::NameRef) -> Option<Syn
855854
None
856855
}
857856

857+
fn resolve_from_join_expr<F>(join_expr: &ast::JoinExpr, try_resolve: &F) -> Option<SyntaxNodePtr>
858+
where
859+
F: Fn(&ast::FromItem) -> Option<SyntaxNodePtr>,
860+
{
861+
if let Some(nested_join) = join_expr.join_expr()
862+
&& let Some(result) = resolve_from_join_expr(&nested_join, try_resolve)
863+
{
864+
return Some(result);
865+
}
866+
if let Some(from_item) = join_expr.from_item()
867+
&& let Some(result) = try_resolve(&from_item)
868+
{
869+
return Some(result);
870+
}
871+
if let Some(join) = join_expr.join()
872+
&& let Some(from_item) = join.from_item()
873+
&& let Some(result) = try_resolve(&from_item)
874+
{
875+
return Some(result);
876+
}
877+
None
878+
}
879+
880+
fn resolve_select_column(binder: &Binder, name_ref: &ast::NameRef) -> Option<SyntaxNodePtr> {
881+
let select = name_ref.syntax().ancestors().find_map(ast::Select::cast)?;
882+
let from_clause = select.from_clause()?;
883+
884+
for from_item in from_clause.from_items() {
885+
if let Some(result) = resolve_from_item_for_column(binder, &from_item, name_ref) {
886+
return Some(result);
887+
}
888+
}
889+
890+
for join_expr in from_clause.join_exprs() {
891+
if let Some(result) = resolve_from_join_expr(&join_expr, &|from_item: &ast::FromItem| {
892+
resolve_from_item_for_column(binder, from_item, name_ref)
893+
}) {
894+
return Some(result);
895+
}
896+
}
897+
898+
None
899+
}
900+
858901
fn resolve_delete_where_column(binder: &Binder, name_ref: &ast::NameRef) -> Option<SyntaxNodePtr> {
859902
let column_name = Name::from_node(name_ref);
860903

@@ -887,10 +930,7 @@ fn resolve_delete_where_column(binder: &Binder, name_ref: &ast::NameRef) -> Opti
887930
None
888931
}
889932

890-
fn resolve_function_call_style_column(
891-
binder: &Binder,
892-
name_ref: &ast::NameRef,
893-
) -> Option<SyntaxNodePtr> {
933+
fn resolve_fn_call_column(binder: &Binder, name_ref: &ast::NameRef) -> Option<SyntaxNodePtr> {
894934
let column_name = Name::from_node(name_ref);
895935

896936
// function call syntax for columns is only valid if there is one argument
@@ -905,9 +945,32 @@ fn resolve_function_call_style_column(
905945

906946
let select = name_ref.syntax().ancestors().find_map(ast::Select::cast)?;
907947
let from_clause = select.from_clause()?;
908-
let from_item = from_clause.from_items().next()?;
909948

910-
// get the table name and schema from the FROM clause
949+
for from_item in from_clause.from_items() {
950+
if let Some(result) =
951+
resolve_from_item_for_fn_call_column(binder, &from_item, &column_name, name_ref)
952+
{
953+
return Some(result);
954+
}
955+
}
956+
957+
for join_expr in from_clause.join_exprs() {
958+
if let Some(result) = resolve_from_join_expr(&join_expr, &|from_item: &ast::FromItem| {
959+
resolve_from_item_for_fn_call_column(binder, from_item, &column_name, name_ref)
960+
}) {
961+
return Some(result);
962+
}
963+
}
964+
965+
None
966+
}
967+
968+
fn resolve_from_item_for_fn_call_column(
969+
binder: &Binder,
970+
from_item: &ast::FromItem,
971+
column_name: &Name,
972+
name_ref: &ast::NameRef,
973+
) -> Option<SyntaxNodePtr> {
911974
let (table_name, schema) = if let Some(name_ref_node) = from_item.name_ref() {
912975
(Name::from_node(&name_ref_node), None)
913976
} else {
@@ -931,7 +994,7 @@ fn resolve_function_call_style_column(
931994
for arg in create_table.table_arg_list()?.args() {
932995
if let ast::TableArg::Column(column) = arg
933996
&& let Some(col_name) = column.name()
934-
&& Name::from_node(&col_name) == column_name
997+
&& Name::from_node(&col_name) == *column_name
935998
{
936999
return Some(SyntaxNodePtr::new(col_name.syntax()));
9371000
}
@@ -940,6 +1003,74 @@ fn resolve_function_call_style_column(
9401003
None
9411004
}
9421005

1006+
fn is_from_item_match(from_item: &ast::FromItem, qualifier: &Name) -> bool {
1007+
if let Some(alias_name) = from_item.alias().and_then(|a| a.name())
1008+
&& Name::from_node(&alias_name) == *qualifier
1009+
{
1010+
return true;
1011+
}
1012+
1013+
if let Some(name_ref) = from_item.name_ref()
1014+
&& Name::from_node(&name_ref) == *qualifier
1015+
{
1016+
return true;
1017+
}
1018+
1019+
if let Some(field_expr) = from_item.field_expr()
1020+
&& let Some(field) = field_expr.field()
1021+
&& Name::from_node(&field) == *qualifier
1022+
{
1023+
return true;
1024+
}
1025+
1026+
false
1027+
}
1028+
1029+
fn find_from_item_in_join_expr(
1030+
join_expr: &ast::JoinExpr,
1031+
qualifier: &Name,
1032+
) -> Option<ast::FromItem> {
1033+
if let Some(nested_join_expr) = join_expr.join_expr()
1034+
&& let Some(found) = find_from_item_in_join_expr(&nested_join_expr, qualifier)
1035+
{
1036+
return Some(found);
1037+
}
1038+
1039+
if let Some(from_item) = join_expr.from_item()
1040+
&& is_from_item_match(&from_item, qualifier)
1041+
{
1042+
return Some(from_item);
1043+
}
1044+
1045+
if let Some(join) = join_expr.join()
1046+
&& let Some(from_item) = join.from_item()
1047+
&& is_from_item_match(&from_item, qualifier)
1048+
{
1049+
return Some(from_item);
1050+
}
1051+
1052+
None
1053+
}
1054+
1055+
fn find_from_item_in_from_clause(
1056+
from_clause: &ast::FromClause,
1057+
qualifier: &Name,
1058+
) -> Option<ast::FromItem> {
1059+
for from_item in from_clause.from_items() {
1060+
if is_from_item_match(&from_item, qualifier) {
1061+
return Some(from_item);
1062+
}
1063+
}
1064+
1065+
for join_expr in from_clause.join_exprs() {
1066+
if let Some(found) = find_from_item_in_join_expr(&join_expr, qualifier) {
1067+
return Some(found);
1068+
}
1069+
}
1070+
1071+
None
1072+
}
1073+
9431074
fn find_containing_path(name_ref: &ast::NameRef) -> Option<ast::Path> {
9441075
for ancestor in name_ref.syntax().ancestors() {
9451076
if let Some(path) = ast::Path::cast(ancestor) {

0 commit comments

Comments
 (0)