Skip to content

Commit 9f21299

Browse files
authored
internal: Remove the PackageMap from settings (#295)
Create a new CombinedSettings struct that includes the global settings and the current package settings.
1 parent 0ae12c5 commit 9f21299

File tree

13 files changed

+114
-160
lines changed

13 files changed

+114
-160
lines changed

internal/cmd/cmd.go

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -129,14 +129,14 @@ var genCmd = &cobra.Command{
129129

130130
for _, pkg := range settings.Packages {
131131
name := pkg.Name
132-
132+
combo := dinosql.Combine(settings, pkg)
133133
var result dinosql.Generateable
134134

135135
switch pkg.Engine {
136136

137137
case dinosql.EngineMySQL:
138138
// Experimental MySQL support
139-
q, err := mysql.GeneratePkg(name, pkg.Schema, pkg.Queries, settings)
139+
q, err := mysql.GeneratePkg(name, pkg.Schema, pkg.Queries, combo)
140140
if err != nil {
141141
fmt.Fprintf(os.Stderr, "# package %s\n", name)
142142
if parserErr, ok := err.(*dinosql.ParserErr); ok {
@@ -183,7 +183,7 @@ var genCmd = &cobra.Command{
183183

184184
}
185185

186-
files, err := dinosql.Generate(result, settings)
186+
files, err := dinosql.Generate(result, combo)
187187
if err != nil {
188188
fmt.Fprintf(os.Stderr, "# package %s\n", name)
189189
fmt.Fprintf(os.Stderr, "error generating code: %s\n", err)

internal/dinosql/config.go

Lines changed: 15 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -28,11 +28,10 @@ The only supported version is "1".
2828
const errMessageNoPackages = `No packages are configured`
2929

3030
type GenerateSettings struct {
31-
Version string `json:"version"`
32-
Packages []PackageSettings `json:"packages"`
33-
Overrides []Override `json:"overrides,omitempty"`
34-
Rename map[string]string `json:"rename,omitempty"`
35-
PackageMap map[string]PackageSettings
31+
Version string `json:"version"`
32+
Packages []PackageSettings `json:"packages"`
33+
Overrides []Override `json:"overrides,omitempty"`
34+
Rename map[string]string `json:"rename,omitempty"`
3635
}
3736

3837
type Engine string
@@ -198,21 +197,19 @@ func ParseConfig(rd io.Reader) (GenerateSettings, error) {
198197
config.Packages[j].Engine = EnginePostgreSQL
199198
}
200199
}
201-
err := config.PopulatePkgMap()
202-
203-
return config, err
200+
return config, nil
204201
}
205202

206-
func (s *GenerateSettings) PopulatePkgMap() error {
207-
packageMap := make(map[string]PackageSettings)
203+
type CombinedSettings struct {
204+
Global GenerateSettings
205+
Package PackageSettings
206+
Overrides []Override
207+
}
208208

209-
for _, c := range s.Packages {
210-
if c.Name == "" {
211-
return ErrNoPackageName
212-
}
213-
packageMap[c.Name] = c
209+
func Combine(gen GenerateSettings, pkg PackageSettings) CombinedSettings {
210+
return CombinedSettings{
211+
Global: gen,
212+
Package: pkg,
213+
Overrides: append(gen.Overrides, pkg.Overrides...),
214214
}
215-
s.PackageMap = packageMap
216-
217-
return nil
218215
}

internal/dinosql/gen.go

Lines changed: 31 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,6 @@ import (
66
"fmt"
77
"go/format"
88
"log"
9-
"path/filepath"
109
"regexp"
1110
"sort"
1211
"strings"
@@ -159,13 +158,12 @@ type GoQuery struct {
159158
}
160159

161160
type Generateable interface {
162-
Structs(settings GenerateSettings) []GoStruct
163-
PkgName() string
164-
GoQueries(settings GenerateSettings) []GoQuery
165-
Enums(settings GenerateSettings) []GoEnum
161+
Structs(settings CombinedSettings) []GoStruct
162+
GoQueries(settings CombinedSettings) []GoQuery
163+
Enums(settings CombinedSettings) []GoEnum
166164
}
167165

168-
func UsesType(r Generateable, typ string, settings GenerateSettings) bool {
166+
func UsesType(r Generateable, typ string, settings CombinedSettings) bool {
169167
for _, strct := range r.Structs(settings) {
170168
for _, f := range strct.Fields {
171169
fType := strings.TrimPrefix(f.Type, "[]")
@@ -177,7 +175,7 @@ func UsesType(r Generateable, typ string, settings GenerateSettings) bool {
177175
return false
178176
}
179177

180-
func UsesArrays(r Generateable, settings GenerateSettings) bool {
178+
func UsesArrays(r Generateable, settings CombinedSettings) bool {
181179
for _, strct := range r.Structs(settings) {
182180
for _, f := range strct.Fields {
183181
if strings.HasPrefix(f.Type, "[]") {
@@ -188,11 +186,11 @@ func UsesArrays(r Generateable, settings GenerateSettings) bool {
188186
return false
189187
}
190188

191-
func Imports(r Generateable, settings GenerateSettings) func(string) [][]string {
189+
func Imports(r Generateable, settings CombinedSettings) func(string) [][]string {
192190
return func(filename string) [][]string {
193191
if filename == "db.go" {
194192
imps := []string{"context", "database/sql"}
195-
if settings.PackageMap[r.PkgName()].EmitPreparedQueries {
193+
if settings.Package.EmitPreparedQueries {
196194
imps = append(imps, "fmt")
197195
}
198196
return [][]string{imps}
@@ -210,7 +208,7 @@ func Imports(r Generateable, settings GenerateSettings) func(string) [][]string
210208
}
211209
}
212210

213-
func InterfaceImports(r Generateable, settings GenerateSettings) [][]string {
211+
func InterfaceImports(r Generateable, settings CombinedSettings) [][]string {
214212
gq := r.GoQueries(settings)
215213
uses := func(name string) bool {
216214
for _, q := range gq {
@@ -246,7 +244,7 @@ func InterfaceImports(r Generateable, settings GenerateSettings) [][]string {
246244

247245
pkg := make(map[string]struct{})
248246
overrideTypes := map[string]string{}
249-
for _, o := range append(settings.Overrides, settings.PackageMap[r.PkgName()].Overrides...) {
247+
for _, o := range settings.Overrides {
250248
if o.goBasicType {
251249
continue
252250
}
@@ -284,7 +282,7 @@ func InterfaceImports(r Generateable, settings GenerateSettings) [][]string {
284282
return [][]string{stds, pkgs}
285283
}
286284

287-
func ModelImports(r Generateable, settings GenerateSettings) [][]string {
285+
func ModelImports(r Generateable, settings CombinedSettings) [][]string {
288286
std := make(map[string]struct{})
289287
if UsesType(r, "sql.Null", settings) {
290288
std["database/sql"] = struct{}{}
@@ -302,7 +300,7 @@ func ModelImports(r Generateable, settings GenerateSettings) [][]string {
302300
// Custom imports
303301
pkg := make(map[string]struct{})
304302
overrideTypes := map[string]string{}
305-
for _, o := range append(settings.Overrides, settings.PackageMap[r.PkgName()].Overrides...) {
303+
for _, o := range settings.Overrides {
306304
if o.goBasicType {
307305
continue
308306
}
@@ -340,7 +338,7 @@ func ModelImports(r Generateable, settings GenerateSettings) [][]string {
340338
return [][]string{stds, pkgs}
341339
}
342340

343-
func QueryImports(r Generateable, settings GenerateSettings, filename string) [][]string {
341+
func QueryImports(r Generateable, settings CombinedSettings, filename string) [][]string {
344342
// for _, strct := range r.Structs() {
345343
// for _, f := range strct.Fields {
346344
// if strings.HasPrefix(f.Type, "[]") {
@@ -437,7 +435,7 @@ func QueryImports(r Generateable, settings GenerateSettings, filename string) []
437435

438436
pkg := make(map[string]struct{})
439437
overrideTypes := map[string]string{}
440-
for _, o := range append(settings.Overrides, settings.PackageMap[r.PkgName()].Overrides...) {
438+
for _, o := range settings.Overrides {
441439
if o.goBasicType {
442440
continue
443441
}
@@ -490,7 +488,7 @@ func enumValueName(value string) string {
490488
return name
491489
}
492490

493-
func (r Result) Enums(settings GenerateSettings) []GoEnum {
491+
func (r Result) Enums(settings CombinedSettings) []GoEnum {
494492
var enums []GoEnum
495493
for name, schema := range r.Catalog.Schemas {
496494
if name == "pg_catalog" {
@@ -523,8 +521,8 @@ func (r Result) Enums(settings GenerateSettings) []GoEnum {
523521
return enums
524522
}
525523

526-
func StructName(name string, settings GenerateSettings) string {
527-
if rename := settings.Rename[name]; rename != "" {
524+
func StructName(name string, settings CombinedSettings) string {
525+
if rename := settings.Global.Rename[name]; rename != "" {
528526
return rename
529527
}
530528
out := ""
@@ -538,7 +536,7 @@ func StructName(name string, settings GenerateSettings) string {
538536
return out
539537
}
540538

541-
func (r Result) Structs(settings GenerateSettings) []GoStruct {
539+
func (r Result) Structs(settings CombinedSettings) []GoStruct {
542540
var structs []GoStruct
543541
for name, schema := range r.Catalog.Schemas {
544542
if name == "pg_catalog" {
@@ -573,9 +571,9 @@ func (r Result) Structs(settings GenerateSettings) []GoStruct {
573571
return structs
574572
}
575573

576-
func (r Result) goType(col core.Column, settings GenerateSettings) string {
574+
func (r Result) goType(col core.Column, settings CombinedSettings) string {
577575
// package overrides have a higher precedence
578-
for _, oride := range append(settings.Overrides, settings.PackageMap[r.PkgName()].Overrides...) {
576+
for _, oride := range settings.Overrides {
579577
if oride.Column != "" && oride.columnName == col.Name && oride.table == col.Table {
580578
return oride.goTypeName
581579
}
@@ -587,12 +585,12 @@ func (r Result) goType(col core.Column, settings GenerateSettings) string {
587585
return typ
588586
}
589587

590-
func (r Result) goInnerType(col core.Column, settings GenerateSettings) string {
588+
func (r Result) goInnerType(col core.Column, settings CombinedSettings) string {
591589
columnType := col.DataType
592590
notNull := col.NotNull || col.IsArray
593591

594592
// package overrides have a higher precedence
595-
for _, oride := range append(settings.Overrides, settings.PackageMap[r.PkgName()].Overrides...) {
593+
for _, oride := range settings.Overrides {
596594
if oride.PostgresType != "" && oride.PostgresType == columnType && oride.Null != notNull {
597595
return oride.goTypeName
598596
}
@@ -728,7 +726,7 @@ func (r Result) goInnerType(col core.Column, settings GenerateSettings) string {
728726
// JSON tags: count, count_2, count_2
729727
//
730728
// This is unlikely to happen, so don't fix it yet
731-
func (r Result) columnsToStruct(name string, columns []core.Column, settings GenerateSettings) *GoStruct {
729+
func (r Result) columnsToStruct(name string, columns []core.Column, settings CombinedSettings) *GoStruct {
732730
gs := GoStruct{
733731
Name: name,
734732
}
@@ -788,7 +786,7 @@ func compareFQN(a *core.FQN, b *core.FQN) bool {
788786
return a.Catalog == b.Catalog && a.Schema == b.Schema && a.Rel == b.Rel
789787
}
790788

791-
func (r Result) GoQueries(settings GenerateSettings) []GoQuery {
789+
func (r Result) GoQueries(settings CombinedSettings) []GoQuery {
792790
structs := r.Structs(settings)
793791

794792
qs := make([]GoQuery, 0, len(r.Queries))
@@ -1185,30 +1183,25 @@ func LowerTitle(s string) string {
11851183
return string(a)
11861184
}
11871185

1188-
func Generate(r Generateable, settings GenerateSettings) (map[string]string, error) {
1186+
func Generate(r Generateable, settings CombinedSettings) (map[string]string, error) {
11891187
funcMap := template.FuncMap{
11901188
"lowerTitle": LowerTitle,
11911189
"imports": Imports(r, settings),
11921190
}
11931191

1194-
pkgName := r.PkgName()
1195-
pkgConfig := settings.PackageMap[pkgName]
1196-
if pkgName == "" {
1197-
pkgName = filepath.Base(pkgConfig.Path)
1198-
}
1199-
12001192
dbFile := template.Must(template.New("table").Funcs(funcMap).Parse(dbTmpl))
12011193
modelsFile := template.Must(template.New("table").Funcs(funcMap).Parse(modelsTmpl))
12021194
sqlFile := template.Must(template.New("table").Funcs(funcMap).Parse(sqlTmpl))
12031195
ifaceFile := template.Must(template.New("table").Funcs(funcMap).Parse(ifaceTmpl))
12041196

1197+
pkg := settings.Package
12051198
tctx := tmplCtx{
1206-
Settings: settings,
1207-
EmitInterface: pkgConfig.EmitInterface,
1208-
EmitJSONTags: pkgConfig.EmitJSONTags,
1209-
EmitPreparedQueries: pkgConfig.EmitPreparedQueries,
1199+
Settings: settings.Global,
1200+
EmitInterface: pkg.EmitInterface,
1201+
EmitJSONTags: pkg.EmitJSONTags,
1202+
EmitPreparedQueries: pkg.EmitPreparedQueries,
12101203
Q: "`",
1211-
Package: pkgName,
1204+
Package: pkg.Name,
12121205
GoQueries: r.GoQueries(settings),
12131206
Enums: r.Enums(settings),
12141207
Structs: r.Structs(settings),
@@ -1243,7 +1236,7 @@ func Generate(r Generateable, settings GenerateSettings) (map[string]string, err
12431236
if err := execute("models.go", modelsFile); err != nil {
12441237
return nil, err
12451238
}
1246-
if pkgConfig.EmitInterface {
1239+
if pkg.EmitInterface {
12471240
if err := execute("querier.go", ifaceFile); err != nil {
12481241
return nil, err
12491242
}

0 commit comments

Comments
 (0)