@@ -11,18 +11,20 @@ import (
11
11
"strings"
12
12
"unicode"
13
13
14
+ "github.com/davecgh/go-spew/spew"
15
+ pg "github.com/lfittl/pg_query_go"
16
+ nodes "github.com/lfittl/pg_query_go/nodes"
17
+
14
18
"github.com/kyleconroy/sqlc/internal/catalog"
15
19
"github.com/kyleconroy/sqlc/internal/migrations"
16
20
"github.com/kyleconroy/sqlc/internal/multierr"
17
21
core "github.com/kyleconroy/sqlc/internal/pg"
18
22
"github.com/kyleconroy/sqlc/internal/postgres"
19
23
"github.com/kyleconroy/sqlc/internal/postgresql/ast"
24
+ "github.com/kyleconroy/sqlc/internal/postgresql/rewrite"
20
25
"github.com/kyleconroy/sqlc/internal/postgresql/validate"
26
+ "github.com/kyleconroy/sqlc/internal/source"
21
27
"github.com/kyleconroy/sqlc/internal/sql/sqlpath"
22
-
23
- "github.com/davecgh/go-spew/spew"
24
- pg "github.com/lfittl/pg_query_go"
25
- nodes "github.com/lfittl/pg_query_go/nodes"
26
28
)
27
29
28
30
func keepSpew () {
@@ -344,7 +346,7 @@ func parseQuery(c core.Catalog, stmt nodes.Node, source string, rewriteParameter
344
346
}
345
347
346
348
// Re-write query AST
347
- raw , namedParams , edits := rewriteNamedParameters (raw )
349
+ raw , namedParams , edits := rewrite . NamedParameters (raw )
348
350
rvs := rangeVars (raw .Stmt )
349
351
refs := findParameters (raw .Stmt )
350
352
if rewriteParameters {
@@ -403,10 +405,10 @@ func parseQuery(c core.Catalog, stmt nodes.Node, source string, rewriteParameter
403
405
}, nil
404
406
}
405
407
406
- func rewriteNumberedParameters (refs []paramRef , raw nodes.RawStmt , sql string ) ([]edit , error ) {
407
- edits := make ([]edit , len (refs ))
408
+ func rewriteNumberedParameters (refs []paramRef , raw nodes.RawStmt , sql string ) ([]source. Edit , error ) {
409
+ edits := make ([]source. Edit , len (refs ))
408
410
for i , ref := range refs {
409
- edits [i ] = edit {
411
+ edits [i ] = source. Edit {
410
412
Location : ref .ref .Location - raw .StmtLocation ,
411
413
Old : fmt .Sprintf ("$%d" , ref .ref .Number ),
412
414
New : "?" ,
@@ -431,13 +433,7 @@ func stripComments(sql string) (string, []string, error) {
431
433
return strings .Join (lines , "\n " ), comments , s .Err ()
432
434
}
433
435
434
- type edit struct {
435
- Location int
436
- Old string
437
- New string
438
- }
439
-
440
- func expand (qc * QueryCatalog , raw nodes.RawStmt ) ([]edit , error ) {
436
+ func expand (qc * QueryCatalog , raw nodes.RawStmt ) ([]source.Edit , error ) {
441
437
list := ast .Search (raw , func (node nodes.Node ) bool {
442
438
switch node .(type ) {
443
439
case nodes.DeleteStmt :
@@ -452,7 +448,7 @@ func expand(qc *QueryCatalog, raw nodes.RawStmt) ([]edit, error) {
452
448
if len (list .Items ) == 0 {
453
449
return nil , nil
454
450
}
455
- var edits []edit
451
+ var edits []source. Edit
456
452
for _ , item := range list .Items {
457
453
edit , err := expandStmt (qc , raw , item )
458
454
if err != nil {
@@ -470,7 +466,7 @@ func quoteIdent(ident string) string {
470
466
return ident
471
467
}
472
468
473
- func expandStmt (qc * QueryCatalog , raw nodes.RawStmt , node nodes.Node ) ([]edit , error ) {
469
+ func expandStmt (qc * QueryCatalog , raw nodes.RawStmt , node nodes.Node ) ([]source. Edit , error ) {
474
470
tables , err := sourceTables (qc , node )
475
471
if err != nil {
476
472
return nil , err
@@ -490,7 +486,7 @@ func expandStmt(qc *QueryCatalog, raw nodes.RawStmt, node nodes.Node) ([]edit, e
490
486
return nil , fmt .Errorf ("outputColumns: unsupported node type: %T" , n )
491
487
}
492
488
493
- var edits []edit
489
+ var edits []source. Edit
494
490
for _ , target := range targets .Items {
495
491
res , ok := target .(nodes.ResTarget )
496
492
if ! ok {
@@ -548,7 +544,7 @@ func expandStmt(qc *QueryCatalog, raw nodes.RawStmt, node nodes.Node) ([]edit, e
548
544
for _ , p := range parts {
549
545
old = append (old , quoteIdent (p ))
550
546
}
551
- edits = append (edits , edit {
547
+ edits = append (edits , source. Edit {
552
548
Location : res .Location - raw .StmtLocation ,
553
549
Old : strings .Join (old , "." ),
554
550
New : strings .Join (cols , ", " ),
@@ -557,7 +553,7 @@ func expandStmt(qc *QueryCatalog, raw nodes.RawStmt, node nodes.Node) ([]edit, e
557
553
return edits , nil
558
554
}
559
555
560
- func editQuery (raw string , a []edit ) (string , error ) {
556
+ func editQuery (raw string , a []source. Edit ) (string , error ) {
561
557
if len (a ) == 0 {
562
558
return raw , nil
563
559
}
0 commit comments