Skip to content

Commit f94a933

Browse files
authored
compiler: Validate function calls (#505)
1 parent ddaa211 commit f94a933

File tree

11 files changed

+142
-40
lines changed

11 files changed

+142
-40
lines changed

internal/compiler/runtime/parse.go

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,9 @@ func parseQuery(c *catalog.Catalog, stmt ast.Node, src string, rewriteParameters
4646
if rawSQL == "" {
4747
return nil, errors.New("missing semicolon at end of file")
4848
}
49+
if err := validate.FuncCall(c, raw); err != nil {
50+
return nil, err
51+
}
4952

5053
return &Query{}, nil
5154
}

internal/postgresql/convert.go

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1538,11 +1538,17 @@ func convertFromExpr(n *nodes.FromExpr) *pg.FromExpr {
15381538
}
15391539
}
15401540

1541-
func convertFuncCall(n *nodes.FuncCall) *pg.FuncCall {
1541+
func convertFuncCall(n *nodes.FuncCall) *ast.FuncCall {
15421542
if n == nil {
15431543
return nil
15441544
}
1545-
return &pg.FuncCall{
1545+
fn, err := parseFuncName(n.Funcname)
1546+
if err != nil {
1547+
// TODO: How should we handle errors?
1548+
panic(err)
1549+
}
1550+
return &ast.FuncCall{
1551+
Func: fn,
15461552
Funcname: convertList(n.Funcname),
15471553
Args: convertList(n.Args),
15481554
AggOrder: convertList(n.AggOrder),
@@ -2839,11 +2845,11 @@ func convertWindowClause(n *nodes.WindowClause) *pg.WindowClause {
28392845
}
28402846
}
28412847

2842-
func convertWindowDef(n *nodes.WindowDef) *pg.WindowDef {
2848+
func convertWindowDef(n *nodes.WindowDef) *ast.WindowDef {
28432849
if n == nil {
28442850
return nil
28452851
}
2846-
return &pg.WindowDef{
2852+
return &ast.WindowDef{
28472853
Name: n.Name,
28482854
Refname: n.Refname,
28492855
PartitionClause: convertList(n.PartitionClause),

internal/postgresql/rewrite_test.go

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@ import (
44
"strings"
55
"testing"
66

7+
"github.com/kyleconroy/sqlc/internal/sql/ast"
78
"github.com/kyleconroy/sqlc/internal/sql/ast/pg"
89
"github.com/kyleconroy/sqlc/internal/sql/astutils"
910

@@ -24,7 +25,7 @@ func TestApply(t *testing.T) {
2425

2526
expect := &output[0]
2627
actual := astutils.Apply(&input[0], func(cr *astutils.Cursor) bool {
27-
fun, ok := cr.Node().(*pg.FuncCall)
28+
fun, ok := cr.Node().(*ast.FuncCall)
2829
if !ok {
2930
return true
3031
}

internal/sql/ast/pg/func_call.go renamed to internal/sql/ast/func_call.go

Lines changed: 6 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,11 @@
1-
package pg
2-
3-
import (
4-
"github.com/kyleconroy/sqlc/internal/sql/ast"
5-
)
1+
package ast
62

73
type FuncCall struct {
8-
Funcname *ast.List
9-
Args *ast.List
10-
AggOrder *ast.List
11-
AggFilter ast.Node
4+
Func *FuncName
5+
Funcname *List
6+
Args *List
7+
AggOrder *List
8+
AggFilter Node
129
AggWithinGroup bool
1310
AggStar bool
1411
AggDistinct bool

internal/sql/ast/pg/window_def.go

Lines changed: 0 additions & 20 deletions
This file was deleted.

internal/sql/ast/window_def.go

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,16 @@
1+
package ast
2+
3+
type WindowDef struct {
4+
Name *string
5+
Refname *string
6+
PartitionClause *List
7+
OrderClause *List
8+
FrameOptions int
9+
StartOffset Node
10+
EndOffset Node
11+
Location int
12+
}
13+
14+
func (n *WindowDef) Pos() int {
15+
return n.Location
16+
}

internal/sql/astutils/rewrite.go

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -760,7 +760,8 @@ func (a *application) apply(parent ast.Node, name string, iter *iterator, n ast.
760760
a.apply(n, "Fromlist", nil, n.Fromlist)
761761
a.apply(n, "Quals", nil, n.Quals)
762762

763-
case *pg.FuncCall:
763+
case *ast.FuncCall:
764+
a.apply(n, "Func", nil, n.Func)
764765
a.apply(n, "Funcname", nil, n.Funcname)
765766
a.apply(n, "Args", nil, n.Args)
766767
a.apply(n, "AggOrder", nil, n.AggOrder)
@@ -1206,7 +1207,7 @@ func (a *application) apply(parent ast.Node, name string, iter *iterator, n ast.
12061207
a.apply(n, "StartOffset", nil, n.StartOffset)
12071208
a.apply(n, "EndOffset", nil, n.EndOffset)
12081209

1209-
case *pg.WindowDef:
1210+
case *ast.WindowDef:
12101211
a.apply(n, "PartitionClause", nil, n.PartitionClause)
12111212
a.apply(n, "OrderClause", nil, n.OrderClause)
12121213
a.apply(n, "StartOffset", nil, n.StartOffset)

internal/sql/astutils/walk.go

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -645,7 +645,8 @@ func Walk(f Visitor, node ast.Node) {
645645
walkn(f, n.Fromlist)
646646
walkn(f, n.Quals)
647647

648-
case *pg.FuncCall:
648+
case *ast.FuncCall:
649+
walkn(f, n.Func)
649650
walkn(f, n.Funcname)
650651
walkn(f, n.Args)
651652
walkn(f, n.AggOrder)
@@ -1091,7 +1092,7 @@ func Walk(f Visitor, node ast.Node) {
10911092
walkn(f, n.StartOffset)
10921093
walkn(f, n.EndOffset)
10931094

1094-
case *pg.WindowDef:
1095+
case *ast.WindowDef:
10951096
walkn(f, n.PartitionClause)
10961097
walkn(f, n.OrderClause)
10971098
walkn(f, n.StartOffset)

internal/sql/catalog/public.go

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,25 @@
1+
package catalog
2+
3+
import (
4+
"github.com/kyleconroy/sqlc/internal/sql/ast"
5+
)
6+
7+
// TODO: Decide on a real, exported interface
8+
func (c *Catalog) ListFuncsByName(rel *ast.FuncName) ([]Function, error) {
9+
var funcs []Function
10+
11+
ns := rel.Schema
12+
if ns == "" {
13+
ns = c.DefaultSchema
14+
}
15+
s, err := c.getSchema(ns)
16+
if err != nil {
17+
return nil, err
18+
}
19+
for i := range s.Funcs {
20+
if s.Funcs[i].Name == rel.Name {
21+
funcs = append(funcs, *s.Funcs[i])
22+
}
23+
}
24+
return funcs, nil
25+
}

internal/sql/named/is.go

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,8 +7,14 @@ import (
77
)
88

99
func IsParamFunc(node ast.Node) bool {
10-
fun, ok := node.(*pg.FuncCall)
11-
return ok && astutils.Join(fun.Funcname, ".") == "sqlc.arg"
10+
call, ok := node.(*ast.FuncCall)
11+
if !ok {
12+
return false
13+
}
14+
if call.Func == nil {
15+
return false
16+
}
17+
return call.Func.Schema == "sqlc" && call.Func.Name == "arg"
1218
}
1319

1420
func IsParamSign(node ast.Node) bool {

0 commit comments

Comments
 (0)