diff --git a/crates/squawk_ide/src/classify.rs b/crates/squawk_ide/src/classify.rs index 9196e3d9..55ec72f3 100644 --- a/crates/squawk_ide/src/classify.rs +++ b/crates/squawk_ide/src/classify.rs @@ -254,6 +254,7 @@ pub(crate) fn classify_name_ref(name_ref: &ast::NameRef) -> Option } } if let Some(references_constraint) = ast::ReferencesConstraint::cast(ancestor.clone()) { + // TODO: the ast is too flat here if let Some(column_ref) = references_constraint.column() && column_ref .syntax() diff --git a/crates/squawk_ide/src/expand_selection.rs b/crates/squawk_ide/src/expand_selection.rs index 0a6eb3a9..c153074c 100644 --- a/crates/squawk_ide/src/expand_selection.rs +++ b/crates/squawk_ide/src/expand_selection.rs @@ -307,7 +307,7 @@ mod tests { let root = file.syntax(); let mut range = TextRange::empty(offset); - let mut results = Vec::new(); + let mut results = vec![]; for _ in 0..20 { let new_range = extend_selection(root, range); diff --git a/crates/squawk_ide/src/find_references.rs b/crates/squawk_ide/src/find_references.rs index 464c68ca..2d7bb074 100644 --- a/crates/squawk_ide/src/find_references.rs +++ b/crates/squawk_ide/src/find_references.rs @@ -3,6 +3,7 @@ use crate::offsets::token_from_offset; use crate::resolve; use rowan::{TextRange, TextSize}; use smallvec::{SmallVec, smallvec}; +use squawk_syntax::SyntaxNode; use squawk_syntax::{ SyntaxNodePtr, ast::{self, AstNode}, @@ -11,17 +12,17 @@ use squawk_syntax::{ pub fn find_references(file: &ast::SourceFile, offset: TextSize) -> Vec { let binder = binder::bind(file); - let Some(targets) = find_targets(file, offset, &binder) else { + let root = file.syntax(); + let Some(targets) = find_targets(file, root, offset, &binder) else { return vec![]; }; let mut refs = vec![]; - for node in file.syntax().descendants() { match_ast! { match node { ast::NameRef(name_ref) => { - if let Some(found_refs) = resolve::resolve_name_ref(&binder, &name_ref) + if let Some(found_refs) = resolve::resolve_name_ref(&binder, root, &name_ref) && found_refs.iter().any(|ptr| targets.contains(ptr)) { refs.push(name_ref.syntax().text_range()); @@ -44,6 +45,7 @@ pub fn find_references(file: &ast::SourceFile, offset: TextSize) -> Vec Option> { @@ -55,7 +57,7 @@ fn find_targets( } if let Some(name_ref) = ast::NameRef::cast(parent.clone()) { - return resolve::resolve_name_ref(binder, &name_ref); + return resolve::resolve_name_ref(binder, root, &name_ref); } None diff --git a/crates/squawk_ide/src/goto_definition.rs b/crates/squawk_ide/src/goto_definition.rs index 111aa635..6f2af077 100644 --- a/crates/squawk_ide/src/goto_definition.rs +++ b/crates/squawk_ide/src/goto_definition.rs @@ -58,7 +58,8 @@ pub fn goto_definition(file: ast::SourceFile, offset: TextSize) -> SmallVec<[Tex if let Some(name_ref) = ast::NameRef::cast(parent.clone()) { let binder_output = binder::bind(&file); - if let Some(ptrs) = resolve::resolve_name_ref(&binder_output, &name_ref) { + let root = file.syntax(); + if let Some(ptrs) = resolve::resolve_name_ref(&binder_output, root, &name_ref) { return ptrs .iter() .map(|ptr| ptr.to_node(file.syntax()).text_range()) @@ -243,6 +244,34 @@ drop table t$0; "); } + #[test] + fn goto_definition_on_dot_prefers_previous_token() { + assert_snapshot!(goto(" +create table t(a int); +select t.$0a from t; +"), @r" + ╭▸ + 2 │ create table t(a int); + │ ─ 2. destination + 3 │ select t.a from t; + ╰╴ ─ 1. source + "); + } + + #[test] + fn goto_with_table_star() { + assert_snapshot!(goto(" +with t as (select 1 a) +select t$0.* from t; +"), @r" + ╭▸ + 2 │ with t as (select 1 a) + │ ─ 2. destination + 3 │ select t.* from t; + ╰╴ ─ 1. source + "); + } + #[test] fn goto_drop_sequence() { assert_snapshot!(goto(" @@ -2415,6 +2444,32 @@ select a$0 from x; "); } + #[test] + fn goto_cte_qualified_column_prefers_cte_over_table() { + assert_snapshot!(goto(" +create table u(id int, b int); +with u as (select 1 id, 2 b) +select u.id$0 from u; +"), @r" + ╭▸ + 3 │ with u as (select 1 id, 2 b) + │ ── 2. destination + 4 │ select u.id from u; + ╰╴ ─ 1. source + "); + } + + #[test] + fn goto_subquery_qualified_column() { + assert_snapshot!(goto(" +select t.a$0 from (select 1 a) t; +"), @r" + ╭▸ + 2 │ select t.a from (select 1 a) t; + ╰╴ ─ 1. source ─ 2. destination + "); + } + #[test] fn goto_cte_multiple_columns() { assert_snapshot!(goto(" @@ -2474,6 +2529,50 @@ select a$0 from y; "); } + #[test] + fn goto_cte_qualified_star_join_column() { + assert_snapshot!(goto(" +create table u(id int, b int); +create table t(id int, a int); + +with k as ( + select u.* from t join u on a = b +) +select b$0 from k; +"), @r" + ╭▸ + 2 │ create table u(id int, b int); + │ ─ 2. destination + ‡ + 8 │ select b from k; + ╰╴ ─ 1. source + "); + } + + #[test] + fn goto_cte_qualified_star_join_column_with_partial_column_list() { + assert_snapshot!(goto(" +with + u as ( + select 1 id, 2 b + ), + t as ( + select 1 id, 2 a + ), + k(x) as ( + select u.* from t join u on a = b + ) +select b$0 from k; +"), @r" + ╭▸ + 4 │ select 1 id, 2 b + │ ─ 2. destination + ‡ + 12 │ select b from k; + ╰╴ ─ 1. source + "); + } + #[test] fn goto_cte_reference_inside_cte() { assert_snapshot!(goto(" @@ -2614,6 +2713,32 @@ select a$0 from (select * from foo.t); "); } + #[test] + fn goto_subquery_column_qualified_star_join() { + assert_snapshot!(goto(" +create table t(a int); +create table u(b int); +select b$0 from (select u.* from t join u on a = b); +"), @r" + ╭▸ + 3 │ create table u(b int); + │ ─ 2. destination + 4 │ select b from (select u.* from t join u on a = b); + ╰╴ ─ 1. source + "); + } + + #[test] + fn goto_subquery_column_qualified_star_join_not_found() { + goto_not_found( + " +create table t(a int); +create table u(b int); +select a$0 from (select u.* from t join u on a = b); +", + ); + } + #[test] fn goto_insert_table() { assert_snapshot!(goto(" diff --git a/crates/squawk_ide/src/hover.rs b/crates/squawk_ide/src/hover.rs index f9fb9e5f..dc0f77b3 100644 --- a/crates/squawk_ide/src/hover.rs +++ b/crates/squawk_ide/src/hover.rs @@ -3,14 +3,41 @@ use crate::offsets::token_from_offset; use crate::resolve; use crate::{binder, symbols::Name}; use rowan::TextSize; -use squawk_syntax::ast::{self, AstNode}; +use squawk_syntax::SyntaxNode; +use squawk_syntax::{ + SyntaxKind, + ast::{self, AstNode}, +}; pub fn hover(file: &ast::SourceFile, offset: TextSize) -> Option { let token = token_from_offset(file, offset)?; let parent = token.parent()?; + let root = file.syntax(); let binder = binder::bind(file); + if token.kind() == SyntaxKind::STAR { + if let Some(field_expr) = ast::FieldExpr::cast(parent.clone()) + && field_expr.star_token().is_some() + && let Some(result) = hover_qualified_star(root, &field_expr, &binder) + { + return Some(result); + } + + if let Some(arg_list) = ast::ArgList::cast(parent.clone()) + && let Some(result) = hover_unqualified_star_in_arg_list(root, &arg_list, &binder) + { + return Some(result); + } + + if let Some(target) = ast::Target::cast(parent.clone()) + && target.star_token().is_some() + && let Some(result) = hover_unqualified_star(root, &target, &binder) + { + return Some(result); + } + } + if let Some(name_ref) = ast::NameRef::cast(parent.clone()) { let context = classify_name_ref(&name_ref)?; match context { @@ -27,25 +54,25 @@ pub fn hover(file: &ast::SourceFile, offset: TextSize) -> Option { | NameRefClass::ExcludeConstraintColumn | NameRefClass::PartitionByColumn | NameRefClass::JoinUsingColumn => { - return hover_column(file, &name_ref, &binder); + return hover_column(root, &name_ref, &binder); } NameRefClass::TypeReference | NameRefClass::DropType => { - return hover_type(file, &name_ref, &binder); + return hover_type(root, &name_ref, &binder); } NameRefClass::CompositeTypeField => { - return hover_composite_type_field(file, &name_ref, &binder); + return hover_composite_type_field(root, &name_ref, &binder); } NameRefClass::SelectColumn | NameRefClass::SelectQualifiedColumn => { // Try hover as column first - if let Some(result) = hover_column(file, &name_ref, &binder) { + if let Some(result) = hover_column(root, &name_ref, &binder) { return Some(result); } // If no column, try as function (handles field-style function calls like `t.b`) - if let Some(result) = hover_function(file, &name_ref, &binder) { + if let Some(result) = hover_function(root, &name_ref, &binder) { return Some(result); } // Finally try as table (handles case like `select t from t;` where t is the table) - return hover_table(file, &name_ref, &binder); + return hover_table(root, &name_ref, &binder); } NameRefClass::Table | NameRefClass::DropTable @@ -62,38 +89,38 @@ pub fn hover(file: &ast::SourceFile, offset: TextSize) -> Option { | NameRefClass::LikeTable | NameRefClass::InheritsTable | NameRefClass::PartitionOfTable => { - return hover_table(file, &name_ref, &binder); + return hover_table(root, &name_ref, &binder); } NameRefClass::ForeignKeyColumn | NameRefClass::ForeignKeyLocalColumn | NameRefClass::SequenceOwnedByColumn => { - return hover_column(file, &name_ref, &binder); + return hover_column(root, &name_ref, &binder); } - NameRefClass::DropSequence => return hover_sequence(file, &name_ref, &binder), - NameRefClass::DropDatabase => return hover_database(file, &name_ref, &binder), - NameRefClass::Tablespace => return hover_tablespace(file, &name_ref, &binder), - NameRefClass::DropIndex => return hover_index(file, &name_ref, &binder), - NameRefClass::DropFunction => return hover_function(file, &name_ref, &binder), - NameRefClass::DropAggregate => return hover_aggregate(file, &name_ref, &binder), + NameRefClass::DropSequence => return hover_sequence(root, &name_ref, &binder), + NameRefClass::DropDatabase => return hover_database(root, &name_ref, &binder), + NameRefClass::Tablespace => return hover_tablespace(root, &name_ref, &binder), + NameRefClass::DropIndex => return hover_index(root, &name_ref, &binder), + NameRefClass::DropFunction => return hover_function(root, &name_ref, &binder), + NameRefClass::DropAggregate => return hover_aggregate(root, &name_ref, &binder), NameRefClass::DropProcedure | NameRefClass::CallProcedure => { - return hover_procedure(file, &name_ref, &binder); + return hover_procedure(root, &name_ref, &binder); } - NameRefClass::DropRoutine => return hover_routine(file, &name_ref, &binder), + NameRefClass::DropRoutine => return hover_routine(root, &name_ref, &binder), NameRefClass::DefaultConstraintFunctionCall => { - return hover_function(file, &name_ref, &binder); + return hover_function(root, &name_ref, &binder); } NameRefClass::SelectFunctionCall => { // Try function first, but fall back to column if no function found // (handles function-call-style column access like `select a(t)`) - if let Some(result) = hover_function(file, &name_ref, &binder) { + if let Some(result) = hover_function(root, &name_ref, &binder) { return Some(result); } - return hover_column(file, &name_ref, &binder); + return hover_column(root, &name_ref, &binder); } NameRefClass::SchemaQualifier | NameRefClass::DropSchema | NameRefClass::CreateSchema => { - return hover_schema(file, &name_ref, &binder); + return hover_schema(root, &name_ref, &binder); } } } @@ -148,26 +175,37 @@ pub fn hover(file: &ast::SourceFile, offset: TextSize) -> Option { None } -fn format_column( - schema: &str, - table_name: &str, - column_name: &str, - ty: &impl std::fmt::Display, -) -> String { - format!("column {schema}.{table_name}.{column_name} {ty}") +struct ColumnHover {} +impl ColumnHover { + fn table_column(table_name: &str, column_name: &str) -> String { + format!("column {table_name}.{column_name}") + } + fn schema_table_column_type( + schema: &str, + table_name: &str, + column_name: &str, + ty: &str, + ) -> String { + format!("column {schema}.{table_name}.{column_name} {ty}") + } + fn schema_table_column(schema: &str, table_name: &str, column_name: &str) -> String { + format!("column {schema}.{table_name}.{column_name}") + } } fn hover_column( - file: &ast::SourceFile, + root: &SyntaxNode, name_ref: &ast::NameRef, binder: &binder::Binder, ) -> Option { - let column_ptrs = resolve::resolve_name_ref(binder, name_ref)?; + let column_ptrs = resolve::resolve_name_ref(binder, root, name_ref)?; - let root = file.syntax(); let results: Vec = column_ptrs .iter() - .filter_map(|column_ptr| format_hover_for_column_ptr(binder, root, column_ptr, name_ref)) + .filter_map(|column_ptr| { + let column_name_node = column_ptr.to_node(root); + format_hover_for_column_node(binder, &column_name_node, name_ref) + }) .collect(); if results.is_empty() { @@ -177,39 +215,50 @@ fn hover_column( Some(results.join("\n")) } -fn format_hover_for_column_ptr( +fn format_hover_for_column_node( binder: &binder::Binder, - root: &squawk_syntax::SyntaxNode, - column_ptr: &squawk_syntax::SyntaxNodePtr, + column_name_node: &squawk_syntax::SyntaxNode, name_ref: &ast::NameRef, ) -> Option { - let column_name_node = column_ptr.to_node(root); - - if let Some(with_table) = column_name_node.ancestors().find_map(ast::WithTable::cast) { - let cte_name = with_table.name()?; - let column_name = if column_name_node - .ancestors() - .any(|a| ast::Values::can_cast(a.kind())) + for a in column_name_node.ancestors() { + if let Some(with_table) = ast::WithTable::cast(a.clone()) { + let cte_name = with_table.name()?; + let column_name = if column_name_node + .ancestors() + .any(|a| ast::Values::can_cast(a.kind())) + { + Name::from_node(name_ref) + } else { + Name::from_string(column_name_node.text().to_string()) + }; + let table_name = Name::from_node(&cte_name); + return Some(ColumnHover::table_column( + &table_name.to_string(), + &column_name.to_string(), + )); + } + if ast::ParenSelect::can_cast(a.kind()) + && let Some(field_expr) = name_ref.syntax().parent().and_then(ast::FieldExpr::cast) + && let Some(base) = field_expr.base() + && let ast::Expr::NameRef(table_name_ref) = base { - Name::from_node(name_ref) - } else { - Name::from_string(column_name_node.text().to_string()) - }; - return Some(format!( - "column {}.{}", - cte_name.syntax().text(), - column_name - )); - } + let table_name = Name::from_node(&table_name_ref); + let column_name = Name::from_string(column_name_node.text().to_string()); + return Some(ColumnHover::table_column( + &table_name.to_string(), + &column_name.to_string(), + )); + } - // create view v(a) as select 1; - // select a from v; - // ^ - if let Some(create_view) = column_name_node.ancestors().find_map(ast::CreateView::cast) - && let Some(column_name) = - ast::Name::cast(column_name_node.clone()).map(|name| Name::from_node(&name)) - { - return format_view_column(&create_view, column_name, binder); + // create view v(a) as select 1; + // select a from v; + // ^ + if let Some(create_view) = ast::CreateView::cast(a.clone()) + && let Some(column_name) = + ast::Name::cast(column_name_node.clone()).map(|name| Name::from_node(&name)) + { + return format_view_column(&create_view, column_name, binder); + } } let column = column_name_node.ancestors().find_map(ast::Column::cast)?; @@ -221,31 +270,27 @@ fn format_hover_for_column_ptr( .ancestors() .find_map(ast::CreateTable::cast)?; let path = create_table.path()?; - let table_name = path.segment()?.name()?; - - let schema = if let Some(qualifier) = path.qualifier() { - qualifier.syntax().text().to_string() - } else { - table_schema(&create_table, binder)? - }; + let (schema, table_name) = resolve::resolve_table_info(binder, &path)?; - Some(format_column( + let schema = schema.to_string(); + let column_name = Name::from_node(&column_name); + let ty = &ty.syntax().text().to_string(); + Some(ColumnHover::schema_table_column_type( &schema, - &table_name.syntax().text().to_string(), - &column_name.syntax().text().to_string(), - &ty.syntax().text(), + &table_name, + &column_name.to_string(), + ty, )) } fn hover_composite_type_field( - file: &ast::SourceFile, + root: &SyntaxNode, name_ref: &ast::NameRef, binder: &binder::Binder, ) -> Option { - let field_ptr = resolve::resolve_name_ref(binder, name_ref)? + let field_ptr = resolve::resolve_name_ref(binder, root, name_ref)? .into_iter() .next()?; - let root = file.syntax(); let field_name_node = field_ptr.to_node(root); let column = field_name_node.ancestors().find_map(ast::Column::cast)?; @@ -257,13 +302,7 @@ fn hover_composite_type_field( .ancestors() .find_map(ast::CreateType::cast)?; let type_path = create_type.path()?; - let type_name = type_path.segment()?.name()?.syntax().text().to_string(); - - let schema = if let Some(qualifier) = type_path.qualifier() { - qualifier.syntax().text().to_string() - } else { - type_schema(&create_type, binder)? - }; + let (schema, type_name) = resolve::resolve_type_info(binder, &type_path)?; Some(format!( "field {}.{}.{} {}", @@ -282,62 +321,249 @@ fn hover_column_definition( let column_name = column.name()?.syntax().text().to_string(); let ty = column.ty()?; let path = create_table.path()?; - let table_name = path.segment()?.name()?.syntax().text().to_string(); - - let schema = if let Some(qualifier) = path.qualifier() { - qualifier.syntax().text().to_string() - } else { - table_schema(create_table, binder)? - }; - - Some(format_column( - &schema, + let (schema, table_name) = resolve::resolve_table_info(binder, &path)?; + let ty = ty.syntax().text().to_string(); + Some(ColumnHover::schema_table_column_type( + &schema.to_string(), &table_name, &column_name, - &ty.syntax().text(), + &ty, )) } fn hover_table( - file: &ast::SourceFile, + root: &SyntaxNode, name_ref: &ast::NameRef, binder: &binder::Binder, ) -> Option { - let table_ptr = resolve::resolve_name_ref(binder, name_ref)? + if let Some(result) = hover_subquery_table(name_ref) { + return Some(result); + } + + let table_ptr = resolve::resolve_name_ref(binder, root, name_ref)? .into_iter() .next()?; - let root = file.syntax(); + hover_table_from_ptr(root, &table_ptr, binder) +} + +fn hover_table_from_ptr( + root: &SyntaxNode, + table_ptr: &squawk_syntax::SyntaxNodePtr, + binder: &binder::Binder, +) -> Option { let table_name_node = table_ptr.to_node(root); - if let Some(with_table) = table_name_node.ancestors().find_map(ast::WithTable::cast) { - return format_with_table(&with_table); + match resolve::find_table_source(&table_name_node)? { + resolve::TableSource::WithTable(with_table) => format_with_table(&with_table), + resolve::TableSource::CreateView(create_view) => format_create_view(&create_view, binder), + resolve::TableSource::CreateTable(create_table) => { + format_create_table(&create_table, binder) + } } +} + +fn hover_qualified_star( + root: &SyntaxNode, + field_expr: &ast::FieldExpr, + binder: &binder::Binder, +) -> Option { + let table_ptr = resolve::resolve_qualified_star_table(binder, field_expr)?; + hover_qualified_star_columns(root, &table_ptr, binder) +} - // create view v as select 1 a; - // select a from v; - // ^ - if let Some(create_view) = table_name_node.ancestors().find_map(ast::CreateView::cast) { - return format_create_view(&create_view, binder); +fn hover_unqualified_star( + root: &SyntaxNode, + target: &ast::Target, + binder: &binder::Binder, +) -> Option { + let table_ptrs = resolve::resolve_unqualified_star_tables(binder, target)?; + let mut results = vec![]; + for table_ptr in table_ptrs { + if let Some(columns) = hover_qualified_star_columns(root, &table_ptr, binder) { + results.push(columns); + } } - let create_table = table_name_node - .ancestors() - .find_map(ast::CreateTable::cast)?; + if results.is_empty() { + return None; + } + + Some(results.join("\n")) +} + +fn hover_unqualified_star_in_arg_list( + root: &SyntaxNode, + arg_list: &ast::ArgList, + binder: &binder::Binder, +) -> Option { + let table_ptrs = resolve::resolve_unqualified_star_tables_in_arg_list(binder, arg_list)?; + let mut results = vec![]; + for table_ptr in table_ptrs { + if let Some(columns) = hover_qualified_star_columns(root, &table_ptr, binder) { + results.push(columns); + } + } + + if results.is_empty() { + return None; + } + + Some(results.join("\n")) +} + +fn hover_subquery_table(name_ref: &ast::NameRef) -> Option { + let select = name_ref.syntax().ancestors().find_map(ast::Select::cast)?; + let from_clause = select.from_clause()?; + let qualifier = Name::from_node(name_ref); + let from_item = resolve::find_from_item_in_from_clause(&from_clause, &qualifier)?; + let paren_select = from_item.paren_select()?; + format_subquery_table(name_ref, &paren_select) +} + +fn format_subquery_table( + name_ref: &ast::NameRef, + paren_select: &ast::ParenSelect, +) -> Option { + let name = name_ref.syntax().text().to_string(); + let query = paren_select.syntax().text().to_string(); + Some(format!("subquery {} as {}", name, query)) +} - format_create_table(&create_table, binder) +fn hover_qualified_star_columns( + root: &SyntaxNode, + table_ptr: &squawk_syntax::SyntaxNodePtr, + binder: &binder::Binder, +) -> Option { + let table_name_node = table_ptr.to_node(root); + + if let Some(paren_select) = ast::ParenSelect::cast(table_name_node.clone()) { + return hover_qualified_star_columns_from_subquery(root, &paren_select, binder); + } + + match resolve::find_table_source(&table_name_node)? { + resolve::TableSource::WithTable(with_table) => { + hover_qualified_star_columns_from_cte(&with_table) + } + resolve::TableSource::CreateTable(create_table) => { + hover_qualified_star_columns_from_table(&create_table, binder) + } + resolve::TableSource::CreateView(create_view) => { + hover_qualified_star_columns_from_view(&create_view, binder) + } + } +} + +fn hover_qualified_star_columns_from_table( + create_table: &ast::CreateTable, + binder: &binder::Binder, +) -> Option { + let path = create_table.path()?; + let (schema, table_name) = resolve::resolve_table_info(binder, &path)?; + let schema = schema.to_string(); + let results: Vec = resolve::collect_table_columns(create_table) + .into_iter() + .filter_map(|column| { + let column_name = Name::from_node(&column.name()?); + let ty = column.ty()?; + let ty = &ty.syntax().text().to_string(); + Some(ColumnHover::schema_table_column_type( + &schema, + &table_name, + &column_name.to_string(), + ty, + )) + }) + .collect(); + + if results.is_empty() { + return None; + } + + Some(results.join("\n")) +} + +fn hover_qualified_star_columns_from_cte(with_table: &ast::WithTable) -> Option { + let cte_name = Name::from_node(&with_table.name()?); + let column_names = resolve::collect_with_table_column_names(with_table); + let results: Vec = column_names + .iter() + .map(|column_name| { + ColumnHover::table_column(&cte_name.to_string(), &column_name.to_string()) + }) + .collect(); + + if results.is_empty() { + return None; + } + + Some(results.join("\n")) +} + +fn hover_qualified_star_columns_from_view( + create_view: &ast::CreateView, + binder: &binder::Binder, +) -> Option { + let path = create_view.path()?; + let (schema, view_name) = resolve::resolve_view_info(binder, &path)?; + + let schema_str = schema.to_string(); + let column_names = resolve::collect_view_column_names(create_view); + let results: Vec = column_names + .iter() + .map(|column_name| { + ColumnHover::schema_table_column(&schema_str, &view_name, &column_name.to_string()) + }) + .collect(); + + if results.is_empty() { + return None; + } + + Some(results.join("\n")) +} + +fn hover_qualified_star_columns_from_subquery( + root: &SyntaxNode, + paren_select: &ast::ParenSelect, + binder: &binder::Binder, +) -> Option { + let ast::SelectVariant::Select(select) = paren_select.select()? else { + return None; + }; + + let select_clause = select.select_clause()?; + let target_list = select_clause.target_list()?; + + let mut results = vec![]; + + for target in target_list.targets() { + if target.star_token().is_some() { + let table_ptrs = resolve::resolve_unqualified_star_tables(binder, &target)?; + for table_ptr in table_ptrs { + if let Some(columns) = hover_qualified_star_columns(root, &table_ptr, binder) { + results.push(columns) + } + } + } + } + + if results.is_empty() { + return None; + } + + Some(results.join("\n")) } fn hover_index( - file: &ast::SourceFile, + root: &SyntaxNode, name_ref: &ast::NameRef, binder: &binder::Binder, ) -> Option { - let index_ptr = resolve::resolve_name_ref(binder, name_ref)? + let index_ptr = resolve::resolve_name_ref(binder, root, name_ref)? .into_iter() .next()?; - let root = file.syntax(); let index_name_node = index_ptr.to_node(root); let create_index = index_name_node @@ -348,15 +574,14 @@ fn hover_index( } fn hover_sequence( - file: &ast::SourceFile, + root: &SyntaxNode, name_ref: &ast::NameRef, binder: &binder::Binder, ) -> Option { - let sequence_ptr = resolve::resolve_name_ref(binder, name_ref)? + let sequence_ptr = resolve::resolve_name_ref(binder, root, name_ref)? .into_iter() .next()?; - let root = file.syntax(); let sequence_name_node = sequence_ptr.to_node(root); let create_sequence = sequence_name_node @@ -367,41 +592,38 @@ fn hover_sequence( } fn hover_tablespace( - file: &ast::SourceFile, + root: &SyntaxNode, name_ref: &ast::NameRef, binder: &binder::Binder, ) -> Option { - let tablespace_ptr = resolve::resolve_name_ref(binder, name_ref)? + let tablespace_ptr = resolve::resolve_name_ref(binder, root, name_ref)? .into_iter() .next()?; - let root = file.syntax(); let tablespace_name_node = tablespace_ptr.to_node(root); Some(format!("tablespace {}", tablespace_name_node.text())) } fn hover_database( - file: &ast::SourceFile, + root: &SyntaxNode, name_ref: &ast::NameRef, binder: &binder::Binder, ) -> Option { - let database_ptr = resolve::resolve_name_ref(binder, name_ref)? + let database_ptr = resolve::resolve_name_ref(binder, root, name_ref)? .into_iter() .next()?; - let root = file.syntax(); let database_name_node = database_ptr.to_node(root); Some(format!("database {}", database_name_node.text())) } fn hover_type( - file: &ast::SourceFile, + root: &SyntaxNode, name_ref: &ast::NameRef, binder: &binder::Binder, ) -> Option { - let type_ptr = resolve::resolve_name_ref(binder, name_ref)? + let type_ptr = resolve::resolve_name_ref(binder, root, name_ref)? .into_iter() .next()?; - let root = file.syntax(); let type_name_node = type_ptr.to_node(root); let create_type = type_name_node.ancestors().find_map(ast::CreateType::cast)?; @@ -412,15 +634,8 @@ fn hover_type( // Insert inferred schema into the create table hover info fn format_create_table(create_table: &ast::CreateTable, binder: &binder::Binder) -> Option { let path = create_table.path()?; - let segment = path.segment()?; - let table_name = segment.name()?.syntax().text().to_string(); - - let schema = if let Some(qualifier) = path.qualifier() { - qualifier.syntax().text().to_string() - } else { - table_schema(create_table, binder)? - }; - + let (schema, table_name) = resolve::resolve_table_info(binder, &path)?; + let schema = schema.to_string(); let args = create_table.table_arg_list()?.syntax().text().to_string(); Some(format!("table {}.{}{}", schema, table_name, args)) @@ -428,14 +643,8 @@ fn format_create_table(create_table: &ast::CreateTable, binder: &binder::Binder) fn format_create_view(create_view: &ast::CreateView, binder: &binder::Binder) -> Option { let path = create_view.path()?; - let segment = path.segment()?; - let view_name = segment.name()?.syntax().text().to_string(); - - let schema = if let Some(qualifier) = path.qualifier() { - qualifier.syntax().text().to_string() - } else { - view_schema(create_view, binder)? - }; + let (schema, view_name) = resolve::resolve_view_info(binder, &path)?; + let schema = schema.to_string(); let column_list = create_view .column_list() @@ -456,16 +665,12 @@ fn format_view_column( binder: &binder::Binder, ) -> Option { let path = create_view.path()?; - let segment = path.segment()?; - let view_name = Name::from_node(&segment.name()?); - - let schema = if let Some(qualifier) = path.qualifier() { - Name::from_string(qualifier.syntax().text().to_string()) - } else { - Name::from_string(view_schema(create_view, binder)?) - }; - - Some(format!("column {}.{}.{}", schema, view_name, column_name)) + let (schema, view_name) = resolve::resolve_view_info(binder, &path)?; + Some(ColumnHover::schema_table_column( + &schema.to_string(), + &view_name, + &column_name.to_string(), + )) } fn format_with_table(with_table: &ast::WithTable) -> Option { @@ -474,43 +679,6 @@ fn format_with_table(with_table: &ast::WithTable) -> Option { Some(format!("with {} as ({})", name, query)) } -fn table_schema(create_table: &ast::CreateTable, binder: &binder::Binder) -> Option { - let is_temp = create_table.temp_token().is_some() || create_table.temporary_token().is_some(); - if is_temp { - return Some("pg_temp".to_string()); - } - - let position = create_table.syntax().text_range().start(); - let search_path = binder.search_path_at(position); - search_path.first().map(|s| s.to_string()) -} - -fn view_schema(create_view: &ast::CreateView, binder: &binder::Binder) -> Option { - let is_temp = create_view.temp_token().is_some() || create_view.temporary_token().is_some(); - if is_temp { - return Some("pg_temp".to_string()); - } - - let position = create_view.syntax().text_range().start(); - let search_path = binder.search_path_at(position); - search_path.first().map(|s| s.to_string()) -} - -fn sequence_schema( - create_sequence: &ast::CreateSequence, - binder: &binder::Binder, -) -> Option { - let is_temp = - create_sequence.temp_token().is_some() || create_sequence.temporary_token().is_some(); - if is_temp { - return Some("pg_temp".to_string()); - } - - let position = create_sequence.syntax().text_range().start(); - let search_path = binder.search_path_at(position); - search_path.first().map(|s| s.to_string()) -} - fn format_create_index(create_index: &ast::CreateIndex, binder: &binder::Binder) -> Option { let index_name = create_index.name()?.syntax().text().to_string(); @@ -534,14 +702,7 @@ fn format_create_sequence( binder: &binder::Binder, ) -> Option { let path = create_sequence.path()?; - let segment = path.segment()?; - let sequence_name = segment.name()?.syntax().text().to_string(); - - let schema = if let Some(qualifier) = path.qualifier() { - qualifier.syntax().text().to_string() - } else { - sequence_schema(create_sequence, binder)? - }; + let (schema, sequence_name) = resolve::resolve_sequence_info(binder, &path)?; Some(format!("sequence {}.{}", schema, sequence_name)) } @@ -564,14 +725,7 @@ fn index_schema(create_index: &ast::CreateIndex, binder: &binder::Binder) -> Opt fn format_create_type(create_type: &ast::CreateType, binder: &binder::Binder) -> Option { let path = create_type.path()?; - let segment = path.segment()?; - let type_name = segment.name()?.syntax().text().to_string(); - - let schema = if let Some(qualifier) = path.qualifier() { - qualifier.syntax().text().to_string() - } else { - type_schema(create_type, binder)? - }; + let (schema, type_name) = resolve::resolve_type_info(binder, &path)?; if let Some(variant_list) = create_type.variant_list() { let variants = variant_list.syntax().text().to_string(); @@ -594,22 +748,15 @@ fn format_create_type(create_type: &ast::CreateType, binder: &binder::Binder) -> Some(format!("type {}.{}", schema, type_name)) } -fn type_schema(create_type: &ast::CreateType, binder: &binder::Binder) -> Option { - let position = create_type.syntax().text_range().start(); - let search_path = binder.search_path_at(position); - search_path.first().map(|s| s.to_string()) -} - fn hover_schema( - file: &ast::SourceFile, + root: &SyntaxNode, name_ref: &ast::NameRef, binder: &binder::Binder, ) -> Option { - let schema_ptr = resolve::resolve_name_ref(binder, name_ref)? + let schema_ptr = resolve::resolve_name_ref(binder, root, name_ref)? .into_iter() .next()?; - let root = file.syntax(); let schema_name_node = schema_ptr.to_node(root); let create_schema = schema_name_node @@ -637,15 +784,14 @@ fn format_create_schema(create_schema: &ast::CreateSchema) -> Option { } fn hover_function( - file: &ast::SourceFile, + root: &SyntaxNode, name_ref: &ast::NameRef, binder: &binder::Binder, ) -> Option { - let function_ptr = resolve::resolve_name_ref(binder, name_ref)? + let function_ptr = resolve::resolve_name_ref(binder, root, name_ref)? .into_iter() .next()?; - let root = file.syntax(); let function_name_node = function_ptr.to_node(root); let create_function = function_name_node @@ -660,15 +806,7 @@ fn format_create_function( binder: &binder::Binder, ) -> Option { let path = create_function.path()?; - let segment = path.segment()?; - let name = segment.name()?; - let function_name = name.syntax().text().to_string(); - - let schema = if let Some(qualifier) = path.qualifier() { - qualifier.syntax().text().to_string() - } else { - function_schema(create_function, binder)? - }; + let (schema, function_name) = resolve::resolve_function_info(binder, &path)?; let param_list = create_function.param_list()?; let params = param_list.syntax().text().to_string(); @@ -682,25 +820,15 @@ fn format_create_function( )) } -fn function_schema( - create_function: &ast::CreateFunction, - binder: &binder::Binder, -) -> Option { - let position = create_function.syntax().text_range().start(); - let search_path = binder.search_path_at(position); - search_path.first().map(|s| s.to_string()) -} - fn hover_aggregate( - file: &ast::SourceFile, + root: &SyntaxNode, name_ref: &ast::NameRef, binder: &binder::Binder, ) -> Option { - let aggregate_ptr = resolve::resolve_name_ref(binder, name_ref)? + let aggregate_ptr = resolve::resolve_name_ref(binder, root, name_ref)? .into_iter() .next()?; - let root = file.syntax(); let aggregate_name_node = aggregate_ptr.to_node(root); let create_aggregate = aggregate_name_node @@ -715,15 +843,7 @@ fn format_create_aggregate( binder: &binder::Binder, ) -> Option { let path = create_aggregate.path()?; - let segment = path.segment()?; - let name = segment.name()?; - let aggregate_name = name.syntax().text().to_string(); - - let schema = if let Some(qualifier) = path.qualifier() { - qualifier.syntax().text().to_string() - } else { - aggregate_schema(create_aggregate, binder)? - }; + let (schema, aggregate_name) = resolve::resolve_aggregate_info(binder, &path)?; let param_list = create_aggregate.param_list()?; let params = param_list.syntax().text().to_string(); @@ -731,25 +851,15 @@ fn format_create_aggregate( Some(format!("aggregate {}.{}{}", schema, aggregate_name, params)) } -fn aggregate_schema( - create_aggregate: &ast::CreateAggregate, - binder: &binder::Binder, -) -> Option { - let position = create_aggregate.syntax().text_range().start(); - let search_path = binder.search_path_at(position); - search_path.first().map(|s| s.to_string()) -} - fn hover_procedure( - file: &ast::SourceFile, + root: &SyntaxNode, name_ref: &ast::NameRef, binder: &binder::Binder, ) -> Option { - let procedure_ptr = resolve::resolve_name_ref(binder, name_ref)? + let procedure_ptr = resolve::resolve_name_ref(binder, root, name_ref)? .into_iter() .next()?; - let root = file.syntax(); let procedure_name_node = procedure_ptr.to_node(root); let create_procedure = procedure_name_node @@ -764,15 +874,7 @@ fn format_create_procedure( binder: &binder::Binder, ) -> Option { let path = create_procedure.path()?; - let segment = path.segment()?; - let name = segment.name()?; - let procedure_name = name.syntax().text().to_string(); - - let schema = if let Some(qualifier) = path.qualifier() { - qualifier.syntax().text().to_string() - } else { - procedure_schema(create_procedure, binder)? - }; + let (schema, procedure_name) = resolve::resolve_procedure_info(binder, &path)?; let param_list = create_procedure.param_list()?; let params = param_list.syntax().text().to_string(); @@ -780,46 +882,26 @@ fn format_create_procedure( Some(format!("procedure {}.{}{}", schema, procedure_name, params)) } -fn procedure_schema( - create_procedure: &ast::CreateProcedure, - binder: &binder::Binder, -) -> Option { - let position = create_procedure.syntax().text_range().start(); - let search_path = binder.search_path_at(position); - search_path.first().map(|s| s.to_string()) -} - fn hover_routine( - file: &ast::SourceFile, + root: &SyntaxNode, name_ref: &ast::NameRef, binder: &binder::Binder, ) -> Option { - let routine_ptr = resolve::resolve_name_ref(binder, name_ref)? + let routine_ptr = resolve::resolve_name_ref(binder, root, name_ref)? .into_iter() .next()?; + let routine_name = routine_ptr.to_node(root); - let root = file.syntax(); - let routine_name_node = routine_ptr.to_node(root); - - if let Some(create_function) = routine_name_node - .ancestors() - .find_map(ast::CreateFunction::cast) - { - return format_create_function(&create_function, binder); - } - - if let Some(create_aggregate) = routine_name_node - .ancestors() - .find_map(ast::CreateAggregate::cast) - { - return format_create_aggregate(&create_aggregate, binder); - } - - if let Some(create_procedure) = routine_name_node - .ancestors() - .find_map(ast::CreateProcedure::cast) - { - return format_create_procedure(&create_procedure, binder); + for a in routine_name.ancestors() { + if let Some(create_function) = ast::CreateFunction::cast(a.clone()) { + return format_create_function(&create_function, binder); + } + if let Some(create_aggregate) = ast::CreateAggregate::cast(a.clone()) { + return format_create_aggregate(&create_aggregate, binder); + } + if let Some(create_procedure) = ast::CreateProcedure::cast(a) { + return format_create_procedure(&create_procedure, binder); + } } None @@ -1766,6 +1848,30 @@ select * from users$0; "); } + #[test] + fn hover_on_subquery_qualified_table_ref() { + assert_snapshot!(check_hover(" +select t$0.a from (select 1 a) t; +"), @r" + hover: subquery t as (select 1 a) + ╭▸ + 2 │ select t.a from (select 1 a) t; + ╰╴ ─ hover + "); + } + + #[test] + fn hover_on_subquery_qualified_column_ref() { + assert_snapshot!(check_hover(" +select t.a$0 from (select 1 a) t; +"), @r" + hover: column t.a + ╭▸ + 2 │ select t.a from (select 1 a) t; + ╰╴ ─ hover + "); + } + #[test] fn hover_on_select_from_table_with_schema() { assert_snapshot!(check_hover(" @@ -1880,6 +1986,48 @@ select id$0 from users; "); } + #[test] + fn hover_on_select_qualified_star() { + assert_snapshot!(check_hover(" +create table u(id int, b int); +select u.*$0 from u; +"), @r" + hover: column public.u.id int + column public.u.b int + ╭▸ + 3 │ select u.* from u; + ╰╴ ─ hover + "); + } + + #[test] + fn hover_on_select_unqualified_star() { + assert_snapshot!(check_hover(" +create table u(id int, b int); +select *$0 from u; +"), @r" + hover: column public.u.id int + column public.u.b int + ╭▸ + 3 │ select * from u; + ╰╴ ─ hover + "); + } + + #[test] + fn hover_on_select_count_star() { + assert_snapshot!(check_hover(" +create table u(id int, b int); +select count(*$0) from u; +"), @r" + hover: column public.u.id int + column public.u.b int + ╭▸ + 3 │ select count(*) from u; + ╰╴ ─ hover + "); + } + #[test] fn hover_on_insert_table() { assert_snapshot!(check_hover(" @@ -2279,6 +2427,66 @@ select COLUMN1$0, COLUMN2 from t; "); } + #[test] + fn hover_on_cte_qualified_star() { + assert_snapshot!(check_hover(" +with u as (select 1 id, 2 b) +select u.*$0 from u; +"), @r" + hover: column u.id + column u.b + ╭▸ + 3 │ select u.* from u; + ╰╴ ─ hover + "); + } + + #[test] + fn hover_on_cte_values_qualified_star() { + assert_snapshot!(check_hover(" +with t as (values (1, 2), (3, 4)) +select t.*$0 from t; +"), @r" + hover: column t.column1 + column t.column2 + ╭▸ + 3 │ select t.* from t; + ╰╴ ─ hover + "); + } + + #[test] + fn hover_on_star_with_subquery_from_cte() { + assert_snapshot!(check_hover(" +with u as (select 1 id, 2 b) +select *$0 from (select *, *, * from u); +"), @r" + hover: column u.id + column u.b + column u.id + column u.b + column u.id + column u.b + ╭▸ + 3 │ select * from (select *, *, * from u); + ╰╴ ─ hover + "); + } + + #[test] + fn hover_on_view_qualified_star() { + assert_snapshot!(check_hover(" +create view v as select 1 id, 2 b; +select v.*$0 from v; +"), @r" + hover: column public.v.id + column public.v.b + ╭▸ + 3 │ select v.* from v; + ╰╴ ─ hover + "); + } + #[test] fn hover_on_drop_procedure() { assert_snapshot!(check_hover(" diff --git a/crates/squawk_ide/src/inlay_hints.rs b/crates/squawk_ide/src/inlay_hints.rs index d533b51d..043be56a 100644 --- a/crates/squawk_ide/src/inlay_hints.rs +++ b/crates/squawk_ide/src/inlay_hints.rs @@ -3,7 +3,10 @@ use crate::binder::Binder; use crate::resolve; use crate::symbols::Name; use rowan::{TextRange, TextSize}; -use squawk_syntax::ast::{self, AstNode}; +use squawk_syntax::{ + SyntaxNode, + ast::{self, AstNode}, +}; /// `VSCode` has some theming options based on these types. #[derive(Debug, Clone, Copy, PartialEq, Eq)] @@ -23,12 +26,13 @@ pub struct InlayHint { pub fn inlay_hints(file: &ast::SourceFile) -> Vec { let mut hints = vec![]; let binder = binder::bind(file); + let root = file.syntax(); - for node in file.syntax().descendants() { + for node in root.descendants() { if let Some(call_expr) = ast::CallExpr::cast(node.clone()) { - inlay_hint_call_expr(&mut hints, file, &binder, call_expr); + inlay_hint_call_expr(&mut hints, root, &binder, call_expr); } else if let Some(insert) = ast::Insert::cast(node) { - inlay_hint_insert(&mut hints, file, &binder, insert); + inlay_hint_insert(&mut hints, root, &binder, insert); } } @@ -37,7 +41,7 @@ pub fn inlay_hints(file: &ast::SourceFile) -> Vec { fn inlay_hint_call_expr( hints: &mut Vec, - file: &ast::SourceFile, + root: &SyntaxNode, binder: &Binder, call_expr: ast::CallExpr, ) -> Option<()> { @@ -50,11 +54,10 @@ fn inlay_hint_call_expr( ast::FieldExpr::cast(expr.syntax().clone())?.field()? }; - let function_ptr = resolve::resolve_name_ref(binder, &name_ref)? + let function_ptr = resolve::resolve_name_ref(binder, root, &name_ref)? .into_iter() .next()?; - let root = file.syntax(); let function_name_node = function_ptr.to_node(root); if let Some(create_function) = function_name_node @@ -81,13 +84,13 @@ fn inlay_hint_call_expr( fn inlay_hint_insert( hints: &mut Vec, - file: &ast::SourceFile, + root: &SyntaxNode, binder: &Binder, insert: ast::Insert, ) -> Option<()> { let values = insert.values()?; let row_list = values.row_list()?; - let create_table = resolve::resolve_insert_create_table(file, binder, &insert); + let create_table = resolve::resolve_insert_create_table(root, binder, &insert); let columns: Vec<(Name, Option)> = if let Some(column_list) = insert.column_list() { // `insert into t(a, b, c) values (1, 2, 3)` diff --git a/crates/squawk_ide/src/offsets.rs b/crates/squawk_ide/src/offsets.rs index 2a342b62..2d6b18de 100644 --- a/crates/squawk_ide/src/offsets.rs +++ b/crates/squawk_ide/src/offsets.rs @@ -6,10 +6,14 @@ use squawk_syntax::{ pub(crate) fn token_from_offset(file: &ast::SourceFile, offset: TextSize) -> Option { let mut token = file.syntax().token_at_offset(offset).right_biased()?; - // want to be lenient in case someone clicks the trailing `;` of a line - // instead of an identifier - // or if someone clicks the `,` in a target list, like `select a, b, c` - if token.kind() == SyntaxKind::SEMICOLON || token.kind() == SyntaxKind::COMMA { + // want to be lenient in case someone clicks: + // - the trailing `;` of a line + // - the `,` in a target list, like `select a, b, c` + // - the `.` following a table/schema/column, like `select t.a from t` + if matches!( + token.kind(), + SyntaxKind::SEMICOLON | SyntaxKind::COMMA | SyntaxKind::DOT + ) { token = token.prev_token()?; } return Some(token); diff --git a/crates/squawk_ide/src/resolve.rs b/crates/squawk_ide/src/resolve.rs index 6effbb88..3ef95157 100644 --- a/crates/squawk_ide/src/resolve.rs +++ b/crates/squawk_ide/src/resolve.rs @@ -13,6 +13,7 @@ use crate::symbols::{Name, SymbolKind}; pub(crate) fn resolve_name_ref( binder: &Binder, + root: &SyntaxNode, name_ref: &ast::NameRef, ) -> Option> { let context = classify_name_ref(name_ref)?; @@ -114,7 +115,8 @@ pub(crate) fn resolve_name_ref( let path = sequence_option.path()?; let column_name = Name::from_node(name_ref); let table_path = path.qualifier()?; - resolve_column_for_path(binder, &table_path, column_name).map(|ptr| smallvec![ptr]) + resolve_column_for_path(binder, root, &table_path, column_name) + .map(|ptr| smallvec![ptr]) } NameRefClass::Tablespace => { let tablespace_name = Name::from_node(name_ref); @@ -145,7 +147,7 @@ pub(crate) fn resolve_name_ref( return None; }; let column_name = Name::from_node(name_ref); - resolve_column_for_path(binder, &path, column_name).map(|ptr| smallvec![ptr]) + resolve_column_for_path(binder, root, &path, column_name).map(|ptr| smallvec![ptr]) } NameRefClass::ForeignKeyLocalColumn => { let create_table = name_ref @@ -370,7 +372,7 @@ pub(crate) fn resolve_name_ref( // select a(t) from t; // ``` if schema.is_none() - && let Some(ptr) = resolve_fn_call_column(binder, name_ref) + && let Some(ptr) = resolve_fn_call_column(binder, root, name_ref) { return Some(smallvec![ptr]); } @@ -378,30 +380,30 @@ pub(crate) fn resolve_name_ref( None } NameRefClass::CreateIndexColumn => { - resolve_create_index_column(binder, name_ref).map(|ptr| smallvec![ptr]) + resolve_create_index_column(binder, root, name_ref).map(|ptr| smallvec![ptr]) } NameRefClass::SelectColumn => { - resolve_select_column(binder, name_ref).map(|ptr| smallvec![ptr]) + resolve_select_column(binder, root, name_ref).map(|ptr| smallvec![ptr]) } NameRefClass::SelectQualifiedColumnTable => { resolve_select_qualified_column_table(binder, name_ref).map(|ptr| smallvec![ptr]) } NameRefClass::SelectQualifiedColumn => { - resolve_select_qualified_column(binder, name_ref).map(|ptr| smallvec![ptr]) + resolve_select_qualified_column(binder, root, name_ref).map(|ptr| smallvec![ptr]) } NameRefClass::CompositeTypeField => { - resolve_composite_type_field(binder, name_ref).map(|ptr| smallvec![ptr]) + resolve_composite_type_field(binder, root, name_ref).map(|ptr| smallvec![ptr]) } NameRefClass::InsertColumn => { - resolve_insert_column(binder, name_ref).map(|ptr| smallvec![ptr]) + resolve_insert_column(binder, root, name_ref).map(|ptr| smallvec![ptr]) } NameRefClass::DeleteWhereColumn => { - resolve_delete_where_column(binder, name_ref).map(|ptr| smallvec![ptr]) + resolve_delete_where_column(binder, root, name_ref).map(|ptr| smallvec![ptr]) } NameRefClass::UpdateWhereColumn | NameRefClass::UpdateSetColumn => { - resolve_update_where_column(binder, name_ref).map(|ptr| smallvec![ptr]) + resolve_update_where_column(binder, root, name_ref).map(|ptr| smallvec![ptr]) } - NameRefClass::JoinUsingColumn => resolve_join_using_columns(binder, name_ref), + NameRefClass::JoinUsingColumn => resolve_join_using_columns(binder, root, name_ref), NameRefClass::UpdateFromTable => { let table_name = Name::from_node(name_ref); let schema = if let Some(parent) = name_ref.syntax().parent() @@ -631,7 +633,11 @@ fn resolve_schema(binder: &Binder, schema_name: &Name) -> Option Some(binder.symbols[symbol_id].ptr) } -fn resolve_create_index_column(binder: &Binder, name_ref: &ast::NameRef) -> Option { +fn resolve_create_index_column( + binder: &Binder, + root: &SyntaxNode, + name_ref: &ast::NameRef, +) -> Option { let column_name = Name::from_node(name_ref); let create_index = name_ref @@ -641,11 +647,12 @@ fn resolve_create_index_column(binder: &Binder, name_ref: &ast::NameRef) -> Opti let relation_name = create_index.relation_name()?; let path = relation_name.path()?; - resolve_column_for_path(binder, &path, column_name) + resolve_column_for_path(binder, root, &path, column_name) } fn resolve_column_for_path( binder: &Binder, + root: &SyntaxNode, path: &ast::Path, column_name: Name, ) -> Option { @@ -655,7 +662,6 @@ fn resolve_column_for_path( let table_ptr = resolve_table(binder, &table_name, &schema, position)?; - let root = &path.syntax().ancestors().last()?; let table_name_node = table_ptr.to_node(root); let create_table = table_name_node @@ -665,13 +671,17 @@ fn resolve_column_for_path( find_column_in_create_table(&create_table, &column_name) } -fn resolve_insert_column(binder: &Binder, name_ref: &ast::NameRef) -> Option { +fn resolve_insert_column( + binder: &Binder, + root: &SyntaxNode, + name_ref: &ast::NameRef, +) -> Option { let column_name = Name::from_node(name_ref); let insert = name_ref.syntax().ancestors().find_map(ast::Insert::cast)?; let path = insert.path()?; - resolve_column_for_path(binder, &path, column_name) + resolve_column_for_path(binder, root, &path, column_name) } fn resolve_select_qualified_column_table( @@ -720,6 +730,10 @@ fn resolve_select_qualified_column_table( } let (table_name, schema) = if let Some(name_ref_node) = from_item.name_ref() { + if let Some(cte_ptr) = resolve_cte_table(name_ref, &table_name) { + return Some(cte_ptr); + } + // `from foo` let from_table_name = Name::from_node(&name_ref_node); if from_table_name == table_name { @@ -748,6 +762,7 @@ fn resolve_select_qualified_column_table( // TODO: this is similar to resolve_from_item_for_column, maybe we can simplify fn resolve_select_qualified_column( binder: &Binder, + root: &SyntaxNode, name_ref: &ast::NameRef, ) -> Option { let column_name = Name::from_node(name_ref); @@ -790,6 +805,16 @@ fn resolve_select_qualified_column( && let Some(alias_name) = alias.name() && Name::from_node(&alias_name) == column_table_name { + if let Some(paren_select) = from_item.paren_select() { + return resolve_subquery_column( + binder, + root, + &paren_select, + name_ref, + &column_name, + ); + } + // `from t as u(a, b, c)` if let Some(column_list) = alias.column_list() { for column in column_list.columns() { @@ -807,7 +832,7 @@ fn resolve_select_qualified_column( // ``` if let Some(name_ref_node) = from_item.name_ref() { let cte_name = Name::from_node(&name_ref_node); - return resolve_cte_column(name_ref, &cte_name, &column_name); + return resolve_cte_column(binder, root, name_ref, &cte_name, &column_name); } } @@ -848,7 +873,12 @@ fn resolve_select_qualified_column( } }; - let root = &name_ref.syntax().ancestors().last()?; + if schema.is_none() + && let Some(cte_column_ptr) = + resolve_cte_column(binder, root, name_ref, &table_name, &column_name) + { + return Some(cte_column_ptr); + } if let Some(table_ptr) = resolve_table(binder, &table_name, &schema, position) { let table_name_node = table_ptr.to_node(root); @@ -882,45 +912,47 @@ fn resolve_select_qualified_column( fn resolve_from_item_for_column( binder: &Binder, + root: &SyntaxNode, from_item: &ast::FromItem, name_ref: &ast::NameRef, ) -> Option { let column_name = Name::from_node(name_ref); if let Some(paren_select) = from_item.paren_select() { - return resolve_subquery_column(binder, &paren_select, name_ref, &column_name); + return resolve_subquery_column(binder, root, &paren_select, name_ref, &column_name); } if let Some(paren_expr) = from_item.paren_expr() { - return resolve_column_from_paren_expr(binder, &paren_expr, name_ref, &column_name); + return resolve_column_from_paren_expr(binder, root, &paren_expr, name_ref, &column_name); } - let (table_name, schema) = if let Some(name_ref_node) = from_item.name_ref() { - (Name::from_node(&name_ref_node), None) - } else { - let field_expr = from_item.field_expr()?; - let table_name = Name::from_node(&field_expr.field()?); - let ast::Expr::NameRef(schema_name_ref) = field_expr.base()? else { - return None; - }; - let schema = Schema(Name::from_node(&schema_name_ref)); - (table_name, Some(schema)) - }; + let (table_name, schema) = table_and_schema_from_from_item(from_item)?; if schema.is_none() - && let Some(cte_column_ptr) = resolve_cte_column(name_ref, &table_name, &column_name) + && let Some(cte_column_ptr) = + resolve_cte_column(binder, root, name_ref, &table_name, &column_name) { return Some(cte_column_ptr); } + resolve_column_from_table_or_view(binder, root, name_ref, &table_name, &schema, &column_name) +} + +fn resolve_column_from_table_or_view( + binder: &Binder, + root: &SyntaxNode, + name_ref: &ast::NameRef, + table_name: &Name, + schema: &Option, + column_name: &Name, +) -> Option { let position = name_ref.syntax().text_range().start(); - let root = &name_ref.syntax().ancestors().last()?; - if let Some(table_ptr) = resolve_table(binder, &table_name, &schema, position) { + if let Some(table_ptr) = resolve_table(binder, table_name, schema, position) { let table_name_node = table_ptr.to_node(root); if let Some(create_table) = table_name_node.ancestors().find_map(ast::CreateTable::cast) { // 1. try to find a matching column - if let Some(ptr) = find_column_in_create_table(&create_table, &column_name) { + if let Some(ptr) = find_column_in_create_table(&create_table, column_name) { return Some(ptr); } @@ -937,11 +969,11 @@ fn resolve_from_item_for_column( } // ditto as above but with view - if let Some(view_ptr) = resolve_view(binder, &table_name, &schema, position) { + if let Some(view_ptr) = resolve_view(binder, table_name, schema, position) { let view_name_node = view_ptr.to_node(root); if let Some(create_view) = view_name_node.ancestors().find_map(ast::CreateView::cast) { - if let Some(ptr) = find_column_in_create_view(&create_view, &column_name) { + if let Some(ptr) = find_column_in_create_view(&create_view, column_name) { return Some(ptr); } @@ -954,6 +986,30 @@ fn resolve_from_item_for_column( None } +fn resolve_from_item_for_cte_star( + binder: &Binder, + root: &SyntaxNode, + from_item: &ast::FromItem, + name_ref: &ast::NameRef, + cte_name: &Name, + column_name: &Name, +) -> Option { + if let Some((table_name, schema)) = table_and_schema_from_from_item(from_item) + && table_name == *cte_name + { + return resolve_column_from_table_or_view( + binder, + root, + name_ref, + &table_name, + &schema, + column_name, + ); + } + + resolve_from_item_for_column(binder, root, from_item, name_ref) +} + fn resolve_from_join_expr(join_expr: &ast::JoinExpr, try_resolve: &F) -> Option where F: Fn(&ast::FromItem) -> Option, @@ -977,19 +1033,23 @@ where None } -fn resolve_select_column(binder: &Binder, name_ref: &ast::NameRef) -> Option { +fn resolve_select_column( + binder: &Binder, + root: &SyntaxNode, + name_ref: &ast::NameRef, +) -> Option { let select = name_ref.syntax().ancestors().find_map(ast::Select::cast)?; let from_clause = select.from_clause()?; for from_item in from_clause.from_items() { - if let Some(result) = resolve_from_item_for_column(binder, &from_item, name_ref) { + if let Some(result) = resolve_from_item_for_column(binder, root, &from_item, name_ref) { return Some(result); } } for join_expr in from_clause.join_exprs() { if let Some(result) = resolve_from_join_expr(&join_expr, &|from_item: &ast::FromItem| { - resolve_from_item_for_column(binder, from_item, name_ref) + resolve_from_item_for_column(binder, root, from_item, name_ref) }) { return Some(result); } @@ -998,18 +1058,23 @@ fn resolve_select_column(binder: &Binder, name_ref: &ast::NameRef) -> Option Option { +fn resolve_delete_where_column( + binder: &Binder, + root: &SyntaxNode, + name_ref: &ast::NameRef, +) -> Option { let column_name = Name::from_node(name_ref); let delete = name_ref.syntax().ancestors().find_map(ast::Delete::cast)?; let relation_name = delete.relation_name()?; let path = relation_name.path()?; - resolve_column_for_path(binder, &path, column_name) + resolve_column_for_path(binder, root, &path, column_name) } fn resolve_join_using_columns( binder: &Binder, + root: &SyntaxNode, name_ref: &ast::NameRef, ) -> Option> { let join_expr = name_ref @@ -1020,7 +1085,7 @@ fn resolve_join_using_columns( let mut results: SmallVec<[SyntaxNodePtr; 1]> = SmallVec::new(); collect_from_join_expr(&join_expr, &mut results, &|from_item: &ast::FromItem| { - resolve_from_item_for_column(binder, from_item, name_ref) + resolve_from_item_for_column(binder, root, from_item, name_ref) }); (!results.is_empty()).then_some(results) @@ -1050,7 +1115,11 @@ fn collect_from_join_expr( } } -fn resolve_update_where_column(binder: &Binder, name_ref: &ast::NameRef) -> Option { +fn resolve_update_where_column( + binder: &Binder, + root: &SyntaxNode, + name_ref: &ast::NameRef, +) -> Option { let column_name = Name::from_node(name_ref); let update = name_ref.syntax().ancestors().find_map(ast::Update::cast)?; @@ -1058,7 +1127,7 @@ fn resolve_update_where_column(binder: &Binder, name_ref: &ast::NameRef) -> Opti // `update t set a = b from u` if let Some(from_clause) = update.from_clause() { for from_item in from_clause.from_items() { - if let Some(result) = resolve_from_item_for_column(binder, &from_item, name_ref) { + if let Some(result) = resolve_from_item_for_column(binder, root, &from_item, name_ref) { return Some(result); } } @@ -1066,7 +1135,7 @@ fn resolve_update_where_column(binder: &Binder, name_ref: &ast::NameRef) -> Opti for join_expr in from_clause.join_exprs() { if let Some(result) = resolve_from_join_expr(&join_expr, &|from_item: &ast::FromItem| { - resolve_from_item_for_column(binder, from_item, name_ref) + resolve_from_item_for_column(binder, root, from_item, name_ref) }) { return Some(result); @@ -1078,10 +1147,14 @@ fn resolve_update_where_column(binder: &Binder, name_ref: &ast::NameRef) -> Opti let relation_name = update.relation_name()?; let path = relation_name.path()?; - resolve_column_for_path(binder, &path, column_name) + resolve_column_for_path(binder, root, &path, column_name) } -fn resolve_fn_call_column(binder: &Binder, name_ref: &ast::NameRef) -> Option { +fn resolve_fn_call_column( + binder: &Binder, + root: &SyntaxNode, + name_ref: &ast::NameRef, +) -> Option { let column_name = Name::from_node(name_ref); // function call syntax for columns is only valid if there is one argument @@ -1099,7 +1172,7 @@ fn resolve_fn_call_column(binder: &Binder, name_ref: &ast::NameRef) -> Option Option Option Option { - let (table_name, schema) = if let Some(name_ref_node) = from_item.name_ref() { - (Name::from_node(&name_ref_node), None) - } else { - let field_expr = from_item.field_expr()?; - let table_name = Name::from_node(&field_expr.field()?); - let ast::Expr::NameRef(schema_name_ref) = field_expr.base()? else { - return None; - }; - let schema = Schema(Name::from_node(&schema_name_ref)); - (table_name, Some(schema)) - }; + let (table_name, schema) = table_and_schema_from_from_item(from_item)?; let position = name_ref.syntax().text_range().start(); let table_ptr = resolve_table(binder, &table_name, &schema, position)?; - let root = &name_ref.syntax().ancestors().last()?; let table_name_node = table_ptr.to_node(root); let create_table = table_name_node .ancestors() @@ -1146,6 +1209,20 @@ fn resolve_from_item_for_fn_call_column( find_column_in_create_table(&create_table, column_name) } +fn table_and_schema_from_from_item(from_item: &ast::FromItem) -> Option<(Name, Option)> { + if let Some(name_ref_node) = from_item.name_ref() { + return Some((Name::from_node(&name_ref_node), None)); + } + + let field_expr = from_item.field_expr()?; + let table_name = Name::from_node(&field_expr.field()?); + let ast::Expr::NameRef(schema_name_ref) = field_expr.base()? else { + return None; + }; + let schema = Schema(Name::from_node(&schema_name_ref)); + Some((table_name, Some(schema))) +} + fn is_from_item_match(from_item: &ast::FromItem, qualifier: &Name) -> bool { if let Some(alias_name) = from_item.alias().and_then(|a| a.name()) && Name::from_node(&alias_name) == *qualifier @@ -1195,7 +1272,7 @@ fn find_from_item_in_join_expr( None } -fn find_from_item_in_from_clause( +pub(crate) fn find_from_item_in_from_clause( from_clause: &ast::FromClause, qualifier: &Name, ) -> Option { @@ -1280,8 +1357,7 @@ fn find_column_in_create_view( 0 }; - let query = create_view.query()?; - let select = match query { + let select = match create_view.query()? { ast::SelectVariant::Select(s) => s, ast::SelectVariant::ParenSelect(ps) => match ps.select()? { ast::SelectVariant::Select(s) => s, @@ -1340,6 +1416,8 @@ fn find_parent_with_clause(node: &SyntaxNode) -> Option { } fn resolve_cte_column( + binder: &Binder, + root: &SyntaxNode, name_ref: &ast::NameRef, cte_name: &Name, column_name: &Name, @@ -1387,28 +1465,26 @@ fn resolve_cte_column( continue; } - let select_variant = match query { - ast::WithQuery::Select(s) => ast::SelectVariant::Select(s), - ast::WithQuery::ParenSelect(ps) => ps.select()?, - ast::WithQuery::CompoundSelect(compound) => compound.lhs()?, - _ => continue, - }; - - let cte_select = match select_variant { - ast::SelectVariant::Select(s) => s, - ast::SelectVariant::CompoundSelect(compound) => match compound.lhs()? { - ast::SelectVariant::Select(s) => s, - _ => continue, - }, - _ => continue, + let Some(cte_select) = select_from_with_query(query) else { + continue; }; let select_clause = cte_select.select_clause()?; let target_list = select_clause.target_list()?; + let from_clause = cte_select.from_clause(); + let mut column_index: usize = 0; - for (idx, target) in target_list.targets().enumerate() { + for target in target_list.targets() { // Skip targets that are covered by the column list - if idx < column_list_len { + let target_column_count = from_clause + .as_ref() + .and_then(|from_clause| { + count_columns_for_target(binder, root, name_ref, &target, from_clause) + }) + .unwrap_or(1); + let column_list_end = column_index.saturating_add(target_column_count); + if column_list_end <= column_list_len { + column_index = column_list_end; continue; } @@ -1419,20 +1495,38 @@ fn resolve_cte_column( return Some(SyntaxNodePtr::new(&node)); } - if matches!(col_name, ColumnName::Star) { - if let Some(from_clause) = cte_select.from_clause() - && let Some(from_item) = from_clause.from_items().next() - && let Some(from_name_ref) = from_item.name_ref() - { - let from_table_name = Name::from_node(&from_name_ref); - // Skip recursive CTE lookup if the FROM table has the same name as the current CTE - // (CTEs don't shadow themselves in their own definition) - if from_table_name != *cte_name { - return resolve_cte_column(name_ref, &from_table_name, column_name); - } - } + if matches!(col_name, ColumnName::Star) + && let Some(from_clause) = &from_clause + && let Some(result) = resolve_from_clause_for_cte_star( + binder, + root, + name_ref, + cte_name, + column_name, + from_clause, + ) + { + return Some(result); } } + if let Some(expr) = target.expr() + && let ast::Expr::FieldExpr(field_expr) = expr + && let Some(table_name) = qualified_star_table_name(&field_expr) + && let Some(from_clause) = &from_clause + && let Some(result) = resolve_qualified_star_in_from_clause( + binder, + root, + name_ref, + cte_name, + column_name, + from_clause, + &table_name, + ) + { + return Some(result); + } + + column_index = column_list_end; } } } @@ -1442,6 +1536,7 @@ fn resolve_cte_column( fn resolve_subquery_column( binder: &Binder, + root: &SyntaxNode, paren_select: &ast::ParenSelect, name_ref: &ast::NameRef, column_name: &Name, @@ -1465,7 +1560,7 @@ fn resolve_subquery_column( if let Some(from_clause) = subquery_select.from_clause() { for from_item in from_clause.from_items() { if let Some(result) = - resolve_from_item_for_column(binder, &from_item, name_ref) + resolve_from_item_for_column(binder, root, &from_item, name_ref) { return Some(result); } @@ -1474,7 +1569,7 @@ fn resolve_subquery_column( for join_expr in from_clause.join_exprs() { if let Some(result) = resolve_from_join_expr(&join_expr, &|from_item: &ast::FromItem| { - resolve_from_item_for_column(binder, from_item, name_ref) + resolve_from_item_for_column(binder, root, from_item, name_ref) }) { return Some(result); @@ -1483,13 +1578,442 @@ fn resolve_subquery_column( } } } + + if let Some(expr) = target.expr() + && let ast::Expr::FieldExpr(field_expr) = expr + && let Some(table_name) = qualified_star_table_name(&field_expr) + && let Some(from_clause) = subquery_select.from_clause() + && let Some(from_item) = find_from_item_in_from_clause(&from_clause, &table_name) + && let Some(result) = resolve_from_item_for_column(binder, root, &from_item, name_ref) + { + return Some(result); + } } None } +fn qualified_star_table_name(field_expr: &ast::FieldExpr) -> Option { + field_expr.star_token()?; + + match field_expr.base()? { + ast::Expr::NameRef(name_ref) => Some(Name::from_node(&name_ref)), + ast::Expr::FieldExpr(inner_field_expr) => { + let field = inner_field_expr.field()?; + Some(Name::from_node(&field)) + } + _ => None, + } +} + +pub(crate) fn resolve_qualified_star_table( + binder: &Binder, + field_expr: &ast::FieldExpr, +) -> Option { + let table_name = qualified_star_table_name(field_expr)?; + let select = field_expr + .syntax() + .ancestors() + .find_map(ast::Select::cast)?; + let from_clause = select.from_clause()?; + let from_item = find_from_item_in_from_clause(&from_clause, &table_name)?; + let (table_name, schema) = table_and_schema_from_from_item(&from_item)?; + let position = field_expr.syntax().text_range().start(); + + if let Some(ptr) = resolve_table(binder, &table_name, &schema, position) { + return Some(ptr); + } + + if let Some(ptr) = resolve_view(binder, &table_name, &schema, position) { + return Some(ptr); + } + + if schema.is_none() + && let Some(name_ref) = from_item.name_ref() + { + return resolve_cte_table(&name_ref, &table_name); + } + + None +} + +pub(crate) fn resolve_unqualified_star_tables( + binder: &Binder, + target: &ast::Target, +) -> Option> { + target.star_token()?; + + let select = target.syntax().ancestors().find_map(ast::Select::cast)?; + let from_clause = select.from_clause()?; + let position = target.syntax().text_range().start(); + + let mut results = vec![]; + + for from_item in from_clause.from_items() { + collect_tables_from_item(binder, position, &from_item, &mut results); + } + + for join_expr in from_clause.join_exprs() { + collect_tables_from_join_expr(binder, position, &join_expr, &mut results); + } + + if results.is_empty() { + return None; + } + + Some(results) +} + +pub(crate) fn resolve_unqualified_star_tables_in_arg_list( + binder: &Binder, + arg_list: &ast::ArgList, +) -> Option> { + let select = arg_list.syntax().ancestors().find_map(ast::Select::cast)?; + let from_clause = select.from_clause()?; + let position = arg_list.syntax().text_range().start(); + + let mut results = vec![]; + + for from_item in from_clause.from_items() { + collect_tables_from_item(binder, position, &from_item, &mut results); + } + + for join_expr in from_clause.join_exprs() { + collect_tables_from_join_expr(binder, position, &join_expr, &mut results); + } + + if results.is_empty() { + return None; + } + + Some(results) +} + +fn collect_tables_from_join_expr( + binder: &Binder, + position: TextSize, + join_expr: &ast::JoinExpr, + results: &mut Vec, +) { + if let Some(nested) = join_expr.join_expr() { + collect_tables_from_join_expr(binder, position, &nested, results); + } + + if let Some(from_item) = join_expr.from_item() { + collect_tables_from_item(binder, position, &from_item, results); + } + + if let Some(join) = join_expr.join() + && let Some(from_item) = join.from_item() + { + collect_tables_from_item(binder, position, &from_item, results); + } +} + +fn collect_tables_from_item( + binder: &Binder, + position: TextSize, + from_item: &ast::FromItem, + results: &mut Vec, +) { + if let Some(paren_select) = from_item.paren_select() { + results.push(SyntaxNodePtr::new(paren_select.syntax())); + return; + } + + let Some((table_name, schema)) = table_and_schema_from_from_item(from_item) else { + return; + }; + + if let Some(ptr) = resolve_table(binder, &table_name, &schema, position) { + results.push(ptr); + return; + } + + if let Some(ptr) = resolve_view(binder, &table_name, &schema, position) { + results.push(ptr); + return; + } + + if schema.is_none() + && let Some(name_ref) = from_item.name_ref() + && let Some(cte_ptr) = resolve_cte_table(&name_ref, &table_name) + { + results.push(cte_ptr); + return; + } +} + +pub(crate) enum TableSource { + WithTable(ast::WithTable), + CreateView(ast::CreateView), + CreateTable(ast::CreateTable), +} + +pub(crate) fn find_table_source(node: &SyntaxNode) -> Option { + for ancestor in node.ancestors() { + if let Some(with_table) = ast::WithTable::cast(ancestor.clone()) { + return Some(TableSource::WithTable(with_table)); + } + + if let Some(create_view) = ast::CreateView::cast(ancestor.clone()) { + return Some(TableSource::CreateView(create_view)); + } + + if let Some(create_table) = ast::CreateTable::cast(ancestor) { + return Some(TableSource::CreateTable(create_table)); + } + } + + None +} + +pub(crate) fn select_from_with_query(query: ast::WithQuery) -> Option { + let select_variant = match query { + ast::WithQuery::Select(s) => ast::SelectVariant::Select(s), + ast::WithQuery::ParenSelect(ps) => ps.select()?, + ast::WithQuery::CompoundSelect(compound) => compound.lhs()?, + _ => return None, + }; + + match select_variant { + ast::SelectVariant::Select(s) => Some(s), + ast::SelectVariant::CompoundSelect(compound) => match compound.lhs()? { + ast::SelectVariant::Select(s) => Some(s), + _ => None, + }, + _ => None, + } +} + +fn count_columns_for_target( + binder: &Binder, + root: &SyntaxNode, + name_ref: &ast::NameRef, + target: &ast::Target, + from_clause: &ast::FromClause, +) -> Option { + if target.star_token().is_some() { + return count_columns_for_from_clause(binder, root, name_ref, from_clause); + } + + if let Some(expr) = target.expr() + && let ast::Expr::FieldExpr(field_expr) = expr + && let Some(table_name) = qualified_star_table_name(&field_expr) + && let Some(from_item) = find_from_item_in_from_clause(from_clause, &table_name) + { + return count_columns_for_from_item(binder, root, name_ref, &from_item); + } + + Some(1) +} + +fn count_columns_for_from_clause( + binder: &Binder, + root: &SyntaxNode, + name_ref: &ast::NameRef, + from_clause: &ast::FromClause, +) -> Option { + let mut total: usize = 0; + let mut found = false; + + for from_item in from_clause.from_items() { + if let Some(count) = count_columns_for_from_item(binder, root, name_ref, &from_item) { + total = total.saturating_add(count); + found = true; + } + } + + for join_expr in from_clause.join_exprs() { + if let Some(count) = count_columns_for_join_expr(binder, root, name_ref, &join_expr) { + total = total.saturating_add(count); + found = true; + } + } + + found.then_some(total) +} + +fn count_columns_for_join_expr( + binder: &Binder, + root: &SyntaxNode, + name_ref: &ast::NameRef, + join_expr: &ast::JoinExpr, +) -> Option { + let mut total: usize = 0; + let mut found = false; + + if let Some(nested) = join_expr.join_expr() + && let Some(count) = count_columns_for_join_expr(binder, root, name_ref, &nested) + { + total = total.saturating_add(count); + found = true; + } + + if let Some(from_item) = join_expr.from_item() + && let Some(count) = count_columns_for_from_item(binder, root, name_ref, &from_item) + { + total = total.saturating_add(count); + found = true; + } + + if let Some(join) = join_expr.join() + && let Some(from_item) = join.from_item() + && let Some(count) = count_columns_for_from_item(binder, root, name_ref, &from_item) + { + total = total.saturating_add(count); + found = true; + } + + found.then_some(total) +} + +fn count_columns_for_from_item( + binder: &Binder, + root: &SyntaxNode, + name_ref: &ast::NameRef, + from_item: &ast::FromItem, +) -> Option { + let (table_name, schema) = table_and_schema_from_from_item(from_item)?; + let position = name_ref.syntax().text_range().start(); + + if let Some(table_ptr) = resolve_table(binder, &table_name, &schema, position) { + let table_name_node = table_ptr.to_node(root); + + if let Some(create_table) = table_name_node.ancestors().find_map(ast::CreateTable::cast) { + let mut count: usize = 0; + if let Some(args) = create_table.table_arg_list() { + for arg in args.args() { + if matches!(arg, ast::TableArg::Column(_)) { + count = count.saturating_add(1); + } + } + } + return Some(count); + } + } + + if let Some(view_ptr) = resolve_view(binder, &table_name, &schema, position) { + let view_name_node = view_ptr.to_node(root); + + if let Some(create_view) = view_name_node.ancestors().find_map(ast::CreateView::cast) { + if let Some(column_list) = create_view.column_list() { + return Some(column_list.columns().count()); + } + + let select = match create_view.query()? { + ast::SelectVariant::Select(s) => s, + ast::SelectVariant::ParenSelect(ps) => match ps.select()? { + ast::SelectVariant::Select(s) => s, + _ => return None, + }, + _ => return None, + }; + + if let Some(target_list) = select.select_clause().and_then(|c| c.target_list()) { + return Some(target_list.targets().count()); + } + } + } + + if schema.is_none() + && let Some(cte_column_count) = count_columns_for_cte(name_ref, &table_name) + { + return Some(cte_column_count); + } + + None +} + +fn count_columns_for_cte(name_ref: &ast::NameRef, cte_name: &Name) -> Option { + let with_clause = find_parent_with_clause(name_ref.syntax())?; + + for with_table in with_clause.with_tables() { + if let Some(name) = with_table.name() + && Name::from_node(&name) == *cte_name + { + if with_table + .syntax() + .text_range() + .contains_range(name_ref.syntax().text_range()) + { + return None; + } + + if let Some(column_list) = with_table.column_list() { + return Some(column_list.columns().count()); + } + + let query = with_table.query()?; + + if let ast::WithQuery::Values(values) = query { + if let Some(row_list) = values.row_list() + && let Some(first_row) = row_list.rows().next() + { + return Some(first_row.exprs().count()); + } + return None; + } + + let select = select_from_with_query(query)?; + + if let Some(target_list) = select.select_clause().and_then(|c| c.target_list()) { + return Some(target_list.targets().count()); + } + } + } + + None +} + +fn resolve_from_clause_for_cte_star( + binder: &Binder, + root: &SyntaxNode, + name_ref: &ast::NameRef, + cte_name: &Name, + column_name: &Name, + from_clause: &ast::FromClause, +) -> Option { + for from_item in from_clause.from_items() { + if let Some(result) = resolve_from_item_for_cte_star( + binder, + root, + &from_item, + name_ref, + cte_name, + column_name, + ) { + return Some(result); + } + } + + for join_expr in from_clause.join_exprs() { + if let Some(result) = resolve_from_join_expr(&join_expr, &|from_item: &ast::FromItem| { + resolve_from_item_for_cte_star(binder, root, from_item, name_ref, cte_name, column_name) + }) { + return Some(result); + } + } + + None +} + +fn resolve_qualified_star_in_from_clause( + binder: &Binder, + root: &SyntaxNode, + name_ref: &ast::NameRef, + cte_name: &Name, + column_name: &Name, + from_clause: &ast::FromClause, + table_name: &Name, +) -> Option { + let from_item = find_from_item_in_from_clause(from_clause, table_name)?; + resolve_from_item_for_cte_star(binder, root, &from_item, name_ref, cte_name, column_name) +} + fn resolve_column_from_paren_expr( binder: &Binder, + root: &SyntaxNode, paren_expr: &ast::ParenExpr, name_ref: &ast::NameRef, column_name: &Name, @@ -1511,20 +2035,20 @@ fn resolve_column_from_paren_expr( } if let Some(ast::Expr::ParenExpr(paren_expr)) = paren_expr.expr() { - return resolve_column_from_paren_expr(binder, &paren_expr, name_ref, column_name); + return resolve_column_from_paren_expr(binder, root, &paren_expr, name_ref, column_name); } if let Some(from_item) = paren_expr.from_item() && let Some(paren_select) = from_item.paren_select() { - return resolve_subquery_column(binder, &paren_select, name_ref, column_name); + return resolve_subquery_column(binder, root, &paren_select, name_ref, column_name); } None } pub(crate) fn resolve_insert_create_table( - file: &ast::SourceFile, + root: &SyntaxNode, binder: &Binder, insert: &ast::Insert, ) -> Option { @@ -1534,7 +2058,6 @@ pub(crate) fn resolve_insert_create_table( let position = insert.syntax().text_range().start(); let table_ptr = resolve_table(binder, &table_name, &schema, position)?; - let root = file.syntax(); let table_name_node = table_ptr.to_node(root); table_name_node.ancestors().find_map(ast::CreateTable::cast) @@ -1577,6 +2100,80 @@ pub(crate) fn resolve_materialized_view_info( resolve_symbol_info(binder, path, SymbolKind::View) } +pub(crate) fn resolve_sequence_info(binder: &Binder, path: &ast::Path) -> Option<(Schema, String)> { + resolve_symbol_info(binder, path, SymbolKind::Sequence) +} + +pub(crate) fn collect_table_columns(create_table: &ast::CreateTable) -> Vec { + let mut columns = vec![]; + if let Some(arg_list) = create_table.table_arg_list() { + for arg in arg_list.args() { + if let ast::TableArg::Column(column) = arg { + columns.push(column); + } + } + } + columns +} + +pub(crate) fn collect_view_column_names(create_view: &ast::CreateView) -> Vec { + if let Some(column_list) = create_view.column_list() { + let columns = collect_column_names_from_column_list(&column_list); + if !columns.is_empty() { + return columns; + } + } + + let Some(select) = select_from_view_query(create_view) else { + return vec![]; + }; + let Some(select_clause) = select.select_clause() else { + return vec![]; + }; + let Some(target_list) = select_clause.target_list() else { + return vec![]; + }; + + collect_target_list_column_names(&target_list) +} + +pub(crate) fn collect_with_table_column_names(with_table: &ast::WithTable) -> Vec { + if let Some(column_list) = with_table.column_list() { + let columns = collect_column_names_from_column_list(&column_list); + if !columns.is_empty() { + return columns; + } + } + + let Some(query) = with_table.query() else { + return vec![]; + }; + + if let ast::WithQuery::Values(values) = query { + let mut results = vec![]; + if let Some(row_list) = values.row_list() + && let Some(first_row) = row_list.rows().next() + { + for (idx, _expr) in first_row.exprs().enumerate() { + results.push(Name::from_string(format!("column{}", idx + 1))); + } + } + return results; + } + + let Some(cte_select) = select_from_with_query(query) else { + return vec![]; + }; + let Some(select_clause) = cte_select.select_clause() else { + return vec![]; + }; + let Some(target_list) = select_clause.target_list() else { + return vec![]; + }; + + collect_target_list_column_names(&target_list) +} + fn resolve_symbol_info( binder: &Binder, path: &ast::Path, @@ -1612,6 +2209,40 @@ fn resolve_symbol_info( None } +fn collect_column_names_from_column_list(column_list: &ast::ColumnList) -> Vec { + let mut columns = vec![]; + for column in column_list.columns() { + if let Some(name) = column.name() { + columns.push(Name::from_node(&name)); + } + } + columns +} + +fn collect_target_list_column_names(target_list: &ast::TargetList) -> Vec { + let mut columns = vec![]; + for target in target_list.targets() { + if let Some((col_name, _node)) = ColumnName::from_target(target) + && let Some(col_name_str) = col_name.to_string() + { + columns.push(Name::from_string(col_name_str)); + } + } + columns +} + +fn select_from_view_query(create_view: &ast::CreateView) -> Option { + let query = create_view.query()?; + match query { + ast::SelectVariant::Select(select) => Some(select), + ast::SelectVariant::ParenSelect(paren_select) => match paren_select.select()? { + ast::SelectVariant::Select(select) => Some(select), + _ => None, + }, + _ => None, + } +} + fn extract_table_name_from_path(path: &ast::Path) -> Option { let segment = path.segment()?; if let Some(name_ref) = segment.name_ref() { @@ -1624,10 +2255,14 @@ fn extract_table_name_from_path(path: &ast::Path) -> Option { } fn extract_schema_from_path(path: &ast::Path) -> Option { - path.qualifier() - .and_then(|q| q.segment()) - .and_then(|s| s.name_ref()) - .map(|name_ref| name_ref.syntax().text().to_string()) + let segment = path.qualifier().and_then(|q| q.segment())?; + if let Some(name_ref) = segment.name_ref() { + return Some(name_ref.syntax().text().to_string()); + } + if let Some(name) = segment.name() { + return Some(name.syntax().text().to_string()); + } + None } fn extract_param_signature(node: &impl ast::HasParamList) -> Option> { @@ -1660,19 +2295,25 @@ fn unwrap_paren_expr(expr: ast::Expr) -> Option { None } -fn resolve_composite_type_field(binder: &Binder, name_ref: &ast::NameRef) -> Option { +fn resolve_composite_type_field( + binder: &Binder, + root: &SyntaxNode, + name_ref: &ast::NameRef, +) -> Option { let field_name = Name::from_node(name_ref); let field_expr = name_ref.syntax().parent().and_then(ast::FieldExpr::cast)?; let base = field_expr.base()?; let base_name_ref = unwrap_paren_expr(base)?; - let root = &name_ref.syntax().ancestors().last()?; + + let column_ptr = resolve_select_column(binder, root, &base_name_ref)?; + let column_node = column_ptr.to_node(root); let (type_name, schema) = - if let Some(type_info) = resolve_composite_type_from_column(binder, &base_name_ref, root) { + if let Some(type_info) = resolve_composite_type_from_column_node(&column_node) { type_info } else { - resolve_composite_type_from_cast(binder, &base_name_ref, root)? + resolve_composite_type_from_cast_node(&column_node)? }; let position = name_ref.syntax().text_range().start(); @@ -1693,25 +2334,17 @@ fn resolve_composite_type_field(binder: &Binder, name_ref: &ast::NameRef) -> Opt None } -fn resolve_composite_type_from_column( - binder: &Binder, - base_name_ref: &ast::NameRef, - root: &SyntaxNode, +fn resolve_composite_type_from_column_node( + column_node: &SyntaxNode, ) -> Option<(Name, Option)> { - let column_ptr = resolve_select_column(binder, base_name_ref)?; - let column_node = column_ptr.to_node(root); let column = column_node.ancestors().find_map(ast::Column::cast)?; let ty = column.ty()?; extract_type_name_and_schema(&ty) } -fn resolve_composite_type_from_cast( - binder: &Binder, - base_name_ref: &ast::NameRef, - root: &SyntaxNode, +fn resolve_composite_type_from_cast_node( + column_node: &SyntaxNode, ) -> Option<(Name, Option)> { - let column_ptr = resolve_select_column(binder, base_name_ref)?; - let column_node = column_ptr.to_node(root); let target = column_node.ancestors().find_map(ast::Target::cast)?; let ast::Expr::CastExpr(cast_expr) = target.expr()? else { return None;