Skip to content

Commit 385d61f

Browse files
mrnuggetkyleconroy
authored andcommitted
Avoid duplicate import statements (#164)
Before this change it was possible to generate duplicate import statements. Take this scheme: { "version": "1", "packages": [{ "schema": "schema.sql", "queries": "queries.sql", "name": "queries", "path": "." }], "overrides": [ { "column": "toy.owner_id", "go_type": "database/sql.NullInt64" }, { "column": "toy.manufacturer_id", "go_type": "database/sql.NullInt64" } ] } That would generate a `models.go` file with two imports for `database/sql`. What the change here does is to make all imports, in the queries and the models file, unique.
1 parent ada53d9 commit 385d61f

File tree

1 file changed

+45
-23
lines changed

1 file changed

+45
-23
lines changed

internal/dinosql/gen.go

Lines changed: 45 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -195,44 +195,54 @@ func (r Result) Imports(settings PackageSettings) func(string) [][]string {
195195
}
196196

197197
func (r Result) ModelImports() [][]string {
198-
var std []string
198+
std := make(map[string]struct{})
199199
if r.UsesType("sql.Null") {
200-
std = append(std, "database/sql")
200+
std["database/sql"] = struct{}{}
201201
}
202202
if r.UsesType("json.RawMessage") {
203-
std = append(std, "encoding/json")
203+
std["encoding/json"] = struct{}{}
204204
}
205205
if r.UsesType("time.Time") {
206-
std = append(std, "time")
206+
std["time"] = struct{}{}
207207
}
208208
if r.UsesType("net.IP") {
209-
std = append(std, "net")
209+
std["net"] = struct{}{}
210210
}
211211

212212
// Custom imports
213-
var pkg []string
213+
pkg := make(map[string]struct{})
214214
overrideTypes := map[string]string{}
215215
for _, o := range append(r.Settings.Overrides, r.packageSettings.Overrides...) {
216216
overrideTypes[o.goTypeName] = o.goPackage
217217
}
218218

219219
_, overrideNullTime := overrideTypes["pq.NullTime"]
220220
if r.UsesType("pq.NullTime") && !overrideNullTime {
221-
pkg = append(pkg, "github.com/lib/pq")
221+
pkg["github.com/lib/pq"] = struct{}{}
222222
}
223223

224224
_, overrideUUID := overrideTypes["uuid.UUID"]
225225
if r.UsesType("uuid.UUID") && !overrideUUID {
226-
pkg = append(pkg, "github.com/google/uuid")
226+
pkg["github.com/google/uuid"] = struct{}{}
227227
}
228228

229229
for goType, importPath := range overrideTypes {
230-
if r.UsesType(goType) {
231-
pkg = append(pkg, importPath)
230+
if _, ok := std[importPath]; !ok && r.UsesType(goType) {
231+
pkg[importPath] = struct{}{}
232232
}
233233
}
234234

235-
return [][]string{std, pkg}
235+
pkgs := make([]string, 0, len(pkg))
236+
for p, _ := range pkg {
237+
pkgs = append(pkgs, p)
238+
}
239+
240+
stds := make([]string, 0, len(std))
241+
for s, _ := range std {
242+
stds = append(stds, s)
243+
}
244+
245+
return [][]string{stds, pkgs}
236246
}
237247

238248
func (r Result) QueryImports(filename string) [][]string {
@@ -314,46 +324,58 @@ func (r Result) QueryImports(filename string) [][]string {
314324
return false
315325
}
316326

317-
std := []string{"context"}
327+
std := map[string]struct{}{
328+
"context": struct{}{},
329+
}
318330
if uses("sql.Null") {
319-
std = append(std, "database/sql")
331+
std["database/sql"] = struct{}{}
320332
}
321333
if uses("json.RawMessage") {
322-
std = append(std, "encoding/json")
334+
std["encoding/json"] = struct{}{}
323335
}
324336
if uses("time.Time") {
325-
std = append(std, "time")
337+
std["time"] = struct{}{}
326338
}
327339
if uses("net.IP") {
328-
std = append(std, "net")
340+
std["net"] = struct{}{}
329341
}
330342

331-
var pkg []string
343+
pkg := make(map[string]struct{})
332344
overrideTypes := map[string]string{}
333345
for _, o := range append(r.Settings.Overrides, r.packageSettings.Overrides...) {
334346
overrideTypes[o.goTypeName] = o.goPackage
335347
}
336348

337349
if sliceScan() {
338-
pkg = append(pkg, "github.com/lib/pq")
350+
pkg["github.com/lib/pq"] = struct{}{}
339351
}
340352
_, overrideNullTime := overrideTypes["pq.NullTime"]
341353
if uses("pq.NullTime") && !overrideNullTime {
342-
pkg = append(pkg, "github.com/lib/pq")
354+
pkg["github.com/lib/pq"] = struct{}{}
343355
}
344356
_, overrideUUID := overrideTypes["uuid.UUID"]
345357
if uses("uuid.UUID") && !overrideUUID {
346-
pkg = append(pkg, "github.com/google/uuid")
358+
pkg["github.com/google/uuid"] = struct{}{}
347359
}
348360

349361
// Custom imports
350362
for goType, importPath := range overrideTypes {
351-
if uses(goType) {
352-
pkg = append(pkg, importPath)
363+
if _, ok := std[importPath]; !ok && uses(goType) {
364+
pkg[importPath] = struct{}{}
353365
}
354366
}
355367

356-
return [][]string{std, pkg}
368+
pkgs := make([]string, 0, len(pkg))
369+
for p, _ := range pkg {
370+
pkgs = append(pkgs, p)
371+
}
372+
373+
stds := make([]string, 0, len(std))
374+
for s, _ := range std {
375+
stds = append(stds, s)
376+
}
377+
378+
return [][]string{stds, pkgs}
357379
}
358380

359381
func (r Result) Enums() []GoEnum {

0 commit comments

Comments
 (0)