Skip to content

Commit 0ae12c5

Browse files
authored
internal/dinosql: Import needed types for Querier (#285)
Before this change, sqlc would generate a Querier interface without importing the correct types. This is now fixed.
1 parent eadf951 commit 0ae12c5

File tree

3 files changed

+118
-17
lines changed

3 files changed

+118
-17
lines changed

examples/ondeck/db.go

Lines changed: 0 additions & 15 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

examples/ondeck/querier.go

Lines changed: 22 additions & 0 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

internal/dinosql/gen.go

Lines changed: 96 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -202,10 +202,88 @@ func Imports(r Generateable, settings GenerateSettings) func(string) [][]string
202202
return ModelImports(r, settings)
203203
}
204204

205+
if filename == "querier.go" {
206+
return InterfaceImports(r, settings)
207+
}
208+
205209
return QueryImports(r, settings, filename)
206210
}
207211
}
208212

213+
func InterfaceImports(r Generateable, settings GenerateSettings) [][]string {
214+
gq := r.GoQueries(settings)
215+
uses := func(name string) bool {
216+
for _, q := range gq {
217+
if !q.Ret.isEmpty() {
218+
if strings.HasPrefix(q.Ret.Type(), name) {
219+
return true
220+
}
221+
}
222+
if !q.Arg.isEmpty() {
223+
if strings.HasPrefix(q.Arg.Type(), name) {
224+
return true
225+
}
226+
}
227+
}
228+
return false
229+
}
230+
231+
std := map[string]struct{}{
232+
"context": struct{}{},
233+
}
234+
if uses("sql.Null") {
235+
std["database/sql"] = struct{}{}
236+
}
237+
if uses("json.RawMessage") {
238+
std["encoding/json"] = struct{}{}
239+
}
240+
if uses("time.Time") {
241+
std["time"] = struct{}{}
242+
}
243+
if uses("net.IP") {
244+
std["net"] = struct{}{}
245+
}
246+
247+
pkg := make(map[string]struct{})
248+
overrideTypes := map[string]string{}
249+
for _, o := range append(settings.Overrides, settings.PackageMap[r.PkgName()].Overrides...) {
250+
if o.goBasicType {
251+
continue
252+
}
253+
overrideTypes[o.goTypeName] = o.goPackage
254+
}
255+
256+
_, overrideNullTime := overrideTypes["pq.NullTime"]
257+
if uses("pq.NullTime") && !overrideNullTime {
258+
pkg["github.com/lib/pq"] = struct{}{}
259+
}
260+
_, overrideUUID := overrideTypes["uuid.UUID"]
261+
if uses("uuid.UUID") && !overrideUUID {
262+
pkg["github.com/google/uuid"] = struct{}{}
263+
}
264+
265+
// Custom imports
266+
for goType, importPath := range overrideTypes {
267+
if _, ok := std[importPath]; !ok && uses(goType) {
268+
pkg[importPath] = struct{}{}
269+
}
270+
}
271+
272+
pkgs := make([]string, 0, len(pkg))
273+
for p, _ := range pkg {
274+
pkgs = append(pkgs, p)
275+
}
276+
277+
stds := make([]string, 0, len(std))
278+
for s, _ := range std {
279+
stds = append(stds, s)
280+
}
281+
282+
sort.Strings(stds)
283+
sort.Strings(pkgs)
284+
return [][]string{stds, pkgs}
285+
}
286+
209287
func ModelImports(r Generateable, settings GenerateSettings) [][]string {
210288
std := make(map[string]struct{})
211289
if UsesType(r, "sql.Null", settings) {
@@ -903,8 +981,19 @@ func (q *Queries) WithTx(tx *sql.Tx) *Queries {
903981
{{- end}}
904982
}
905983
}
984+
`
985+
986+
var ifaceTmpl = `// Code generated by sqlc. DO NOT EDIT.
987+
988+
package {{.Package}}
989+
990+
import (
991+
{{range imports .SourceName}}
992+
{{range .}}"{{.}}"
993+
{{end}}
994+
{{end}}
995+
)
906996
907-
{{if .EmitInterface }}
908997
type Querier interface {
909998
{{- range .GoQueries}}
910999
{{- if eq .Cmd ":one"}}
@@ -923,7 +1012,6 @@ type Querier interface {
9231012
}
9241013
9251014
var _ Querier = (*Queries)(nil)
926-
{{end}}
9271015
`
9281016

9291017
var modelsTmpl = `// Code generated by sqlc. DO NOT EDIT.
@@ -1112,6 +1200,7 @@ func Generate(r Generateable, settings GenerateSettings) (map[string]string, err
11121200
dbFile := template.Must(template.New("table").Funcs(funcMap).Parse(dbTmpl))
11131201
modelsFile := template.Must(template.New("table").Funcs(funcMap).Parse(modelsTmpl))
11141202
sqlFile := template.Must(template.New("table").Funcs(funcMap).Parse(sqlTmpl))
1203+
ifaceFile := template.Must(template.New("table").Funcs(funcMap).Parse(ifaceTmpl))
11151204

11161205
tctx := tmplCtx{
11171206
Settings: settings,
@@ -1154,6 +1243,11 @@ func Generate(r Generateable, settings GenerateSettings) (map[string]string, err
11541243
if err := execute("models.go", modelsFile); err != nil {
11551244
return nil, err
11561245
}
1246+
if pkgConfig.EmitInterface {
1247+
if err := execute("querier.go", ifaceFile); err != nil {
1248+
return nil, err
1249+
}
1250+
}
11571251

11581252
files := map[string]struct{}{}
11591253
for _, gq := range r.GoQueries(settings) {

0 commit comments

Comments
 (0)