Skip to content

Commit 5d28068

Browse files
authored
sqlpath: Move ReadSQLFiles into a separate package (#495)
1 parent 9761692 commit 5d28068

File tree

6 files changed

+105
-85
lines changed

6 files changed

+105
-85
lines changed

internal/dinosql/parser.go

Lines changed: 3 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,6 @@ import (
55
"errors"
66
"fmt"
77
"io/ioutil"
8-
"os"
98
"path/filepath"
109
"sort"
1110
"strconv"
@@ -17,6 +16,7 @@ import (
1716
core "github.com/kyleconroy/sqlc/internal/pg"
1817
"github.com/kyleconroy/sqlc/internal/postgres"
1918
"github.com/kyleconroy/sqlc/internal/postgresql/ast"
19+
"github.com/kyleconroy/sqlc/internal/sql/sqlpath"
2020

2121
"github.com/davecgh/go-spew/spew"
2222
pg "github.com/lfittl/pg_query_go"
@@ -60,43 +60,8 @@ func (e *ParserErr) Error() string {
6060
return fmt.Sprintf("multiple errors: %d errors", len(e.Errs))
6161
}
6262

63-
func ReadSQLFiles(paths []string) ([]string, error) {
64-
var files []string
65-
for _, path := range paths {
66-
f, err := os.Stat(path)
67-
if err != nil {
68-
return nil, fmt.Errorf("path %s does not exist", path)
69-
}
70-
if f.IsDir() {
71-
listing, err := ioutil.ReadDir(path)
72-
if err != nil {
73-
return nil, err
74-
}
75-
for _, f := range listing {
76-
files = append(files, filepath.Join(path, f.Name()))
77-
}
78-
} else {
79-
files = append(files, path)
80-
}
81-
}
82-
var sqlFiles []string
83-
for _, file := range files {
84-
if !strings.HasSuffix(file, ".sql") {
85-
continue
86-
}
87-
if strings.HasPrefix(filepath.Base(file), ".") {
88-
continue
89-
}
90-
if migrations.IsDown(filepath.Base(file)) {
91-
continue
92-
}
93-
sqlFiles = append(sqlFiles, file)
94-
}
95-
return sqlFiles, nil
96-
}
97-
9863
func ParseCatalog(schemas []string) (core.Catalog, error) {
99-
files, err := ReadSQLFiles(schemas)
64+
files, err := sqlpath.Glob(schemas)
10065
if err != nil {
10166
return core.Catalog{}, err
10267
}
@@ -202,7 +167,7 @@ func ParseQueries(c core.Catalog, queriesPaths []string, opts ParserOpts) (*Resu
202167
var q []*Query
203168

204169
set := map[string]struct{}{}
205-
files, err := ReadSQLFiles(queriesPaths)
170+
files, err := sqlpath.Glob(queriesPaths)
206171
if err != nil {
207172
return nil, err
208173
}

internal/dinosql/parser_test.go

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,8 @@ import (
66
"path"
77
"testing"
88

9+
"github.com/kyleconroy/sqlc/internal/sql/sqlpath"
10+
911
"github.com/google/go-cmp/cmp"
1012
pg "github.com/lfittl/pg_query_go"
1113
nodes "github.com/lfittl/pg_query_go/nodes"
@@ -235,7 +237,7 @@ func TestReadFiles(t *testing.T) {
235237
path.Join(subdir2, "include-me.up.sql"),
236238
}
237239

238-
filesRead, err := ReadSQLFiles(input)
240+
filesRead, err := sqlpath.Glob(input)
239241
if err != nil {
240242
t.Error(err)
241243
}

internal/mysql/parse.go

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@ import (
1212
"github.com/kyleconroy/sqlc/internal/config"
1313
"github.com/kyleconroy/sqlc/internal/dinosql"
1414
"github.com/kyleconroy/sqlc/internal/migrations"
15+
"github.com/kyleconroy/sqlc/internal/sql/sqlpath"
1516
)
1617

1718
// Query holds the data for walking and validating mysql querys
@@ -32,7 +33,7 @@ type Column struct {
3233
}
3334

3435
func parsePath(sqlPath []string, generator PackageGenerator) (*Result, error) {
35-
files, err := dinosql.ReadSQLFiles(sqlPath)
36+
files, err := sqlpath.Glob(sqlPath)
3637
if err != nil {
3738
return nil, err
3839
}

internal/sql/catalog/catalog.go

Lines changed: 47 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -217,51 +217,55 @@ func New(def string) *Catalog {
217217

218218
func (c *Catalog) Build(stmts []ast.Statement) error {
219219
for i := range stmts {
220-
if stmts[i].Raw == nil {
221-
continue
222-
}
223-
var err error
224-
switch n := stmts[i].Raw.Stmt.(type) {
225-
case *ast.AlterTableStmt:
226-
err = c.alterTable(n)
227-
case *ast.AlterTableSetSchemaStmt:
228-
err = c.alterTableSetSchema(n)
229-
case *ast.AlterTypeAddValueStmt:
230-
err = c.alterTypeAddValue(n)
231-
case *ast.AlterTypeRenameValueStmt:
232-
err = c.alterTypeRenameValue(n)
233-
case *ast.CommentOnColumnStmt:
234-
err = c.commentOnColumn(n)
235-
case *ast.CommentOnSchemaStmt:
236-
err = c.commentOnSchema(n)
237-
case *ast.CommentOnTableStmt:
238-
err = c.commentOnTable(n)
239-
case *ast.CommentOnTypeStmt:
240-
err = c.commentOnType(n)
241-
case *ast.CreateEnumStmt:
242-
err = c.createEnum(n)
243-
case *ast.CreateFunctionStmt:
244-
err = c.createFunction(n)
245-
case *ast.CreateSchemaStmt:
246-
err = c.createSchema(n)
247-
case *ast.CreateTableStmt:
248-
err = c.createTable(n)
249-
case *ast.DropFunctionStmt:
250-
err = c.dropFunction(n)
251-
case *ast.DropSchemaStmt:
252-
err = c.dropSchema(n)
253-
case *ast.DropTableStmt:
254-
err = c.dropTable(n)
255-
case *ast.DropTypeStmt:
256-
err = c.dropType(n)
257-
case *ast.RenameColumnStmt:
258-
err = c.renameColumn(n)
259-
case *ast.RenameTableStmt:
260-
err = c.renameTable(n)
261-
}
262-
if err != nil {
220+
if err := c.Update(stmts[i]); err != nil {
263221
return err
264222
}
265223
}
266224
return nil
267225
}
226+
227+
func (c *Catalog) Update(stmt ast.Statement) error {
228+
if stmt.Raw == nil {
229+
return nil
230+
}
231+
var err error
232+
switch n := stmt.Raw.Stmt.(type) {
233+
case *ast.AlterTableStmt:
234+
err = c.alterTable(n)
235+
case *ast.AlterTableSetSchemaStmt:
236+
err = c.alterTableSetSchema(n)
237+
case *ast.AlterTypeAddValueStmt:
238+
err = c.alterTypeAddValue(n)
239+
case *ast.AlterTypeRenameValueStmt:
240+
err = c.alterTypeRenameValue(n)
241+
case *ast.CommentOnColumnStmt:
242+
err = c.commentOnColumn(n)
243+
case *ast.CommentOnSchemaStmt:
244+
err = c.commentOnSchema(n)
245+
case *ast.CommentOnTableStmt:
246+
err = c.commentOnTable(n)
247+
case *ast.CommentOnTypeStmt:
248+
err = c.commentOnType(n)
249+
case *ast.CreateEnumStmt:
250+
err = c.createEnum(n)
251+
case *ast.CreateFunctionStmt:
252+
err = c.createFunction(n)
253+
case *ast.CreateSchemaStmt:
254+
err = c.createSchema(n)
255+
case *ast.CreateTableStmt:
256+
err = c.createTable(n)
257+
case *ast.DropFunctionStmt:
258+
err = c.dropFunction(n)
259+
case *ast.DropSchemaStmt:
260+
err = c.dropSchema(n)
261+
case *ast.DropTableStmt:
262+
err = c.dropTable(n)
263+
case *ast.DropTypeStmt:
264+
err = c.dropType(n)
265+
case *ast.RenameColumnStmt:
266+
err = c.renameColumn(n)
267+
case *ast.RenameTableStmt:
268+
err = c.renameTable(n)
269+
}
270+
return err
271+
}

internal/sql/sqlpath/read.go

Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,48 @@
1+
package sqlpath
2+
3+
import (
4+
"fmt"
5+
"io/ioutil"
6+
"os"
7+
"path/filepath"
8+
"strings"
9+
10+
"github.com/kyleconroy/sqlc/internal/migrations"
11+
)
12+
13+
// Return a list of SQL files in the listed paths. Only includes files ending
14+
// in .sql. Omits hidden files, directories, and migrations.
15+
func Glob(paths []string) ([]string, error) {
16+
var files []string
17+
for _, path := range paths {
18+
f, err := os.Stat(path)
19+
if err != nil {
20+
return nil, fmt.Errorf("path %s does not exist", path)
21+
}
22+
if f.IsDir() {
23+
listing, err := ioutil.ReadDir(path)
24+
if err != nil {
25+
return nil, err
26+
}
27+
for _, f := range listing {
28+
files = append(files, filepath.Join(path, f.Name()))
29+
}
30+
} else {
31+
files = append(files, path)
32+
}
33+
}
34+
var sqlFiles []string
35+
for _, file := range files {
36+
if !strings.HasSuffix(file, ".sql") {
37+
continue
38+
}
39+
if strings.HasPrefix(filepath.Base(file), ".") {
40+
continue
41+
}
42+
if migrations.IsDown(filepath.Base(file)) {
43+
continue
44+
}
45+
sqlFiles = append(sqlFiles, file)
46+
}
47+
return sqlFiles, nil
48+
}

internal/sqltest/postgres.go

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@ import (
1010
"testing"
1111
"time"
1212

13-
"github.com/kyleconroy/sqlc/internal/dinosql"
13+
"github.com/kyleconroy/sqlc/internal/sql/sqlpath"
1414

1515
_ "github.com/lib/pq"
1616
)
@@ -79,7 +79,7 @@ func PostgreSQL(t *testing.T, migrations []string) (*sql.DB, func()) {
7979
t.Fatal(err)
8080
}
8181

82-
files, err := dinosql.ReadSQLFiles(migrations)
82+
files, err := sqlpath.Glob(migrations)
8383
if err != nil {
8484
t.Fatal(err)
8585
}

0 commit comments

Comments
 (0)