Skip to content

Commit 6826ce0

Browse files
authored
rewrite: Move parameter rewrite to package (#499)
1 parent 80f0219 commit 6826ce0

File tree

4 files changed

+33
-30
lines changed

4 files changed

+33
-30
lines changed

internal/dinosql/parser.go

Lines changed: 16 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -11,18 +11,20 @@ import (
1111
"strings"
1212
"unicode"
1313

14+
"github.com/davecgh/go-spew/spew"
15+
pg "github.com/lfittl/pg_query_go"
16+
nodes "github.com/lfittl/pg_query_go/nodes"
17+
1418
"github.com/kyleconroy/sqlc/internal/catalog"
1519
"github.com/kyleconroy/sqlc/internal/migrations"
1620
"github.com/kyleconroy/sqlc/internal/multierr"
1721
core "github.com/kyleconroy/sqlc/internal/pg"
1822
"github.com/kyleconroy/sqlc/internal/postgres"
1923
"github.com/kyleconroy/sqlc/internal/postgresql/ast"
24+
"github.com/kyleconroy/sqlc/internal/postgresql/rewrite"
2025
"github.com/kyleconroy/sqlc/internal/postgresql/validate"
26+
"github.com/kyleconroy/sqlc/internal/source"
2127
"github.com/kyleconroy/sqlc/internal/sql/sqlpath"
22-
23-
"github.com/davecgh/go-spew/spew"
24-
pg "github.com/lfittl/pg_query_go"
25-
nodes "github.com/lfittl/pg_query_go/nodes"
2628
)
2729

2830
func keepSpew() {
@@ -344,7 +346,7 @@ func parseQuery(c core.Catalog, stmt nodes.Node, source string, rewriteParameter
344346
}
345347

346348
// Re-write query AST
347-
raw, namedParams, edits := rewriteNamedParameters(raw)
349+
raw, namedParams, edits := rewrite.NamedParameters(raw)
348350
rvs := rangeVars(raw.Stmt)
349351
refs := findParameters(raw.Stmt)
350352
if rewriteParameters {
@@ -403,10 +405,10 @@ func parseQuery(c core.Catalog, stmt nodes.Node, source string, rewriteParameter
403405
}, nil
404406
}
405407

406-
func rewriteNumberedParameters(refs []paramRef, raw nodes.RawStmt, sql string) ([]edit, error) {
407-
edits := make([]edit, len(refs))
408+
func rewriteNumberedParameters(refs []paramRef, raw nodes.RawStmt, sql string) ([]source.Edit, error) {
409+
edits := make([]source.Edit, len(refs))
408410
for i, ref := range refs {
409-
edits[i] = edit{
411+
edits[i] = source.Edit{
410412
Location: ref.ref.Location - raw.StmtLocation,
411413
Old: fmt.Sprintf("$%d", ref.ref.Number),
412414
New: "?",
@@ -431,13 +433,7 @@ func stripComments(sql string) (string, []string, error) {
431433
return strings.Join(lines, "\n"), comments, s.Err()
432434
}
433435

434-
type edit struct {
435-
Location int
436-
Old string
437-
New string
438-
}
439-
440-
func expand(qc *QueryCatalog, raw nodes.RawStmt) ([]edit, error) {
436+
func expand(qc *QueryCatalog, raw nodes.RawStmt) ([]source.Edit, error) {
441437
list := ast.Search(raw, func(node nodes.Node) bool {
442438
switch node.(type) {
443439
case nodes.DeleteStmt:
@@ -452,7 +448,7 @@ func expand(qc *QueryCatalog, raw nodes.RawStmt) ([]edit, error) {
452448
if len(list.Items) == 0 {
453449
return nil, nil
454450
}
455-
var edits []edit
451+
var edits []source.Edit
456452
for _, item := range list.Items {
457453
edit, err := expandStmt(qc, raw, item)
458454
if err != nil {
@@ -470,7 +466,7 @@ func quoteIdent(ident string) string {
470466
return ident
471467
}
472468

473-
func expandStmt(qc *QueryCatalog, raw nodes.RawStmt, node nodes.Node) ([]edit, error) {
469+
func expandStmt(qc *QueryCatalog, raw nodes.RawStmt, node nodes.Node) ([]source.Edit, error) {
474470
tables, err := sourceTables(qc, node)
475471
if err != nil {
476472
return nil, err
@@ -490,7 +486,7 @@ func expandStmt(qc *QueryCatalog, raw nodes.RawStmt, node nodes.Node) ([]edit, e
490486
return nil, fmt.Errorf("outputColumns: unsupported node type: %T", n)
491487
}
492488

493-
var edits []edit
489+
var edits []source.Edit
494490
for _, target := range targets.Items {
495491
res, ok := target.(nodes.ResTarget)
496492
if !ok {
@@ -548,7 +544,7 @@ func expandStmt(qc *QueryCatalog, raw nodes.RawStmt, node nodes.Node) ([]edit, e
548544
for _, p := range parts {
549545
old = append(old, quoteIdent(p))
550546
}
551-
edits = append(edits, edit{
547+
edits = append(edits, source.Edit{
552548
Location: res.Location - raw.StmtLocation,
553549
Old: strings.Join(old, "."),
554550
New: strings.Join(cols, ", "),
@@ -557,7 +553,7 @@ func expandStmt(qc *QueryCatalog, raw nodes.RawStmt, node nodes.Node) ([]edit, e
557553
return edits, nil
558554
}
559555

560-
func editQuery(raw string, a []edit) (string, error) {
556+
func editQuery(raw string, a []source.Edit) (string, error) {
561557
if len(a) == 0 {
562558
return raw, nil
563559
}

internal/dinosql/parser_test.go

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -6,12 +6,12 @@ import (
66
"path"
77
"testing"
88

9-
"github.com/kyleconroy/sqlc/internal/source"
10-
"github.com/kyleconroy/sqlc/internal/sql/sqlpath"
11-
129
"github.com/google/go-cmp/cmp"
1310
pg "github.com/lfittl/pg_query_go"
1411
nodes "github.com/lfittl/pg_query_go/nodes"
12+
13+
"github.com/kyleconroy/sqlc/internal/source"
14+
"github.com/kyleconroy/sqlc/internal/sql/sqlpath"
1515
)
1616

1717
const pluck = `
@@ -177,7 +177,7 @@ func TestExpand(t *testing.T) {
177177
// pretend that foo has two columns, a and b
178178
raw := `SELECT *, *, foo.* FROM foo`
179179
expected := `SELECT a, b, a, b, foo.a, foo.b FROM foo`
180-
edits := []edit{
180+
edits := []source.Edit{
181181
{7, "*", "a, b"},
182182
{10, "*", "a, b"},
183183
{13, "foo.*", "foo.a, foo.b"},

internal/dinosql/rewrite.go renamed to internal/postgresql/rewrite/parameters.go

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
package dinosql
1+
package rewrite
22

33
import (
44
"fmt"
@@ -7,6 +7,7 @@ import (
77

88
"github.com/kyleconroy/sqlc/internal/postgresql"
99
"github.com/kyleconroy/sqlc/internal/postgresql/ast"
10+
"github.com/kyleconroy/sqlc/internal/source"
1011
)
1112

1213
// Given an AST node, return the string representation of names
@@ -40,7 +41,7 @@ func isNamedParamSignCast(node nodes.Node) bool {
4041
return ast.Join(expr.Name, ".") == "@" && cast
4142
}
4243

43-
func rewriteNamedParameters(raw nodes.RawStmt) (nodes.RawStmt, map[int]string, []edit) {
44+
func NamedParameters(raw nodes.RawStmt) (nodes.RawStmt, map[int]string, []source.Edit) {
4445
foundFunc := ast.Search(raw, postgresql.IsNamedParamFunc)
4546
foundSign := ast.Search(raw, postgresql.IsNamedParamSign)
4647
if len(foundFunc.Items)+len(foundSign.Items) == 0 {
@@ -49,7 +50,7 @@ func rewriteNamedParameters(raw nodes.RawStmt) (nodes.RawStmt, map[int]string, [
4950

5051
args := map[string]int{}
5152
argn := 0
52-
var edits []edit
53+
var edits []source.Edit
5354
node := ast.Apply(raw, func(cr *ast.Cursor) bool {
5455
node := cr.Node()
5556
switch {
@@ -77,7 +78,7 @@ func rewriteNamedParameters(raw nodes.RawStmt) (nodes.RawStmt, map[int]string, [
7778
} else {
7879
old = fmt.Sprintf("sqlc.arg(%s)", param)
7980
}
80-
edits = append(edits, edit{
81+
edits = append(edits, source.Edit{
8182
Location: fun.Location - raw.StmtLocation,
8283
Old: old,
8384
New: fmt.Sprintf("$%d", args[param]),
@@ -104,7 +105,7 @@ func rewriteNamedParameters(raw nodes.RawStmt) (nodes.RawStmt, map[int]string, [
104105
cr.Replace(cast)
105106
}
106107
// TODO: This code assumes that @foo::bool is on a single line
107-
edits = append(edits, edit{
108+
edits = append(edits, source.Edit{
108109
Location: expr.Location - raw.StmtLocation,
109110
Old: fmt.Sprintf("@%s", param),
110111
New: fmt.Sprintf("$%d", args[param]),
@@ -128,7 +129,7 @@ func rewriteNamedParameters(raw nodes.RawStmt) (nodes.RawStmt, map[int]string, [
128129
})
129130
}
130131
// TODO: This code assumes that @foo is on a single line
131-
edits = append(edits, edit{
132+
edits = append(edits, source.Edit{
132133
Location: expr.Location - raw.StmtLocation,
133134
Old: fmt.Sprintf("@%s", param),
134135
New: fmt.Sprintf("$%d", args[param]),

internal/source/code.go

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,12 @@ package source
22

33
import "unicode"
44

5+
type Edit struct {
6+
Location int
7+
Old string
8+
New string
9+
}
10+
511
func LineNumber(source string, head int) (int, int) {
612
// Calculate the true line and column number for a query, ignoring spaces
713
var comment bool

0 commit comments

Comments
 (0)