Skip to content

Commit fd329cb

Browse files
authored
compiler: Port bottom of parseQuery (#510)
1 parent b28e3da commit fd329cb

File tree

4 files changed

+85
-50
lines changed

4 files changed

+85
-50
lines changed

internal/compiler/parse.go

Lines changed: 31 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@ package compiler
22

33
import (
44
"errors"
5+
"fmt"
56
"strings"
67

78
"github.com/kyleconroy/sqlc/internal/metadata"
@@ -13,6 +14,10 @@ import (
1314
)
1415

1516
type Query struct {
17+
SQL string
18+
Name string
19+
Cmd string // TODO: Pick a better name. One of: one, many, exec, execrows
20+
Comments []string
1621
}
1722

1823
var ErrUnsupportedStatementType = errors.New("parseQuery: unsupported statement type")
@@ -59,5 +64,30 @@ func parseQuery(p Parser, c *catalog.Catalog, stmt ast.Node, src string, rewrite
5964
return nil, err
6065
}
6166

62-
return &Query{}, nil
67+
// TODO: Then a miracle occurs
68+
69+
var edits []source.Edit
70+
expanded, err := source.Mutate(rawSQL, edits)
71+
if err != nil {
72+
return nil, err
73+
}
74+
75+
// If the query string was edited, make sure the syntax is valid
76+
if expanded != rawSQL {
77+
if _, err := p.Parse(strings.NewReader(expanded)); err != nil {
78+
return nil, fmt.Errorf("edited query syntax is invalid: %w", err)
79+
}
80+
}
81+
82+
trimmed, comments, err := source.StripComments(expanded)
83+
if err != nil {
84+
return nil, err
85+
}
86+
87+
return &Query{
88+
Cmd: cmd,
89+
Comments: comments,
90+
Name: name,
91+
SQL: trimmed,
92+
}, nil
6393
}

internal/dinosql/parser.go

Lines changed: 2 additions & 46 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
11
package dinosql
22

33
import (
4-
"bufio"
54
"errors"
65
"fmt"
76
"io/ioutil"
@@ -311,7 +310,7 @@ func parseQuery(c core.Catalog, stmt nodes.Node, src string, rewriteParameters b
311310
}
312311
edits = append(edits, expandEdits...)
313312

314-
expanded, err := editQuery(rawSQL, edits)
313+
expanded, err := source.Mutate(rawSQL, edits)
315314
if err != nil {
316315
return nil, err
317316
}
@@ -323,7 +322,7 @@ func parseQuery(c core.Catalog, stmt nodes.Node, src string, rewriteParameters b
323322
}
324323
}
325324

326-
trimmed, comments, err := stripComments(strings.TrimSpace(expanded))
325+
trimmed, comments, err := source.StripComments(expanded)
327326
if err != nil {
328327
return nil, err
329328
}
@@ -350,22 +349,6 @@ func rewriteNumberedParameters(refs []paramRef, raw nodes.RawStmt, sql string) (
350349
return edits, nil
351350
}
352351

353-
func stripComments(sql string) (string, []string, error) {
354-
s := bufio.NewScanner(strings.NewReader(sql))
355-
var lines, comments []string
356-
for s.Scan() {
357-
if strings.HasPrefix(s.Text(), "-- name:") {
358-
continue
359-
}
360-
if strings.HasPrefix(s.Text(), "--") {
361-
comments = append(comments, strings.TrimPrefix(s.Text(), "--"))
362-
continue
363-
}
364-
lines = append(lines, s.Text())
365-
}
366-
return strings.Join(lines, "\n"), comments, s.Err()
367-
}
368-
369352
func expand(qc *QueryCatalog, raw nodes.RawStmt) ([]source.Edit, error) {
370353
list := ast.Search(raw, func(node nodes.Node) bool {
371354
switch node.(type) {
@@ -486,33 +469,6 @@ func expandStmt(qc *QueryCatalog, raw nodes.RawStmt, node nodes.Node) ([]source.
486469
return edits, nil
487470
}
488471

489-
func editQuery(raw string, a []source.Edit) (string, error) {
490-
if len(a) == 0 {
491-
return raw, nil
492-
}
493-
sort.Slice(a, func(i, j int) bool { return a[i].Location > a[j].Location })
494-
s := raw
495-
for _, edit := range a {
496-
start := edit.Location
497-
if start > len(s) {
498-
return "", fmt.Errorf("edit start location is out of bounds")
499-
}
500-
if len(edit.New) <= 0 {
501-
return "", fmt.Errorf("empty edit contents")
502-
}
503-
if len(edit.Old) <= 0 {
504-
return "", fmt.Errorf("empty edit contents")
505-
}
506-
stop := edit.Location + len(edit.Old) - 1 // Assumes edit.New is non-empty
507-
if stop < len(s) {
508-
s = s[:start] + edit.New + s[stop+1:]
509-
} else {
510-
s = s[:start] + edit.New
511-
}
512-
}
513-
return s, nil
514-
}
515-
516472
type QueryCatalog struct {
517473
catalog core.Catalog
518474
ctes map[string]core.Table

internal/dinosql/parser_test.go

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -147,7 +147,7 @@ func TestRewriteParameters(t *testing.T) {
147147
if err != nil {
148148
t.Error(err)
149149
}
150-
rewritten, err := editQuery(q.orig, edits)
150+
rewritten, err := source.Mutate(q.orig, edits)
151151
if err != nil {
152152
t.Error(err)
153153
}
@@ -167,7 +167,7 @@ func TestExpand(t *testing.T) {
167167
{10, "*", "a, b"},
168168
{13, "foo.*", "foo.a, foo.b"},
169169
}
170-
actual, err := editQuery(raw, edits)
170+
actual, err := source.Mutate(raw, edits)
171171
if err != nil {
172172
t.Error(err)
173173
}

internal/source/code.go

Lines changed: 50 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,12 @@
11
package source
22

3-
import "unicode"
3+
import (
4+
"bufio"
5+
"fmt"
6+
"sort"
7+
"strings"
8+
"unicode"
9+
)
410

511
type Edit struct {
612
Location int
@@ -43,3 +49,46 @@ func Pluck(source string, location, length int) (string, error) {
4349
tail := location + length
4450
return source[head:tail], nil
4551
}
52+
53+
func Mutate(raw string, a []Edit) (string, error) {
54+
if len(a) == 0 {
55+
return raw, nil
56+
}
57+
sort.Slice(a, func(i, j int) bool { return a[i].Location > a[j].Location })
58+
s := raw
59+
for _, edit := range a {
60+
start := edit.Location
61+
if start > len(s) {
62+
return "", fmt.Errorf("edit start location is out of bounds")
63+
}
64+
if len(edit.New) <= 0 {
65+
return "", fmt.Errorf("empty edit contents")
66+
}
67+
if len(edit.Old) <= 0 {
68+
return "", fmt.Errorf("empty edit contents")
69+
}
70+
stop := edit.Location + len(edit.Old) - 1 // Assumes edit.New is non-empty
71+
if stop < len(s) {
72+
s = s[:start] + edit.New + s[stop+1:]
73+
} else {
74+
s = s[:start] + edit.New
75+
}
76+
}
77+
return s, nil
78+
}
79+
80+
func StripComments(sql string) (string, []string, error) {
81+
s := bufio.NewScanner(strings.NewReader(strings.TrimSpace(sql)))
82+
var lines, comments []string
83+
for s.Scan() {
84+
if strings.HasPrefix(s.Text(), "-- name:") {
85+
continue
86+
}
87+
if strings.HasPrefix(s.Text(), "--") {
88+
comments = append(comments, strings.TrimPrefix(s.Text(), "--"))
89+
continue
90+
}
91+
lines = append(lines, s.Text())
92+
}
93+
return strings.Join(lines, "\n"), comments, s.Err()
94+
}

0 commit comments

Comments
 (0)