Skip to content

Commit d248d0f

Browse files
committed
feat(codegen/golang): allow exporting models file to separate package
1 parent 6b2ed20 commit d248d0f

File tree

9 files changed

+125
-28
lines changed

9 files changed

+125
-28
lines changed

internal/codegen/golang/gen.go

Lines changed: 34 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@ import (
77
"errors"
88
"fmt"
99
"go/format"
10+
"path/filepath"
1011
"strings"
1112
"text/template"
1213

@@ -122,7 +123,7 @@ func Generate(ctx context.Context, req *plugin.GenerateRequest) (*plugin.Generat
122123
}
123124

124125
if options.OmitUnusedStructs {
125-
enums, structs = filterUnusedStructs(enums, structs, queries)
126+
enums, structs = filterUnusedStructs(options, enums, structs, queries)
126127
}
127128

128129
if err := validate(options, enums, structs, queries); err != nil {
@@ -211,6 +212,7 @@ func generate(req *plugin.GenerateRequest, options *opts.Options, enums []Enum,
211212
"imports": i.Imports,
212213
"hasImports": i.HasImports,
213214
"hasPrefix": strings.HasPrefix,
215+
"trimPrefix": strings.TrimPrefix,
214216

215217
// These methods are Go specific, they do not belong in the codegen package
216218
// (as that is language independent)
@@ -232,14 +234,15 @@ func generate(req *plugin.GenerateRequest, options *opts.Options, enums []Enum,
232234

233235
output := map[string]string{}
234236

235-
execute := func(name, templateName string) error {
237+
execute := func(name, packageName, templateName string) error {
236238
imports := i.Imports(name)
237239
replacedQueries := replaceConflictedArg(imports, queries)
238240

239241
var b bytes.Buffer
240242
w := bufio.NewWriter(&b)
241243
tctx.SourceName = name
242244
tctx.GoQueries = replacedQueries
245+
tctx.Package = packageName
243246
err := tmpl.ExecuteTemplate(w, templateName, &tctx)
244247
w.Flush()
245248
if err != nil {
@@ -251,8 +254,13 @@ func generate(req *plugin.GenerateRequest, options *opts.Options, enums []Enum,
251254
return fmt.Errorf("source error: %w", err)
252255
}
253256

254-
if templateName == "queryFile" && options.OutputFilesSuffix != "" {
255-
name += options.OutputFilesSuffix
257+
if templateName == "queryFile" {
258+
if options.OutputQueryFilesDirectory != "" {
259+
name = filepath.Join(options.OutputQueryFilesDirectory, name)
260+
}
261+
if options.OutputFilesSuffix != "" {
262+
name += options.OutputFilesSuffix
263+
}
256264
}
257265

258266
if !strings.HasSuffix(name, ".go") {
@@ -284,24 +292,29 @@ func generate(req *plugin.GenerateRequest, options *opts.Options, enums []Enum,
284292
batchFileName = options.OutputBatchFileName
285293
}
286294

287-
if err := execute(dbFileName, "dbFile"); err != nil {
295+
modelsPackageName := options.Package
296+
if options.OutputModelsPackage != "" {
297+
modelsPackageName = options.OutputModelsPackage
298+
}
299+
300+
if err := execute(dbFileName, options.Package, "dbFile"); err != nil {
288301
return nil, err
289302
}
290-
if err := execute(modelsFileName, "modelsFile"); err != nil {
303+
if err := execute(modelsFileName, modelsPackageName, "modelsFile"); err != nil {
291304
return nil, err
292305
}
293306
if options.EmitInterface {
294-
if err := execute(querierFileName, "interfaceFile"); err != nil {
307+
if err := execute(querierFileName, options.Package, "interfaceFile"); err != nil {
295308
return nil, err
296309
}
297310
}
298311
if tctx.UsesCopyFrom {
299-
if err := execute(copyfromFileName, "copyfromFile"); err != nil {
312+
if err := execute(copyfromFileName, options.Package, "copyfromFile"); err != nil {
300313
return nil, err
301314
}
302315
}
303316
if tctx.UsesBatch {
304-
if err := execute(batchFileName, "batchFile"); err != nil {
317+
if err := execute(batchFileName, options.Package, "batchFile"); err != nil {
305318
return nil, err
306319
}
307320
}
@@ -312,7 +325,7 @@ func generate(req *plugin.GenerateRequest, options *opts.Options, enums []Enum,
312325
}
313326

314327
for source := range files {
315-
if err := execute(source, "queryFile"); err != nil {
328+
if err := execute(source, options.Package, "queryFile"); err != nil {
316329
return nil, err
317330
}
318331
}
@@ -362,7 +375,7 @@ func checkNoTimesForMySQLCopyFrom(queries []Query) error {
362375
return nil
363376
}
364377

365-
func filterUnusedStructs(enums []Enum, structs []Struct, queries []Query) ([]Enum, []Struct) {
378+
func filterUnusedStructs(options *opts.Options, enums []Enum, structs []Struct, queries []Query) ([]Enum, []Struct) {
366379
keepTypes := make(map[string]struct{})
367380

368381
for _, query := range queries {
@@ -389,16 +402,23 @@ func filterUnusedStructs(enums []Enum, structs []Struct, queries []Query) ([]Enu
389402

390403
keepEnums := make([]Enum, 0, len(enums))
391404
for _, enum := range enums {
392-
_, keep := keepTypes[enum.Name]
393-
_, keepNull := keepTypes["Null"+enum.Name]
405+
var enumType string
406+
if options.ModelsPackageImportPath != "" {
407+
enumType = options.OutputModelsPackage + "." + enum.Name
408+
} else {
409+
enumType = enum.Name
410+
}
411+
412+
_, keep := keepTypes[enumType]
413+
_, keepNull := keepTypes["Null"+enumType]
394414
if keep || keepNull {
395415
keepEnums = append(keepEnums, enum)
396416
}
397417
}
398418

399419
keepStructs := make([]Struct, 0, len(structs))
400420
for _, st := range structs {
401-
if _, ok := keepTypes[st.Name]; ok {
421+
if _, ok := keepTypes[st.Type()]; ok {
402422
keepStructs = append(keepStructs, st)
403423
}
404424
}

internal/codegen/golang/imports.go

Lines changed: 47 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -160,7 +160,7 @@ var pqtypeTypes = map[string]struct{}{
160160
"pqtype.NullRawMessage": {},
161161
}
162162

163-
func buildImports(options *opts.Options, queries []Query, uses func(string) bool) (map[string]struct{}, map[ImportSpec]struct{}) {
163+
func buildImports(options *opts.Options, queries []Query, outputFile OutputFile, uses func(string) bool) (map[string]struct{}, map[ImportSpec]struct{}) {
164164
pkg := make(map[ImportSpec]struct{})
165165
std := make(map[string]struct{})
166166

@@ -243,11 +243,52 @@ func buildImports(options *opts.Options, queries []Query, uses func(string) bool
243243
}
244244
}
245245

246+
requiresModelsPackageImport := func() bool {
247+
if options.ModelsPackageImportPath == "" {
248+
return false
249+
}
250+
251+
for _, q := range queries {
252+
// Check if the return type is from models package (possibly a model struct or an enum)
253+
if q.hasRetType() && strings.HasPrefix(q.Ret.Type(), options.OutputModelsPackage+".") {
254+
return true
255+
}
256+
257+
// Check if the return type struct contains a type from models package (possibly an enum field or an embedded struct)
258+
if outputFile != OutputFileInterface && q.hasRetType() && q.Ret.IsStruct() {
259+
for _, f := range q.Ret.Struct.Fields {
260+
if strings.HasPrefix(f.Type, options.OutputModelsPackage+".") {
261+
return true
262+
}
263+
}
264+
}
265+
266+
// Check if the argument type is from models package (possibly an enum)
267+
if !q.Arg.isEmpty() && strings.HasPrefix(q.Arg.Type(), options.OutputModelsPackage+".") {
268+
return true
269+
}
270+
271+
// Check if the argument struct contains a type from models package (possibly an enum field)
272+
if outputFile != OutputFileInterface && !q.Arg.isEmpty() && q.Arg.IsStruct() {
273+
for _, f := range q.Arg.Struct.Fields {
274+
if strings.HasPrefix(f.Type, options.OutputModelsPackage+".") {
275+
return true
276+
}
277+
}
278+
}
279+
280+
}
281+
return false
282+
}
283+
if requiresModelsPackageImport() {
284+
pkg[ImportSpec{Path: options.ModelsPackageImportPath}] = struct{}{}
285+
}
286+
246287
return std, pkg
247288
}
248289

249290
func (i *importer) interfaceImports() fileImports {
250-
std, pkg := buildImports(i.Options, i.Queries, func(name string) bool {
291+
std, pkg := buildImports(i.Options, i.Queries, OutputFileInterface, func(name string) bool {
251292
for _, q := range i.Queries {
252293
if q.hasRetType() {
253294
if usesBatch([]Query{q}) {
@@ -272,7 +313,7 @@ func (i *importer) interfaceImports() fileImports {
272313
}
273314

274315
func (i *importer) modelImports() fileImports {
275-
std, pkg := buildImports(i.Options, nil, i.usesType)
316+
std, pkg := buildImports(i.Options, nil, OutputFileModel, i.usesType)
276317

277318
if len(i.Enums) > 0 {
278319
std["fmt"] = struct{}{}
@@ -311,7 +352,7 @@ func (i *importer) queryImports(filename string) fileImports {
311352
}
312353
}
313354

314-
std, pkg := buildImports(i.Options, gq, func(name string) bool {
355+
std, pkg := buildImports(i.Options, gq, OutputFileQuery, func(name string) bool {
315356
for _, q := range gq {
316357
if q.hasRetType() {
317358
if q.Ret.EmitStruct() {
@@ -412,7 +453,7 @@ func (i *importer) copyfromImports() fileImports {
412453
copyFromQueries = append(copyFromQueries, q)
413454
}
414455
}
415-
std, pkg := buildImports(i.Options, copyFromQueries, func(name string) bool {
456+
std, pkg := buildImports(i.Options, copyFromQueries, OutputFileCopyfrom, func(name string) bool {
416457
for _, q := range copyFromQueries {
417458
if q.hasRetType() {
418459
if strings.HasPrefix(q.Ret.Type(), name) {
@@ -447,7 +488,7 @@ func (i *importer) batchImports() fileImports {
447488
batchQueries = append(batchQueries, q)
448489
}
449490
}
450-
std, pkg := buildImports(i.Options, batchQueries, func(name string) bool {
491+
std, pkg := buildImports(i.Options, batchQueries, OutputFileBatch, func(name string) bool {
451492
for _, q := range batchQueries {
452493
if q.hasRetType() {
453494
if q.Ret.EmitStruct() {

internal/codegen/golang/opts/options.go

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -35,8 +35,11 @@ type Options struct {
3535
OutputBatchFileName string `json:"output_batch_file_name,omitempty" yaml:"output_batch_file_name"`
3636
OutputDbFileName string `json:"output_db_file_name,omitempty" yaml:"output_db_file_name"`
3737
OutputModelsFileName string `json:"output_models_file_name,omitempty" yaml:"output_models_file_name"`
38+
OutputModelsPackage string `json:"output_models_package,omitempty" yaml:"output_models_package"`
39+
ModelsPackageImportPath string `json:"models_package_import_path,omitempty" yaml:"models_package_import_path"`
3840
OutputQuerierFileName string `json:"output_querier_file_name,omitempty" yaml:"output_querier_file_name"`
3941
OutputCopyfromFileName string `json:"output_copyfrom_file_name,omitempty" yaml:"output_copyfrom_file_name"`
42+
OutputQueryFilesDirectory string `json:"output_query_files_directory,omitempty" yaml:"output_query_files_directory"`
4043
OutputFilesSuffix string `json:"output_files_suffix,omitempty" yaml:"output_files_suffix"`
4144
InflectionExcludeTableNames []string `json:"inflection_exclude_table_names,omitempty" yaml:"inflection_exclude_table_names"`
4245
QueryParameterLimit *int32 `json:"query_parameter_limit,omitempty" yaml:"query_parameter_limit"`
@@ -150,6 +153,11 @@ func ValidateOpts(opts *Options) error {
150153
if *opts.QueryParameterLimit < 0 {
151154
return fmt.Errorf("invalid options: query parameter limit must not be negative")
152155
}
153-
156+
if opts.OutputModelsPackage != "" && opts.ModelsPackageImportPath == "" {
157+
return fmt.Errorf("invalid options: models_package_import_path must be set when output_models_package is used")
158+
}
159+
if opts.ModelsPackageImportPath != "" && opts.OutputModelsPackage == "" {
160+
return fmt.Errorf("invalid options: output_models_package must be set when models_package_import_path is used")
161+
}
154162
return nil
155163
}
Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,12 @@
1+
package golang
2+
3+
type OutputFile string
4+
5+
const (
6+
OutputFileModel OutputFile = "modelFile"
7+
OutputFileQuery OutputFile = "queryFile"
8+
OutputFileDb OutputFile = "dbFile"
9+
OutputFileInterface OutputFile = "interfaceFile"
10+
OutputFileCopyfrom OutputFile = "copyfromFile"
11+
OutputFileBatch OutputFile = "batchFile"
12+
)

internal/codegen/golang/postgresql_type.go

Lines changed: 11 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -571,17 +571,24 @@ func postgresType(req *plugin.GenerateRequest, options *opts.Options, col *plugi
571571

572572
for _, enum := range schema.Enums {
573573
if rel.Name == enum.Name && rel.Schema == schema.Name {
574+
enumName := ""
574575
if notNull {
575576
if schema.Name == req.Catalog.DefaultSchema {
576-
return StructName(enum.Name, options)
577+
enumName = StructName(enum.Name, options)
578+
} else {
579+
enumName = StructName(schema.Name+"_"+enum.Name, options)
577580
}
578-
return StructName(schema.Name+"_"+enum.Name, options)
579581
} else {
580582
if schema.Name == req.Catalog.DefaultSchema {
581-
return "Null" + StructName(enum.Name, options)
583+
enumName = "Null" + StructName(enum.Name, options)
584+
} else {
585+
enumName = "Null" + StructName(schema.Name+"_"+enum.Name, options)
582586
}
583-
return "Null" + StructName(schema.Name+"_"+enum.Name, options)
584587
}
588+
if options.ModelsPackageImportPath != "" {
589+
return options.OutputModelsPackage + "." + enumName
590+
}
591+
return enumName
585592
}
586593
}
587594

internal/codegen/golang/query.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -88,7 +88,7 @@ func (v QueryValue) Type() string {
8888
return v.Typ
8989
}
9090
if v.Struct != nil {
91-
return v.Struct.Name
91+
return v.Struct.Type()
9292
}
9393
panic("no type for QueryValue: " + v.Name)
9494
}

internal/codegen/golang/result.go

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -83,6 +83,7 @@ func buildStructs(req *plugin.GenerateRequest, options *opts.Options) []Struct {
8383
s := Struct{
8484
Table: &plugin.Identifier{Schema: schema.Name, Name: table.Rel.Name},
8585
Name: StructName(structName, options),
86+
Package: options.OutputModelsPackage,
8687
Comment: table.Comment,
8788
}
8889
for _, column := range table.Columns {
@@ -146,7 +147,7 @@ func newGoEmbed(embed *plugin.Identifier, structs []Struct, defaultSchema string
146147
}
147148

148149
return &goEmbed{
149-
modelType: s.Name,
150+
modelType: s.Type(),
150151
modelName: s.Name,
151152
fields: fields,
152153
}

internal/codegen/golang/struct.go

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,10 +12,18 @@ import (
1212
type Struct struct {
1313
Table *plugin.Identifier
1414
Name string
15+
Package string
1516
Fields []Field
1617
Comment string
1718
}
1819

20+
func (s Struct) Type() string {
21+
if s.Package != "" {
22+
return s.Package + "." + s.Name
23+
}
24+
return s.Name
25+
}
26+
1927
func StructName(name string, options *opts.Options) string {
2028
if rename := options.Rename[name]; rename != "" {
2129
return rename

internal/codegen/golang/templates/template.tmpl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -156,7 +156,7 @@ type {{.Name}} struct { {{- range .Fields}}
156156
{{- if .Comment}}
157157
{{comment .Comment}}{{else}}
158158
{{- end}}
159-
{{.Name}} {{.Type}} {{if .Tag}}{{$.Q}}{{.Tag}}{{$.Q}}{{end}}
159+
{{.Name}} {{trimPrefix .Type (printf "%s%s" $.Package ".") }} {{if .Tag}}{{$.Q}}{{.Tag}}{{$.Q}}{{end}}
160160
{{- end}}
161161
}
162162
{{end}}

0 commit comments

Comments
 (0)