diff --git a/PLAN.md b/PLAN.md index b56d5ecd..ba431566 100644 --- a/PLAN.md +++ b/PLAN.md @@ -590,7 +590,7 @@ FROM ( WHERE total_amount > 1000; ``` -### Rule: aggregate free having condition +### Rule: aggregate free `having` condition ```sql select a from t group by a having a > 10; @@ -600,6 +600,21 @@ select a from t group by a having a > 10; select a from t where a > 10 group by a; ``` +### Rule: conflicting function and aggregate definitions + +```sql +create function foo(int) returns int as $$ + select $1 * 2; +$$ language sql; + +create aggregate foo(int) ( + sfunc = int4pl, + stype = int, + initcond = '0' +); +-- Query 1 ERROR at Line 1: : ERROR: function "foo" already exists with same argument types +``` + ### Rule: order direction is redundent ```sql diff --git a/crates/squawk_ide/src/hover.rs b/crates/squawk_ide/src/hover.rs index 0f597360..d3f69a34 100644 --- a/crates/squawk_ide/src/hover.rs +++ b/crates/squawk_ide/src/hover.rs @@ -44,6 +44,10 @@ pub fn hover(file: &ast::SourceFile, offset: TextSize) -> Option { return hover_function(file, &name_ref, &binder); } + if is_aggregate_ref(&name_ref) { + return hover_aggregate(file, &name_ref, &binder); + } + if is_select_function_call(&name_ref) { // Try function first, but fall back to column if no function found // (handles function-call-style column access like `select a(t)`) @@ -85,6 +89,14 @@ pub fn hover(file: &ast::SourceFile, offset: TextSize) -> Option { return format_create_function(&create_function, &binder); } + if let Some(create_aggregate) = name + .syntax() + .ancestors() + .find_map(ast::CreateAggregate::cast) + { + return format_create_aggregate(&create_aggregate, &binder); + } + if let Some(create_schema) = name.syntax().ancestors().find_map(ast::CreateSchema::cast) { return format_create_schema(&create_schema); } @@ -353,6 +365,15 @@ fn is_function_ref(name_ref: &ast::NameRef) -> bool { false } +fn is_aggregate_ref(name_ref: &ast::NameRef) -> bool { + for ancestor in name_ref.syntax().ancestors() { + if ast::DropAggregate::can_cast(ancestor.kind()) { + return true; + } + } + false +} + fn is_select_function_call(name_ref: &ast::NameRef) -> bool { let mut in_call_expr = false; let mut in_arg_list = false; @@ -498,6 +519,53 @@ fn function_schema( search_path.first().map(|s| s.to_string()) } +fn hover_aggregate( + file: &ast::SourceFile, + name_ref: &ast::NameRef, + binder: &binder::Binder, +) -> Option { + let aggregate_ptr = resolve::resolve_name_ref(binder, name_ref)?; + + let root = file.syntax(); + let aggregate_name_node = aggregate_ptr.to_node(root); + + let create_aggregate = aggregate_name_node + .ancestors() + .find_map(ast::CreateAggregate::cast)?; + + format_create_aggregate(&create_aggregate, binder) +} + +fn format_create_aggregate( + create_aggregate: &ast::CreateAggregate, + binder: &binder::Binder, +) -> Option { + let path = create_aggregate.path()?; + let segment = path.segment()?; + let name = segment.name()?; + let aggregate_name = name.syntax().text().to_string(); + + let schema = if let Some(qualifier) = path.qualifier() { + qualifier.syntax().text().to_string() + } else { + aggregate_schema(create_aggregate, binder)? + }; + + let param_list = create_aggregate.param_list()?; + let params = param_list.syntax().text().to_string(); + + Some(format!("aggregate {}.{}{}", schema, aggregate_name, params)) +} + +fn aggregate_schema( + create_aggregate: &ast::CreateAggregate, + binder: &binder::Binder, +) -> Option { + let position = create_aggregate.syntax().text_range().start(); + let search_path = binder.search_path_at(position); + search_path.first().map(|s| s.to_string()) +} + #[cfg(test)] mod test { use crate::hover::hover; @@ -1003,6 +1071,114 @@ drop function foo$0(); "); } + #[test] + fn hover_on_drop_function_overloaded() { + assert_snapshot!(check_hover(" +create function add(complex) returns complex as $$ select null $$ language sql; +create function add(bigint) returns bigint as $$ select 1 $$ language sql; +drop function add$0(complex); +"), @r" + hover: function public.add(complex) returns complex + ╭▸ + 4 │ drop function add(complex); + ╰╴ ─ hover + "); + } + + #[test] + fn hover_on_drop_function_second_overload() { + assert_snapshot!(check_hover(" +create function add(complex) returns complex as $$ select null $$ language sql; +create function add(bigint) returns bigint as $$ select 1 $$ language sql; +drop function add$0(bigint); +"), @r" + hover: function public.add(bigint) returns bigint + ╭▸ + 4 │ drop function add(bigint); + ╰╴ ─ hover + "); + } + + #[test] + fn hover_on_drop_aggregate() { + assert_snapshot!(check_hover(" +create aggregate myavg(int) (sfunc = int4_avg_accum, stype = _int8); +drop aggregate myavg$0(int); +"), @r" + hover: aggregate public.myavg(int) + ╭▸ + 3 │ drop aggregate myavg(int); + ╰╴ ─ hover + "); + } + + #[test] + fn hover_on_drop_aggregate_with_schema() { + assert_snapshot!(check_hover(" +create aggregate myschema.myavg(int) (sfunc = int4_avg_accum, stype = _int8); +drop aggregate myschema.myavg$0(int); +"), @r" + hover: aggregate myschema.myavg(int) + ╭▸ + 3 │ drop aggregate myschema.myavg(int); + ╰╴ ─ hover + "); + } + + #[test] + fn hover_on_create_aggregate_definition() { + assert_snapshot!(check_hover(" +create aggregate myavg$0(int) (sfunc = int4_avg_accum, stype = _int8); +"), @r" + hover: aggregate public.myavg(int) + ╭▸ + 2 │ create aggregate myavg(int) (sfunc = int4_avg_accum, stype = _int8); + ╰╴ ─ hover + "); + } + + #[test] + fn hover_on_drop_aggregate_with_search_path() { + assert_snapshot!(check_hover(r#" +set search_path to myschema; +create aggregate myavg(int) (sfunc = int4_avg_accum, stype = _int8); +drop aggregate myavg$0(int); +"#), @r" + hover: aggregate myschema.myavg(int) + ╭▸ + 4 │ drop aggregate myavg(int); + ╰╴ ─ hover + "); + } + + #[test] + fn hover_on_drop_aggregate_overloaded() { + assert_snapshot!(check_hover(" +create aggregate sum(complex) (sfunc = complex_add, stype = complex, initcond = '(0,0)'); +create aggregate sum(bigint) (sfunc = bigint_add, stype = bigint, initcond = '0'); +drop aggregate sum$0(complex); +"), @r" + hover: aggregate public.sum(complex) + ╭▸ + 4 │ drop aggregate sum(complex); + ╰╴ ─ hover + "); + } + + #[test] + fn hover_on_drop_aggregate_second_overload() { + assert_snapshot!(check_hover(" +create aggregate sum(complex) (sfunc = complex_add, stype = complex, initcond = '(0,0)'); +create aggregate sum(bigint) (sfunc = bigint_add, stype = bigint, initcond = '0'); +drop aggregate sum$0(bigint); +"), @r" + hover: aggregate public.sum(bigint) + ╭▸ + 4 │ drop aggregate sum(bigint); + ╰╴ ─ hover + "); + } + #[test] fn hover_on_select_function_call() { assert_snapshot!(check_hover("