diff --git a/crates/squawk_ide/src/binder.rs b/crates/squawk_ide/src/binder.rs index dde81b6b..a7694a6a 100644 --- a/crates/squawk_ide/src/binder.rs +++ b/crates/squawk_ide/src/binder.rs @@ -116,7 +116,7 @@ fn bind_create_index(b: &mut Binder, create_index: ast::CreateIndex) { return; }; - let index_name = Name::new(name.syntax().text().to_string()); + let index_name = Name::from_node(&name); let name_ptr = SyntaxNodePtr::new(name.syntax()); let Some(schema) = b.current_search_path().first().cloned() else { @@ -195,7 +195,7 @@ fn bind_create_schema(b: &mut Binder, create_schema: ast::CreateSchema) { return; }; - let schema_name = Name::new(schema_name_node.syntax().text().to_string()); + let schema_name = Name::from_node(&schema_name_node); let name_ptr = SyntaxNodePtr::new(schema_name_node.syntax()); let schema_id = b.symbols.alloc(Symbol { @@ -213,10 +213,10 @@ fn item_name(path: &ast::Path) -> Option { let segment = path.segment()?; if let Some(name) = segment.name() { - return Some(Name::new(name.syntax().text().to_string())); + return Some(Name::from_node(&name)); } if let Some(name) = segment.name_ref() { - return Some(Name::new(name.syntax().text().to_string())); + return Some(Name::from_node(&name)); } None @@ -240,7 +240,7 @@ fn schema_name(b: &Binder, path: &ast::Path, is_temp: bool) -> Option { .and_then(|q| q.segment()) .and_then(|s| s.name_ref()) { - return Some(Schema(Name::new(name_ref.syntax().text().to_string()))); + return Some(Schema(Name::from_node(&name_ref))); } if is_temp { @@ -339,7 +339,7 @@ fn extract_param_signature(param_list: Option) -> Option Opti resolve_table(binder, &table_name, &schema, position) } NameRefContext::SelectFromTable => { - let table_name = Name::new(name_ref.syntax().text().to_string()); + let table_name = Name::from_node(name_ref); let schema = if let Some(parent) = name_ref.syntax().parent() && let Some(field_expr) = ast::FieldExpr::cast(parent) && let Some(base) = field_expr.base() && let Some(schema_name_ref) = ast::NameRef::cast(base.syntax().clone()) { - let schema_text = schema_name_ref.syntax().text().to_string(); - Some(Schema(Name::new(schema_text))) + Some(Schema(Name::from_node(&schema_name_ref))) } else { None }; @@ -106,7 +105,7 @@ pub(crate) fn resolve_name_ref(binder: &Binder, name_ref: &ast::NameRef) -> Opti ) } NameRefContext::DropSchema | NameRefContext::SchemaQualifier => { - let schema_name = Name::new(name_ref.syntax().text().to_string()); + let schema_name = Name::from_node(name_ref); resolve_schema(binder, &schema_name) } NameRefContext::SelectFunctionCall => { @@ -115,12 +114,11 @@ pub(crate) fn resolve_name_ref(binder: &Binder, name_ref: &ast::NameRef) -> Opti { let base = field_expr.base()?; let schema_name_ref = ast::NameRef::cast(base.syntax().clone())?; - let schema_text = schema_name_ref.syntax().text().to_string(); - Some(Schema(Name::new(schema_text))) + Some(Schema(Name::from_node(&schema_name_ref))) } else { None }; - let function_name = Name::new(name_ref.syntax().text().to_string()); + let function_name = Name::from_node(name_ref); let position = name_ref.syntax().text_range().start(); // functions take precedence @@ -451,7 +449,7 @@ fn resolve_schema(binder: &Binder, schema_name: &Name) -> Option } fn resolve_create_index_column(binder: &Binder, name_ref: &ast::NameRef) -> Option { - let column_name = Name::new(name_ref.syntax().text().to_string()); + let column_name = Name::from_node(name_ref); let create_index = name_ref .syntax() @@ -476,7 +474,7 @@ fn resolve_create_index_column(binder: &Binder, name_ref: &ast::NameRef) -> Opti for arg in create_table.table_arg_list()?.args() { if let ast::TableArg::Column(column) = arg && let Some(col_name) = column.name() - && Name::new(col_name.syntax().text().to_string()) == column_name + && Name::from_node(&col_name) == column_name { return Some(SyntaxNodePtr::new(col_name.syntax())); } @@ -486,7 +484,7 @@ fn resolve_create_index_column(binder: &Binder, name_ref: &ast::NameRef) -> Opti } fn resolve_insert_column(binder: &Binder, name_ref: &ast::NameRef) -> Option { - let column_name = Name::new(name_ref.syntax().text().to_string()); + let column_name = Name::from_node(name_ref); let insert = name_ref.syntax().ancestors().find_map(ast::Insert::cast)?; let path = insert.path()?; @@ -507,7 +505,7 @@ fn resolve_insert_column(binder: &Binder, name_ref: &ast::NameRef) -> Option Option { - let table_name = Name::new(name_ref.syntax().text().to_string()); + let table_name = Name::from_node(name_ref); let field_expr = name_ref.syntax().parent().and_then(ast::FieldExpr::cast)?; @@ -530,9 +528,7 @@ fn resolve_select_qualified_column_table( { // if we're at the field `bar` in `foo.bar` if let ast::Expr::NameRef(schema_name_ref) = field_expr.base()? { - Some(Schema(Name::new( - schema_name_ref.syntax().text().to_string(), - ))) + Some(Schema(Name::from_node(&schema_name_ref))) } else { None } @@ -542,9 +538,7 @@ fn resolve_select_qualified_column_table( && let ast::Expr::NameRef(schema_name_ref) = inner_base { // if we're at the field `foo` in `foo.buzz.bar` - Some(Schema(Name::new( - schema_name_ref.syntax().text().to_string(), - ))) + Some(Schema(Name::from_node(&schema_name_ref))) } else { None }; @@ -560,7 +554,7 @@ fn resolve_select_qualified_column_table( let (table_name, schema) = if let Some(name_ref_node) = from_item.name_ref() { // `from foo` - let from_table_name = Name::new(name_ref_node.syntax().text().to_string()); + let from_table_name = Name::from_node(&name_ref_node); if from_table_name == table_name { (from_table_name, None) } else { @@ -569,14 +563,14 @@ fn resolve_select_qualified_column_table( } else { // `from bar.foo` let from_field_expr = from_item.field_expr()?; - let from_table_name = Name::new(from_field_expr.field()?.syntax().text().to_string()); + let from_table_name = Name::from_node(&from_field_expr.field()?); if from_table_name != table_name { return None; } let ast::Expr::NameRef(schema_name_ref) = from_field_expr.base()? else { return None; }; - let schema = Schema(Name::new(schema_name_ref.syntax().text().to_string())); + let schema = Schema(Name::from_node(&schema_name_ref)); (from_table_name, Some(schema)) }; @@ -588,7 +582,7 @@ fn resolve_select_qualified_column( binder: &Binder, name_ref: &ast::NameRef, ) -> Option { - let column_name = Name::new(name_ref.syntax().text().to_string()); + let column_name = Name::from_node(name_ref); let field_expr = name_ref.syntax().parent().and_then(ast::FieldExpr::cast)?; @@ -597,7 +591,7 @@ fn resolve_select_qualified_column( if let Some(base) = field_expr.base() && let ast::Expr::NameRef(table_name_ref) = base { - (Name::new(table_name_ref.syntax().text().to_string()), None) + (Name::from_node(&table_name_ref), None) // we have `foo.bar.buzz` } else if let Some(base) = field_expr.base() && let ast::Expr::FieldExpr(inner_field_expr) = base @@ -606,9 +600,9 @@ fn resolve_select_qualified_column( && let ast::Expr::NameRef(schema_name_ref) = inner_base { ( - Name::new(table_field.syntax().text().to_string()), - Some(Schema(Name::new( - schema_name_ref.syntax().text().to_string(), + Name::from_node(&table_field), + Some(Schema(Name::from_node( + &schema_name_ref ))), ) } else { @@ -626,7 +620,7 @@ fn resolve_select_qualified_column( if let Some(name_ref_node) = from_item.name_ref() { // `from bar` - let from_table_name = Name::new(name_ref_node.syntax().text().to_string()); + let from_table_name = Name::from_node(&name_ref_node); if from_table_name == column_table_name { (from_table_name, None) } else { @@ -635,14 +629,14 @@ fn resolve_select_qualified_column( } else { // `from foo.bar` let from_field_expr = from_item.field_expr()?; - let from_table_name = Name::new(from_field_expr.field()?.syntax().text().to_string()); + let from_table_name = Name::from_node(&from_field_expr.field()?); if from_table_name != column_table_name { return None; } let ast::Expr::NameRef(schema_name_ref) = from_field_expr.base()? else { return None; }; - let schema = Schema(Name::new(schema_name_ref.syntax().text().to_string())); + let schema = Schema(Name::from_node(&schema_name_ref)); (from_table_name, Some(schema)) } }; @@ -659,7 +653,7 @@ fn resolve_select_qualified_column( for arg in create_table.table_arg_list()?.args() { if let ast::TableArg::Column(column) = arg && let Some(col_name) = column.name() - && Name::new(col_name.syntax().text().to_string()) == column_name + && Name::from_node(&col_name) == column_name { return Some(SyntaxNodePtr::new(col_name.syntax())); } @@ -671,7 +665,7 @@ fn resolve_select_qualified_column( } fn resolve_select_column(binder: &Binder, name_ref: &ast::NameRef) -> Option { - let column_name = Name::new(name_ref.syntax().text().to_string()); + let column_name = Name::from_node(name_ref); let select = name_ref.syntax().ancestors().find_map(ast::Select::cast)?; let from_clause = select.from_clause()?; @@ -686,14 +680,14 @@ fn resolve_select_column(binder: &Binder, name_ref: &ast::NameRef) -> Option Option Option Option { - let column_name = Name::new(name_ref.syntax().text().to_string()); + let column_name = Name::from_node(name_ref); let delete = name_ref.syntax().ancestors().find_map(ast::Delete::cast)?; let relation_name = delete.relation_name()?; @@ -757,7 +751,7 @@ fn resolve_delete_where_column(binder: &Binder, name_ref: &ast::NameRef) -> Opti for arg in create_table.table_arg_list()?.args() { if let ast::TableArg::Column(column) = arg && let Some(col_name) = column.name() - && Name::new(col_name.syntax().text().to_string()) == column_name + && Name::from_node(&col_name) == column_name { return Some(SyntaxNodePtr::new(col_name.syntax())); } @@ -770,7 +764,7 @@ fn resolve_function_call_style_column( binder: &Binder, name_ref: &ast::NameRef, ) -> Option { - let column_name = Name::new(name_ref.syntax().text().to_string()); + let column_name = Name::from_node(name_ref); // function call syntax for columns is only valid if there is one argument let call_expr = name_ref @@ -788,14 +782,14 @@ fn resolve_function_call_style_column( // get the table name and schema from the FROM clause let (table_name, schema) = if let Some(name_ref_node) = from_item.name_ref() { - (Name::new(name_ref_node.syntax().text().to_string()), None) + (Name::from_node(&name_ref_node), None) } else { let field_expr = from_item.field_expr()?; - let table_name = Name::new(field_expr.field()?.syntax().text().to_string()); + let table_name = Name::from_node(&field_expr.field()?); let ast::Expr::NameRef(schema_name_ref) = field_expr.base()? else { return None; }; - let schema = Schema(Name::new(schema_name_ref.syntax().text().to_string())); + let schema = Schema(Name::from_node(&schema_name_ref)); (table_name, Some(schema)) }; @@ -810,7 +804,7 @@ fn resolve_function_call_style_column( for arg in create_table.table_arg_list()?.args() { if let ast::TableArg::Column(column) = arg && let Some(col_name) = column.name() - && Name::new(col_name.syntax().text().to_string()) == column_name + && Name::from_node(&col_name) == column_name { return Some(SyntaxNodePtr::new(col_name.syntax())); } @@ -831,24 +825,24 @@ fn find_containing_path(name_ref: &ast::NameRef) -> Option { fn extract_table_name(path: &ast::Path) -> Option { let segment = path.segment()?; let name_ref = segment.name_ref()?; - Some(Name::new(name_ref.syntax().text().to_string())) + Some(Name::from_node(&name_ref)) } fn extract_schema_name(path: &ast::Path) -> Option { path.qualifier() .and_then(|q| q.segment()) .and_then(|s| s.name_ref()) - .map(|name_ref| Schema(Name::new(name_ref.syntax().text().to_string()))) + .map(|name_ref| Schema(Name::from_node(&name_ref))) } pub(crate) fn extract_column_name(col: &ast::Column) -> Option { - let text = if let Some(name_ref) = col.name_ref() { - name_ref.syntax().text().to_string() + let name = if let Some(name_ref) = col.name_ref() { + Name::from_node(&name_ref) } else { let name = col.name()?; - name.syntax().text().to_string() + Name::from_node(&name) }; - Some(Name::new(text)) + Some(name) } pub(crate) fn find_column_in_table( @@ -858,7 +852,7 @@ pub(crate) fn find_column_in_table( table_arg_list.args().find_map(|arg| { if let ast::TableArg::Column(column) = arg && let Some(name) = column.name() - && Name::new(name.syntax().text().to_string()) == *col_name + && Name::from_node(&name) == *col_name { Some(name.syntax().text_range()) } else { @@ -873,7 +867,7 @@ fn resolve_cte_table(name_ref: &ast::NameRef, cte_name: &Name) -> Option Option> { && let Some(segment) = path.segment() && let Some(name_ref) = segment.name_ref() { - params.push(Name::new(name_ref.syntax().text().to_string())); + params.push(Name::from_node(&name_ref)); } } (!params.is_empty()).then_some(params) diff --git a/crates/squawk_ide/src/symbols.rs b/crates/squawk_ide/src/symbols.rs index e4b400e3..20db7787 100644 --- a/crates/squawk_ide/src/symbols.rs +++ b/crates/squawk_ide/src/symbols.rs @@ -1,6 +1,6 @@ use la_arena::Idx; use smol_str::SmolStr; -use squawk_syntax::SyntaxNodePtr; +use squawk_syntax::{SyntaxNodePtr, ast}; use std::fmt; use crate::quote::normalize_identifier; @@ -13,7 +13,7 @@ pub(crate) struct Schema(pub(crate) Name); impl Schema { pub(crate) fn new(name: impl Into) -> Self { - Schema(Name::new(name)) + Schema(Name::from_string(name)) } } @@ -24,11 +24,16 @@ impl fmt::Display for Schema { } impl Name { - pub(crate) fn new(text: impl Into) -> Self { + pub(crate) fn from_string(text: impl Into) -> Self { let text = text.into(); let normalized = normalize_identifier(&text); Name(normalized.into()) } + pub(crate) fn from_node(node: &impl ast::NameLike) -> Self { + let text = node.syntax().text().to_string(); + let normalized = normalize_identifier(&text); + Name(normalized.into()) + } } impl fmt::Display for Name { @@ -61,11 +66,11 @@ mod test { use super::*; #[test] fn name_case_insensitive_compare() { - assert_eq!(Name::new("foo"), Name::new("FOO")); + assert_eq!(Name::from_string("foo"), Name::from_string("FOO")); } #[test] fn name_quote_comparing() { - assert_eq!(Name::new(r#""foo""#), Name::new("foo")); + assert_eq!(Name::from_string(r#""foo""#), Name::from_string("foo")); } } diff --git a/crates/squawk_syntax/src/ast.rs b/crates/squawk_syntax/src/ast.rs index 2c5e059e..993196da 100644 --- a/crates/squawk_syntax/src/ast.rs +++ b/crates/squawk_syntax/src/ast.rs @@ -55,6 +55,7 @@ pub use self::{ // HasGenericParams, HasLoopBody, HasName, HasParamList, + NameLike, }, }; diff --git a/crates/squawk_syntax/src/ast/node_ext.rs b/crates/squawk_syntax/src/ast/node_ext.rs index 400bb148..62fc90c1 100644 --- a/crates/squawk_syntax/src/ast/node_ext.rs +++ b/crates/squawk_syntax/src/ast/node_ext.rs @@ -231,6 +231,8 @@ pub(crate) fn text_of_first_token(node: &SyntaxNode) -> TokenText<'_> { impl ast::HasParamList for ast::FunctionSig {} impl ast::HasParamList for ast::Aggregate {} +impl ast::NameLike for ast::Name {} +impl ast::NameLike for ast::NameRef {} #[test] fn index_expr() { diff --git a/crates/squawk_syntax/src/ast/traits.rs b/crates/squawk_syntax/src/ast/traits.rs index 7011d72c..087310f9 100644 --- a/crates/squawk_syntax/src/ast/traits.rs +++ b/crates/squawk_syntax/src/ast/traits.rs @@ -9,6 +9,8 @@ pub trait HasName: AstNode { } } +pub trait NameLike: AstNode {} + pub trait HasArgList: AstNode { fn arg_list(&self) -> Option { support::child(self.syntax())