|
| 1 | +use rowan::TextRange; |
| 2 | +use squawk_syntax::ast::{self, AstNode}; |
| 3 | + |
| 4 | +use crate::binder; |
| 5 | +use crate::resolve::{resolve_function_info, resolve_table_info}; |
| 6 | + |
| 7 | +pub enum DocumentSymbolKind { |
| 8 | + Table, |
| 9 | + Function, |
| 10 | +} |
| 11 | + |
| 12 | +pub struct DocumentSymbol { |
| 13 | + pub name: String, |
| 14 | + pub detail: Option<String>, |
| 15 | + pub kind: DocumentSymbolKind, |
| 16 | + pub range: TextRange, |
| 17 | + pub selection_range: TextRange, |
| 18 | +} |
| 19 | + |
| 20 | +pub fn document_symbols(file: &ast::SourceFile) -> Vec<DocumentSymbol> { |
| 21 | + let binder = binder::bind(file); |
| 22 | + let mut symbols = vec![]; |
| 23 | + |
| 24 | + for stmt in file.stmts() { |
| 25 | + match stmt { |
| 26 | + ast::Stmt::CreateTable(create_table) => { |
| 27 | + if let Some(symbol) = create_table_symbol(&binder, create_table) { |
| 28 | + symbols.push(symbol); |
| 29 | + } |
| 30 | + } |
| 31 | + ast::Stmt::CreateFunction(create_function) => { |
| 32 | + if let Some(symbol) = create_function_symbol(&binder, create_function) { |
| 33 | + symbols.push(symbol); |
| 34 | + } |
| 35 | + } |
| 36 | + _ => {} |
| 37 | + } |
| 38 | + } |
| 39 | + |
| 40 | + symbols |
| 41 | +} |
| 42 | + |
| 43 | +fn create_table_symbol( |
| 44 | + binder: &binder::Binder, |
| 45 | + create_table: ast::CreateTable, |
| 46 | +) -> Option<DocumentSymbol> { |
| 47 | + let path = create_table.path()?; |
| 48 | + let segment = path.segment()?; |
| 49 | + let name_node = segment.name()?; |
| 50 | + |
| 51 | + let (schema, table_name) = resolve_table_info(binder, &path)?; |
| 52 | + let name = format!("{}.{}", schema.0, table_name); |
| 53 | + |
| 54 | + let range = create_table.syntax().text_range(); |
| 55 | + let selection_range = name_node.syntax().text_range(); |
| 56 | + |
| 57 | + Some(DocumentSymbol { |
| 58 | + name, |
| 59 | + detail: None, |
| 60 | + kind: DocumentSymbolKind::Table, |
| 61 | + range, |
| 62 | + selection_range, |
| 63 | + }) |
| 64 | +} |
| 65 | + |
| 66 | +fn create_function_symbol( |
| 67 | + binder: &binder::Binder, |
| 68 | + create_function: ast::CreateFunction, |
| 69 | +) -> Option<DocumentSymbol> { |
| 70 | + let path = create_function.path()?; |
| 71 | + let segment = path.segment()?; |
| 72 | + let name_node = segment.name()?; |
| 73 | + |
| 74 | + let (schema, function_name) = resolve_function_info(binder, &path)?; |
| 75 | + let name = format!("{}.{}", schema.0, function_name); |
| 76 | + |
| 77 | + let range = create_function.syntax().text_range(); |
| 78 | + let selection_range = name_node.syntax().text_range(); |
| 79 | + |
| 80 | + Some(DocumentSymbol { |
| 81 | + name, |
| 82 | + detail: None, |
| 83 | + kind: DocumentSymbolKind::Function, |
| 84 | + range, |
| 85 | + selection_range, |
| 86 | + }) |
| 87 | +} |
| 88 | + |
| 89 | +#[cfg(test)] |
| 90 | +mod tests { |
| 91 | + use super::*; |
| 92 | + use annotate_snippets::{AnnotationKind, Level, Renderer, Snippet, renderer::DecorStyle}; |
| 93 | + use insta::assert_snapshot; |
| 94 | + |
| 95 | + fn symbols_not_found(sql: &str) { |
| 96 | + let parse = ast::SourceFile::parse(sql); |
| 97 | + let file = parse.tree(); |
| 98 | + let symbols = document_symbols(&file); |
| 99 | + if !symbols.is_empty() { |
| 100 | + panic!("Symbols found. If this is expected, use `symbols` instead.") |
| 101 | + } |
| 102 | + } |
| 103 | + |
| 104 | + fn symbols(sql: &str) -> String { |
| 105 | + let parse = ast::SourceFile::parse(sql); |
| 106 | + let file = parse.tree(); |
| 107 | + let symbols = document_symbols(&file); |
| 108 | + if symbols.is_empty() { |
| 109 | + panic!("No symbols found. If this is expected, use `symbols_not_found` instead.") |
| 110 | + } |
| 111 | + |
| 112 | + let mut groups = vec![]; |
| 113 | + for symbol in symbols { |
| 114 | + let kind = match symbol.kind { |
| 115 | + DocumentSymbolKind::Table => "table", |
| 116 | + DocumentSymbolKind::Function => "function", |
| 117 | + }; |
| 118 | + let title = format!("{}: {}", kind, symbol.name); |
| 119 | + let group = Level::INFO.primary_title(title).element( |
| 120 | + Snippet::source(sql) |
| 121 | + .fold(true) |
| 122 | + .annotation( |
| 123 | + AnnotationKind::Primary |
| 124 | + .span(symbol.selection_range.into()) |
| 125 | + .label("name"), |
| 126 | + ) |
| 127 | + .annotation( |
| 128 | + AnnotationKind::Context |
| 129 | + .span(symbol.range.into()) |
| 130 | + .label("select range"), |
| 131 | + ), |
| 132 | + ); |
| 133 | + groups.push(group); |
| 134 | + } |
| 135 | + |
| 136 | + let renderer = Renderer::plain().decor_style(DecorStyle::Unicode); |
| 137 | + renderer.render(&groups).to_string() |
| 138 | + } |
| 139 | + |
| 140 | + #[test] |
| 141 | + fn create_table() { |
| 142 | + assert_snapshot!(symbols("create table users (id int);"), @r" |
| 143 | + info: table: public.users |
| 144 | + ╭▸ |
| 145 | + 1 │ create table users (id int); |
| 146 | + │ ┬────────────┯━━━━───────── |
| 147 | + │ │ │ |
| 148 | + │ │ name |
| 149 | + ╰╴select range |
| 150 | + "); |
| 151 | + } |
| 152 | + |
| 153 | + #[test] |
| 154 | + fn create_function() { |
| 155 | + assert_snapshot!( |
| 156 | + symbols("create function hello() returns void as $$ select 1; $$ language sql;"), |
| 157 | + @r" |
| 158 | + info: function: public.hello |
| 159 | + ╭▸ |
| 160 | + 1 │ create function hello() returns void as $$ select 1; $$ language sql; |
| 161 | + │ ┬───────────────┯━━━━─────────────────────────────────────────────── |
| 162 | + │ │ │ |
| 163 | + │ │ name |
| 164 | + ╰╴select range |
| 165 | + " |
| 166 | + ); |
| 167 | + } |
| 168 | + |
| 169 | + #[test] |
| 170 | + fn multiple_symbols() { |
| 171 | + assert_snapshot!(symbols(" |
| 172 | +create table users (id int); |
| 173 | +create table posts (id int); |
| 174 | +create function get_user(user_id int) returns void as $$ select 1; $$ language sql; |
| 175 | +"), @r" |
| 176 | + info: table: public.users |
| 177 | + ╭▸ |
| 178 | + 2 │ create table users (id int); |
| 179 | + │ ┬────────────┯━━━━───────── |
| 180 | + │ │ │ |
| 181 | + │ │ name |
| 182 | + │ select range |
| 183 | + ╰╴ |
| 184 | + info: table: public.posts |
| 185 | + ╭▸ |
| 186 | + 3 │ create table posts (id int); |
| 187 | + │ ┬────────────┯━━━━───────── |
| 188 | + │ │ │ |
| 189 | + │ │ name |
| 190 | + ╰╴select range |
| 191 | + info: function: public.get_user |
| 192 | + ╭▸ |
| 193 | + 4 │ create function get_user(user_id int) returns void as $$ select 1; $$ language sql; |
| 194 | + │ ┬───────────────┯━━━━━━━────────────────────────────────────────────────────────── |
| 195 | + │ │ │ |
| 196 | + │ │ name |
| 197 | + ╰╴select range |
| 198 | + "); |
| 199 | + } |
| 200 | + |
| 201 | + #[test] |
| 202 | + fn qualified_names() { |
| 203 | + assert_snapshot!(symbols(" |
| 204 | +create table public.users (id int); |
| 205 | +create function my_schema.hello() returns void as $$ select 1; $$ language sql; |
| 206 | +"), @r" |
| 207 | + info: table: public.users |
| 208 | + ╭▸ |
| 209 | + 2 │ create table public.users (id int); |
| 210 | + │ ┬───────────────────┯━━━━───────── |
| 211 | + │ │ │ |
| 212 | + │ │ name |
| 213 | + │ select range |
| 214 | + ╰╴ |
| 215 | + info: function: my_schema.hello |
| 216 | + ╭▸ |
| 217 | + 3 │ create function my_schema.hello() returns void as $$ select 1; $$ language sql; |
| 218 | + │ ┬─────────────────────────┯━━━━─────────────────────────────────────────────── |
| 219 | + │ │ │ |
| 220 | + │ │ name |
| 221 | + ╰╴select range |
| 222 | + "); |
| 223 | + } |
| 224 | + |
| 225 | + #[test] |
| 226 | + fn empty_file() { |
| 227 | + symbols_not_found("") |
| 228 | + } |
| 229 | + |
| 230 | + #[test] |
| 231 | + fn non_create_statements() { |
| 232 | + symbols_not_found("select * from users;") |
| 233 | + } |
| 234 | +} |
0 commit comments