Skip to content

Commit 1e09b23

Browse files
authored
ide: goto def with update (#795)
1 parent 680274a commit 1e09b23

File tree

2 files changed

+284
-1
lines changed

2 files changed

+284
-1
lines changed

crates/squawk_ide/src/goto_definition.rs

Lines changed: 175 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3077,4 +3077,179 @@ delete from users where id in (select id$0 from old_data);
30773077
╰╴ ─ 1. source
30783078
");
30793079
}
3080+
3081+
#[test]
3082+
fn goto_update_table() {
3083+
assert_snapshot!(goto("
3084+
create table users(id int, email text);
3085+
update users$0 set email = '[email protected]';
3086+
"), @r"
3087+
╭▸
3088+
2 │ create table users(id int, email text);
3089+
│ ───── 2. destination
3090+
3 │ update users set email = '[email protected]';
3091+
╰╴ ─ 1. source
3092+
");
3093+
}
3094+
3095+
#[test]
3096+
fn goto_update_table_with_schema() {
3097+
assert_snapshot!(goto("
3098+
create table public.users(id int, email text);
3099+
update public.users$0 set email = '[email protected]';
3100+
"), @r"
3101+
╭▸
3102+
2 │ create table public.users(id int, email text);
3103+
│ ───── 2. destination
3104+
3 │ update public.users set email = '[email protected]';
3105+
╰╴ ─ 1. source
3106+
");
3107+
}
3108+
3109+
#[test]
3110+
fn goto_update_table_with_search_path() {
3111+
assert_snapshot!(goto("
3112+
set search_path to foo;
3113+
create table foo.users(id int, email text);
3114+
update users$0 set email = '[email protected]';
3115+
"), @r"
3116+
╭▸
3117+
3 │ create table foo.users(id int, email text);
3118+
│ ───── 2. destination
3119+
4 │ update users set email = '[email protected]';
3120+
╰╴ ─ 1. source
3121+
");
3122+
}
3123+
3124+
#[test]
3125+
fn goto_update_where_column() {
3126+
assert_snapshot!(goto("
3127+
create table users(id int, email text);
3128+
update users set email = '[email protected]' where id$0 = 1;
3129+
"), @r"
3130+
╭▸
3131+
2 │ create table users(id int, email text);
3132+
│ ── 2. destination
3133+
3 │ update users set email = '[email protected]' where id = 1;
3134+
╰╴ ─ 1. source
3135+
");
3136+
}
3137+
3138+
#[test]
3139+
fn goto_update_where_column_with_schema() {
3140+
assert_snapshot!(goto("
3141+
create table public.users(id int, email text);
3142+
update public.users set email = '[email protected]' where id$0 = 1;
3143+
"), @r"
3144+
╭▸
3145+
2 │ create table public.users(id int, email text);
3146+
│ ── 2. destination
3147+
3 │ update public.users set email = '[email protected]' where id = 1;
3148+
╰╴ ─ 1. source
3149+
");
3150+
}
3151+
3152+
#[test]
3153+
fn goto_update_where_column_with_search_path() {
3154+
assert_snapshot!(goto("
3155+
set search_path to foo;
3156+
create table foo.users(id int, email text);
3157+
update users set email = '[email protected]' where id$0 = 1;
3158+
"), @r"
3159+
╭▸
3160+
3 │ create table foo.users(id int, email text);
3161+
│ ── 2. destination
3162+
4 │ update users set email = '[email protected]' where id = 1;
3163+
╰╴ ─ 1. source
3164+
");
3165+
}
3166+
3167+
#[test]
3168+
fn goto_update_set_column() {
3169+
assert_snapshot!(goto("
3170+
create table users(id int, email text);
3171+
update users set email$0 = '[email protected]' where id = 1;
3172+
"), @r"
3173+
╭▸
3174+
2 │ create table users(id int, email text);
3175+
│ ───── 2. destination
3176+
3 │ update users set email = '[email protected]' where id = 1;
3177+
╰╴ ─ 1. source
3178+
");
3179+
}
3180+
3181+
#[test]
3182+
fn goto_update_set_column_with_schema() {
3183+
assert_snapshot!(goto("
3184+
create table public.users(id int, email text);
3185+
update public.users set email$0 = '[email protected]' where id = 1;
3186+
"), @r"
3187+
╭▸
3188+
2 │ create table public.users(id int, email text);
3189+
│ ───── 2. destination
3190+
3 │ update public.users set email = '[email protected]' where id = 1;
3191+
╰╴ ─ 1. source
3192+
");
3193+
}
3194+
3195+
#[test]
3196+
fn goto_update_set_column_with_search_path() {
3197+
assert_snapshot!(goto("
3198+
set search_path to foo;
3199+
create table foo.users(id int, email text);
3200+
update users set email$0 = '[email protected]' where id = 1;
3201+
"), @r"
3202+
╭▸
3203+
3 │ create table foo.users(id int, email text);
3204+
│ ───── 2. destination
3205+
4 │ update users set email = '[email protected]' where id = 1;
3206+
╰╴ ─ 1. source
3207+
");
3208+
}
3209+
3210+
#[test]
3211+
fn goto_update_from_table() {
3212+
assert_snapshot!(goto("
3213+
create table users(id int, email text);
3214+
create table messages(id int, user_id int, email text);
3215+
update users set email = messages.email from messages$0 where users.id = messages.user_id;
3216+
"), @r"
3217+
╭▸
3218+
3 │ create table messages(id int, user_id int, email text);
3219+
│ ──────── 2. destination
3220+
4 │ update users set email = messages.email from messages where users.id = messages.user_id;
3221+
╰╴ ─ 1. source
3222+
");
3223+
}
3224+
3225+
#[test]
3226+
fn goto_update_from_table_with_schema() {
3227+
assert_snapshot!(goto("
3228+
create table users(id int, email text);
3229+
create table public.messages(id int, user_id int, email text);
3230+
update users set email = messages.email from public.messages$0 where users.id = messages.user_id;
3231+
"), @r"
3232+
╭▸
3233+
3 │ create table public.messages(id int, user_id int, email text);
3234+
│ ──────── 2. destination
3235+
4 │ update users set email = messages.email from public.messages where users.id = messages.user_id;
3236+
╰╴ ─ 1. source
3237+
");
3238+
}
3239+
3240+
#[test]
3241+
fn goto_update_from_table_with_search_path() {
3242+
assert_snapshot!(goto("
3243+
set search_path to foo;
3244+
create table users(id int, email text);
3245+
create table foo.messages(id int, user_id int, email text);
3246+
update users set email = messages.email from messages$0 where users.id = messages.user_id;
3247+
"), @r"
3248+
╭▸
3249+
4 │ create table foo.messages(id int, user_id int, email text);
3250+
│ ──────── 2. destination
3251+
5 │ update users set email = messages.email from messages where users.id = messages.user_id;
3252+
╰╴ ─ 1. source
3253+
");
3254+
}
30803255
}

crates/squawk_ide/src/resolve.rs

Lines changed: 109 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,10 @@ enum NameRefContext {
3131
InsertColumn,
3232
DeleteTable,
3333
DeleteWhereColumn,
34+
UpdateTable,
35+
UpdateWhereColumn,
36+
UpdateSetColumn,
37+
UpdateFromTable,
3438
SchemaQualifier,
3539
}
3640

@@ -42,7 +46,8 @@ pub(crate) fn resolve_name_ref(binder: &Binder, name_ref: &ast::NameRef) -> Opti
4246
| NameRefContext::Table
4347
| NameRefContext::CreateIndex
4448
| NameRefContext::InsertTable
45-
| NameRefContext::DeleteTable => {
49+
| NameRefContext::DeleteTable
50+
| NameRefContext::UpdateTable => {
4651
let path = find_containing_path(name_ref)?;
4752
let table_name = extract_table_name(&path)?;
4853
let schema = extract_schema_name(&path);
@@ -201,6 +206,29 @@ pub(crate) fn resolve_name_ref(binder: &Binder, name_ref: &ast::NameRef) -> Opti
201206
NameRefContext::SelectQualifiedColumn => resolve_select_qualified_column(binder, name_ref),
202207
NameRefContext::InsertColumn => resolve_insert_column(binder, name_ref),
203208
NameRefContext::DeleteWhereColumn => resolve_delete_where_column(binder, name_ref),
209+
NameRefContext::UpdateWhereColumn => resolve_update_where_column(binder, name_ref),
210+
NameRefContext::UpdateSetColumn => resolve_update_set_column(binder, name_ref),
211+
NameRefContext::UpdateFromTable => {
212+
let table_name = Name::from_node(name_ref);
213+
let schema = if let Some(parent) = name_ref.syntax().parent()
214+
&& let Some(field_expr) = ast::FieldExpr::cast(parent)
215+
&& let Some(base) = field_expr.base()
216+
&& let Some(schema_name_ref) = ast::NameRef::cast(base.syntax().clone())
217+
{
218+
Some(Schema(Name::from_node(&schema_name_ref)))
219+
} else {
220+
None
221+
};
222+
223+
if schema.is_none()
224+
&& let Some(cte_ptr) = resolve_cte_table(name_ref, &table_name)
225+
{
226+
return Some(cte_ptr);
227+
}
228+
229+
let position = name_ref.syntax().text_range().start();
230+
resolve_table(binder, &table_name, &schema, position)
231+
}
204232
}
205233
}
206234

@@ -211,6 +239,7 @@ fn classify_name_ref_context(name_ref: &ast::NameRef) -> Option<NameRefContext>
211239
let mut in_column_list = false;
212240
let mut in_where_clause = false;
213241
let mut in_from_clause = false;
242+
let mut in_set_clause = false;
214243

215244
// TODO: can we combine this if and the one that follows?
216245
if let Some(parent) = name_ref.syntax().parent()
@@ -368,12 +397,27 @@ fn classify_name_ref_context(name_ref: &ast::NameRef) -> Option<NameRefContext>
368397
if ast::WhereClause::can_cast(ancestor.kind()) {
369398
in_where_clause = true;
370399
}
400+
if ast::SetClause::can_cast(ancestor.kind()) {
401+
in_set_clause = true;
402+
}
371403
if ast::Delete::can_cast(ancestor.kind()) {
372404
if in_where_clause {
373405
return Some(NameRefContext::DeleteWhereColumn);
374406
}
375407
return Some(NameRefContext::DeleteTable);
376408
}
409+
if ast::Update::can_cast(ancestor.kind()) {
410+
if in_where_clause {
411+
return Some(NameRefContext::UpdateWhereColumn);
412+
}
413+
if in_set_clause {
414+
return Some(NameRefContext::UpdateSetColumn);
415+
}
416+
if in_from_clause {
417+
return Some(NameRefContext::UpdateFromTable);
418+
}
419+
return Some(NameRefContext::UpdateTable);
420+
}
377421
}
378422

379423
None
@@ -930,6 +974,70 @@ fn resolve_delete_where_column(binder: &Binder, name_ref: &ast::NameRef) -> Opti
930974
None
931975
}
932976

977+
fn resolve_update_where_column(binder: &Binder, name_ref: &ast::NameRef) -> Option<SyntaxNodePtr> {
978+
let column_name = Name::from_node(name_ref);
979+
980+
let update = name_ref.syntax().ancestors().find_map(ast::Update::cast)?;
981+
let relation_name = update.relation_name()?;
982+
let path = relation_name.path()?;
983+
984+
let table_name = extract_table_name(&path)?;
985+
let schema = extract_schema_name(&path);
986+
let position = name_ref.syntax().text_range().start();
987+
988+
let table_ptr = resolve_table(binder, &table_name, &schema, position)?;
989+
990+
let root = &name_ref.syntax().ancestors().last()?;
991+
let table_name_node = table_ptr.to_node(root);
992+
993+
let create_table = table_name_node
994+
.ancestors()
995+
.find_map(ast::CreateTable::cast)?;
996+
997+
for arg in create_table.table_arg_list()?.args() {
998+
if let ast::TableArg::Column(column) = arg
999+
&& let Some(col_name) = column.name()
1000+
&& Name::from_node(&col_name) == column_name
1001+
{
1002+
return Some(SyntaxNodePtr::new(col_name.syntax()));
1003+
}
1004+
}
1005+
1006+
None
1007+
}
1008+
1009+
fn resolve_update_set_column(binder: &Binder, name_ref: &ast::NameRef) -> Option<SyntaxNodePtr> {
1010+
let column_name = Name::from_node(name_ref);
1011+
1012+
let update = name_ref.syntax().ancestors().find_map(ast::Update::cast)?;
1013+
let relation_name = update.relation_name()?;
1014+
let path = relation_name.path()?;
1015+
1016+
let table_name = extract_table_name(&path)?;
1017+
let schema = extract_schema_name(&path);
1018+
let position = name_ref.syntax().text_range().start();
1019+
1020+
let table_ptr = resolve_table(binder, &table_name, &schema, position)?;
1021+
1022+
let root = &name_ref.syntax().ancestors().last()?;
1023+
let table_name_node = table_ptr.to_node(root);
1024+
1025+
let create_table = table_name_node
1026+
.ancestors()
1027+
.find_map(ast::CreateTable::cast)?;
1028+
1029+
for arg in create_table.table_arg_list()?.args() {
1030+
if let ast::TableArg::Column(column) = arg
1031+
&& let Some(col_name) = column.name()
1032+
&& Name::from_node(&col_name) == column_name
1033+
{
1034+
return Some(SyntaxNodePtr::new(col_name.syntax()));
1035+
}
1036+
}
1037+
1038+
None
1039+
}
1040+
9331041
fn resolve_fn_call_column(binder: &Binder, name_ref: &ast::NameRef) -> Option<SyntaxNodePtr> {
9341042
let column_name = Name::from_node(name_ref);
9351043

0 commit comments

Comments
 (0)