@@ -202,10 +202,88 @@ func Imports(r Generateable, settings GenerateSettings) func(string) [][]string
202
202
return ModelImports (r , settings )
203
203
}
204
204
205
+ if filename == "querier.go" {
206
+ return InterfaceImports (r , settings )
207
+ }
208
+
205
209
return QueryImports (r , settings , filename )
206
210
}
207
211
}
208
212
213
+ func InterfaceImports (r Generateable , settings GenerateSettings ) [][]string {
214
+ gq := r .GoQueries (settings )
215
+ uses := func (name string ) bool {
216
+ for _ , q := range gq {
217
+ if ! q .Ret .isEmpty () {
218
+ if strings .HasPrefix (q .Ret .Type (), name ) {
219
+ return true
220
+ }
221
+ }
222
+ if ! q .Arg .isEmpty () {
223
+ if strings .HasPrefix (q .Arg .Type (), name ) {
224
+ return true
225
+ }
226
+ }
227
+ }
228
+ return false
229
+ }
230
+
231
+ std := map [string ]struct {}{
232
+ "context" : struct {}{},
233
+ }
234
+ if uses ("sql.Null" ) {
235
+ std ["database/sql" ] = struct {}{}
236
+ }
237
+ if uses ("json.RawMessage" ) {
238
+ std ["encoding/json" ] = struct {}{}
239
+ }
240
+ if uses ("time.Time" ) {
241
+ std ["time" ] = struct {}{}
242
+ }
243
+ if uses ("net.IP" ) {
244
+ std ["net" ] = struct {}{}
245
+ }
246
+
247
+ pkg := make (map [string ]struct {})
248
+ overrideTypes := map [string ]string {}
249
+ for _ , o := range append (settings .Overrides , settings .PackageMap [r .PkgName ()].Overrides ... ) {
250
+ if o .goBasicType {
251
+ continue
252
+ }
253
+ overrideTypes [o .goTypeName ] = o .goPackage
254
+ }
255
+
256
+ _ , overrideNullTime := overrideTypes ["pq.NullTime" ]
257
+ if uses ("pq.NullTime" ) && ! overrideNullTime {
258
+ pkg ["github.com/lib/pq" ] = struct {}{}
259
+ }
260
+ _ , overrideUUID := overrideTypes ["uuid.UUID" ]
261
+ if uses ("uuid.UUID" ) && ! overrideUUID {
262
+ pkg ["github.com/google/uuid" ] = struct {}{}
263
+ }
264
+
265
+ // Custom imports
266
+ for goType , importPath := range overrideTypes {
267
+ if _ , ok := std [importPath ]; ! ok && uses (goType ) {
268
+ pkg [importPath ] = struct {}{}
269
+ }
270
+ }
271
+
272
+ pkgs := make ([]string , 0 , len (pkg ))
273
+ for p , _ := range pkg {
274
+ pkgs = append (pkgs , p )
275
+ }
276
+
277
+ stds := make ([]string , 0 , len (std ))
278
+ for s , _ := range std {
279
+ stds = append (stds , s )
280
+ }
281
+
282
+ sort .Strings (stds )
283
+ sort .Strings (pkgs )
284
+ return [][]string {stds , pkgs }
285
+ }
286
+
209
287
func ModelImports (r Generateable , settings GenerateSettings ) [][]string {
210
288
std := make (map [string ]struct {})
211
289
if UsesType (r , "sql.Null" , settings ) {
@@ -903,8 +981,19 @@ func (q *Queries) WithTx(tx *sql.Tx) *Queries {
903
981
{{- end}}
904
982
}
905
983
}
984
+ `
985
+
986
+ var ifaceTmpl = `// Code generated by sqlc. DO NOT EDIT.
987
+
988
+ package {{.Package}}
989
+
990
+ import (
991
+ {{range imports .SourceName}}
992
+ {{range .}}"{{.}}"
993
+ {{end}}
994
+ {{end}}
995
+ )
906
996
907
- {{if .EmitInterface }}
908
997
type Querier interface {
909
998
{{- range .GoQueries}}
910
999
{{- if eq .Cmd ":one"}}
@@ -923,7 +1012,6 @@ type Querier interface {
923
1012
}
924
1013
925
1014
var _ Querier = (*Queries)(nil)
926
- {{end}}
927
1015
`
928
1016
929
1017
var modelsTmpl = `// Code generated by sqlc. DO NOT EDIT.
@@ -1112,6 +1200,7 @@ func Generate(r Generateable, settings GenerateSettings) (map[string]string, err
1112
1200
dbFile := template .Must (template .New ("table" ).Funcs (funcMap ).Parse (dbTmpl ))
1113
1201
modelsFile := template .Must (template .New ("table" ).Funcs (funcMap ).Parse (modelsTmpl ))
1114
1202
sqlFile := template .Must (template .New ("table" ).Funcs (funcMap ).Parse (sqlTmpl ))
1203
+ ifaceFile := template .Must (template .New ("table" ).Funcs (funcMap ).Parse (ifaceTmpl ))
1115
1204
1116
1205
tctx := tmplCtx {
1117
1206
Settings : settings ,
@@ -1154,6 +1243,11 @@ func Generate(r Generateable, settings GenerateSettings) (map[string]string, err
1154
1243
if err := execute ("models.go" , modelsFile ); err != nil {
1155
1244
return nil , err
1156
1245
}
1246
+ if pkgConfig .EmitInterface {
1247
+ if err := execute ("querier.go" , ifaceFile ); err != nil {
1248
+ return nil , err
1249
+ }
1250
+ }
1157
1251
1158
1252
files := map [string ]struct {}{}
1159
1253
for _ , gq := range r .GoQueries (settings ) {
0 commit comments