77 "errors"
88 "fmt"
99 "go/format"
10+ "path/filepath"
1011 "strings"
1112 "text/template"
1213
@@ -122,7 +123,7 @@ func Generate(ctx context.Context, req *plugin.GenerateRequest) (*plugin.Generat
122123 }
123124
124125 if options .OmitUnusedStructs {
125- enums , structs = filterUnusedStructs (enums , structs , queries )
126+ enums , structs = filterUnusedStructs (options , enums , structs , queries )
126127 }
127128
128129 if err := validate (options , enums , structs , queries ); err != nil {
@@ -211,6 +212,7 @@ func generate(req *plugin.GenerateRequest, options *opts.Options, enums []Enum,
211212 "imports" : i .Imports ,
212213 "hasImports" : i .HasImports ,
213214 "hasPrefix" : strings .HasPrefix ,
215+ "trimPrefix" : strings .TrimPrefix ,
214216
215217 // These methods are Go specific, they do not belong in the codegen package
216218 // (as that is language independent)
@@ -232,14 +234,15 @@ func generate(req *plugin.GenerateRequest, options *opts.Options, enums []Enum,
232234
233235 output := map [string ]string {}
234236
235- execute := func (name , templateName string ) error {
237+ execute := func (name , packageName , templateName string ) error {
236238 imports := i .Imports (name )
237239 replacedQueries := replaceConflictedArg (imports , queries )
238240
239241 var b bytes.Buffer
240242 w := bufio .NewWriter (& b )
241243 tctx .SourceName = name
242244 tctx .GoQueries = replacedQueries
245+ tctx .Package = packageName
243246 err := tmpl .ExecuteTemplate (w , templateName , & tctx )
244247 w .Flush ()
245248 if err != nil {
@@ -251,8 +254,13 @@ func generate(req *plugin.GenerateRequest, options *opts.Options, enums []Enum,
251254 return fmt .Errorf ("source error: %w" , err )
252255 }
253256
254- if templateName == "queryFile" && options .OutputFilesSuffix != "" {
255- name += options .OutputFilesSuffix
257+ if templateName == "queryFile" {
258+ if options .OutputQueryFilesDirectory != "" {
259+ name = filepath .Join (options .OutputQueryFilesDirectory , name )
260+ }
261+ if options .OutputFilesSuffix != "" {
262+ name += options .OutputFilesSuffix
263+ }
256264 }
257265
258266 if ! strings .HasSuffix (name , ".go" ) {
@@ -284,24 +292,29 @@ func generate(req *plugin.GenerateRequest, options *opts.Options, enums []Enum,
284292 batchFileName = options .OutputBatchFileName
285293 }
286294
287- if err := execute (dbFileName , "dbFile" ); err != nil {
295+ modelsPackageName := options .Package
296+ if options .OutputModelsPackage != "" {
297+ modelsPackageName = options .OutputModelsPackage
298+ }
299+
300+ if err := execute (dbFileName , options .Package , "dbFile" ); err != nil {
288301 return nil , err
289302 }
290- if err := execute (modelsFileName , "modelsFile" ); err != nil {
303+ if err := execute (modelsFileName , modelsPackageName , "modelsFile" ); err != nil {
291304 return nil , err
292305 }
293306 if options .EmitInterface {
294- if err := execute (querierFileName , "interfaceFile" ); err != nil {
307+ if err := execute (querierFileName , options . Package , "interfaceFile" ); err != nil {
295308 return nil , err
296309 }
297310 }
298311 if tctx .UsesCopyFrom {
299- if err := execute (copyfromFileName , "copyfromFile" ); err != nil {
312+ if err := execute (copyfromFileName , options . Package , "copyfromFile" ); err != nil {
300313 return nil , err
301314 }
302315 }
303316 if tctx .UsesBatch {
304- if err := execute (batchFileName , "batchFile" ); err != nil {
317+ if err := execute (batchFileName , options . Package , "batchFile" ); err != nil {
305318 return nil , err
306319 }
307320 }
@@ -312,7 +325,7 @@ func generate(req *plugin.GenerateRequest, options *opts.Options, enums []Enum,
312325 }
313326
314327 for source := range files {
315- if err := execute (source , "queryFile" ); err != nil {
328+ if err := execute (source , options . Package , "queryFile" ); err != nil {
316329 return nil , err
317330 }
318331 }
@@ -362,7 +375,7 @@ func checkNoTimesForMySQLCopyFrom(queries []Query) error {
362375 return nil
363376}
364377
365- func filterUnusedStructs (enums []Enum , structs []Struct , queries []Query ) ([]Enum , []Struct ) {
378+ func filterUnusedStructs (options * opts. Options , enums []Enum , structs []Struct , queries []Query ) ([]Enum , []Struct ) {
366379 keepTypes := make (map [string ]struct {})
367380
368381 for _ , query := range queries {
@@ -389,16 +402,23 @@ func filterUnusedStructs(enums []Enum, structs []Struct, queries []Query) ([]Enu
389402
390403 keepEnums := make ([]Enum , 0 , len (enums ))
391404 for _ , enum := range enums {
392- _ , keep := keepTypes [enum .Name ]
393- _ , keepNull := keepTypes ["Null" + enum .Name ]
405+ var enumType string
406+ if options .ModelsPackageImportPath != "" {
407+ enumType = options .OutputModelsPackage + "." + enum .Name
408+ } else {
409+ enumType = enum .Name
410+ }
411+
412+ _ , keep := keepTypes [enumType ]
413+ _ , keepNull := keepTypes ["Null" + enumType ]
394414 if keep || keepNull {
395415 keepEnums = append (keepEnums , enum )
396416 }
397417 }
398418
399419 keepStructs := make ([]Struct , 0 , len (structs ))
400420 for _ , st := range structs {
401- if _ , ok := keepTypes [st .Name ]; ok {
421+ if _ , ok := keepTypes [st .Type () ]; ok {
402422 keepStructs = append (keepStructs , st )
403423 }
404424 }
0 commit comments