Skip to content

Commit d2541ce

Browse files
authored
compiler: Support calling functions with defaults (#635)
Fix a number of bugs when resolving function calls. Switch to using functions generated from a default PostgreSQL instance. Add a new test case ripped from the PostgreSQL docs.
1 parent ca47f5e commit d2541ce

File tree

15 files changed

+290
-75
lines changed

15 files changed

+290
-75
lines changed

internal/codegen/golang/postgresql_type.go

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -13,19 +13,19 @@ func postgresType(r *compiler.Result, col *compiler.Column, settings config.Comb
1313
notNull := col.NotNull || col.IsArray
1414

1515
switch columnType {
16-
case "serial", "pg_catalog.serial4":
16+
case "serial", "serial4", "pg_catalog.serial4":
1717
if notNull {
1818
return "int32"
1919
}
2020
return "sql.NullInt32"
2121

22-
case "bigserial", "pg_catalog.serial8":
22+
case "bigserial", "serial8", "pg_catalog.serial8":
2323
if notNull {
2424
return "int64"
2525
}
2626
return "sql.NullInt64"
2727

28-
case "smallserial", "pg_catalog.serial2":
28+
case "smallserial", "serial2", "pg_catalog.serial2":
2929
return "int16"
3030

3131
case "integer", "int", "int4", "pg_catalog.int4":
@@ -43,19 +43,19 @@ func postgresType(r *compiler.Result, col *compiler.Column, settings config.Comb
4343
case "smallint", "int2", "pg_catalog.int2":
4444
return "int16"
4545

46-
case "float", "double precision", "pg_catalog.float8":
46+
case "float", "double precision", "float8", "pg_catalog.float8":
4747
if notNull {
4848
return "float64"
4949
}
5050
return "sql.NullFloat64"
5151

52-
case "real", "pg_catalog.float4":
52+
case "real", "float4", "pg_catalog.float4":
5353
if notNull {
5454
return "float32"
5555
}
5656
return "sql.NullFloat64" // TODO: Change to sql.NullFloat32 after updating the go.mod file
5757

58-
case "pg_catalog.numeric", "money":
58+
case "numeric", "pg_catalog.numeric", "money":
5959
// Since the Go standard library does not have a decimal type, lib/pq
6060
// returns numerics as strings.
6161
//
@@ -121,7 +121,7 @@ func postgresType(r *compiler.Result, col *compiler.Column, settings config.Comb
121121
}
122122
return "sql.NullString"
123123

124-
case "pg_catalog.interval":
124+
case "interval", "pg_catalog.interval":
125125
if notNull {
126126
return "int64"
127127
}

internal/compiler/output_columns.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -149,7 +149,7 @@ func outputColumns(qc *QueryCatalog, node ast.Node) ([]*Column, error) {
149149
if res.Name != nil {
150150
name = *res.Name
151151
}
152-
fun, err := qc.catalog.GetFuncN(rel, len(n.Args.Items))
152+
fun, err := qc.catalog.ResolveFuncCall(n)
153153
if err == nil {
154154
cols = append(cols, &Column{Name: name, DataType: dataType(fun.ReturnType), NotNull: true})
155155
} else {

internal/compiler/resolve.go

Lines changed: 31 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -170,24 +170,25 @@ func resolveCatalogRefs(c *catalog.Catalog, rvs []*pg.RangeVar, args []paramRef,
170170
}
171171

172172
case *ast.FuncCall:
173-
fun, err := c.GetFuncN(n.Func, len(n.Args.Items))
173+
fun, err := c.ResolveFuncCall(n)
174174
if err != nil {
175+
// Synthesize a function on the fly to avoid returning with an error
176+
// for an unknown Postgres function (e.g. defined in an extension)
175177
var args []*catalog.Argument
176178
for range n.Args.Items {
177179
args = append(args, &catalog.Argument{
178180
Type: &ast.TypeName{Name: "any"},
179181
})
180182
}
181-
// Synthesize a function on the fly to avoid returning with an error
182-
// for an unknown Postgres function (e.g. defined in an extension)
183-
fun = catalog.Function{
183+
fun = &catalog.Function{
184184
Name: n.Func.Name,
185185
Args: args,
186186
ReturnType: &ast.TypeName{Name: "any"},
187187
}
188188
}
189189
for i, item := range n.Args.Items {
190-
defaultName := fun.Name
190+
funcName := fun.Name
191+
var argName string
191192
switch inode := item.(type) {
192193
case *pg.ParamRef:
193194
if inode.Number != ref.ref.Number {
@@ -210,11 +211,15 @@ func resolveCatalogRefs(c *catalog.Catalog, rvs []*pg.RangeVar, args []paramRef,
210211
continue
211212
}
212213
if inode.Name != nil {
213-
defaultName = *inode.Name
214+
argName = *inode.Name
214215
}
215216
}
216217

217218
if fun.Args == nil {
219+
defaultName := funcName
220+
if argName != "" {
221+
defaultName = argName
222+
}
218223
a = append(a, Parameter{
219224
Number: ref.ref.Number,
220225
Column: &Column{
@@ -225,19 +230,31 @@ func resolveCatalogRefs(c *catalog.Catalog, rvs []*pg.RangeVar, args []paramRef,
225230
continue
226231
}
227232

228-
if i >= len(fun.Args) {
229-
return nil, fmt.Errorf("incorrect number of arguments to %s", fun.Name)
233+
var paramName string
234+
var paramType *ast.TypeName
235+
if argName == "" {
236+
paramName = fun.Args[i].Name
237+
paramType = fun.Args[i].Type
238+
} else {
239+
paramName = argName
240+
for _, arg := range fun.Args {
241+
if arg.Name == argName {
242+
paramType = arg.Type
243+
}
244+
}
245+
if paramType == nil {
246+
panic(fmt.Sprintf("named argument %s has no type", paramName))
247+
}
230248
}
231-
arg := fun.Args[i]
232-
name := arg.Name
233-
if name == "" {
234-
name = defaultName
249+
if paramName == "" {
250+
paramName = funcName
235251
}
252+
236253
a = append(a, Parameter{
237254
Number: ref.ref.Number,
238255
Column: &Column{
239-
Name: parameterName(ref.ref.Number, name),
240-
DataType: dataType(arg.Type),
256+
Name: parameterName(ref.ref.Number, paramName),
257+
DataType: dataType(paramType),
241258
NotNull: true,
242259
},
243260
})

internal/endtoend/testdata/func_args/go/query.sql.go

Lines changed: 8 additions & 8 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

internal/endtoend/testdata/generate_series/go/query.sql.go

Lines changed: 3 additions & 3 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

internal/endtoend/testdata/pg_advisory_xact_lock/go/query.sql.go

Lines changed: 2 additions & 2 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

internal/endtoend/testdata/sql_syntax_calling_funcs/go/db.go

Lines changed: 29 additions & 0 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

internal/endtoend/testdata/sql_syntax_calling_funcs/go/models.go

Lines changed: 5 additions & 0 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

internal/endtoend/testdata/sql_syntax_calling_funcs/go/query.sql.go

Lines changed: 74 additions & 0 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.
Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,29 @@
1+
-- https://www.postgresql.org/docs/current/sql-syntax-calling-funcs.html
2+
CREATE FUNCTION concat_lower_or_upper(a text, b text, uppercase boolean DEFAULT false)
3+
RETURNS text
4+
AS
5+
$$
6+
SELECT CASE
7+
WHEN $3 THEN UPPER($1 || ' ' || $2)
8+
ELSE LOWER($1 || ' ' || $2)
9+
END;
10+
$$
11+
LANGUAGE SQL IMMUTABLE STRICT;
12+
13+
-- name: PositionalNotation :one
14+
SELECT concat_lower_or_upper('Hello', 'World', true);
15+
16+
-- name: PositionalNoDefaault :one
17+
SELECT concat_lower_or_upper('Hello', 'World');
18+
19+
-- name: NamedNotation :one
20+
SELECT concat_lower_or_upper(a => 'Hello', b => 'World');
21+
22+
-- name: NamedAnyOrder :one
23+
SELECT concat_lower_or_upper(a => 'Hello', b => 'World', uppercase => true);
24+
25+
-- name: NamedOtherOrder :one
26+
SELECT concat_lower_or_upper(a => 'Hello', uppercase => true, b => 'World');
27+
28+
-- name: MixedNotation :one
29+
SELECT concat_lower_or_upper('Hello', 'World', uppercase => true);

0 commit comments

Comments
 (0)