diff --git a/crates/squawk_ide/src/binder.rs b/crates/squawk_ide/src/binder.rs index 242e87bc..0053e1a1 100644 --- a/crates/squawk_ide/src/binder.rs +++ b/crates/squawk_ide/src/binder.rs @@ -85,6 +85,9 @@ fn bind_stmt(b: &mut Binder, stmt: ast::Stmt) { ast::Stmt::CreateSchema(create_schema) => bind_create_schema(b, create_schema), ast::Stmt::CreateType(create_type) => bind_create_type(b, create_type), ast::Stmt::CreateView(create_view) => bind_create_view(b, create_view), + ast::Stmt::CreateMaterializedView(create_view) => { + bind_create_materialized_view(b, create_view) + } ast::Stmt::Set(set) => bind_set(b, set), _ => {} } @@ -222,13 +225,22 @@ fn bind_create_procedure(b: &mut Binder, create_procedure: ast::CreateProcedure) } fn bind_create_schema(b: &mut Binder, create_schema: ast::CreateSchema) { - let Some(schema_name_node) = create_schema.name() else { + let (schema_name, name_ptr) = if let Some(schema_name_node) = create_schema.name() { + let schema_name = Name::from_node(&schema_name_node); + let name_ptr = SyntaxNodePtr::new(schema_name_node.syntax()); + (schema_name, name_ptr) + } else if let Some(schema_name_ref) = create_schema + .schema_authorization() + .and_then(|authorization| authorization.role()) + .and_then(|role| role.name_ref()) + { + let schema_name = Name::from_node(&schema_name_ref); + let name_ptr = SyntaxNodePtr::new(schema_name_ref.syntax()); + (schema_name, name_ptr) + } else { return; }; - let schema_name = Name::from_node(&schema_name_node); - let name_ptr = SyntaxNodePtr::new(schema_name_node.syntax()); - let schema_id = b.symbols.alloc(Symbol { kind: SymbolKind::Schema, ptr: name_ptr, @@ -293,6 +305,32 @@ fn bind_create_view(b: &mut Binder, create_view: ast::CreateView) { b.scopes[root].insert(view_name, view_id); } +fn bind_create_materialized_view(b: &mut Binder, create_view: ast::CreateMaterializedView) { + let Some(path) = create_view.path() else { + return; + }; + + let Some(view_name) = item_name(&path) else { + return; + }; + + let name_ptr = path_to_ptr(&path); + + let Some(schema) = schema_name(b, &path, false) else { + return; + }; + + let view_id = b.symbols.alloc(Symbol { + kind: SymbolKind::View, + ptr: name_ptr, + schema, + params: None, + }); + + let root = b.root_scope(); + b.scopes[root].insert(view_name, view_id); +} + fn item_name(path: &ast::Path) -> Option { let segment = path.segment()?; diff --git a/crates/squawk_ide/src/classify.rs b/crates/squawk_ide/src/classify.rs index c37b9bb7..f0b78ca3 100644 --- a/crates/squawk_ide/src/classify.rs +++ b/crates/squawk_ide/src/classify.rs @@ -7,12 +7,14 @@ pub(crate) enum NameRefClass { DropIndex, DropType, DropView, + DropMaterializedView, DropFunction, DropAggregate, DropProcedure, DropRoutine, CallProcedure, DropSchema, + CreateSchema, CreateIndex, CreateIndexColumn, SelectFunctionCall, @@ -139,6 +141,7 @@ pub(crate) fn classify_name_ref(name_ref: &ast::NameRef) -> Option } let mut in_type = false; + let mut in_schema_authorization = false; for ancestor in name_ref.syntax().ancestors() { if ast::PathType::can_cast(ancestor.kind()) || ast::ExprType::can_cast(ancestor.kind()) { in_type = true; @@ -146,6 +149,9 @@ pub(crate) fn classify_name_ref(name_ref: &ast::NameRef) -> Option if in_type { return Some(NameRefClass::TypeReference); } + if ast::SchemaAuthorization::can_cast(ancestor.kind()) { + in_schema_authorization = true; + } if ast::DropTable::can_cast(ancestor.kind()) { return Some(NameRefClass::DropTable); } @@ -161,6 +167,9 @@ pub(crate) fn classify_name_ref(name_ref: &ast::NameRef) -> Option if ast::DropView::can_cast(ancestor.kind()) { return Some(NameRefClass::DropView); } + if ast::DropMaterializedView::can_cast(ancestor.kind()) { + return Some(NameRefClass::DropMaterializedView); + } if ast::CastExpr::can_cast(ancestor.kind()) && in_type { return Some(NameRefClass::TypeReference); } @@ -182,6 +191,12 @@ pub(crate) fn classify_name_ref(name_ref: &ast::NameRef) -> Option if ast::DropSchema::can_cast(ancestor.kind()) { return Some(NameRefClass::DropSchema); } + if in_schema_authorization + && let Some(create_schema) = ast::CreateSchema::cast(ancestor.clone()) + && create_schema.name().is_none() + { + return Some(NameRefClass::CreateSchema); + } if ast::PartitionItem::can_cast(ancestor.kind()) { in_partition_item = true; } diff --git a/crates/squawk_ide/src/document_symbols.rs b/crates/squawk_ide/src/document_symbols.rs index 6ef6b45b..615a5c96 100644 --- a/crates/squawk_ide/src/document_symbols.rs +++ b/crates/squawk_ide/src/document_symbols.rs @@ -3,14 +3,18 @@ use squawk_syntax::ast::{self, AstNode}; use crate::binder::{self, extract_string_literal}; use crate::resolve::{ - resolve_function_info, resolve_table_info, resolve_type_info, resolve_view_info, + resolve_aggregate_info, resolve_function_info, resolve_materialized_view_info, + resolve_procedure_info, resolve_table_info, resolve_type_info, resolve_view_info, }; #[derive(Debug)] pub enum DocumentSymbolKind { Table, View, + MaterializedView, Function, + Aggregate, + Procedure, Type, Enum, Column, @@ -46,6 +50,16 @@ pub fn document_symbols(file: &ast::SourceFile) -> Vec { symbols.push(symbol); } } + ast::Stmt::CreateAggregate(create_aggregate) => { + if let Some(symbol) = create_aggregate_symbol(&binder, create_aggregate) { + symbols.push(symbol); + } + } + ast::Stmt::CreateProcedure(create_procedure) => { + if let Some(symbol) = create_procedure_symbol(&binder, create_procedure) { + symbols.push(symbol); + } + } ast::Stmt::CreateType(create_type) => { if let Some(symbol) = create_type_symbol(&binder, create_type) { symbols.push(symbol); @@ -56,6 +70,11 @@ pub fn document_symbols(file: &ast::SourceFile) -> Vec { symbols.push(symbol); } } + ast::Stmt::CreateMaterializedView(create_view) => { + if let Some(symbol) = create_materialized_view_symbol(&binder, create_view) { + symbols.push(symbol); + } + } ast::Stmt::Select(select) => { symbols.extend(cte_table_symbols(select)); } @@ -184,6 +203,39 @@ fn create_view_symbol( }) } +fn create_materialized_view_symbol( + binder: &binder::Binder, + create_view: ast::CreateMaterializedView, +) -> Option { + let path = create_view.path()?; + let segment = path.segment()?; + let name_node = segment.name()?; + + let (schema, view_name) = resolve_materialized_view_info(binder, &path)?; + let name = format!("{}.{}", schema.0, view_name); + + let full_range = create_view.syntax().text_range(); + let focus_range = name_node.syntax().text_range(); + + let mut children = vec![]; + if let Some(column_list) = create_view.column_list() { + for column in column_list.columns() { + if let Some(column_symbol) = create_column_symbol(column) { + children.push(column_symbol); + } + } + } + + Some(DocumentSymbol { + name, + detail: None, + kind: DocumentSymbolKind::MaterializedView, + full_range, + focus_range, + children, + }) +} + fn create_function_symbol( binder: &binder::Binder, create_function: ast::CreateFunction, @@ -208,6 +260,54 @@ fn create_function_symbol( }) } +fn create_aggregate_symbol( + binder: &binder::Binder, + create_aggregate: ast::CreateAggregate, +) -> Option { + let path = create_aggregate.path()?; + let segment = path.segment()?; + let name_node = segment.name()?; + + let (schema, aggregate_name) = resolve_aggregate_info(binder, &path)?; + let name = format!("{}.{}", schema.0, aggregate_name); + + let full_range = create_aggregate.syntax().text_range(); + let focus_range = name_node.syntax().text_range(); + + Some(DocumentSymbol { + name, + detail: None, + kind: DocumentSymbolKind::Aggregate, + full_range, + focus_range, + children: vec![], + }) +} + +fn create_procedure_symbol( + binder: &binder::Binder, + create_procedure: ast::CreateProcedure, +) -> Option { + let path = create_procedure.path()?; + let segment = path.segment()?; + let name_node = segment.name()?; + + let (schema, procedure_name) = resolve_procedure_info(binder, &path)?; + let name = format!("{}.{}", schema.0, procedure_name); + + let full_range = create_procedure.syntax().text_range(); + let focus_range = name_node.syntax().text_range(); + + Some(DocumentSymbol { + name, + detail: None, + kind: DocumentSymbolKind::Procedure, + full_range, + focus_range, + children: vec![], + }) +} + fn create_type_symbol( binder: &binder::Binder, create_type: ast::CreateType, @@ -327,7 +427,10 @@ mod tests { let kind = match symbol.kind { DocumentSymbolKind::Table => "table", DocumentSymbolKind::View => "view", + DocumentSymbolKind::MaterializedView => "materialized view", DocumentSymbolKind::Function => "function", + DocumentSymbolKind::Aggregate => "aggregate", + DocumentSymbolKind::Procedure => "procedure", DocumentSymbolKind::Type => "type", DocumentSymbolKind::Enum => "enum", DocumentSymbolKind::Column => "column", @@ -443,6 +546,54 @@ create table users ( ); } + #[test] + fn create_materialized_view() { + assert_snapshot!( + symbols("create materialized view reports as select 1;"), + @r" + info: materialized view: public.reports + ╭▸ + 1 │ create materialized view reports as select 1; + │ ┬────────────────────────┯━━━━━━──────────── + │ │ │ + │ │ focus range + ╰╴full range + " + ); + } + + #[test] + fn create_aggregate() { + assert_snapshot!( + symbols("create aggregate myavg(int) (sfunc = int4_avg_accum, stype = _int8);"), + @r" + info: aggregate: public.myavg + ╭▸ + 1 │ create aggregate myavg(int) (sfunc = int4_avg_accum, stype = _int8); + │ ┬────────────────┯━━━━───────────────────────────────────────────── + │ │ │ + │ │ focus range + ╰╴full range + " + ); + } + + #[test] + fn create_procedure() { + assert_snapshot!( + symbols("create procedure hello() language sql as $$ select 1; $$;"), + @r" + info: procedure: public.hello + ╭▸ + 1 │ create procedure hello() language sql as $$ select 1; $$; + │ ┬────────────────┯━━━━────────────────────────────────── + │ │ │ + │ │ focus range + ╰╴full range + " + ); + } + #[test] fn multiple_symbols() { assert_snapshot!(symbols(" diff --git a/crates/squawk_ide/src/goto_definition.rs b/crates/squawk_ide/src/goto_definition.rs index 417d1081..dfaf76db 100644 --- a/crates/squawk_ide/src/goto_definition.rs +++ b/crates/squawk_ide/src/goto_definition.rs @@ -543,6 +543,20 @@ drop view v$0; "); } + #[test] + fn goto_drop_materialized_view() { + assert_snapshot!(goto(" +create materialized view v as select 1; +drop materialized view v$0; +"), @r" + ╭▸ + 2 │ create materialized view v as select 1; + │ ─ 2. destination + 3 │ drop materialized view v; + ╰╴ ─ 1. source + "); + } + #[test] fn goto_drop_view_with_schema() { assert_snapshot!(goto(" @@ -603,6 +617,20 @@ select * from v$0; "); } + #[test] + fn goto_select_from_materialized_view() { + assert_snapshot!(goto(" +create materialized view v as select 1; +select * from v$0; +"), @r" + ╭▸ + 2 │ create materialized view v as select 1; + │ ─ 2. destination + 3 │ select * from v; + ╰╴ ─ 1. source + "); + } + #[test] fn goto_select_from_view_with_schema() { assert_snapshot!(goto(" @@ -1557,6 +1585,26 @@ select foo$0(); "); } + #[test] + fn goto_select_aggregate_call() { + assert_snapshot!(goto(" +create aggregate foo(int) ( + sfunc = int4pl, + stype = int, + initcond = '0' +); + +select foo$0(1); +"), @r" + ╭▸ + 2 │ create aggregate foo(int) ( + │ ─── 2. destination + ‡ + 8 │ select foo(1); + ╰╴ ─ 1. source + "); + } + #[test] fn goto_select_function_call_with_schema() { assert_snapshot!(goto(" @@ -2491,6 +2539,34 @@ drop schema foo$0; "); } + #[test] + fn goto_create_schema_authorization() { + assert_snapshot!(goto(" +create schema authorization foo$0; +"), @r" + ╭▸ + 2 │ create schema authorization foo; + │ ┬─┬ + │ │ │ + │ │ 1. source + ╰╴ 2. destination + "); + } + + #[test] + fn goto_drop_schema_authorization() { + assert_snapshot!(goto(" +create schema authorization foo; +drop schema foo$0; +"), @r" + ╭▸ + 2 │ create schema authorization foo; + │ ─── 2. destination + 3 │ drop schema foo; + ╰╴ ─ 1. source + "); + } + #[test] fn goto_drop_schema_defined_after() { assert_snapshot!(goto(" diff --git a/crates/squawk_ide/src/hover.rs b/crates/squawk_ide/src/hover.rs index 4bbbf52c..d073cab7 100644 --- a/crates/squawk_ide/src/hover.rs +++ b/crates/squawk_ide/src/hover.rs @@ -40,6 +40,7 @@ pub fn hover(file: &ast::SourceFile, offset: TextSize) -> Option { NameRefClass::Table | NameRefClass::DropTable | NameRefClass::DropView + | NameRefClass::DropMaterializedView | NameRefClass::CreateIndex | NameRefClass::InsertTable | NameRefClass::DeleteTable @@ -64,7 +65,9 @@ pub fn hover(file: &ast::SourceFile, offset: TextSize) -> Option { } return hover_column(file, &name_ref, &binder); } - NameRefClass::SchemaQualifier | NameRefClass::DropSchema => { + NameRefClass::SchemaQualifier + | NameRefClass::DropSchema + | NameRefClass::CreateSchema => { return hover_schema(file, &name_ref, &binder); } } @@ -456,13 +459,27 @@ fn hover_schema( let root = file.syntax(); let schema_name_node = schema_ptr.to_node(root); - let create_schema = ast::CreateSchema::cast(schema_name_node.parent()?)?; + let create_schema = schema_name_node + .ancestors() + .find_map(ast::CreateSchema::cast)?; format_create_schema(&create_schema) } +fn create_schema_name(create_schema: &ast::CreateSchema) -> Option { + if let Some(schema_name) = create_schema.name() { + return Some(schema_name.syntax().text().to_string()); + } + + create_schema + .schema_authorization() + .and_then(|authorization| authorization.role()) + .and_then(|role| role.name_ref()) + .map(|name_ref| name_ref.syntax().text().to_string()) +} + fn format_create_schema(create_schema: &ast::CreateSchema) -> Option { - let schema_name = create_schema.name()?.syntax().text().to_string(); + let schema_name = create_schema_name(create_schema)?; Some(format!("schema {}", schema_name)) } @@ -1897,6 +1914,31 @@ create schema foo$0; "); } + #[test] + fn hover_on_create_schema_authorization() { + assert_snapshot!(check_hover(" +create schema authorization foo$0; +"), @r" + hover: schema foo + ╭▸ + 2 │ create schema authorization foo; + ╰╴ ─ hover + "); + } + + #[test] + fn hover_on_drop_schema_authorization() { + assert_snapshot!(check_hover(" +create schema authorization foo; +drop schema foo$0; +"), @r" + hover: schema foo + ╭▸ + 3 │ drop schema foo; + ╰╴ ─ hover + "); + } + #[test] fn hover_on_drop_schema() { assert_snapshot!(check_hover(" diff --git a/crates/squawk_ide/src/resolve.rs b/crates/squawk_ide/src/resolve.rs index db04495c..b7a859ad 100644 --- a/crates/squawk_ide/src/resolve.rs +++ b/crates/squawk_ide/src/resolve.rs @@ -84,7 +84,7 @@ pub(crate) fn resolve_name_ref(binder: &Binder, name_ref: &ast::NameRef) -> Opti let position = name_ref.syntax().text_range().start(); resolve_type(binder, &type_name, &schema, position) } - NameRefClass::DropView => { + NameRefClass::DropView | NameRefClass::DropMaterializedView => { let path = find_containing_path(name_ref)?; let view_name = extract_table_name(&path)?; let schema = extract_schema_name(&path); @@ -172,7 +172,7 @@ pub(crate) fn resolve_name_ref(binder: &Binder, name_ref: &ast::NameRef) -> Opti let position = name_ref.syntax().text_range().start(); resolve_procedure(binder, &procedure_name, &schema, None, position) } - NameRefClass::DropSchema | NameRefClass::SchemaQualifier => { + NameRefClass::DropSchema | NameRefClass::SchemaQualifier | NameRefClass::CreateSchema => { let schema_name = Name::from_node(name_ref); resolve_schema(binder, &schema_name) } @@ -194,6 +194,11 @@ pub(crate) fn resolve_name_ref(binder: &Binder, name_ref: &ast::NameRef) -> Opti return Some(ptr); } + // aggregates take precedence over function-call-style column access + if let Some(ptr) = resolve_aggregate(binder, &function_name, &schema, None, position) { + return Some(ptr); + } + // if no function found, check if this is function-call-style column access // ```sql // create table t(a int, b int); @@ -1266,6 +1271,20 @@ pub(crate) fn resolve_function_info(binder: &Binder, path: &ast::Path) -> Option resolve_symbol_info(binder, path, SymbolKind::Function) } +pub(crate) fn resolve_aggregate_info( + binder: &Binder, + path: &ast::Path, +) -> Option<(Schema, String)> { + resolve_symbol_info(binder, path, SymbolKind::Aggregate) +} + +pub(crate) fn resolve_procedure_info( + binder: &Binder, + path: &ast::Path, +) -> Option<(Schema, String)> { + resolve_symbol_info(binder, path, SymbolKind::Procedure) +} + pub(crate) fn resolve_type_info(binder: &Binder, path: &ast::Path) -> Option<(Schema, String)> { resolve_symbol_info(binder, path, SymbolKind::Type) } @@ -1274,6 +1293,13 @@ pub(crate) fn resolve_view_info(binder: &Binder, path: &ast::Path) -> Option<(Sc resolve_symbol_info(binder, path, SymbolKind::View) } +pub(crate) fn resolve_materialized_view_info( + binder: &Binder, + path: &ast::Path, +) -> Option<(Schema, String)> { + resolve_symbol_info(binder, path, SymbolKind::View) +} + fn resolve_symbol_info( binder: &Binder, path: &ast::Path, diff --git a/crates/squawk_server/src/lib.rs b/crates/squawk_server/src/lib.rs index 1f459fa5..916057c2 100644 --- a/crates/squawk_server/src/lib.rs +++ b/crates/squawk_server/src/lib.rs @@ -337,7 +337,10 @@ fn handle_document_symbol( kind: match sym.kind { DocumentSymbolKind::Table => SymbolKind::STRUCT, DocumentSymbolKind::View => SymbolKind::STRUCT, + DocumentSymbolKind::MaterializedView => SymbolKind::STRUCT, DocumentSymbolKind::Function => SymbolKind::FUNCTION, + DocumentSymbolKind::Aggregate => SymbolKind::FUNCTION, + DocumentSymbolKind::Procedure => SymbolKind::FUNCTION, DocumentSymbolKind::Type => SymbolKind::CLASS, DocumentSymbolKind::Enum => SymbolKind::ENUM, DocumentSymbolKind::Column => SymbolKind::FIELD, diff --git a/crates/squawk_wasm/src/lib.rs b/crates/squawk_wasm/src/lib.rs index 984677c3..982d6736 100644 --- a/crates/squawk_wasm/src/lib.rs +++ b/crates/squawk_wasm/src/lib.rs @@ -394,7 +394,12 @@ fn convert_document_symbol( kind: match symbol.kind { squawk_ide::document_symbols::DocumentSymbolKind::Table => "table", squawk_ide::document_symbols::DocumentSymbolKind::View => "view", + squawk_ide::document_symbols::DocumentSymbolKind::MaterializedView => { + "materialized_view" + } squawk_ide::document_symbols::DocumentSymbolKind::Function => "function", + squawk_ide::document_symbols::DocumentSymbolKind::Aggregate => "aggregate", + squawk_ide::document_symbols::DocumentSymbolKind::Procedure => "procedure", squawk_ide::document_symbols::DocumentSymbolKind::Type => "type", squawk_ide::document_symbols::DocumentSymbolKind::Enum => "enum", squawk_ide::document_symbols::DocumentSymbolKind::Column => "column",