Skip to content

Commit ddaa211

Browse files
authored
named: Port parameter style validation to SQL (#504)
1 parent ae3fe91 commit ddaa211

File tree

13 files changed

+248
-22
lines changed

13 files changed

+248
-22
lines changed

internal/compiler/compile.go

Lines changed: 30 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -55,8 +55,8 @@ func enumValueName(value string) string {
5555
}
5656

5757
// end copypasta
58-
func parseCatalog(p Parser, c *catalog.Catalog, schema []string) error {
59-
files, err := sqlpath.Glob(schema)
58+
func parseCatalog(p Parser, c *catalog.Catalog, schemas []string) error {
59+
files, err := sqlpath.Glob(schemas)
6060
if err != nil {
6161
return err
6262
}
@@ -86,6 +86,34 @@ func parseCatalog(p Parser, c *catalog.Catalog, schema []string) error {
8686
return nil
8787
}
8888

89+
func parseQueries(p Parser, c *catalog.Catalog, queries []string) (*Result, error) {
90+
merr := multierr.New()
91+
files, err := sqlpath.Glob(queries)
92+
if err != nil {
93+
return nil, err
94+
}
95+
for _, filename := range files {
96+
blob, err := ioutil.ReadFile(filename)
97+
if err != nil {
98+
merr.Add(filename, "", 0, err)
99+
continue
100+
}
101+
source := string(blob)
102+
stmts, err := p.Parse(strings.NewReader(source))
103+
if err != nil {
104+
merr.Add(filename, source, 0, err)
105+
continue
106+
}
107+
for _, stmt := range stmts {
108+
fmt.Println(stmt)
109+
}
110+
}
111+
if len(merr.Errs()) > 0 {
112+
return nil, merr
113+
}
114+
return &Result{}, nil
115+
}
116+
89117
func buildResult(c *catalog.Catalog) (*Result, error) {
90118
var structs []dinosql.GoStruct
91119
var enums []dinosql.GoEnum
@@ -130,7 +158,6 @@ func buildResult(c *catalog.Catalog) (*Result, error) {
130158
}
131159
}
132160
}
133-
134161
if len(structs) > 0 {
135162
sort.Slice(structs, func(i, j int) bool { return structs[i].Name < structs[j].Name })
136163
}

internal/compiler/engine.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,7 @@ func (e *Engine) ParseCatalog(schema []string) error {
4444
}
4545

4646
func (e *Engine) ParseQueries(queries []string, opts dinosql.ParserOpts) error {
47-
r, err := buildResult(e.catalog)
47+
r, err := parseQueries(e.parser, e.catalog, e.conf.Queries)
4848
if err != nil {
4949
return err
5050
}

internal/compiler/runtime/parse.go

Lines changed: 51 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,51 @@
1+
package runtime
2+
3+
import (
4+
"errors"
5+
6+
"github.com/kyleconroy/sqlc/internal/source"
7+
"github.com/kyleconroy/sqlc/internal/sql/ast"
8+
"github.com/kyleconroy/sqlc/internal/sql/ast/pg"
9+
"github.com/kyleconroy/sqlc/internal/sql/catalog"
10+
"github.com/kyleconroy/sqlc/internal/sql/validate"
11+
)
12+
13+
type Query struct {
14+
}
15+
16+
var ErrUnsupportedStatementType = errors.New("parseQuery: unsupported statement type")
17+
18+
func parseQuery(c *catalog.Catalog, stmt ast.Node, src string, rewriteParameters bool) (*Query, error) {
19+
if err := validate.ParamStyle(stmt); err != nil {
20+
return nil, err
21+
}
22+
if err := validate.ParamRef(stmt); err != nil {
23+
return nil, err
24+
}
25+
raw, ok := stmt.(*ast.RawStmt)
26+
if !ok {
27+
return nil, errors.New("node is not a statement")
28+
}
29+
switch n := raw.Stmt.(type) {
30+
case *pg.SelectStmt:
31+
case *pg.DeleteStmt:
32+
case *pg.InsertStmt:
33+
if err := validate.InsertStmt(n); err != nil {
34+
return nil, err
35+
}
36+
case *pg.TruncateStmt:
37+
case *pg.UpdateStmt:
38+
default:
39+
return nil, ErrUnsupportedStatementType
40+
}
41+
42+
rawSQL, err := source.Pluck(src, raw.StmtLocation, raw.StmtLen)
43+
if err != nil {
44+
return nil, err
45+
}
46+
if rawSQL == "" {
47+
return nil, errors.New("missing semicolon at end of file")
48+
}
49+
50+
return &Query{}, nil
51+
}

internal/dinosql/parser.go

Lines changed: 8 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -148,24 +148,24 @@ func ParseQueries(c core.Catalog, queriesPaths []string, opts ParserOpts) (*Resu
148148
merr.Add(filename, "", 0, err)
149149
continue
150150
}
151-
source := string(blob)
152-
tree, err := pg.Parse(source)
151+
src := string(blob)
152+
tree, err := pg.Parse(src)
153153
if err != nil {
154-
merr.Add(filename, source, 0, err)
154+
merr.Add(filename, src, 0, err)
155155
continue
156156
}
157157
for _, stmt := range tree.Statements {
158-
query, err := parseQuery(c, stmt, source, opts.UsePositionalParameters)
158+
query, err := parseQuery(c, stmt, src, opts.UsePositionalParameters)
159159
if err == errUnsupportedStatementType {
160160
continue
161161
}
162162
if err != nil {
163-
merr.Add(filename, source, location(stmt), err)
163+
merr.Add(filename, src, location(stmt), err)
164164
continue
165165
}
166166
if query.Name != "" {
167167
if _, exists := set[query.Name]; exists {
168-
merr.Add(filename, source, location(stmt), fmt.Errorf("duplicate query name: %s", query.Name))
168+
merr.Add(filename, src, location(stmt), fmt.Errorf("duplicate query name: %s", query.Name))
169169
continue
170170
}
171171
set[query.Name] = struct{}{}
@@ -198,12 +198,6 @@ func location(node nodes.Node) int {
198198
return 0
199199
}
200200

201-
func pluckQuery(source string, n nodes.RawStmt) (string, error) {
202-
head := n.StmtLocation
203-
tail := n.StmtLocation + n.StmtLen
204-
return source[head:tail], nil
205-
}
206-
207201
func rangeVars(root nodes.Node) []nodes.RangeVar {
208202
var vars []nodes.RangeVar
209203
find := ast.VisitorFunc(func(node nodes.Node) {
@@ -303,7 +297,7 @@ func validateCmd(n nodes.Node, name, cmd string) error {
303297

304298
var errUnsupportedStatementType = errors.New("parseQuery: unsupported statement type")
305299

306-
func parseQuery(c core.Catalog, stmt nodes.Node, source string, rewriteParameters bool) (*Query, error) {
300+
func parseQuery(c core.Catalog, stmt nodes.Node, src string, rewriteParameters bool) (*Query, error) {
307301
if err := validate.ParamStyle(stmt); err != nil {
308302
return nil, err
309303
}
@@ -327,7 +321,7 @@ func parseQuery(c core.Catalog, stmt nodes.Node, source string, rewriteParameter
327321
return nil, errUnsupportedStatementType
328322
}
329323

330-
rawSQL, err := pluckQuery(source, raw)
324+
rawSQL, err := source.Pluck(src, raw.StmtLocation, raw.StmtLen)
331325
if err != nil {
332326
return nil, err
333327
}

internal/dinosql/parser_test.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,7 @@ func TestPluck(t *testing.T) {
3737
for i, stmt := range tree.Statements {
3838
switch n := stmt.(type) {
3939
case nodes.RawStmt:
40-
q, err := pluckQuery(pluck, n)
40+
q, err := source.Pluck(pluck, n.StmtLocation, n.StmtLen)
4141
if err != nil {
4242
t.Error(err)
4343
continue

internal/postgresql/parse.go

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -179,7 +179,11 @@ func (p *Parser) Parse(r io.Reader) ([]ast.Statement, error) {
179179
return nil, fmt.Errorf("unexpected nil node")
180180
}
181181
stmts = append(stmts, ast.Statement{
182-
Raw: &ast.RawStmt{Stmt: n},
182+
Raw: &ast.RawStmt{
183+
Stmt: n,
184+
StmtLocation: raw.StmtLocation,
185+
StmtLen: raw.StmtLen,
186+
},
183187
})
184188
}
185189
return stmts, nil

internal/source/code.go

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,3 +37,9 @@ func LineNumber(source string, head int) (int, int) {
3737
}
3838
return line + 1, col
3939
}
40+
41+
func Pluck(source string, location, length int) (string, error) {
42+
head := location
43+
tail := location + length
44+
return source[head:tail], nil
45+
}

internal/sql/ast/ast.go

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13,11 +13,13 @@ func (n *Statement) Pos() int {
1313
}
1414

1515
type RawStmt struct {
16-
Stmt Node
16+
Stmt Node
17+
StmtLocation int
18+
StmtLen int
1719
}
1820

1921
func (n *RawStmt) Pos() int {
20-
return 0
22+
return n.StmtLocation
2123
}
2224

2325
type TableName struct {

internal/sql/astutils/search.go

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,21 @@
1+
package astutils
2+
3+
import "github.com/kyleconroy/sqlc/internal/sql/ast"
4+
5+
type nodeSearch struct {
6+
list *ast.List
7+
check func(ast.Node) bool
8+
}
9+
10+
func (s *nodeSearch) Visit(node ast.Node) Visitor {
11+
if s.check(node) {
12+
s.list.Items = append(s.list.Items, node)
13+
}
14+
return s
15+
}
16+
17+
func Search(root ast.Node, f func(ast.Node) bool) *ast.List {
18+
ns := &nodeSearch{check: f, list: &ast.List{}}
19+
Walk(ns, root)
20+
return ns.list
21+
}

internal/sql/named/is.go

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,17 @@
1+
package named
2+
3+
import (
4+
"github.com/kyleconroy/sqlc/internal/sql/ast"
5+
"github.com/kyleconroy/sqlc/internal/sql/ast/pg"
6+
"github.com/kyleconroy/sqlc/internal/sql/astutils"
7+
)
8+
9+
func IsParamFunc(node ast.Node) bool {
10+
fun, ok := node.(*pg.FuncCall)
11+
return ok && astutils.Join(fun.Funcname, ".") == "sqlc.arg"
12+
}
13+
14+
func IsParamSign(node ast.Node) bool {
15+
expr, ok := node.(*pg.A_Expr)
16+
return ok && astutils.Join(expr.Name, ".") == "@"
17+
}

0 commit comments

Comments
 (0)