Skip to content

Commit 98994e3

Browse files
authored
ide: add document symbols (#777)
1 parent f858af2 commit 98994e3

File tree

4 files changed

+314
-14
lines changed

4 files changed

+314
-14
lines changed
Lines changed: 234 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,234 @@
1+
use rowan::TextRange;
2+
use squawk_syntax::ast::{self, AstNode};
3+
4+
use crate::binder;
5+
use crate::resolve::{resolve_function_info, resolve_table_info};
6+
7+
pub enum DocumentSymbolKind {
8+
Table,
9+
Function,
10+
}
11+
12+
pub struct DocumentSymbol {
13+
pub name: String,
14+
pub detail: Option<String>,
15+
pub kind: DocumentSymbolKind,
16+
pub range: TextRange,
17+
pub selection_range: TextRange,
18+
}
19+
20+
pub fn document_symbols(file: &ast::SourceFile) -> Vec<DocumentSymbol> {
21+
let binder = binder::bind(file);
22+
let mut symbols = vec![];
23+
24+
for stmt in file.stmts() {
25+
match stmt {
26+
ast::Stmt::CreateTable(create_table) => {
27+
if let Some(symbol) = create_table_symbol(&binder, create_table) {
28+
symbols.push(symbol);
29+
}
30+
}
31+
ast::Stmt::CreateFunction(create_function) => {
32+
if let Some(symbol) = create_function_symbol(&binder, create_function) {
33+
symbols.push(symbol);
34+
}
35+
}
36+
_ => {}
37+
}
38+
}
39+
40+
symbols
41+
}
42+
43+
fn create_table_symbol(
44+
binder: &binder::Binder,
45+
create_table: ast::CreateTable,
46+
) -> Option<DocumentSymbol> {
47+
let path = create_table.path()?;
48+
let segment = path.segment()?;
49+
let name_node = segment.name()?;
50+
51+
let (schema, table_name) = resolve_table_info(binder, &path)?;
52+
let name = format!("{}.{}", schema.0, table_name);
53+
54+
let range = create_table.syntax().text_range();
55+
let selection_range = name_node.syntax().text_range();
56+
57+
Some(DocumentSymbol {
58+
name,
59+
detail: None,
60+
kind: DocumentSymbolKind::Table,
61+
range,
62+
selection_range,
63+
})
64+
}
65+
66+
fn create_function_symbol(
67+
binder: &binder::Binder,
68+
create_function: ast::CreateFunction,
69+
) -> Option<DocumentSymbol> {
70+
let path = create_function.path()?;
71+
let segment = path.segment()?;
72+
let name_node = segment.name()?;
73+
74+
let (schema, function_name) = resolve_function_info(binder, &path)?;
75+
let name = format!("{}.{}", schema.0, function_name);
76+
77+
let range = create_function.syntax().text_range();
78+
let selection_range = name_node.syntax().text_range();
79+
80+
Some(DocumentSymbol {
81+
name,
82+
detail: None,
83+
kind: DocumentSymbolKind::Function,
84+
range,
85+
selection_range,
86+
})
87+
}
88+
89+
#[cfg(test)]
90+
mod tests {
91+
use super::*;
92+
use annotate_snippets::{AnnotationKind, Level, Renderer, Snippet, renderer::DecorStyle};
93+
use insta::assert_snapshot;
94+
95+
fn symbols_not_found(sql: &str) {
96+
let parse = ast::SourceFile::parse(sql);
97+
let file = parse.tree();
98+
let symbols = document_symbols(&file);
99+
if !symbols.is_empty() {
100+
panic!("Symbols found. If this is expected, use `symbols` instead.")
101+
}
102+
}
103+
104+
fn symbols(sql: &str) -> String {
105+
let parse = ast::SourceFile::parse(sql);
106+
let file = parse.tree();
107+
let symbols = document_symbols(&file);
108+
if symbols.is_empty() {
109+
panic!("No symbols found. If this is expected, use `symbols_not_found` instead.")
110+
}
111+
112+
let mut groups = vec![];
113+
for symbol in symbols {
114+
let kind = match symbol.kind {
115+
DocumentSymbolKind::Table => "table",
116+
DocumentSymbolKind::Function => "function",
117+
};
118+
let title = format!("{}: {}", kind, symbol.name);
119+
let group = Level::INFO.primary_title(title).element(
120+
Snippet::source(sql)
121+
.fold(true)
122+
.annotation(
123+
AnnotationKind::Primary
124+
.span(symbol.selection_range.into())
125+
.label("name"),
126+
)
127+
.annotation(
128+
AnnotationKind::Context
129+
.span(symbol.range.into())
130+
.label("select range"),
131+
),
132+
);
133+
groups.push(group);
134+
}
135+
136+
let renderer = Renderer::plain().decor_style(DecorStyle::Unicode);
137+
renderer.render(&groups).to_string()
138+
}
139+
140+
#[test]
141+
fn create_table() {
142+
assert_snapshot!(symbols("create table users (id int);"), @r"
143+
info: table: public.users
144+
╭▸
145+
1 │ create table users (id int);
146+
│ ┬────────────┯━━━━─────────
147+
│ │ │
148+
│ │ name
149+
╰╴select range
150+
");
151+
}
152+
153+
#[test]
154+
fn create_function() {
155+
assert_snapshot!(
156+
symbols("create function hello() returns void as $$ select 1; $$ language sql;"),
157+
@r"
158+
info: function: public.hello
159+
╭▸
160+
1 │ create function hello() returns void as $$ select 1; $$ language sql;
161+
│ ┬───────────────┯━━━━───────────────────────────────────────────────
162+
│ │ │
163+
│ │ name
164+
╰╴select range
165+
"
166+
);
167+
}
168+
169+
#[test]
170+
fn multiple_symbols() {
171+
assert_snapshot!(symbols("
172+
create table users (id int);
173+
create table posts (id int);
174+
create function get_user(user_id int) returns void as $$ select 1; $$ language sql;
175+
"), @r"
176+
info: table: public.users
177+
╭▸
178+
2 │ create table users (id int);
179+
│ ┬────────────┯━━━━─────────
180+
│ │ │
181+
│ │ name
182+
│ select range
183+
╰╴
184+
info: table: public.posts
185+
╭▸
186+
3 │ create table posts (id int);
187+
│ ┬────────────┯━━━━─────────
188+
│ │ │
189+
│ │ name
190+
╰╴select range
191+
info: function: public.get_user
192+
╭▸
193+
4 │ create function get_user(user_id int) returns void as $$ select 1; $$ language sql;
194+
│ ┬───────────────┯━━━━━━━──────────────────────────────────────────────────────────
195+
│ │ │
196+
│ │ name
197+
╰╴select range
198+
");
199+
}
200+
201+
#[test]
202+
fn qualified_names() {
203+
assert_snapshot!(symbols("
204+
create table public.users (id int);
205+
create function my_schema.hello() returns void as $$ select 1; $$ language sql;
206+
"), @r"
207+
info: table: public.users
208+
╭▸
209+
2 │ create table public.users (id int);
210+
│ ┬───────────────────┯━━━━─────────
211+
│ │ │
212+
│ │ name
213+
│ select range
214+
╰╴
215+
info: function: my_schema.hello
216+
╭▸
217+
3 │ create function my_schema.hello() returns void as $$ select 1; $$ language sql;
218+
│ ┬─────────────────────────┯━━━━───────────────────────────────────────────────
219+
│ │ │
220+
│ │ name
221+
╰╴select range
222+
");
223+
}
224+
225+
#[test]
226+
fn empty_file() {
227+
symbols_not_found("")
228+
}
229+
230+
#[test]
231+
fn non_create_statements() {
232+
symbols_not_found("select * from users;")
233+
}
234+
}

crates/squawk_ide/src/lib.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
mod binder;
22
pub mod code_actions;
33
pub mod column_name;
4+
pub mod document_symbols;
45
pub mod expand_selection;
56
pub mod find_references;
67
mod generated;

crates/squawk_ide/src/resolve.rs

Lines changed: 19 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -779,30 +779,42 @@ pub(crate) fn resolve_insert_table_columns(
779779
}
780780

781781
pub(crate) fn resolve_table_info(binder: &Binder, path: &ast::Path) -> Option<(Schema, String)> {
782-
let table_name_str = extract_table_name_from_path(path)?;
782+
resolve_symbol_info(binder, path, SymbolKind::Table)
783+
}
784+
785+
pub(crate) fn resolve_function_info(binder: &Binder, path: &ast::Path) -> Option<(Schema, String)> {
786+
resolve_symbol_info(binder, path, SymbolKind::Function)
787+
}
788+
789+
fn resolve_symbol_info(
790+
binder: &Binder,
791+
path: &ast::Path,
792+
kind: SymbolKind,
793+
) -> Option<(Schema, String)> {
794+
let name_str = extract_table_name_from_path(path)?;
783795
let schema = extract_schema_from_path(path);
784796

785-
let table_name_normalized = Name::new(table_name_str.clone());
786-
let symbols = binder.scopes[binder.root_scope()].get(&table_name_normalized)?;
797+
let name_normalized = Name::new(name_str.clone());
798+
let symbols = binder.scopes[binder.root_scope()].get(&name_normalized)?;
787799

788800
if let Some(schema_name) = schema {
789801
let schema_normalized = Schema::new(schema_name);
790802
let symbol_id = symbols.iter().copied().find(|id| {
791803
let symbol = &binder.symbols[*id];
792-
symbol.kind == SymbolKind::Table && symbol.schema == schema_normalized
804+
symbol.kind == kind && symbol.schema == schema_normalized
793805
})?;
794806
let symbol = &binder.symbols[symbol_id];
795-
return Some((symbol.schema.clone(), table_name_str));
807+
return Some((symbol.schema.clone(), name_str));
796808
} else {
797809
let position = path.syntax().text_range().start();
798810
let search_path = binder.search_path_at(position);
799811
for search_schema in search_path {
800812
if let Some(symbol_id) = symbols.iter().copied().find(|id| {
801813
let symbol = &binder.symbols[*id];
802-
symbol.kind == SymbolKind::Table && &symbol.schema == search_schema
814+
symbol.kind == kind && &symbol.schema == search_schema
803815
}) {
804816
let symbol = &binder.symbols[symbol_id];
805-
return Some((symbol.schema.clone(), table_name_str));
817+
return Some((symbol.schema.clone(), name_str));
806818
}
807819
}
808820
}

crates/squawk_server/src/lib.rs

Lines changed: 60 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -6,23 +6,24 @@ use lsp_types::{
66
CodeAction, CodeActionKind, CodeActionOptions, CodeActionOrCommand, CodeActionParams,
77
CodeActionProviderCapability, CodeActionResponse, Command, Diagnostic,
88
DidChangeTextDocumentParams, DidCloseTextDocumentParams, DidOpenTextDocumentParams,
9-
GotoDefinitionParams, GotoDefinitionResponse, Hover, HoverContents, HoverParams,
10-
HoverProviderCapability, InitializeParams, InlayHint, InlayHintKind, InlayHintLabel,
11-
InlayHintLabelPart, InlayHintParams, LanguageString, Location, MarkedString, OneOf,
12-
PublishDiagnosticsParams, ReferenceParams, SelectionRangeParams,
13-
SelectionRangeProviderCapability, ServerCapabilities, TextDocumentSyncCapability,
9+
DocumentSymbol, DocumentSymbolParams, GotoDefinitionParams, GotoDefinitionResponse, Hover,
10+
HoverContents, HoverParams, HoverProviderCapability, InitializeParams, InlayHint,
11+
InlayHintKind, InlayHintLabel, InlayHintLabelPart, InlayHintParams, LanguageString, Location,
12+
MarkedString, OneOf, PublishDiagnosticsParams, ReferenceParams, SelectionRangeParams,
13+
SelectionRangeProviderCapability, ServerCapabilities, SymbolKind, TextDocumentSyncCapability,
1414
TextDocumentSyncKind, Url, WorkDoneProgressOptions, WorkspaceEdit,
1515
notification::{
1616
DidChangeTextDocument, DidCloseTextDocument, DidOpenTextDocument, Notification as _,
1717
PublishDiagnostics,
1818
},
1919
request::{
20-
CodeActionRequest, GotoDefinition, HoverRequest, InlayHintRequest, References, Request,
21-
SelectionRangeRequest,
20+
CodeActionRequest, DocumentSymbolRequest, GotoDefinition, HoverRequest, InlayHintRequest,
21+
References, Request, SelectionRangeRequest,
2222
},
2323
};
2424
use rowan::TextRange;
2525
use squawk_ide::code_actions::code_actions;
26+
use squawk_ide::document_symbols::{DocumentSymbolKind, document_symbols};
2627
use squawk_ide::find_references::find_references;
2728
use squawk_ide::goto_definition::goto_definition;
2829
use squawk_ide::hover::hover;
@@ -67,6 +68,7 @@ pub fn run() -> Result<()> {
6768
definition_provider: Some(OneOf::Left(true)),
6869
hover_provider: Some(HoverProviderCapability::Simple(true)),
6970
inlay_hint_provider: Some(OneOf::Left(true)),
71+
document_symbol_provider: Some(OneOf::Left(true)),
7072
..Default::default()
7173
})
7274
.unwrap();
@@ -119,6 +121,9 @@ fn main_loop(connection: Connection, params: serde_json::Value) -> Result<()> {
119121
InlayHintRequest::METHOD => {
120122
handle_inlay_hints(&connection, req, &documents)?;
121123
}
124+
DocumentSymbolRequest::METHOD => {
125+
handle_document_symbol(&connection, req, &documents)?;
126+
}
122127
"squawk/syntaxTree" => {
123128
handle_syntax_tree(&connection, req, &documents)?;
124129
}
@@ -296,6 +301,54 @@ fn handle_inlay_hints(
296301
Ok(())
297302
}
298303

304+
fn handle_document_symbol(
305+
connection: &Connection,
306+
req: lsp_server::Request,
307+
documents: &HashMap<Url, DocumentState>,
308+
) -> Result<()> {
309+
let params: DocumentSymbolParams = serde_json::from_value(req.params)?;
310+
let uri = params.text_document.uri;
311+
312+
let content = documents.get(&uri).map_or("", |doc| &doc.content);
313+
let parse = SourceFile::parse(content);
314+
let file = parse.tree();
315+
let line_index = LineIndex::new(content);
316+
317+
let symbols = document_symbols(&file);
318+
319+
let lsp_symbols: Vec<DocumentSymbol> = symbols
320+
.into_iter()
321+
.map(|sym| {
322+
let range = lsp_utils::range(&line_index, sym.range);
323+
let selection_range = lsp_utils::range(&line_index, sym.selection_range);
324+
325+
DocumentSymbol {
326+
name: sym.name,
327+
detail: sym.detail,
328+
kind: match sym.kind {
329+
DocumentSymbolKind::Table => SymbolKind::STRUCT,
330+
DocumentSymbolKind::Function => SymbolKind::FUNCTION,
331+
},
332+
tags: None,
333+
range,
334+
selection_range,
335+
children: None,
336+
#[allow(deprecated)]
337+
deprecated: None,
338+
}
339+
})
340+
.collect();
341+
342+
let resp = Response {
343+
id: req.id,
344+
result: Some(serde_json::to_value(&lsp_symbols).unwrap()),
345+
error: None,
346+
};
347+
348+
connection.sender.send(Message::Response(resp))?;
349+
Ok(())
350+
}
351+
299352
fn handle_selection_range(
300353
connection: &Connection,
301354
req: lsp_server::Request,

0 commit comments

Comments
 (0)