diff --git a/CLAUDE.md b/CLAUDE.md index 43abb0d491..a0ec46dfc8 100644 --- a/CLAUDE.md +++ b/CLAUDE.md @@ -339,6 +339,25 @@ docker compose up -d 3. **Use specific package tests:** Faster iteration during development 4. **Start databases early:** `docker compose up -d` before running integration tests 5. **Read existing tests:** Good examples in `/internal/engine/postgresql/*_test.go` +6. **Always run go fmt:** Format code before committing (see Code Formatting below) + +## Code Formatting + +**Always run `go fmt` before committing changes.** This ensures consistent code style across the codebase. + +```bash +# Format specific packages +go fmt ./internal/codegen/golang/... +go fmt ./internal/poet/... + +# Format all packages +go fmt ./... +``` + +For the code generation packages specifically: +```bash +go fmt ./internal/codegen/golang/... ./internal/poet/... +``` ## Git Workflow diff --git a/internal/codegen/golang/gen.go b/internal/codegen/golang/gen.go index 7df56a0a41..7abdcdf691 100644 --- a/internal/codegen/golang/gen.go +++ b/internal/codegen/golang/gen.go @@ -1,17 +1,12 @@ package golang import ( - "bufio" - "bytes" "context" "errors" "fmt" - "go/format" "strings" - "text/template" "github.com/sqlc-dev/sqlc/internal/codegen/golang/opts" - "github.com/sqlc-dev/sqlc/internal/codegen/sdk" "github.com/sqlc-dev/sqlc/internal/metadata" "github.com/sqlc-dev/sqlc/internal/plugin" ) @@ -171,7 +166,7 @@ func generate(req *plugin.GenerateRequest, options *opts.Options, enums []Enum, Structs: structs, } - tctx := tmplCtx{ + tctx := &tmplCtx{ EmitInterface: options.EmitInterface, EmitJSONTags: options.EmitJsonTags, JsonTagsIDUppercase: options.JsonTagsIdUppercase, @@ -209,64 +204,9 @@ func generate(req *plugin.GenerateRequest, options *opts.Options, enums []Enum, return nil, errors.New(":batch* commands are only supported by pgx") } - funcMap := template.FuncMap{ - "lowerTitle": sdk.LowerTitle, - "comment": sdk.DoubleSlashComment, - "escape": sdk.EscapeBacktick, - "imports": i.Imports, - "hasImports": i.HasImports, - "hasPrefix": strings.HasPrefix, - - // These methods are Go specific, they do not belong in the codegen package - // (as that is language independent) - "dbarg": tctx.codegenDbarg, - "emitPreparedQueries": tctx.codegenEmitPreparedQueries, - "queryMethod": tctx.codegenQueryMethod, - "queryRetval": tctx.codegenQueryRetval, - } - - tmpl := template.Must( - template.New("table"). - Funcs(funcMap). - ParseFS( - templates, - "templates/*.tmpl", - "templates/*/*.tmpl", - ), - ) - output := map[string]string{} - execute := func(name, templateName string) error { - imports := i.Imports(name) - replacedQueries := replaceConflictedArg(imports, queries) - - var b bytes.Buffer - w := bufio.NewWriter(&b) - tctx.SourceName = name - tctx.GoQueries = replacedQueries - err := tmpl.ExecuteTemplate(w, templateName, &tctx) - w.Flush() - if err != nil { - return err - } - code, err := format.Source(b.Bytes()) - if err != nil { - fmt.Println(b.String()) - return fmt.Errorf("source error: %w", err) - } - - if templateName == "queryFile" && options.OutputFilesSuffix != "" { - name += options.OutputFilesSuffix - } - - if !strings.HasSuffix(name, ".go") { - name += ".go" - } - output[name] = string(code) - return nil - } - + // File names dbFileName := "db.go" if options.OutputDbFileName != "" { dbFileName = options.OutputDbFileName @@ -283,46 +223,89 @@ func generate(req *plugin.GenerateRequest, options *opts.Options, enums []Enum, if options.OutputCopyfromFileName != "" { copyfromFileName = options.OutputCopyfromFileName } - batchFileName := "batch.go" if options.OutputBatchFileName != "" { batchFileName = options.OutputBatchFileName } - if err := execute(dbFileName, "dbFile"); err != nil { - return nil, err + // Generate db.go + tctx.SourceName = dbFileName + tctx.GoQueries = replaceConflictedArg(i.Imports(dbFileName), queries) + gen := NewCodeGenerator(tctx, i) + + code, err := gen.GenerateDBFile() + if err != nil { + return nil, fmt.Errorf("db file error: %w", err) } - if err := execute(modelsFileName, "modelsFile"); err != nil { - return nil, err + output[dbFileName] = string(code) + + // Generate models.go + tctx.SourceName = modelsFileName + tctx.GoQueries = replaceConflictedArg(i.Imports(modelsFileName), queries) + code, err = gen.GenerateModelsFile() + if err != nil { + return nil, fmt.Errorf("models file error: %w", err) } + output[modelsFileName] = string(code) + + // Generate querier.go if options.EmitInterface { - if err := execute(querierFileName, "interfaceFile"); err != nil { - return nil, err + tctx.SourceName = querierFileName + tctx.GoQueries = replaceConflictedArg(i.Imports(querierFileName), queries) + code, err = gen.GenerateQuerierFile() + if err != nil { + return nil, fmt.Errorf("querier file error: %w", err) } + output[querierFileName] = string(code) } + + // Generate copyfrom.go if tctx.UsesCopyFrom { - if err := execute(copyfromFileName, "copyfromFile"); err != nil { - return nil, err + tctx.SourceName = copyfromFileName + tctx.GoQueries = replaceConflictedArg(i.Imports(copyfromFileName), queries) + code, err = gen.GenerateCopyFromFile() + if err != nil { + return nil, fmt.Errorf("copyfrom file error: %w", err) } + output[copyfromFileName] = string(code) } + + // Generate batch.go if tctx.UsesBatch { - if err := execute(batchFileName, "batchFile"); err != nil { - return nil, err + tctx.SourceName = batchFileName + tctx.GoQueries = replaceConflictedArg(i.Imports(batchFileName), queries) + code, err = gen.GenerateBatchFile() + if err != nil { + return nil, fmt.Errorf("batch file error: %w", err) } + output[batchFileName] = string(code) } - files := map[string]struct{}{} + // Generate query files + sourceFiles := map[string]struct{}{} for _, gq := range queries { - files[gq.SourceName] = struct{}{} + sourceFiles[gq.SourceName] = struct{}{} } - for source := range files { - if err := execute(source, "queryFile"); err != nil { - return nil, err + for source := range sourceFiles { + tctx.SourceName = source + tctx.GoQueries = replaceConflictedArg(i.Imports(source), queries) + code, err = gen.GenerateQueryFile(source) + if err != nil { + return nil, fmt.Errorf("query file error for %s: %w", source, err) } + + filename := source + if options.OutputFilesSuffix != "" { + filename += options.OutputFilesSuffix + } + if !strings.HasSuffix(filename, ".go") { + filename += ".go" + } + output[filename] = string(code) } - resp := plugin.GenerateResponse{} + resp := plugin.GenerateResponse{} for filename, code := range output { resp.Files = append(resp.Files, &plugin.File{ Name: filename, diff --git a/internal/codegen/golang/generator.go b/internal/codegen/golang/generator.go new file mode 100644 index 0000000000..5aeaa4e97f --- /dev/null +++ b/internal/codegen/golang/generator.go @@ -0,0 +1,2299 @@ +package golang + +import ( + "fmt" + "strings" + + "github.com/sqlc-dev/sqlc/internal/codegen/sdk" + "github.com/sqlc-dev/sqlc/internal/metadata" + "github.com/sqlc-dev/sqlc/internal/poet" +) + +// CodeGenerator generates Go source code for sqlc. +type CodeGenerator struct { + tctx *tmplCtx + i *importer +} + +// NewCodeGenerator creates a new code generator. +func NewCodeGenerator(tctx *tmplCtx, i *importer) *CodeGenerator { + return &CodeGenerator{tctx: tctx, i: i} +} + +// GenerateDBFile generates the db.go file content. +func (g *CodeGenerator) GenerateDBFile() ([]byte, error) { + f := g.newFile("") + if g.tctx.SQLDriver.IsPGX() { + g.addDBCodePGX(f) + } else { + g.addDBCodeStd(f) + } + return poet.Render(f) +} + +// GenerateModelsFile generates the models.go file content. +func (g *CodeGenerator) GenerateModelsFile() ([]byte, error) { + f := g.newFile("") + g.addModelsCode(f) + return poet.Render(f) +} + +// GenerateQuerierFile generates the querier.go file content. +func (g *CodeGenerator) GenerateQuerierFile() ([]byte, error) { + f := g.newFile("") + if g.tctx.SQLDriver.IsPGX() { + g.addInterfaceCodePGX(f) + } else { + g.addInterfaceCodeStd(f) + } + return poet.Render(f) +} + +// GenerateQueryFile generates a query source file content. +func (g *CodeGenerator) GenerateQueryFile(sourceName string) ([]byte, error) { + f := g.newFile(sourceName) + if g.tctx.SQLDriver.IsPGX() { + g.addQueryCodePGX(f, sourceName) + } else { + g.addQueryCodeStd(f, sourceName) + } + return poet.Render(f) +} + +// GenerateCopyFromFile generates the copyfrom.go file content. +func (g *CodeGenerator) GenerateCopyFromFile() ([]byte, error) { + f := g.newFile(g.tctx.SourceName) + if g.tctx.SQLDriver.IsPGX() { + g.addCopyFromCodePGX(f) + } else if g.tctx.SQLDriver.IsGoSQLDriverMySQL() { + g.addCopyFromCodeMySQL(f) + } + return poet.Render(f) +} + +// GenerateBatchFile generates the batch.go file content. +func (g *CodeGenerator) GenerateBatchFile() ([]byte, error) { + f := g.newFile(g.tctx.SourceName) + g.addBatchCodePGX(f) + return poet.Render(f) +} + +func (g *CodeGenerator) newFile(sourceComment string) *poet.File { + f := &poet.File{ + BuildTags: g.tctx.BuildTags, + Package: g.tctx.Package, + } + + // File comments + f.Comments = append(f.Comments, "// Code generated by sqlc. DO NOT EDIT.") + if !g.tctx.OmitSqlcVersion { + f.Comments = append(f.Comments, "// versions:") + f.Comments = append(f.Comments, "// sqlc "+g.tctx.SqlcVersion) + } + if sourceComment != "" { + f.Comments = append(f.Comments, "// source: "+sourceComment) + } + + // Imports - two groups: stdlib and third-party, separated by blank line + imports := g.i.Imports(g.tctx.SourceName) + var stdlibImports, thirdPartyImports []poet.Import + for _, imp := range imports[0] { + stdlibImports = append(stdlibImports, poet.Import{Path: imp.Path, Alias: imp.ID}) + } + for _, imp := range imports[1] { + thirdPartyImports = append(thirdPartyImports, poet.Import{Path: imp.Path, Alias: imp.ID}) + } + f.ImportGroups = [][]poet.Import{stdlibImports, thirdPartyImports} + + return f +} + +func (g *CodeGenerator) addDBCodeStd(f *poet.File) { + // DBTX interface + f.Decls = append(f.Decls, poet.TypeDef{ + Name: "DBTX", + Type: poet.Interface{ + Methods: []poet.Method{ + {Name: "ExecContext", Params: []poet.Param{{Name: "", Type: "context.Context"}, {Name: "", Type: "string"}, {Name: "", Type: "...interface{}"}}, Results: []poet.Param{{Type: "sql.Result"}, {Type: "error"}}}, + {Name: "PrepareContext", Params: []poet.Param{{Name: "", Type: "context.Context"}, {Name: "", Type: "string"}}, Results: []poet.Param{{Type: "*sql.Stmt"}, {Type: "error"}}}, + {Name: "QueryContext", Params: []poet.Param{{Name: "", Type: "context.Context"}, {Name: "", Type: "string"}, {Name: "", Type: "...interface{}"}}, Results: []poet.Param{{Type: "*sql.Rows"}, {Type: "error"}}}, + {Name: "QueryRowContext", Params: []poet.Param{{Name: "", Type: "context.Context"}, {Name: "", Type: "string"}, {Name: "", Type: "...interface{}"}}, Results: []poet.Param{{Type: "*sql.Row"}}}, + }, + }, + }) + + // New function + if g.tctx.EmitMethodsWithDBArgument { + f.Decls = append(f.Decls, poet.Func{ + Name: "New", + Results: []poet.Param{{Type: "Queries", Pointer: true}}, + Stmts: []poet.Stmt{poet.Return{Values: []string{ + poet.StructLit{Type: "Queries", Pointer: true}.Render(), + }}}, + }) + } else { + f.Decls = append(f.Decls, poet.Func{ + Name: "New", + Params: []poet.Param{{Name: "db", Type: "DBTX"}}, + Results: []poet.Param{{Type: "Queries", Pointer: true}}, + Stmts: []poet.Stmt{poet.Return{Values: []string{ + poet.StructLit{Type: "Queries", Pointer: true, Fields: [][2]string{ + {"db", "db"}, + }}.Render(), + }}}, + }) + } + + // Prepare and Close functions for prepared queries + if g.tctx.EmitPreparedQueries { + // Build Prepare function statements + var prepareStmts []poet.Stmt + prepareStmts = append(prepareStmts, poet.Assign{ + Left: []string{"q"}, + Op: ":=", + Right: []string{poet.StructLit{Type: "Queries", Fields: [][2]string{{"db", "db"}}}.Render()}, + }) + prepareStmts = append(prepareStmts, poet.VarDecl{Name: "err", Type: "error"}) + if len(g.tctx.GoQueries) == 0 { + prepareStmts = append(prepareStmts, poet.Assign{Left: []string{"_"}, Op: "=", Right: []string{"err"}}) + } + for _, query := range g.tctx.GoQueries { + prepareStmts = append(prepareStmts, poet.If{ + Init: fmt.Sprintf("q.%s, err = db.PrepareContext(ctx, %s)", query.FieldName, query.ConstantName), + Cond: "err != nil", + Body: []poet.Stmt{poet.Return{Values: []string{ + "nil", + fmt.Sprintf(`fmt.Errorf("error preparing query %s: %%w", err)`, query.MethodName), + }}}, + }) + } + prepareStmts = append(prepareStmts, poet.Return{Values: []string{"&q", "nil"}}) + + f.Decls = append(f.Decls, poet.Func{ + Name: "Prepare", + Params: []poet.Param{{Name: "ctx", Type: "context.Context"}, {Name: "db", Type: "DBTX"}}, + Results: []poet.Param{{Type: "*Queries"}, {Type: "error"}}, + Stmts: prepareStmts, + }) + + // Build Close function statements + var closeStmts []poet.Stmt + closeStmts = append(closeStmts, poet.VarDecl{Name: "err", Type: "error"}) + for _, query := range g.tctx.GoQueries { + closeStmts = append(closeStmts, poet.If{ + Cond: fmt.Sprintf("q.%s != nil", query.FieldName), + Body: []poet.Stmt{poet.If{ + Init: fmt.Sprintf("cerr := q.%s.Close()", query.FieldName), + Cond: "cerr != nil", + Body: []poet.Stmt{poet.Assign{ + Left: []string{"err"}, + Op: "=", + Right: []string{fmt.Sprintf(`fmt.Errorf("error closing %s: %%w", cerr)`, query.FieldName)}, + }}, + }}, + }) + } + closeStmts = append(closeStmts, poet.Return{Values: []string{"err"}}) + + f.Decls = append(f.Decls, poet.Func{ + Recv: &poet.Param{Name: "q", Type: "*Queries"}, + Name: "Close", + Results: []poet.Param{{Type: "error"}}, + Stmts: closeStmts, + }) + + // Helper functions + f.Decls = append(f.Decls, poet.Func{ + Recv: &poet.Param{Name: "q", Type: "Queries", Pointer: true}, + Name: "exec", + Params: []poet.Param{ + {Name: "ctx", Type: "context.Context"}, + {Name: "stmt", Type: "sql.Stmt", Pointer: true}, + {Name: "query", Type: "string"}, + {Name: "args", Type: "...interface{}"}, + }, + Results: []poet.Param{{Type: "sql.Result"}, {Type: "error"}}, + Stmts: []poet.Stmt{poet.Switch{ + Cases: []poet.Case{ + { + Values: []string{"stmt != nil && q.tx != nil"}, + Body: []poet.Stmt{poet.Return{ + Values: []string{ + "q.tx.StmtContext(ctx, stmt).ExecContext(ctx, args...)", + }, + }}, + }, + { + Values: []string{"stmt != nil"}, + Body: []poet.Stmt{poet.Return{ + Values: []string{"stmt.ExecContext(ctx, args...)"}, + }}, + }, + { + Body: []poet.Stmt{poet.Return{ + Values: []string{"q.db.ExecContext(ctx, query, args...)"}, + }}, + }, + }, + }}, + }) + + f.Decls = append(f.Decls, poet.Func{ + Recv: &poet.Param{Name: "q", Type: "Queries", Pointer: true}, + Name: "query", + Params: []poet.Param{ + {Name: "ctx", Type: "context.Context"}, + {Name: "stmt", Type: "sql.Stmt", Pointer: true}, + {Name: "query", Type: "string"}, + {Name: "args", Type: "...interface{}"}, + }, + Results: []poet.Param{{Type: "sql.Rows", Pointer: true}, {Type: "error"}}, + Stmts: []poet.Stmt{poet.Switch{ + Cases: []poet.Case{ + { + Values: []string{"stmt != nil && q.tx != nil"}, + Body: []poet.Stmt{poet.Return{ + Values: []string{ + "q.tx.StmtContext(ctx, stmt).QueryContext(ctx, args...)", + }, + }}, + }, + { + Values: []string{"stmt != nil"}, + Body: []poet.Stmt{poet.Return{ + Values: []string{"stmt.QueryContext(ctx, args...)"}, + }}, + }, + { + Body: []poet.Stmt{poet.Return{ + Values: []string{"q.db.QueryContext(ctx, query, args...)"}, + }}, + }, + }, + }}, + }) + + f.Decls = append(f.Decls, poet.Func{ + Recv: &poet.Param{Name: "q", Type: "Queries", Pointer: true}, + Name: "queryRow", + Params: []poet.Param{ + {Name: "ctx", Type: "context.Context"}, + {Name: "stmt", Type: "sql.Stmt", Pointer: true}, + {Name: "query", Type: "string"}, + {Name: "args", Type: "...interface{}"}, + }, + Results: []poet.Param{{Type: "sql.Row", Pointer: true}}, + Stmts: []poet.Stmt{poet.Switch{ + Cases: []poet.Case{ + { + Values: []string{"stmt != nil && q.tx != nil"}, + Body: []poet.Stmt{poet.Return{ + Values: []string{ + "q.tx.StmtContext(ctx, stmt).QueryRowContext(ctx, args...)", + }, + }}, + }, + { + Values: []string{"stmt != nil"}, + Body: []poet.Stmt{poet.Return{ + Values: []string{"stmt.QueryRowContext(ctx, args...)"}, + }}, + }, + { + Body: []poet.Stmt{poet.Return{ + Values: []string{"q.db.QueryRowContext(ctx, query, args...)"}, + }}, + }, + }, + }}, + }) + } + + // Queries struct + var fields []poet.Field + if !g.tctx.EmitMethodsWithDBArgument { + fields = append(fields, poet.Field{Name: "db", Type: "DBTX"}) + } + if g.tctx.EmitPreparedQueries { + fields = append(fields, poet.Field{Name: "tx", Type: "*sql.Tx"}) + for _, query := range g.tctx.GoQueries { + fields = append(fields, poet.Field{Name: query.FieldName, Type: "*sql.Stmt"}) + } + } + f.Decls = append(f.Decls, poet.TypeDef{ + Name: "Queries", + Type: poet.Struct{Fields: fields}, + }) + + // WithTx method + if !g.tctx.EmitMethodsWithDBArgument { + withTxFields := [][2]string{{"db", "tx"}} + if g.tctx.EmitPreparedQueries { + withTxFields = append(withTxFields, [2]string{"tx", "tx"}) + for _, query := range g.tctx.GoQueries { + withTxFields = append(withTxFields, [2]string{query.FieldName, "q." + query.FieldName}) + } + } + f.Decls = append(f.Decls, poet.Func{ + Recv: &poet.Param{Name: "q", Type: "*Queries"}, + Name: "WithTx", + Params: []poet.Param{{Name: "tx", Type: "*sql.Tx"}}, + Results: []poet.Param{{Type: "*Queries"}}, + Stmts: []poet.Stmt{poet.Return{Values: []string{ + poet.StructLit{Type: "Queries", Pointer: true, Multiline: true, Fields: withTxFields}.Render(), + }}}, + }) + } +} + +func (g *CodeGenerator) addDBCodePGX(f *poet.File) { + // DBTX interface + methods := []poet.Method{ + { + Name: "Exec", + Params: []poet.Param{{Type: "context.Context"}, {Type: "string"}, {Type: "...interface{}"}}, + Results: []poet.Param{{Type: "pgconn.CommandTag"}, {Type: "error"}}, + }, + { + Name: "Query", + Params: []poet.Param{{Type: "context.Context"}, {Type: "string"}, {Type: "...interface{}"}}, + Results: []poet.Param{{Type: "pgx.Rows"}, {Type: "error"}}, + }, + { + Name: "QueryRow", + Params: []poet.Param{{Type: "context.Context"}, {Type: "string"}, {Type: "...interface{}"}}, + Results: []poet.Param{{Type: "pgx.Row"}}, + }, + } + if g.tctx.UsesCopyFrom { + methods = append(methods, poet.Method{ + Name: "CopyFrom", + Params: []poet.Param{{Name: "ctx", Type: "context.Context"}, {Name: "tableName", Type: "pgx.Identifier"}, {Name: "columnNames", Type: "[]string"}, {Name: "rowSrc", Type: "pgx.CopyFromSource"}}, + Results: []poet.Param{{Type: "int64"}, {Type: "error"}}, + }) + } + if g.tctx.UsesBatch { + methods = append(methods, poet.Method{ + Name: "SendBatch", + Params: []poet.Param{{Name: "", Type: "context.Context"}, {Name: "", Type: "*pgx.Batch"}}, + Results: []poet.Param{{Type: "pgx.BatchResults"}}, + }) + } + f.Decls = append(f.Decls, poet.TypeDef{ + Name: "DBTX", + Type: poet.Interface{Methods: methods}, + }) + + // New function + if g.tctx.EmitMethodsWithDBArgument { + f.Decls = append(f.Decls, poet.Func{ + Name: "New", + Results: []poet.Param{{Type: "Queries", Pointer: true}}, + Stmts: []poet.Stmt{poet.Return{Values: []string{ + poet.StructLit{Type: "Queries", Pointer: true}.Render(), + }}}, + }) + } else { + f.Decls = append(f.Decls, poet.Func{ + Name: "New", + Params: []poet.Param{{Name: "db", Type: "DBTX"}}, + Results: []poet.Param{{Type: "Queries", Pointer: true}}, + Stmts: []poet.Stmt{poet.Return{Values: []string{ + poet.StructLit{Type: "Queries", Pointer: true, Fields: [][2]string{ + {"db", "db"}, + }}.Render(), + }}}, + }) + } + + // Queries struct + var fields []poet.Field + if !g.tctx.EmitMethodsWithDBArgument { + fields = append(fields, poet.Field{Name: "db", Type: "DBTX"}) + } + f.Decls = append(f.Decls, poet.TypeDef{ + Name: "Queries", + Type: poet.Struct{Fields: fields}, + }) + + // WithTx method + if !g.tctx.EmitMethodsWithDBArgument { + f.Decls = append(f.Decls, poet.Func{ + Recv: &poet.Param{Name: "q", Type: "Queries", Pointer: true}, + Name: "WithTx", + Params: []poet.Param{{Name: "tx", Type: "pgx.Tx"}}, + Results: []poet.Param{{Type: "Queries", Pointer: true}}, + Stmts: []poet.Stmt{poet.Return{Values: []string{ + poet.StructLit{Type: "Queries", Pointer: true, Multiline: true, Fields: [][2]string{ + {"db", "tx"}, + }}.Render(), + }}}, + }) + } +} + +func (g *CodeGenerator) addModelsCode(f *poet.File) { + // Enums + for _, enum := range g.tctx.Enums { + // Type alias + f.Decls = append(f.Decls, poet.TypeDef{ + Comment: enum.Comment, + Name: enum.Name, + Type: poet.TypeName{Name: "string"}, + }) + + // Constants + var consts []poet.Const + for _, c := range enum.Constants { + consts = append(consts, poet.Const{ + Name: c.Name, + Type: c.Type, + Value: fmt.Sprintf("%q", c.Value), + }) + } + f.Decls = append(f.Decls, poet.ConstBlock{Consts: consts}) + + // Scan method + typeCast := poet.TypeCast{Type: enum.Name, Value: "s"}.Render() + f.Decls = append(f.Decls, poet.Func{ + Recv: &poet.Param{Name: "e", Type: enum.Name, Pointer: true}, + Name: "Scan", + Params: []poet.Param{{Name: "src", Type: "interface{}"}}, + Results: []poet.Param{{Type: "error"}}, + Stmts: []poet.Stmt{ + poet.Switch{ + Expr: "s := src.(type)", + Cases: []poet.Case{ + { + Values: []string{"[]byte"}, + Body: []poet.Stmt{ + poet.Assign{Left: []string{"*e"}, Op: "=", Right: []string{typeCast}}, + }, + }, + { + Values: []string{"string"}, + Body: []poet.Stmt{ + poet.Assign{Left: []string{"*e"}, Op: "=", Right: []string{typeCast}}, + }, + }, + { + Body: []poet.Stmt{poet.Return{Values: []string{ + poet.CallExpr{ + Func: "fmt.Errorf", + Args: []string{fmt.Sprintf(`"unsupported scan type for %s: %%T"`, enum.Name), "src"}, + }.Render(), + }}}, + }, + }, + }, + poet.Return{Values: []string{"nil"}}, + }, + }) + + // Null type + var nullFields []poet.Field + if enum.NameTag() != "" { + nullFields = append(nullFields, poet.Field{ + Name: enum.Name, Type: enum.Name, Tag: enum.NameTag(), + }) + } else { + nullFields = append(nullFields, poet.Field{ + Name: enum.Name, Type: enum.Name, + }) + } + validComment := fmt.Sprintf("Valid is true if %s is not NULL", enum.Name) + if enum.ValidTag() != "" { + nullFields = append(nullFields, poet.Field{ + Name: "Valid", Type: "bool", Tag: enum.ValidTag(), + TrailingComment: validComment, + }) + } else { + nullFields = append(nullFields, poet.Field{ + Name: "Valid", Type: "bool", TrailingComment: validComment, + }) + } + f.Decls = append(f.Decls, poet.TypeDef{ + Name: "Null" + enum.Name, + Type: poet.Struct{Fields: nullFields}, + }) + + // Null Scan method + f.Decls = append(f.Decls, poet.Func{ + Comment: "Scan implements the Scanner interface.", + Recv: &poet.Param{Name: "ns", Type: "Null" + enum.Name, Pointer: true}, + Name: "Scan", + Params: []poet.Param{{Name: "value", Type: "interface{}"}}, + Results: []poet.Param{{Type: "error"}}, + Stmts: []poet.Stmt{ + poet.If{ + Cond: "value == nil", + Body: []poet.Stmt{ + poet.Assign{ + Left: []string{"ns." + enum.Name, "ns.Valid"}, + Op: "=", + Right: []string{`""`, "false"}, + }, + poet.Return{Values: []string{"nil"}}, + }, + }, + poet.Assign{Left: []string{"ns.Valid"}, Op: "=", Right: []string{"true"}}, + poet.Return{Values: []string{fmt.Sprintf("ns.%s.Scan(value)", enum.Name)}}, + }, + }) + + // Null Value method + f.Decls = append(f.Decls, poet.Func{ + Comment: "Value implements the driver Valuer interface.", + Recv: &poet.Param{Name: "ns", Type: "Null" + enum.Name}, + Name: "Value", + Results: []poet.Param{{Type: "driver.Value"}, {Type: "error"}}, + Stmts: []poet.Stmt{ + poet.If{ + Cond: "!ns.Valid", + Body: []poet.Stmt{poet.Return{Values: []string{"nil", "nil"}}}, + }, + poet.Return{Values: []string{ + poet.TypeCast{Type: "string", Value: "ns." + enum.Name}.Render(), + "nil", + }}, + }, + }) + + // Valid method + if g.tctx.EmitEnumValidMethod { + var caseValues []string + for _, c := range enum.Constants { + caseValues = append(caseValues, c.Name) + } + f.Decls = append(f.Decls, poet.Func{ + Recv: &poet.Param{Name: "e", Type: enum.Name}, + Name: "Valid", + Results: []poet.Param{{Type: "bool"}}, + Stmts: []poet.Stmt{ + poet.Switch{ + Expr: "e", + Cases: []poet.Case{ + {Values: caseValues, Body: []poet.Stmt{poet.Return{Values: []string{"true"}}}}, + }, + }, + poet.Return{Values: []string{"false"}}, + }, + }) + } + + // AllValues method + if g.tctx.EmitAllEnumValues { + var enumValues []string + for _, c := range enum.Constants { + enumValues = append(enumValues, c.Name) + } + f.Decls = append(f.Decls, poet.Func{ + Name: fmt.Sprintf("All%sValues", enum.Name), + Results: []poet.Param{{Type: "[]" + enum.Name}}, + Stmts: []poet.Stmt{poet.Return{Values: []string{ + poet.SliceLit{Type: enum.Name, Multiline: true, Values: enumValues}.Render(), + }}}, + }) + } + } + + // Structs + for _, s := range g.tctx.Structs { + var fields []poet.Field + for _, fld := range s.Fields { + fields = append(fields, poet.Field{ + Comment: fld.Comment, + Name: fld.Name, + Type: fld.Type, + Tag: fld.Tag(), + }) + } + f.Decls = append(f.Decls, poet.TypeDef{ + Comment: s.Comment, + Name: s.Name, + Type: poet.Struct{Fields: fields}, + }) + } +} + +func (g *CodeGenerator) addInterfaceCodeStd(f *poet.File) { + var methods []poet.Method + for _, q := range g.tctx.GoQueries { + m := g.buildInterfaceMethod(q, false) + if m != nil { + methods = append(methods, *m) + } + } + f.Decls = append(f.Decls, poet.TypeDef{ + Name: "Querier", + Type: poet.Interface{Methods: methods}, + }) + f.Decls = append(f.Decls, poet.Var{ + Name: "_", + Type: "Querier", + Value: "(*Queries)(nil)", + }) +} + +func (g *CodeGenerator) addInterfaceCodePGX(f *poet.File) { + var methods []poet.Method + for _, q := range g.tctx.GoQueries { + m := g.buildInterfaceMethod(q, true) + if m != nil { + methods = append(methods, *m) + } + } + f.Decls = append(f.Decls, poet.TypeDef{ + Name: "Querier", + Type: poet.Interface{Methods: methods}, + }) + f.Decls = append(f.Decls, poet.Var{ + Name: "_", + Type: "Querier", + Value: "(*Queries)(nil)", + }) +} + +func (g *CodeGenerator) buildInterfaceMethod(q Query, isPGX bool) *poet.Method { + var params string + var returnType string + + switch q.Cmd { + case ":one": + params = q.Arg.Pair() + returnType = fmt.Sprintf("(%s, error)", q.Ret.DefineType()) + case ":many": + params = q.Arg.Pair() + returnType = fmt.Sprintf("([]%s, error)", q.Ret.DefineType()) + case ":exec": + params = q.Arg.Pair() + returnType = "error" + case ":execrows": + params = q.Arg.Pair() + returnType = "(int64, error)" + case ":execlastid": + params = q.Arg.Pair() + returnType = "(int64, error)" + case ":execresult": + params = q.Arg.Pair() + if isPGX { + returnType = "(pgconn.CommandTag, error)" + } else { + returnType = "(sql.Result, error)" + } + case ":copyfrom": + params = q.Arg.SlicePair() + returnType = "(int64, error)" + case ":batchexec", ":batchmany", ":batchone": + params = q.Arg.SlicePair() + returnType = fmt.Sprintf("*%sBatchResults", q.MethodName) + default: + return nil + } + + if g.tctx.EmitMethodsWithDBArgument { + if params != "" { + params = "db DBTX, " + params + } else { + params = "db DBTX" + } + } + + comment := "" + for _, c := range q.Comments { + comment += "//" + c + "\n" + } + comment = strings.TrimSuffix(comment, "\n") + + // Build params list + var paramList []poet.Param + paramList = append(paramList, poet.Param{Name: "ctx", Type: "context.Context"}) + if params != "" { + paramList = append(paramList, poet.Param{Name: "", Type: params}) + } + + return &poet.Method{ + Comment: comment, + Name: q.MethodName, + Params: paramList, + Results: []poet.Param{{Type: returnType}}, + } +} + +func (g *CodeGenerator) addQueryCodeStd(f *poet.File, sourceName string) { + for _, q := range g.tctx.GoQueries { + if q.SourceName != sourceName { + continue + } + g.addQueryStd(f, q) + } +} + +func (g *CodeGenerator) addQueryStd(f *poet.File, q Query) { + // SQL constant + f.Decls = append(f.Decls, poet.Const{ + Name: q.ConstantName, + Value: fmt.Sprintf("`-- name: %s %s\n%s\n`", q.MethodName, q.Cmd, sdk.EscapeBacktick(q.SQL)), + }) + + // Arg struct if needed + if q.Arg.EmitStruct() { + var fields []poet.Field + for _, fld := range q.Arg.UniqueFields() { + fields = append(fields, poet.Field{Name: fld.Name, Type: fld.Type, Tag: fld.Tag()}) + } + f.Decls = append(f.Decls, poet.TypeDef{ + Name: q.Arg.Type(), + Type: poet.Struct{Fields: fields}, + }) + } + + // Ret struct if needed + if q.Ret.EmitStruct() { + var fields []poet.Field + for _, fld := range q.Ret.Struct.Fields { + fields = append(fields, poet.Field{Name: fld.Name, Type: fld.Type, Tag: fld.Tag()}) + } + f.Decls = append(f.Decls, poet.TypeDef{ + Name: q.Ret.Type(), + Type: poet.Struct{Fields: fields}, + }) + } + + // Method + switch q.Cmd { + case ":one": + g.addQueryOneStd(f, q) + case ":many": + g.addQueryManyStd(f, q) + case ":exec": + g.addQueryExecStd(f, q) + case ":execrows": + g.addQueryExecRowsStd(f, q) + case ":execlastid": + g.addQueryExecLastIDStd(f, q) + case ":execresult": + g.addQueryExecResultStd(f, q) + } +} + +func (g *CodeGenerator) queryComments(q Query) string { + var comment string + for _, c := range q.Comments { + comment += "//" + c + "\n" + } + return strings.TrimSuffix(comment, "\n") +} + +func (g *CodeGenerator) addQueryOneStd(f *poet.File, q Query) { + params := g.buildQueryParams(q) + + // Fall back to RawStmt for slice queries (complex handling) + if q.Arg.HasSqlcSlices() { + var stmts []poet.Stmt + + // Query exec call (complex dynamic SQL handling) + var queryExec strings.Builder + g.writeQueryExecStdCall(&queryExec, q, "row :=") + stmts = append(stmts, poet.RawStmt{Code: queryExec.String()}) + + // var (if arg and ret are different) + if q.Arg.Pair() != q.Ret.Pair() || q.Arg.DefineType() != q.Ret.DefineType() { + stmts = append(stmts, poet.VarDecl{Name: q.Ret.Name, Type: q.Ret.Type()}) + } + + // err := row.Scan() + stmts = append(stmts, poet.Assign{ + Left: []string{"err"}, + Op: ":=", + Right: []string{fmt.Sprintf("row.Scan(%s)", q.Ret.Scan())}, + }) + + // if err != nil { err = fmt.Errorf(...) } + if g.tctx.WrapErrors { + stmts = append(stmts, poet.If{ + Cond: "err != nil", + Body: []poet.Stmt{ + poet.Assign{ + Left: []string{"err"}, + Op: "=", + Right: []string{fmt.Sprintf(`fmt.Errorf("query %s: %%w", err)`, q.MethodName)}, + }, + }, + }) + } + + // return , err + stmts = append(stmts, poet.Return{Values: []string{q.Ret.ReturnName(), "err"}}) + + f.Decls = append(f.Decls, poet.Func{ + Comment: g.queryComments(q), + Recv: &poet.Param{Name: "q", Type: "Queries", Pointer: true}, + Name: q.MethodName, + Params: params, + Results: []poet.Param{{Type: q.Ret.DefineType()}, {Type: "error"}}, + Stmts: stmts, + }) + return + } + + var stmts []poet.Stmt + + // row := + stmts = append(stmts, poet.Assign{ + Left: []string{"row"}, + Op: ":=", + Right: []string{g.queryExecStdCallExpr(q)}, + }) + + // var (if arg and ret are different) + if q.Arg.Pair() != q.Ret.Pair() || q.Arg.DefineType() != q.Ret.DefineType() { + stmts = append(stmts, poet.VarDecl{Name: q.Ret.Name, Type: q.Ret.Type()}) + } + + // err := row.Scan() + stmts = append(stmts, poet.Assign{ + Left: []string{"err"}, + Op: ":=", + Right: []string{fmt.Sprintf("row.Scan(%s)", q.Ret.Scan())}, + }) + + // if err != nil { err = fmt.Errorf(...) } + if g.tctx.WrapErrors { + stmts = append(stmts, poet.If{ + Cond: "err != nil", + Body: []poet.Stmt{ + poet.Assign{ + Left: []string{"err"}, + Op: "=", + Right: []string{fmt.Sprintf(`fmt.Errorf("query %s: %%w", err)`, q.MethodName)}, + }, + }, + }) + } + + // return , err + stmts = append(stmts, poet.Return{Values: []string{q.Ret.ReturnName(), "err"}}) + + f.Decls = append(f.Decls, poet.Func{ + Comment: g.queryComments(q), + Recv: &poet.Param{Name: "q", Type: "Queries", Pointer: true}, + Name: q.MethodName, + Params: params, + Results: []poet.Param{{Type: q.Ret.DefineType()}, {Type: "error"}}, + Stmts: stmts, + }) +} + +func (g *CodeGenerator) addQueryManyStd(f *poet.File, q Query) { + params := g.buildQueryParams(q) + + // Fall back to RawStmt for slice queries (complex handling) + if q.Arg.HasSqlcSlices() { + var stmts []poet.Stmt + + // Query exec call (complex dynamic SQL handling) + var queryExec strings.Builder + g.writeQueryExecStdCall(&queryExec, q, "rows, err :=") + stmts = append(stmts, poet.RawStmt{Code: queryExec.String()}) + + // if err != nil { return nil, err } + errReturn := g.wrapErrorReturn(q, "nil") + stmts = append(stmts, poet.If{ + Cond: "err != nil", + Body: []poet.Stmt{poet.Return{Values: errReturn}}, + }) + + // defer rows.Close() + stmts = append(stmts, poet.Defer{Call: "rows.Close()"}) + + // var items [] or items := []{} + if g.tctx.EmitEmptySlices { + stmts = append(stmts, poet.Assign{ + Left: []string{"items"}, + Op: ":=", + Right: []string{fmt.Sprintf("[]%s{}", q.Ret.DefineType())}, + }) + } else { + stmts = append(stmts, poet.VarDecl{ + Name: "items", + Type: "[]" + q.Ret.DefineType(), + }) + } + + // for rows.Next() { ... } + scanErrReturn := g.wrapErrorReturn(q, "nil") + stmts = append(stmts, poet.For{ + Range: "rows.Next()", + Body: []poet.Stmt{ + poet.VarDecl{Name: q.Ret.Name, Type: q.Ret.Type()}, + poet.If{ + Init: fmt.Sprintf("err := rows.Scan(%s)", q.Ret.Scan()), + Cond: "err != nil", + Body: []poet.Stmt{poet.Return{Values: scanErrReturn}}, + }, + poet.Assign{ + Left: []string{"items"}, + Op: "=", + Right: []string{fmt.Sprintf("append(items, %s)", q.Ret.ReturnName())}, + }, + }, + }) + + // if err := rows.Close(); err != nil { return nil, err } + closeErrReturn := g.wrapErrorReturn(q, "nil") + stmts = append(stmts, poet.If{ + Init: "err := rows.Close()", + Cond: "err != nil", + Body: []poet.Stmt{poet.Return{Values: closeErrReturn}}, + }) + + // if err := rows.Err(); err != nil { return nil, err } + rowsErrReturn := g.wrapErrorReturn(q, "nil") + stmts = append(stmts, poet.If{ + Init: "err := rows.Err()", + Cond: "err != nil", + Body: []poet.Stmt{poet.Return{Values: rowsErrReturn}}, + }) + + // return items, nil + stmts = append(stmts, poet.Return{Values: []string{"items", "nil"}}) + + f.Decls = append(f.Decls, poet.Func{ + Comment: g.queryComments(q), + Recv: &poet.Param{Name: "q", Type: "Queries", Pointer: true}, + Name: q.MethodName, + Params: params, + Results: []poet.Param{{Type: "[]" + q.Ret.DefineType()}, {Type: "error"}}, + Stmts: stmts, + }) + return + } + + var stmts []poet.Stmt + + // rows, err := + stmts = append(stmts, poet.Assign{ + Left: []string{"rows", "err"}, + Op: ":=", + Right: []string{g.queryExecStdCallExpr(q)}, + }) + + // if err != nil { return nil, err } + errReturn := g.wrapErrorReturn(q, "nil") + stmts = append(stmts, poet.If{ + Cond: "err != nil", + Body: []poet.Stmt{poet.Return{Values: errReturn}}, + }) + + // defer rows.Close() + stmts = append(stmts, poet.Defer{Call: "rows.Close()"}) + + // var items [] or items := []{} + if g.tctx.EmitEmptySlices { + stmts = append(stmts, poet.Assign{ + Left: []string{"items"}, + Op: ":=", + Right: []string{fmt.Sprintf("[]%s{}", q.Ret.DefineType())}, + }) + } else { + stmts = append(stmts, poet.VarDecl{ + Name: "items", + Type: "[]" + q.Ret.DefineType(), + }) + } + + // for rows.Next() { ... } + scanErrReturn := g.wrapErrorReturn(q, "nil") + stmts = append(stmts, poet.For{ + Range: "rows.Next()", + Body: []poet.Stmt{ + poet.VarDecl{Name: q.Ret.Name, Type: q.Ret.Type()}, + poet.If{ + Init: fmt.Sprintf("err := rows.Scan(%s)", q.Ret.Scan()), + Cond: "err != nil", + Body: []poet.Stmt{poet.Return{Values: scanErrReturn}}, + }, + poet.Assign{ + Left: []string{"items"}, + Op: "=", + Right: []string{fmt.Sprintf("append(items, %s)", q.Ret.ReturnName())}, + }, + }, + }) + + // if err := rows.Close(); err != nil { return nil, err } + closeErrReturn := g.wrapErrorReturn(q, "nil") + stmts = append(stmts, poet.If{ + Init: "err := rows.Close()", + Cond: "err != nil", + Body: []poet.Stmt{poet.Return{Values: closeErrReturn}}, + }) + + // if err := rows.Err(); err != nil { return nil, err } + rowsErrReturn := g.wrapErrorReturn(q, "nil") + stmts = append(stmts, poet.If{ + Init: "err := rows.Err()", + Cond: "err != nil", + Body: []poet.Stmt{poet.Return{Values: rowsErrReturn}}, + }) + + // return items, nil + stmts = append(stmts, poet.Return{Values: []string{"items", "nil"}}) + + f.Decls = append(f.Decls, poet.Func{ + Comment: g.queryComments(q), + Recv: &poet.Param{Name: "q", Type: "Queries", Pointer: true}, + Name: q.MethodName, + Params: params, + Results: []poet.Param{{Type: "[]" + q.Ret.DefineType()}, {Type: "error"}}, + Stmts: stmts, + }) +} + +// wrapErrorReturn returns the return values for an error return. +// firstVal is the first value to return (e.g., "nil", "0"). +func (g *CodeGenerator) wrapErrorReturn(q Query, firstVal string) []string { + if g.tctx.WrapErrors { + return []string{firstVal, fmt.Sprintf(`fmt.Errorf("query %s: %%w", err)`, q.MethodName)} + } + return []string{firstVal, "err"} +} + +func (g *CodeGenerator) addQueryExecStd(f *poet.File, q Query) { + params := g.buildQueryParams(q) + + // Fall back to RawStmt for slice queries (complex handling) + if q.Arg.HasSqlcSlices() { + var stmts []poet.Stmt + + // Query exec call (complex dynamic SQL handling) + var queryExec strings.Builder + g.writeQueryExecStdCall(&queryExec, q, "_, err :=") + stmts = append(stmts, poet.RawStmt{Code: queryExec.String()}) + + // if err != nil { err = fmt.Errorf(...) } + if g.tctx.WrapErrors { + stmts = append(stmts, poet.If{ + Cond: "err != nil", + Body: []poet.Stmt{ + poet.Assign{ + Left: []string{"err"}, + Op: "=", + Right: []string{fmt.Sprintf(`fmt.Errorf("query %s: %%w", err)`, q.MethodName)}, + }, + }, + }) + } + + // return err + stmts = append(stmts, poet.Return{Values: []string{"err"}}) + + f.Decls = append(f.Decls, poet.Func{ + Comment: g.queryComments(q), + Recv: &poet.Param{Name: "q", Type: "Queries", Pointer: true}, + Name: q.MethodName, + Params: params, + Results: []poet.Param{{Type: "error"}}, + Stmts: stmts, + }) + return + } + + var stmts []poet.Stmt + stmts = append(stmts, poet.Assign{ + Left: []string{"_", "err"}, + Op: ":=", + Right: []string{g.queryExecStdCallExpr(q)}, + }) + + if g.tctx.WrapErrors { + stmts = append(stmts, poet.If{ + Cond: "err != nil", + Body: []poet.Stmt{ + poet.Assign{ + Left: []string{"err"}, + Op: "=", + Right: []string{fmt.Sprintf(`fmt.Errorf("query %s: %%w", err)`, q.MethodName)}, + }, + }, + }) + } + stmts = append(stmts, poet.Return{Values: []string{"err"}}) + + f.Decls = append(f.Decls, poet.Func{ + Comment: g.queryComments(q), + Recv: &poet.Param{Name: "q", Type: "Queries", Pointer: true}, + Name: q.MethodName, + Params: params, + Results: []poet.Param{{Type: "error"}}, + Stmts: stmts, + }) +} + +func (g *CodeGenerator) addQueryExecRowsStd(f *poet.File, q Query) { + params := g.buildQueryParams(q) + + // Fall back to RawStmt for slice queries + if q.Arg.HasSqlcSlices() { + var stmts []poet.Stmt + + // Query exec call (complex dynamic SQL handling) + var queryExec strings.Builder + g.writeQueryExecStdCall(&queryExec, q, "result, err :=") + stmts = append(stmts, poet.RawStmt{Code: queryExec.String()}) + + // if err != nil { return 0, err } + errReturn := g.wrapErrorReturn(q, "0") + stmts = append(stmts, poet.If{ + Cond: "err != nil", + Body: []poet.Stmt{poet.Return{Values: errReturn}}, + }) + + // return result.RowsAffected() + stmts = append(stmts, poet.Return{Values: []string{"result.RowsAffected()"}}) + + f.Decls = append(f.Decls, poet.Func{ + Comment: g.queryComments(q), + Recv: &poet.Param{Name: "q", Type: "Queries", Pointer: true}, + Name: q.MethodName, + Params: params, + Results: []poet.Param{{Type: "int64"}, {Type: "error"}}, + Stmts: stmts, + }) + return + } + + var stmts []poet.Stmt + + // result, err := + stmts = append(stmts, poet.Assign{ + Left: []string{"result", "err"}, + Op: ":=", + Right: []string{g.queryExecStdCallExpr(q)}, + }) + + // if err != nil { return 0, err } + errReturn := g.wrapErrorReturn(q, "0") + stmts = append(stmts, poet.If{ + Cond: "err != nil", + Body: []poet.Stmt{poet.Return{Values: errReturn}}, + }) + + // return result.RowsAffected() + stmts = append(stmts, poet.Return{Values: []string{"result.RowsAffected()"}}) + + f.Decls = append(f.Decls, poet.Func{ + Comment: g.queryComments(q), + Recv: &poet.Param{Name: "q", Type: "Queries", Pointer: true}, + Name: q.MethodName, + Params: params, + Results: []poet.Param{{Type: "int64"}, {Type: "error"}}, + Stmts: stmts, + }) +} + +func (g *CodeGenerator) addQueryExecLastIDStd(f *poet.File, q Query) { + params := g.buildQueryParams(q) + + // Fall back to RawStmt for slice queries + if q.Arg.HasSqlcSlices() { + var stmts []poet.Stmt + + // Query exec call (complex dynamic SQL handling) + var queryExec strings.Builder + g.writeQueryExecStdCall(&queryExec, q, "result, err :=") + stmts = append(stmts, poet.RawStmt{Code: queryExec.String()}) + + // if err != nil { return 0, err } + errReturn := g.wrapErrorReturn(q, "0") + stmts = append(stmts, poet.If{ + Cond: "err != nil", + Body: []poet.Stmt{poet.Return{Values: errReturn}}, + }) + + // return result.LastInsertId() + stmts = append(stmts, poet.Return{Values: []string{"result.LastInsertId()"}}) + + f.Decls = append(f.Decls, poet.Func{ + Comment: g.queryComments(q), + Recv: &poet.Param{Name: "q", Type: "Queries", Pointer: true}, + Name: q.MethodName, + Params: params, + Results: []poet.Param{{Type: "int64"}, {Type: "error"}}, + Stmts: stmts, + }) + return + } + + var stmts []poet.Stmt + + // result, err := + stmts = append(stmts, poet.Assign{ + Left: []string{"result", "err"}, + Op: ":=", + Right: []string{g.queryExecStdCallExpr(q)}, + }) + + // if err != nil { return 0, err } + errReturn := g.wrapErrorReturn(q, "0") + stmts = append(stmts, poet.If{ + Cond: "err != nil", + Body: []poet.Stmt{poet.Return{Values: errReturn}}, + }) + + // return result.LastInsertId() + stmts = append(stmts, poet.Return{Values: []string{"result.LastInsertId()"}}) + + f.Decls = append(f.Decls, poet.Func{ + Comment: g.queryComments(q), + Recv: &poet.Param{Name: "q", Type: "Queries", Pointer: true}, + Name: q.MethodName, + Params: params, + Results: []poet.Param{{Type: "int64"}, {Type: "error"}}, + Stmts: stmts, + }) +} + +func (g *CodeGenerator) addQueryExecResultStd(f *poet.File, q Query) { + params := g.buildQueryParams(q) + + // Fall back to RawStmt for slice queries + if q.Arg.HasSqlcSlices() { + var stmts []poet.Stmt + var queryExec strings.Builder + + if g.tctx.WrapErrors { + // result, err := + g.writeQueryExecStdCall(&queryExec, q, "result, err :=") + stmts = append(stmts, poet.RawStmt{Code: queryExec.String()}) + + // if err != nil { err = fmt.Errorf(...) } + stmts = append(stmts, poet.If{ + Cond: "err != nil", + Body: []poet.Stmt{ + poet.Assign{ + Left: []string{"err"}, + Op: "=", + Right: []string{fmt.Sprintf(`fmt.Errorf("query %s: %%w", err)`, q.MethodName)}, + }, + }, + }) + + // return result, err + stmts = append(stmts, poet.Return{Values: []string{"result", "err"}}) + } else { + // return + g.writeQueryExecStdCall(&queryExec, q, "return") + stmts = append(stmts, poet.RawStmt{Code: queryExec.String()}) + } + + f.Decls = append(f.Decls, poet.Func{ + Comment: g.queryComments(q), + Recv: &poet.Param{Name: "q", Type: "Queries", Pointer: true}, + Name: q.MethodName, + Params: params, + Results: []poet.Param{{Type: "sql.Result"}, {Type: "error"}}, + Stmts: stmts, + }) + return + } + + var stmts []poet.Stmt + + if g.tctx.WrapErrors { + // result, err := + stmts = append(stmts, poet.Assign{ + Left: []string{"result", "err"}, + Op: ":=", + Right: []string{g.queryExecStdCallExpr(q)}, + }) + + // if err != nil { err = fmt.Errorf(...) } + stmts = append(stmts, poet.If{ + Cond: "err != nil", + Body: []poet.Stmt{ + poet.Assign{ + Left: []string{"err"}, + Op: "=", + Right: []string{fmt.Sprintf(`fmt.Errorf("query %s: %%w", err)`, q.MethodName)}, + }, + }, + }) + + // return result, err + stmts = append(stmts, poet.Return{Values: []string{"result", "err"}}) + } else { + // return + stmts = append(stmts, poet.Return{Values: []string{g.queryExecStdCallExpr(q)}}) + } + + f.Decls = append(f.Decls, poet.Func{ + Comment: g.queryComments(q), + Recv: &poet.Param{Name: "q", Type: "Queries", Pointer: true}, + Name: q.MethodName, + Params: params, + Results: []poet.Param{{Type: "sql.Result"}, {Type: "error"}}, + Stmts: stmts, + }) +} + +func (g *CodeGenerator) buildQueryParams(q Query) []poet.Param { + var params []poet.Param + params = append(params, poet.Param{Name: "ctx", Type: "context.Context"}) + if g.tctx.EmitMethodsWithDBArgument { + params = append(params, poet.Param{Name: "db", Type: "DBTX"}) + } + if q.Arg.Pair() != "" { + // Parse the pair into name and type + pair := q.Arg.Pair() + if pair != "" { + params = append(params, poet.Param{Name: "", Type: pair}) + } + } + return params +} + +func (g *CodeGenerator) writeQueryExecStdCall(body *strings.Builder, q Query, retval string) { + db := "q.db" + if g.tctx.EmitMethodsWithDBArgument { + db = "db" + } + + if q.Arg.HasSqlcSlices() { + g.writeQuerySliceExec(body, q, retval, db, false) + return + } + + fmt.Fprintf(body, "\t%s %s\n", retval, g.queryExecStdCallExpr(q)) +} + +// queryExecStdCallExpr returns the method call expression for a query. +func (g *CodeGenerator) queryExecStdCallExpr(q Query) string { + db := "q.db" + if g.tctx.EmitMethodsWithDBArgument { + db = "db" + } + + var method string + switch q.Cmd { + case ":one": + if g.tctx.EmitPreparedQueries { + method = "q.queryRow" + } else { + method = db + ".QueryRowContext" + } + case ":many": + if g.tctx.EmitPreparedQueries { + method = "q.query" + } else { + method = db + ".QueryContext" + } + default: + if g.tctx.EmitPreparedQueries { + method = "q.exec" + } else { + method = db + ".ExecContext" + } + } + + params := q.Arg.Params() + if params != "" { + params = ", " + params + } + + if g.tctx.EmitPreparedQueries { + return fmt.Sprintf("%s(ctx, q.%s, %s%s)", method, q.FieldName, q.ConstantName, params) + } + return fmt.Sprintf("%s(ctx, %s%s)", method, q.ConstantName, params) +} + +func (g *CodeGenerator) buildQuerySliceExecStmts(q Query, retval, db string) []poet.Stmt { + var stmts []poet.Stmt + + stmts = append(stmts, poet.Assign{ + Left: []string{"query"}, Op: ":=", Right: []string{q.ConstantName}, + }) + stmts = append(stmts, poet.VarDecl{Name: "queryParams", Type: "[]interface{}"}) + + // Helper to build slice handling statements + buildSliceHandling := func(varName, colName string) poet.Stmt { + return poet.If{ + Cond: fmt.Sprintf("len(%s) > 0", varName), + Body: []poet.Stmt{ + poet.For{ + Range: fmt.Sprintf("_, v := range %s", varName), + Body: []poet.Stmt{ + poet.Assign{ + Left: []string{"queryParams"}, Op: "=", + Right: []string{"append(queryParams, v)"}, + }, + }, + }, + poet.Assign{ + Left: []string{"query"}, Op: "=", + Right: []string{fmt.Sprintf(`strings.Replace(query, "/*SLICE:%s*/?", strings.Repeat(",?", len(%s))[1:], 1)`, colName, varName)}, + }, + }, + Else: []poet.Stmt{ + poet.Assign{ + Left: []string{"query"}, Op: "=", + Right: []string{fmt.Sprintf(`strings.Replace(query, "/*SLICE:%s*/?", "NULL", 1)`, colName)}, + }, + }, + } + } + + if q.Arg.Struct != nil { + for _, fld := range q.Arg.Struct.Fields { + varName := q.Arg.VariableForField(fld) + if fld.HasSqlcSlice() { + stmts = append(stmts, buildSliceHandling(varName, fld.Column.Name)) + } else { + stmts = append(stmts, poet.Assign{ + Left: []string{"queryParams"}, Op: "=", + Right: []string{fmt.Sprintf("append(queryParams, %s)", varName)}, + }) + } + } + } else { + stmts = append(stmts, buildSliceHandling(q.Arg.Name, q.Arg.Column.Name)) + } + + var method string + switch q.Cmd { + case ":one": + if g.tctx.EmitPreparedQueries { + method = "q.queryRow" + } else { + method = db + ".QueryRowContext" + } + case ":many": + if g.tctx.EmitPreparedQueries { + method = "q.query" + } else { + method = db + ".QueryContext" + } + default: + if g.tctx.EmitPreparedQueries { + method = "q.exec" + } else { + method = db + ".ExecContext" + } + } + + var callExpr string + if g.tctx.EmitPreparedQueries { + callExpr = fmt.Sprintf("%s(ctx, nil, query, queryParams...)", method) + } else { + callExpr = fmt.Sprintf("%s(ctx, query, queryParams...)", method) + } + + // Parse retval to determine assignment type + parts := strings.SplitN(retval, " ", 2) + if len(parts) == 2 { + lhs := strings.Split(strings.TrimSpace(parts[0]), ",") + for i := range lhs { + lhs[i] = strings.TrimSpace(lhs[i]) + } + op := strings.TrimSpace(parts[1]) + stmts = append(stmts, poet.Assign{Left: lhs, Op: op, Right: []string{callExpr}}) + } else { + // Simple return or call + stmts = append(stmts, poet.RawStmt{Code: fmt.Sprintf("\t%s %s\n", retval, callExpr)}) + } + + return stmts +} + +func (g *CodeGenerator) writeQuerySliceExec(body *strings.Builder, q Query, retval, db string, isPGX bool) { + stmts := g.buildQuerySliceExecStmts(q, retval, db) + for _, stmt := range stmts { + body.WriteString(poet.RenderStmt(stmt, "\t")) + } +} + +func (g *CodeGenerator) addQueryCodePGX(f *poet.File, sourceName string) { + for _, q := range g.tctx.GoQueries { + if q.SourceName != sourceName { + continue + } + if strings.HasPrefix(q.Cmd, ":batch") { + // Batch queries are fully handled in batch.go + continue + } + if q.Cmd == metadata.CmdCopyFrom { + // For copyfrom, only emit the struct definition (implementation is in copyfrom.go) + if q.Arg.EmitStruct() { + var fields []poet.Field + for _, fld := range q.Arg.Struct.Fields { + fields = append(fields, poet.Field{Name: fld.Name, Type: fld.Type, Tag: fld.Tag()}) + } + f.Decls = append(f.Decls, poet.TypeDef{ + Name: q.Arg.Type(), + Type: poet.Struct{Fields: fields}, + }) + } + continue + } + g.addQueryPGX(f, q) + } +} + +func (g *CodeGenerator) addQueryPGX(f *poet.File, q Query) { + // SQL constant + f.Decls = append(f.Decls, poet.Const{ + Name: q.ConstantName, + Value: fmt.Sprintf("`-- name: %s %s\n%s\n`", q.MethodName, q.Cmd, sdk.EscapeBacktick(q.SQL)), + }) + + // Arg struct if needed + if q.Arg.EmitStruct() { + var fields []poet.Field + for _, fld := range q.Arg.Struct.Fields { + fields = append(fields, poet.Field{Name: fld.Name, Type: fld.Type, Tag: fld.Tag()}) + } + f.Decls = append(f.Decls, poet.TypeDef{ + Name: q.Arg.Type(), + Type: poet.Struct{Fields: fields}, + }) + } + + // Ret struct if needed + if q.Ret.EmitStruct() { + var fields []poet.Field + for _, fld := range q.Ret.Struct.Fields { + fields = append(fields, poet.Field{Name: fld.Name, Type: fld.Type, Tag: fld.Tag()}) + } + f.Decls = append(f.Decls, poet.TypeDef{ + Name: q.Ret.Type(), + Type: poet.Struct{Fields: fields}, + }) + } + + // Method + switch q.Cmd { + case ":one": + g.addQueryOnePGX(f, q) + case ":many": + g.addQueryManyPGX(f, q) + case ":exec": + g.addQueryExecPGX(f, q) + case ":execrows": + g.addQueryExecRowsPGX(f, q) + case ":execresult": + g.addQueryExecResultPGX(f, q) + } +} + +func (g *CodeGenerator) buildQueryParamsPGX(q Query) []poet.Param { + var params []poet.Param + params = append(params, poet.Param{Name: "ctx", Type: "context.Context"}) + if g.tctx.EmitMethodsWithDBArgument { + params = append(params, poet.Param{Name: "db", Type: "DBTX"}) + } + if q.Arg.Pair() != "" { + params = append(params, poet.Param{Name: "", Type: q.Arg.Pair()}) + } + return params +} + +func (g *CodeGenerator) addQueryOnePGX(f *poet.File, q Query) { + db := "q.db" + if g.tctx.EmitMethodsWithDBArgument { + db = "db" + } + + qParams := q.Arg.Params() + if qParams != "" { + qParams = ", " + qParams + } + + var stmts []poet.Stmt + stmts = append(stmts, poet.Assign{ + Left: []string{"row"}, + Op: ":=", + Right: []string{fmt.Sprintf("%s.QueryRow(ctx, %s%s)", db, q.ConstantName, qParams)}, + }) + + if q.Arg.Pair() != q.Ret.Pair() || q.Arg.DefineType() != q.Ret.DefineType() { + stmts = append(stmts, poet.VarDecl{Name: q.Ret.Name, Type: q.Ret.Type()}) + } + + stmts = append(stmts, poet.Assign{ + Left: []string{"err"}, + Op: ":=", + Right: []string{fmt.Sprintf("row.Scan(%s)", q.Ret.Scan())}, + }) + + if g.tctx.WrapErrors { + stmts = append(stmts, poet.If{ + Cond: "err != nil", + Body: []poet.Stmt{poet.Assign{ + Left: []string{"err"}, + Op: "=", + Right: []string{fmt.Sprintf("fmt.Errorf(\"query %s: %%w\", err)", q.MethodName)}, + }}, + }) + } + + stmts = append(stmts, poet.Return{Values: []string{q.Ret.ReturnName(), "err"}}) + + params := g.buildQueryParamsPGX(q) + f.Decls = append(f.Decls, poet.Func{ + Comment: g.queryComments(q), + Recv: &poet.Param{Name: "q", Type: "*Queries"}, + Name: q.MethodName, + Params: params, + Results: []poet.Param{{Type: q.Ret.DefineType()}, {Type: "error"}}, + Stmts: stmts, + }) +} + +func (g *CodeGenerator) addQueryManyPGX(f *poet.File, q Query) { + db := "q.db" + if g.tctx.EmitMethodsWithDBArgument { + db = "db" + } + + qParams := q.Arg.Params() + if qParams != "" { + qParams = ", " + qParams + } + + // Build error return value + var errReturn []string + if g.tctx.WrapErrors { + errReturn = []string{"nil", fmt.Sprintf("fmt.Errorf(\"query %s: %%w\", err)", q.MethodName)} + } else { + errReturn = []string{"nil", "err"} + } + + var stmts []poet.Stmt + stmts = append(stmts, poet.Assign{ + Left: []string{"rows", "err"}, + Op: ":=", + Right: []string{fmt.Sprintf("%s.Query(ctx, %s%s)", db, q.ConstantName, qParams)}, + }) + + stmts = append(stmts, poet.If{ + Cond: "err != nil", + Body: []poet.Stmt{poet.Return{Values: errReturn}}, + }) + + stmts = append(stmts, poet.Defer{Call: "rows.Close()"}) + + if g.tctx.EmitEmptySlices { + stmts = append(stmts, poet.Assign{ + Left: []string{"items"}, + Op: ":=", + Right: []string{fmt.Sprintf("[]%s{}", q.Ret.DefineType())}, + }) + } else { + stmts = append(stmts, poet.VarDecl{Name: "items", Type: "[]" + q.Ret.DefineType()}) + } + + // For loop body + var forBody []poet.Stmt + forBody = append(forBody, poet.VarDecl{Name: q.Ret.Name, Type: q.Ret.Type()}) + forBody = append(forBody, poet.If{ + Cond: fmt.Sprintf("err := rows.Scan(%s); err != nil", q.Ret.Scan()), + Body: []poet.Stmt{poet.Return{Values: errReturn}}, + }) + forBody = append(forBody, poet.Assign{ + Left: []string{"items"}, + Op: "=", + Right: []string{fmt.Sprintf("append(items, %s)", q.Ret.ReturnName())}, + }) + + stmts = append(stmts, poet.For{Cond: "rows.Next()", Body: forBody}) + + stmts = append(stmts, poet.If{ + Cond: "err := rows.Err(); err != nil", + Body: []poet.Stmt{poet.Return{Values: errReturn}}, + }) + + stmts = append(stmts, poet.Return{Values: []string{"items", "nil"}}) + + params := g.buildQueryParamsPGX(q) + f.Decls = append(f.Decls, poet.Func{ + Comment: g.queryComments(q), + Recv: &poet.Param{Name: "q", Type: "*Queries"}, + Name: q.MethodName, + Params: params, + Results: []poet.Param{{Type: "[]" + q.Ret.DefineType()}, {Type: "error"}}, + Stmts: stmts, + }) +} + +func (g *CodeGenerator) addQueryExecPGX(f *poet.File, q Query) { + db := "q.db" + if g.tctx.EmitMethodsWithDBArgument { + db = "db" + } + + qParams := q.Arg.Params() + if qParams != "" { + qParams = ", " + qParams + } + + var stmts []poet.Stmt + stmts = append(stmts, poet.Assign{ + Left: []string{"_", "err"}, + Op: ":=", + Right: []string{fmt.Sprintf("%s.Exec(ctx, %s%s)", db, q.ConstantName, qParams)}, + }) + + if g.tctx.WrapErrors { + stmts = append(stmts, poet.If{ + Cond: "err != nil", + Body: []poet.Stmt{poet.Return{Values: []string{fmt.Sprintf("fmt.Errorf(\"query %s: %%w\", err)", q.MethodName)}}}, + }) + stmts = append(stmts, poet.Return{Values: []string{"nil"}}) + } else { + stmts = append(stmts, poet.Return{Values: []string{"err"}}) + } + + params := g.buildQueryParamsPGX(q) + f.Decls = append(f.Decls, poet.Func{ + Comment: g.queryComments(q), + Recv: &poet.Param{Name: "q", Type: "*Queries"}, + Name: q.MethodName, + Params: params, + Results: []poet.Param{{Type: "error"}}, + Stmts: stmts, + }) +} + +func (g *CodeGenerator) addQueryExecRowsPGX(f *poet.File, q Query) { + db := "q.db" + if g.tctx.EmitMethodsWithDBArgument { + db = "db" + } + + qParams := q.Arg.Params() + if qParams != "" { + qParams = ", " + qParams + } + + // Build error return value + var errReturn []string + if g.tctx.WrapErrors { + errReturn = []string{"0", fmt.Sprintf("fmt.Errorf(\"query %s: %%w\", err)", q.MethodName)} + } else { + errReturn = []string{"0", "err"} + } + + var stmts []poet.Stmt + stmts = append(stmts, poet.Assign{ + Left: []string{"result", "err"}, + Op: ":=", + Right: []string{fmt.Sprintf("%s.Exec(ctx, %s%s)", db, q.ConstantName, qParams)}, + }) + + stmts = append(stmts, poet.If{ + Cond: "err != nil", + Body: []poet.Stmt{poet.Return{Values: errReturn}}, + }) + + stmts = append(stmts, poet.Return{Values: []string{"result.RowsAffected()", "nil"}}) + + params := g.buildQueryParamsPGX(q) + f.Decls = append(f.Decls, poet.Func{ + Comment: g.queryComments(q), + Recv: &poet.Param{Name: "q", Type: "*Queries"}, + Name: q.MethodName, + Params: params, + Results: []poet.Param{{Type: "int64"}, {Type: "error"}}, + Stmts: stmts, + }) +} + +func (g *CodeGenerator) addQueryExecResultPGX(f *poet.File, q Query) { + db := "q.db" + if g.tctx.EmitMethodsWithDBArgument { + db = "db" + } + + qParams := q.Arg.Params() + if qParams != "" { + qParams = ", " + qParams + } + + var stmts []poet.Stmt + if g.tctx.WrapErrors { + stmts = append(stmts, poet.Assign{ + Left: []string{"result", "err"}, + Op: ":=", + Right: []string{fmt.Sprintf("%s.Exec(ctx, %s%s)", db, q.ConstantName, qParams)}, + }) + stmts = append(stmts, poet.If{ + Cond: "err != nil", + Body: []poet.Stmt{poet.Assign{ + Left: []string{"err"}, + Op: "=", + Right: []string{fmt.Sprintf("fmt.Errorf(\"query %s: %%w\", err)", q.MethodName)}, + }}, + }) + stmts = append(stmts, poet.Return{Values: []string{"result", "err"}}) + } else { + stmts = append(stmts, poet.Return{Values: []string{fmt.Sprintf("%s.Exec(ctx, %s%s)", db, q.ConstantName, qParams)}}) + } + + params := g.buildQueryParamsPGX(q) + f.Decls = append(f.Decls, poet.Func{ + Comment: g.queryComments(q), + Recv: &poet.Param{Name: "q", Type: "*Queries"}, + Name: q.MethodName, + Params: params, + Results: []poet.Param{{Type: "pgconn.CommandTag"}, {Type: "error"}}, + Stmts: stmts, + }) +} + +func (g *CodeGenerator) addCopyFromCodePGX(f *poet.File) { + for _, q := range g.tctx.GoQueries { + if q.Cmd != metadata.CmdCopyFrom { + continue + } + + iterName := "iteratorFor" + q.MethodName + + // Iterator struct + f.Decls = append(f.Decls, poet.TypeDef{ + Comment: fmt.Sprintf("iteratorFor%s implements pgx.CopyFromSource.", q.MethodName), + Name: iterName, + Type: poet.Struct{ + Fields: []poet.Field{ + {Name: "rows", Type: "[]" + q.Arg.DefineType()}, + {Name: "skippedFirstNextCall", Type: "bool"}, + }, + }, + }) + + // Next method + f.Decls = append(f.Decls, poet.Func{ + Recv: &poet.Param{Name: "r", Type: iterName, Pointer: true}, + Name: "Next", + Results: []poet.Param{{Type: "bool"}}, + Stmts: []poet.Stmt{ + poet.If{ + Cond: "len(r.rows) == 0", + Body: []poet.Stmt{poet.Return{Values: []string{"false"}}}, + }, + poet.If{ + Cond: "!r.skippedFirstNextCall", + Body: []poet.Stmt{ + poet.Assign{ + Left: []string{"r.skippedFirstNextCall"}, Op: "=", + Right: []string{"true"}, + }, + poet.Return{Values: []string{"true"}}, + }, + }, + poet.Assign{Left: []string{"r.rows"}, Op: "=", Right: []string{"r.rows[1:]"}}, + poet.Return{Values: []string{"len(r.rows) > 0"}}, + }, + }) + + // Values method + var valuesBody strings.Builder + valuesBody.WriteString("\treturn []interface{}{\n") + if q.Arg.Struct != nil { + for _, fld := range q.Arg.Struct.Fields { + fmt.Fprintf(&valuesBody, "\t\tr.rows[0].%s,\n", fld.Name) + } + } else { + valuesBody.WriteString("\t\tr.rows[0],\n") + } + valuesBody.WriteString("\t}, nil\n") + + f.Decls = append(f.Decls, poet.Func{ + Recv: &poet.Param{Name: "r", Type: iterName}, + Name: "Values", + Results: []poet.Param{{Type: "[]interface{}"}, {Type: "error"}}, + Stmts: []poet.Stmt{poet.RawStmt{Code: valuesBody.String()}}, + }) + + // Err method + f.Decls = append(f.Decls, poet.Func{ + Recv: &poet.Param{Name: "r", Type: iterName}, + Name: "Err", + Results: []poet.Param{{Type: "error"}}, + Stmts: []poet.Stmt{poet.Return{Values: []string{"nil"}}}, + }) + + // Main method + db := "q.db" + var params []poet.Param + params = append(params, poet.Param{Name: "ctx", Type: "context.Context"}) + if g.tctx.EmitMethodsWithDBArgument { + db = "db" + params = append(params, poet.Param{Name: "db", Type: "DBTX"}) + } + params = append(params, poet.Param{Name: "", Type: q.Arg.SlicePair()}) + + body := fmt.Sprintf("\treturn %s.CopyFrom(ctx, %s, %s, &%s{rows: %s})\n", + db, q.TableIdentifierAsGoSlice(), q.Arg.ColumnNamesAsGoSlice(), iterName, q.Arg.Name) + + f.Decls = append(f.Decls, poet.Func{ + Comment: g.queryComments(q), + Recv: &poet.Param{Name: "q", Type: "*Queries"}, + Name: q.MethodName, + Params: params, + Results: []poet.Param{{Type: "int64"}, {Type: "error"}}, + Stmts: []poet.Stmt{poet.RawStmt{Code: body}}, + }) + } +} + +func (g *CodeGenerator) addCopyFromCodeMySQL(f *poet.File) { + for _, q := range g.tctx.GoQueries { + if q.Cmd != metadata.CmdCopyFrom { + continue + } + + // Reader handler sequence + f.Decls = append(f.Decls, poet.Var{ + Name: fmt.Sprintf("readerHandlerSequenceFor%s", q.MethodName), + Type: "uint32", + Value: "1", + }) + + // Convert rows function + var convertBody strings.Builder + fmt.Fprintf(&convertBody, "\te := mysqltsv.NewEncoder(w, %d, nil)\n", len(q.Arg.CopyFromMySQLFields())) + fmt.Fprintf(&convertBody, "\tfor _, row := range %s {\n", q.Arg.Name) + + for _, fld := range q.Arg.CopyFromMySQLFields() { + accessor := "row" + if q.Arg.Struct != nil { + accessor = "row." + fld.Name + } + switch fld.Type { + case "string": + fmt.Fprintf(&convertBody, "\t\te.AppendString(%s)\n", accessor) + case "[]byte", "json.RawMessage": + fmt.Fprintf(&convertBody, "\t\te.AppendBytes(%s)\n", accessor) + default: + fmt.Fprintf(&convertBody, "\t\te.AppendValue(%s)\n", accessor) + } + } + + convertBody.WriteString("\t}\n") + convertBody.WriteString("\tw.CloseWithError(e.Close())\n") + + f.Decls = append(f.Decls, poet.Func{ + Name: fmt.Sprintf("convertRowsFor%s", q.MethodName), + Params: []poet.Param{{Name: "w", Type: "*io.PipeWriter"}, {Name: "", Type: q.Arg.SlicePair()}}, + Stmts: []poet.Stmt{poet.RawStmt{Code: convertBody.String()}}, + }) + + // Main method + db := "q.db" + var params []poet.Param + params = append(params, poet.Param{Name: "ctx", Type: "context.Context"}) + if g.tctx.EmitMethodsWithDBArgument { + db = "db" + params = append(params, poet.Param{Name: "db", Type: "DBTX"}) + } + params = append(params, poet.Param{Name: "", Type: q.Arg.SlicePair()}) + + var colNames []string + for _, name := range q.Arg.ColumnNames() { + colNames = append(colNames, name) + } + colList := strings.Join(colNames, ", ") + + var mainStmts []poet.Stmt + mainStmts = append(mainStmts, poet.Assign{ + Left: []string{"pr", "pw"}, Op: ":=", Right: []string{"io.Pipe()"}, + }) + mainStmts = append(mainStmts, poet.Defer{Call: "pr.Close()"}) + mainStmts = append(mainStmts, poet.Assign{ + Left: []string{"rh"}, + Op: ":=", + Right: []string{fmt.Sprintf("fmt.Sprintf(\"%s_%%d\", atomic.AddUint32(&readerHandlerSequenceFor%s, 1))", q.MethodName, q.MethodName)}, + }) + mainStmts = append(mainStmts, poet.CallStmt{Call: "mysql.RegisterReaderHandler(rh, func() io.Reader { return pr })"}) + mainStmts = append(mainStmts, poet.Defer{Call: "mysql.DeregisterReaderHandler(rh)"}) + mainStmts = append(mainStmts, poet.GoStmt{Call: fmt.Sprintf("convertRowsFor%s(pw, %s)", q.MethodName, q.Arg.Name)}) + // Add comment explaining string interpolation requirement + mainStmts = append(mainStmts, poet.RawStmt{Code: "\t// The string interpolation is necessary because LOAD DATA INFILE requires\n\t// the file name to be given as a literal string.\n"}) + mainStmts = append(mainStmts, poet.Assign{ + Left: []string{"result", "err"}, + Op: ":=", + Right: []string{fmt.Sprintf("%s.ExecContext(ctx, fmt.Sprintf(\"LOAD DATA LOCAL INFILE '%%s' INTO TABLE %s %%s (%s)\", \"Reader::\"+rh, mysqltsv.Escaping))", db, q.TableIdentifierForMySQL(), colList)}, + }) + mainStmts = append(mainStmts, poet.If{ + Cond: "err != nil", + Body: []poet.Stmt{poet.Return{Values: []string{"0", "err"}}}, + }) + mainStmts = append(mainStmts, poet.Return{Values: []string{"result.RowsAffected()"}}) + + comment := g.queryComments(q) + comment += fmt.Sprintf("\n// %s uses MySQL's LOAD DATA LOCAL INFILE and is not atomic.", q.MethodName) + comment += "\n//\n// Errors and duplicate keys are treated as warnings and insertion will" + comment += "\n// continue, even without an error for some cases. Use this in a transaction" + comment += "\n// and use SHOW WARNINGS to check for any problems and roll back if you want to." + comment += "\n//\n// Check the documentation for more information:" + comment += "\n// https://dev.mysql.com/doc/refman/8.0/en/load-data.html#load-data-error-handling" + + f.Decls = append(f.Decls, poet.Func{ + Comment: comment, + Recv: &poet.Param{Name: "q", Type: "*Queries"}, + Name: q.MethodName, + Params: params, + Results: []poet.Param{{Type: "int64"}, {Type: "error"}}, + Stmts: mainStmts, + }) + } +} + +func (g *CodeGenerator) addBatchCodePGX(f *poet.File) { + // Error variable + f.Decls = append(f.Decls, poet.VarBlock{ + Vars: []poet.Var{ + {Name: "ErrBatchAlreadyClosed", Value: `errors.New("batch already closed")`}, + }, + }) + + for _, q := range g.tctx.GoQueries { + if !strings.HasPrefix(q.Cmd, ":batch") { + continue + } + + // SQL constant + f.Decls = append(f.Decls, poet.Const{ + Name: q.ConstantName, + Value: fmt.Sprintf("`-- name: %s %s\n%s\n`", q.MethodName, q.Cmd, sdk.EscapeBacktick(q.SQL)), + }) + + // BatchResults struct + f.Decls = append(f.Decls, poet.TypeDef{ + Name: q.MethodName + "BatchResults", + Type: poet.Struct{ + Fields: []poet.Field{ + {Name: "br", Type: "pgx.BatchResults"}, + {Name: "tot", Type: "int"}, + {Name: "closed", Type: "bool"}, + }, + }, + }) + + // Arg struct if needed + if q.Arg.Struct != nil { + var fields []poet.Field + for _, fld := range q.Arg.Struct.Fields { + fields = append(fields, poet.Field{Name: fld.Name, Type: fld.Type, Tag: fld.Tag()}) + } + f.Decls = append(f.Decls, poet.TypeDef{ + Name: q.Arg.Type(), + Type: poet.Struct{Fields: fields}, + }) + } + + // Ret struct if needed + if q.Ret.EmitStruct() { + var fields []poet.Field + for _, fld := range q.Ret.Struct.Fields { + fields = append(fields, poet.Field{Name: fld.Name, Type: fld.Type, Tag: fld.Tag()}) + } + f.Decls = append(f.Decls, poet.TypeDef{ + Name: q.Ret.Type(), + Type: poet.Struct{Fields: fields}, + }) + } + + // Main batch method + db := "q.db" + var params []poet.Param + params = append(params, poet.Param{Name: "ctx", Type: "context.Context"}) + if g.tctx.EmitMethodsWithDBArgument { + db = "db" + params = append(params, poet.Param{Name: "db", Type: "DBTX"}) + } + params = append(params, poet.Param{Name: "", Type: q.Arg.SlicePair()}) + + var mainBody strings.Builder + mainBody.WriteString("\tbatch := &pgx.Batch{}\n") + fmt.Fprintf(&mainBody, "\tfor _, a := range %s {\n", q.Arg.Name) + mainBody.WriteString("\t\tvals := []interface{}{\n") + if q.Arg.Struct != nil { + for _, fld := range q.Arg.Struct.Fields { + fmt.Fprintf(&mainBody, "\t\t\ta.%s,\n", fld.Name) + } + } else { + mainBody.WriteString("\t\t\ta,\n") + } + mainBody.WriteString("\t\t}\n") + fmt.Fprintf(&mainBody, "\t\tbatch.Queue(%s, vals...)\n", q.ConstantName) + mainBody.WriteString("\t}\n") + fmt.Fprintf(&mainBody, "\tbr := %s.SendBatch(ctx, batch)\n", db) + fmt.Fprintf(&mainBody, "\treturn &%sBatchResults{br, len(%s), false}\n", q.MethodName, q.Arg.Name) + + f.Decls = append(f.Decls, poet.Func{ + Comment: g.queryComments(q), + Recv: &poet.Param{Name: "q", Type: "*Queries"}, + Name: q.MethodName, + Params: params, + Results: []poet.Param{{Type: "*" + q.MethodName + "BatchResults"}}, + Stmts: []poet.Stmt{poet.RawStmt{Code: mainBody.String()}}, + }) + + // Result method based on command type + switch q.Cmd { + case ":batchexec": + var execForBody []poet.Stmt + execForBody = append(execForBody, poet.If{ + Cond: "b.closed", + Body: []poet.Stmt{ + poet.If{ + Cond: "f != nil", + Body: []poet.Stmt{poet.CallStmt{Call: "f(t, ErrBatchAlreadyClosed)"}}, + }, + poet.Continue{}, + }, + }) + execForBody = append(execForBody, poet.Assign{ + Left: []string{"_", "err"}, Op: ":=", Right: []string{"b.br.Exec()"}, + }) + execForBody = append(execForBody, poet.If{ + Cond: "f != nil", + Body: []poet.Stmt{poet.CallStmt{Call: "f(t, err)"}}, + }) + + f.Decls = append(f.Decls, poet.Func{ + Recv: &poet.Param{Name: "b", Type: "*" + q.MethodName + "BatchResults"}, + Name: "Exec", + Params: []poet.Param{{Name: "f", Type: "func(int, error)"}}, + Stmts: []poet.Stmt{ + poet.Defer{Call: "b.br.Close()"}, + poet.For{Init: "t := 0", Cond: "t < b.tot", Post: "t++", Body: execForBody}, + }, + }) + + case ":batchmany": + // Build inner function literal body + var innerFuncBody []poet.Stmt + innerFuncBody = append(innerFuncBody, poet.Assign{ + Left: []string{"rows", "err"}, Op: ":=", Right: []string{"b.br.Query()"}, + }) + innerFuncBody = append(innerFuncBody, poet.If{ + Cond: "err != nil", + Body: []poet.Stmt{poet.Return{Values: []string{"err"}}}, + }) + innerFuncBody = append(innerFuncBody, poet.Defer{Call: "rows.Close()"}) + + // Build rows loop body + var rowsLoopBody []poet.Stmt + rowsLoopBody = append(rowsLoopBody, poet.VarDecl{Name: q.Ret.Name, Type: q.Ret.Type()}) + rowsLoopBody = append(rowsLoopBody, poet.If{ + Cond: fmt.Sprintf("err := rows.Scan(%s); err != nil", q.Ret.Scan()), + Body: []poet.Stmt{poet.Return{Values: []string{"err"}}}, + }) + rowsLoopBody = append(rowsLoopBody, poet.Assign{ + Left: []string{"items"}, Op: "=", Right: []string{fmt.Sprintf("append(items, %s)", q.Ret.ReturnName())}, + }) + + innerFuncBody = append(innerFuncBody, poet.For{Cond: "rows.Next()", Body: rowsLoopBody}) + innerFuncBody = append(innerFuncBody, poet.Return{Values: []string{"rows.Err()"}}) + + // Build function literal with proper indentation (3 tabs for body inside for loop inside func) + innerFunc := poet.FuncLit{ + Results: []poet.Param{{Type: "error"}}, + Body: innerFuncBody, + Indent: "\t\t\t", + } + + // Build main for loop body + var manyForBody []poet.Stmt + if g.tctx.EmitEmptySlices { + manyForBody = append(manyForBody, poet.Assign{ + Left: []string{"items"}, Op: ":=", Right: []string{fmt.Sprintf("[]%s{}", q.Ret.DefineType())}, + }) + } else { + manyForBody = append(manyForBody, poet.VarDecl{Name: "items", Type: "[]" + q.Ret.DefineType()}) + } + manyForBody = append(manyForBody, poet.If{ + Cond: "b.closed", + Body: []poet.Stmt{ + poet.If{ + Cond: "f != nil", + Body: []poet.Stmt{poet.CallStmt{Call: "f(t, items, ErrBatchAlreadyClosed)"}}, + }, + poet.Continue{}, + }, + }) + manyForBody = append(manyForBody, poet.Assign{ + Left: []string{"err"}, Op: ":=", Right: []string{innerFunc.Render() + "()"}, + }) + manyForBody = append(manyForBody, poet.If{ + Cond: "f != nil", + Body: []poet.Stmt{poet.CallStmt{Call: "f(t, items, err)"}}, + }) + + f.Decls = append(f.Decls, poet.Func{ + Recv: &poet.Param{Name: "b", Type: "*" + q.MethodName + "BatchResults"}, + Name: "Query", + Params: []poet.Param{{Name: "f", Type: fmt.Sprintf("func(int, []%s, error)", q.Ret.DefineType())}}, + Stmts: []poet.Stmt{ + poet.Defer{Call: "b.br.Close()"}, + poet.For{Init: "t := 0", Cond: "t < b.tot", Post: "t++", Body: manyForBody}, + }, + }) + + case ":batchone": + // Build closed error value based on return type + closedRetVal := q.Ret.Name + if q.Ret.IsPointer() { + closedRetVal = "nil" + } + + // Build for loop body + var oneForBody []poet.Stmt + oneForBody = append(oneForBody, poet.VarDecl{Name: q.Ret.Name, Type: q.Ret.Type()}) + oneForBody = append(oneForBody, poet.If{ + Cond: "b.closed", + Body: []poet.Stmt{ + poet.If{ + Cond: "f != nil", + Body: []poet.Stmt{poet.CallStmt{Call: fmt.Sprintf("f(t, %s, ErrBatchAlreadyClosed)", closedRetVal)}}, + }, + poet.Continue{}, + }, + }) + oneForBody = append(oneForBody, poet.Assign{ + Left: []string{"row"}, Op: ":=", Right: []string{"b.br.QueryRow()"}, + }) + oneForBody = append(oneForBody, poet.Assign{ + Left: []string{"err"}, Op: ":=", Right: []string{fmt.Sprintf("row.Scan(%s)", q.Ret.Scan())}, + }) + oneForBody = append(oneForBody, poet.If{ + Cond: "f != nil", + Body: []poet.Stmt{poet.CallStmt{Call: fmt.Sprintf("f(t, %s, err)", q.Ret.ReturnName())}}, + }) + + f.Decls = append(f.Decls, poet.Func{ + Recv: &poet.Param{Name: "b", Type: "*" + q.MethodName + "BatchResults"}, + Name: "QueryRow", + Params: []poet.Param{{Name: "f", Type: fmt.Sprintf("func(int, %s, error)", q.Ret.DefineType())}}, + Stmts: []poet.Stmt{ + poet.Defer{Call: "b.br.Close()"}, + poet.For{Init: "t := 0", Cond: "t < b.tot", Post: "t++", Body: oneForBody}, + }, + }) + } + + // Close method + f.Decls = append(f.Decls, poet.Func{ + Recv: &poet.Param{Name: "b", Type: q.MethodName + "BatchResults", Pointer: true}, + Name: "Close", + Results: []poet.Param{{Type: "error"}}, + Stmts: []poet.Stmt{ + poet.Assign{Left: []string{"b.closed"}, Op: "=", Right: []string{"true"}}, + poet.Return{Values: []string{"b.br.Close()"}}, + }, + }) + } +} diff --git a/internal/codegen/golang/template.go b/internal/codegen/golang/template.go deleted file mode 100644 index 0aa7c9fa6a..0000000000 --- a/internal/codegen/golang/template.go +++ /dev/null @@ -1,7 +0,0 @@ -package golang - -import "embed" - -//go:embed templates/* -//go:embed templates/*/* -var templates embed.FS diff --git a/internal/codegen/golang/templates/go-sql-driver-mysql/copyfromCopy.tmpl b/internal/codegen/golang/templates/go-sql-driver-mysql/copyfromCopy.tmpl deleted file mode 100644 index e21475b148..0000000000 --- a/internal/codegen/golang/templates/go-sql-driver-mysql/copyfromCopy.tmpl +++ /dev/null @@ -1,52 +0,0 @@ -{{define "copyfromCodeGoSqlDriver"}} -{{range .GoQueries}} -{{if eq .Cmd ":copyfrom" }} -var readerHandlerSequenceFor{{.MethodName}} uint32 = 1 - -func convertRowsFor{{.MethodName}}(w *io.PipeWriter, {{.Arg.SlicePair}}) { - e := mysqltsv.NewEncoder(w, {{ len .Arg.CopyFromMySQLFields }}, nil) - for _, row := range {{.Arg.Name}} { -{{- with $arg := .Arg }} -{{- range $arg.CopyFromMySQLFields}} -{{- if eq .Type "string"}} - e.AppendString({{if $arg.Struct}}row.{{.Name}}{{else}}row{{end}}) -{{- else if or (eq .Type "[]byte") (eq .Type "json.RawMessage")}} - e.AppendBytes({{if $arg.Struct}}row.{{.Name}}{{else}}row{{end}}) -{{- else}} - e.AppendValue({{if $arg.Struct}}row.{{.Name}}{{else}}row{{end}}) -{{- end}} -{{- end}} -{{- end}} - } - w.CloseWithError(e.Close()) -} - -{{range .Comments}}//{{.}} -{{end -}} -// {{.MethodName}} uses MySQL's LOAD DATA LOCAL INFILE and is not atomic. -// -// Errors and duplicate keys are treated as warnings and insertion will -// continue, even without an error for some cases. Use this in a transaction -// and use SHOW WARNINGS to check for any problems and roll back if you want to. -// -// Check the documentation for more information: -// https://dev.mysql.com/doc/refman/8.0/en/load-data.html#load-data-error-handling -func (q *Queries) {{.MethodName}}(ctx context.Context{{if $.EmitMethodsWithDBArgument}}, db DBTX{{end}}, {{.Arg.SlicePair}}) (int64, error) { - pr, pw := io.Pipe() - defer pr.Close() - rh := fmt.Sprintf("{{.MethodName}}_%d", atomic.AddUint32(&readerHandlerSequenceFor{{.MethodName}}, 1)) - mysql.RegisterReaderHandler(rh, func() io.Reader { return pr }) - defer mysql.DeregisterReaderHandler(rh) - go convertRowsFor{{.MethodName}}(pw, {{.Arg.Name}}) - // The string interpolation is necessary because LOAD DATA INFILE requires - // the file name to be given as a literal string. - result, err := {{if (not $.EmitMethodsWithDBArgument)}}q.{{end}}db.ExecContext(ctx, fmt.Sprintf("LOAD DATA LOCAL INFILE '%s' INTO TABLE {{.TableIdentifierForMySQL}} %s ({{range $index, $name := .Arg.ColumnNames}}{{if gt $index 0}}, {{end}}{{$name}}{{end}})", "Reader::" + rh, mysqltsv.Escaping)) - if err != nil { - return 0, err - } - return result.RowsAffected() -} - -{{end}} -{{end}} -{{end}} diff --git a/internal/codegen/golang/templates/pgx/batchCode.tmpl b/internal/codegen/golang/templates/pgx/batchCode.tmpl deleted file mode 100644 index 35bd701bd3..0000000000 --- a/internal/codegen/golang/templates/pgx/batchCode.tmpl +++ /dev/null @@ -1,134 +0,0 @@ -{{define "batchCodePgx"}} - -var ( - ErrBatchAlreadyClosed = errors.New("batch already closed") -) - -{{range .GoQueries}} -{{if eq (hasPrefix .Cmd ":batch") true }} -const {{.ConstantName}} = {{$.Q}}-- name: {{.MethodName}} {{.Cmd}} -{{escape .SQL}} -{{$.Q}} - -type {{.MethodName}}BatchResults struct { - br pgx.BatchResults - tot int - closed bool -} - -{{if .Arg.Struct}} -type {{.Arg.Type}} struct { {{- range .Arg.Struct.Fields}} - {{.Name}} {{.Type}} {{if .Tag}}{{$.Q}}{{.Tag}}{{$.Q}}{{end}} - {{- end}} -} -{{end}} - -{{if .Ret.EmitStruct}} -type {{.Ret.Type}} struct { {{- range .Ret.Struct.Fields}} - {{.Name}} {{.Type}} {{if .Tag}}{{$.Q}}{{.Tag}}{{$.Q}}{{end}} - {{- end}} -} -{{end}} - -{{range .Comments}}//{{.}} -{{end -}} -func (q *Queries) {{.MethodName}}(ctx context.Context, {{ if $.EmitMethodsWithDBArgument}}db DBTX,{{end}} {{.Arg.SlicePair}}) *{{.MethodName}}BatchResults { - batch := &pgx.Batch{} - for _, a := range {{index .Arg.Name}} { - vals := []interface{}{ - {{- if .Arg.Struct }} - {{- range .Arg.Struct.Fields }} - a.{{.Name}}, - {{- end }} - {{- else }} - a, - {{- end }} - } - batch.Queue({{.ConstantName}}, vals...) - } - br := {{if not $.EmitMethodsWithDBArgument}}q.{{end}}db.SendBatch(ctx, batch) - return &{{.MethodName}}BatchResults{br,len({{.Arg.Name}}),false} -} - -{{if eq .Cmd ":batchexec"}} -func (b *{{.MethodName}}BatchResults) Exec(f func(int, error)) { - defer b.br.Close() - for t := 0; t < b.tot; t++ { - if b.closed { - if f != nil { - f(t, ErrBatchAlreadyClosed) - } - continue - } - _, err := b.br.Exec() - if f != nil { - f(t, err) - } - } -} -{{end}} - -{{if eq .Cmd ":batchmany"}} -func (b *{{.MethodName}}BatchResults) Query(f func(int, []{{.Ret.DefineType}}, error)) { - defer b.br.Close() - for t := 0; t < b.tot; t++ { - {{- if $.EmitEmptySlices}} - items := []{{.Ret.DefineType}}{} - {{else}} - var items []{{.Ret.DefineType}} - {{end -}} - if b.closed { - if f != nil { - f(t, items, ErrBatchAlreadyClosed) - } - continue - } - err := func() error { - rows, err := b.br.Query() - if err != nil { - return err - } - defer rows.Close() - for rows.Next() { - var {{.Ret.Name}} {{.Ret.Type}} - if err := rows.Scan({{.Ret.Scan}}); err != nil { - return err - } - items = append(items, {{.Ret.ReturnName}}) - } - return rows.Err() - }() - if f != nil { - f(t, items, err) - } - } -} -{{end}} - -{{if eq .Cmd ":batchone"}} -func (b *{{.MethodName}}BatchResults) QueryRow(f func(int, {{.Ret.DefineType}}, error)) { - defer b.br.Close() - for t := 0; t < b.tot; t++ { - var {{.Ret.Name}} {{.Ret.Type}} - if b.closed { - if f != nil { - f(t, {{if .Ret.IsPointer}}nil{{else}}{{.Ret.Name}}{{end}}, ErrBatchAlreadyClosed) - } - continue - } - row := b.br.QueryRow() - err := row.Scan({{.Ret.Scan}}) - if f != nil { - f(t, {{.Ret.ReturnName}}, err) - } - } -} -{{end}} - -func (b *{{.MethodName}}BatchResults) Close() error { - b.closed = true - return b.br.Close() -} -{{end}} -{{end}} -{{end}} diff --git a/internal/codegen/golang/templates/pgx/copyfromCopy.tmpl b/internal/codegen/golang/templates/pgx/copyfromCopy.tmpl deleted file mode 100644 index c1cfa68d1d..0000000000 --- a/internal/codegen/golang/templates/pgx/copyfromCopy.tmpl +++ /dev/null @@ -1,51 +0,0 @@ -{{define "copyfromCodePgx"}} -{{range .GoQueries}} -{{if eq .Cmd ":copyfrom" }} -// iteratorFor{{.MethodName}} implements pgx.CopyFromSource. -type iteratorFor{{.MethodName}} struct { - rows []{{.Arg.DefineType}} - skippedFirstNextCall bool -} - -func (r *iteratorFor{{.MethodName}}) Next() bool { - if len(r.rows) == 0 { - return false - } - if !r.skippedFirstNextCall { - r.skippedFirstNextCall = true - return true - } - r.rows = r.rows[1:] - return len(r.rows) > 0 -} - -func (r iteratorFor{{.MethodName}}) Values() ([]interface{}, error) { - return []interface{}{ -{{- if .Arg.Struct }} -{{- range .Arg.Struct.Fields }} - r.rows[0].{{.Name}}, -{{- end }} -{{- else }} - r.rows[0], -{{- end }} - }, nil -} - -func (r iteratorFor{{.MethodName}}) Err() error { - return nil -} - -{{range .Comments}}//{{.}} -{{end -}} -{{- if $.EmitMethodsWithDBArgument -}} -func (q *Queries) {{.MethodName}}(ctx context.Context, db DBTX, {{.Arg.SlicePair}}) (int64, error) { - return db.CopyFrom(ctx, {{.TableIdentifierAsGoSlice}}, {{.Arg.ColumnNamesAsGoSlice}}, &iteratorFor{{.MethodName}}{rows: {{.Arg.Name}}}) -{{- else -}} -func (q *Queries) {{.MethodName}}(ctx context.Context, {{.Arg.SlicePair}}) (int64, error) { - return q.db.CopyFrom(ctx, {{.TableIdentifierAsGoSlice}}, {{.Arg.ColumnNamesAsGoSlice}}, &iteratorFor{{.MethodName}}{rows: {{.Arg.Name}}}) -{{- end}} -} - -{{end}} -{{end}} -{{end}} diff --git a/internal/codegen/golang/templates/pgx/dbCode.tmpl b/internal/codegen/golang/templates/pgx/dbCode.tmpl deleted file mode 100644 index 236554d9f2..0000000000 --- a/internal/codegen/golang/templates/pgx/dbCode.tmpl +++ /dev/null @@ -1,37 +0,0 @@ -{{define "dbCodeTemplatePgx"}} - -type DBTX interface { - Exec(context.Context, string, ...interface{}) (pgconn.CommandTag, error) - Query(context.Context, string, ...interface{}) (pgx.Rows, error) - QueryRow(context.Context, string, ...interface{}) pgx.Row -{{- if .UsesCopyFrom }} - CopyFrom(ctx context.Context, tableName pgx.Identifier, columnNames []string, rowSrc pgx.CopyFromSource) (int64, error) -{{- end }} -{{- if .UsesBatch }} - SendBatch(context.Context, *pgx.Batch) pgx.BatchResults -{{- end }} -} - -{{ if .EmitMethodsWithDBArgument}} -func New() *Queries { - return &Queries{} -{{- else -}} -func New(db DBTX) *Queries { - return &Queries{db: db} -{{- end}} -} - -type Queries struct { - {{if not .EmitMethodsWithDBArgument}} - db DBTX - {{end}} -} - -{{if not .EmitMethodsWithDBArgument}} -func (q *Queries) WithTx(tx pgx.Tx) *Queries { - return &Queries{ - db: tx, - } -} -{{end}} -{{end}} diff --git a/internal/codegen/golang/templates/pgx/interfaceCode.tmpl b/internal/codegen/golang/templates/pgx/interfaceCode.tmpl deleted file mode 100644 index cf7cd36cb9..0000000000 --- a/internal/codegen/golang/templates/pgx/interfaceCode.tmpl +++ /dev/null @@ -1,73 +0,0 @@ -{{define "interfaceCodePgx"}} - type Querier interface { - {{- $dbtxParam := .EmitMethodsWithDBArgument -}} - {{- range .GoQueries}} - {{- if and (eq .Cmd ":one") ($dbtxParam) }} - {{range .Comments}}//{{.}} - {{end -}} - {{.MethodName}}(ctx context.Context, db DBTX, {{.Arg.Pair}}) ({{.Ret.DefineType}}, error) - {{- else if eq .Cmd ":one" }} - {{range .Comments}}//{{.}} - {{end -}} - {{.MethodName}}(ctx context.Context, {{.Arg.Pair}}) ({{.Ret.DefineType}}, error) - {{- end}} - {{- if and (eq .Cmd ":many") ($dbtxParam) }} - {{range .Comments}}//{{.}} - {{end -}} - {{.MethodName}}(ctx context.Context, db DBTX, {{.Arg.Pair}}) ([]{{.Ret.DefineType}}, error) - {{- else if eq .Cmd ":many" }} - {{range .Comments}}//{{.}} - {{end -}} - {{.MethodName}}(ctx context.Context, {{.Arg.Pair}}) ([]{{.Ret.DefineType}}, error) - {{- end}} - {{- if and (eq .Cmd ":exec") ($dbtxParam) }} - {{range .Comments}}//{{.}} - {{end -}} - {{.MethodName}}(ctx context.Context, db DBTX, {{.Arg.Pair}}) error - {{- else if eq .Cmd ":exec" }} - {{range .Comments}}//{{.}} - {{end -}} - {{.MethodName}}(ctx context.Context, {{.Arg.Pair}}) error - {{- end}} - {{- if and (eq .Cmd ":execrows") ($dbtxParam) }} - {{range .Comments}}//{{.}} - {{end -}} - {{.MethodName}}(ctx context.Context, db DBTX, {{.Arg.Pair}}) (int64, error) - {{- else if eq .Cmd ":execrows" }} - {{range .Comments}}//{{.}} - {{end -}} - {{.MethodName}}(ctx context.Context, {{.Arg.Pair}}) (int64, error) - {{- end}} - {{- if and (eq .Cmd ":execresult") ($dbtxParam) }} - {{range .Comments}}//{{.}} - {{end -}} - {{.MethodName}}(ctx context.Context, db DBTX, {{.Arg.Pair}}) (pgconn.CommandTag, error) - {{- else if eq .Cmd ":execresult" }} - {{range .Comments}}//{{.}} - {{end -}} - {{.MethodName}}(ctx context.Context, {{.Arg.Pair}}) (pgconn.CommandTag, error) - {{- end}} - {{- if and (eq .Cmd ":copyfrom") ($dbtxParam) }} - {{range .Comments}}//{{.}} - {{end -}} - {{.MethodName}}(ctx context.Context, db DBTX, {{.Arg.SlicePair}}) (int64, error) - {{- else if eq .Cmd ":copyfrom" }} - {{range .Comments}}//{{.}} - {{end -}} - {{.MethodName}}(ctx context.Context, {{.Arg.SlicePair}}) (int64, error) - {{- end}} - {{- if and (or (eq .Cmd ":batchexec") (eq .Cmd ":batchmany") (eq .Cmd ":batchone")) ($dbtxParam) }} - {{range .Comments}}//{{.}} - {{end -}} - {{.MethodName}}(ctx context.Context, db DBTX, {{.Arg.SlicePair}}) *{{.MethodName}}BatchResults - {{- else if or (eq .Cmd ":batchexec") (eq .Cmd ":batchmany") (eq .Cmd ":batchone") }} - {{range .Comments}}//{{.}} - {{end -}} - {{.MethodName}}(ctx context.Context, {{.Arg.SlicePair}}) *{{.MethodName}}BatchResults - {{- end}} - - {{- end}} - } - - var _ Querier = (*Queries)(nil) -{{end}} diff --git a/internal/codegen/golang/templates/pgx/queryCode.tmpl b/internal/codegen/golang/templates/pgx/queryCode.tmpl deleted file mode 100644 index 59a88c880a..0000000000 --- a/internal/codegen/golang/templates/pgx/queryCode.tmpl +++ /dev/null @@ -1,142 +0,0 @@ -{{define "queryCodePgx"}} -{{range .GoQueries}} -{{if $.OutputQuery .SourceName}} -{{if and (ne .Cmd ":copyfrom") (ne (hasPrefix .Cmd ":batch") true)}} -const {{.ConstantName}} = {{$.Q}}-- name: {{.MethodName}} {{.Cmd}} -{{escape .SQL}} -{{$.Q}} -{{end}} - -{{if ne (hasPrefix .Cmd ":batch") true}} -{{if .Arg.EmitStruct}} -type {{.Arg.Type}} struct { {{- range .Arg.Struct.Fields}} - {{.Name}} {{.Type}} {{if .Tag}}{{$.Q}}{{.Tag}}{{$.Q}}{{end}} - {{- end}} -} -{{end}} - -{{if .Ret.EmitStruct}} -type {{.Ret.Type}} struct { {{- range .Ret.Struct.Fields}} - {{.Name}} {{.Type}} {{if .Tag}}{{$.Q}}{{.Tag}}{{$.Q}}{{end}} - {{- end}} -} -{{end}} -{{end}} - -{{if eq .Cmd ":one"}} -{{range .Comments}}//{{.}} -{{end -}} -{{- if $.EmitMethodsWithDBArgument -}} -func (q *Queries) {{.MethodName}}(ctx context.Context, db DBTX, {{.Arg.Pair}}) ({{.Ret.DefineType}}, error) { - row := db.QueryRow(ctx, {{.ConstantName}}, {{.Arg.Params}}) -{{- else -}} -func (q *Queries) {{.MethodName}}(ctx context.Context, {{.Arg.Pair}}) ({{.Ret.DefineType}}, error) { - row := q.db.QueryRow(ctx, {{.ConstantName}}, {{.Arg.Params}}) -{{- end}} - {{- if or (ne .Arg.Pair .Ret.Pair) (ne .Arg.DefineType .Ret.DefineType) }} - var {{.Ret.Name}} {{.Ret.Type}} - {{- end}} - err := row.Scan({{.Ret.Scan}}) - {{- if $.WrapErrors}} - if err != nil { - err = fmt.Errorf("query {{.MethodName}}: %w", err) - } - {{- end}} - return {{.Ret.ReturnName}}, err -} -{{end}} - -{{if eq .Cmd ":many"}} -{{range .Comments}}//{{.}} -{{end -}} -{{- if $.EmitMethodsWithDBArgument -}} -func (q *Queries) {{.MethodName}}(ctx context.Context, db DBTX, {{.Arg.Pair}}) ([]{{.Ret.DefineType}}, error) { - rows, err := db.Query(ctx, {{.ConstantName}}, {{.Arg.Params}}) -{{- else -}} -func (q *Queries) {{.MethodName}}(ctx context.Context, {{.Arg.Pair}}) ([]{{.Ret.DefineType}}, error) { - rows, err := q.db.Query(ctx, {{.ConstantName}}, {{.Arg.Params}}) -{{- end}} - if err != nil { - return nil, {{if $.WrapErrors}}fmt.Errorf("query {{.MethodName}}: %w", err){{else}}err{{end}} - } - defer rows.Close() - {{- if $.EmitEmptySlices}} - items := []{{.Ret.DefineType}}{} - {{else}} - var items []{{.Ret.DefineType}} - {{end -}} - for rows.Next() { - var {{.Ret.Name}} {{.Ret.Type}} - if err := rows.Scan({{.Ret.Scan}}); err != nil { - return nil, {{if $.WrapErrors}}fmt.Errorf("query {{.MethodName}}: %w", err){{else}}err{{end}} - } - items = append(items, {{.Ret.ReturnName}}) - } - if err := rows.Err(); err != nil { - return nil, {{if $.WrapErrors}}fmt.Errorf("query {{.MethodName}}: %w", err){{else}}err{{end}} - } - return items, nil -} -{{end}} - -{{if eq .Cmd ":exec"}} -{{range .Comments}}//{{.}} -{{end -}} -{{- if $.EmitMethodsWithDBArgument -}} -func (q *Queries) {{.MethodName}}(ctx context.Context, db DBTX, {{.Arg.Pair}}) error { - _, err := db.Exec(ctx, {{.ConstantName}}, {{.Arg.Params}}) -{{- else -}} -func (q *Queries) {{.MethodName}}(ctx context.Context, {{.Arg.Pair}}) error { - _, err := q.db.Exec(ctx, {{.ConstantName}}, {{.Arg.Params}}) -{{- end}} - {{- if $.WrapErrors }} - if err != nil { - return fmt.Errorf("query {{.MethodName}}: %w", err) - } - return nil - {{- else }} - return err - {{- end }} -} -{{end}} - -{{if eq .Cmd ":execrows"}} -{{range .Comments}}//{{.}} -{{end -}} -{{if $.EmitMethodsWithDBArgument -}} -func (q *Queries) {{.MethodName}}(ctx context.Context, db DBTX, {{.Arg.Pair}}) (int64, error) { - result, err := db.Exec(ctx, {{.ConstantName}}, {{.Arg.Params}}) -{{- else -}} -func (q *Queries) {{.MethodName}}(ctx context.Context, {{.Arg.Pair}}) (int64, error) { - result, err := q.db.Exec(ctx, {{.ConstantName}}, {{.Arg.Params}}) -{{- end}} - if err != nil { - return 0, {{if $.WrapErrors}}fmt.Errorf("query {{.MethodName}}: %w", err){{else}}err{{end}} - } - return result.RowsAffected(), nil -} -{{end}} - -{{if eq .Cmd ":execresult"}} -{{range .Comments}}//{{.}} -{{end -}} -{{- if $.EmitMethodsWithDBArgument -}} -func (q *Queries) {{.MethodName}}(ctx context.Context, db DBTX, {{.Arg.Pair}}) (pgconn.CommandTag, error) { - {{queryRetval .}} db.Exec(ctx, {{.ConstantName}}, {{.Arg.Params}}) -{{- else -}} -func (q *Queries) {{.MethodName}}(ctx context.Context, {{.Arg.Pair}}) (pgconn.CommandTag, error) { - {{queryRetval .}} q.db.Exec(ctx, {{.ConstantName}}, {{.Arg.Params}}) -{{- end}} - {{- if $.WrapErrors}} - if err != nil { - err = fmt.Errorf("query {{.MethodName}}: %w", err) - } - return result, err - {{- end}} -} -{{end}} - - -{{end}} -{{end}} -{{end}} diff --git a/internal/codegen/golang/templates/stdlib/dbCode.tmpl b/internal/codegen/golang/templates/stdlib/dbCode.tmpl deleted file mode 100644 index 7433d522f6..0000000000 --- a/internal/codegen/golang/templates/stdlib/dbCode.tmpl +++ /dev/null @@ -1,105 +0,0 @@ -{{define "dbCodeTemplateStd"}} -type DBTX interface { - ExecContext(context.Context, string, ...interface{}) (sql.Result, error) - PrepareContext(context.Context, string) (*sql.Stmt, error) - QueryContext(context.Context, string, ...interface{}) (*sql.Rows, error) - QueryRowContext(context.Context, string, ...interface{}) *sql.Row -} - -{{ if .EmitMethodsWithDBArgument}} -func New() *Queries { - return &Queries{} -{{- else -}} -func New(db DBTX) *Queries { - return &Queries{db: db} -{{- end}} -} - -{{if .EmitPreparedQueries}} -func Prepare(ctx context.Context, db DBTX) (*Queries, error) { - q := Queries{db: db} - var err error - {{- if eq (len .GoQueries) 0 }} - _ = err - {{- end }} - {{- range .GoQueries }} - if q.{{.FieldName}}, err = db.PrepareContext(ctx, {{.ConstantName}}); err != nil { - return nil, fmt.Errorf("error preparing query {{.MethodName}}: %w", err) - } - {{- end}} - return &q, nil -} - -func (q *Queries) Close() error { - var err error - {{- range .GoQueries }} - if q.{{.FieldName}} != nil { - if cerr := q.{{.FieldName}}.Close(); cerr != nil { - err = fmt.Errorf("error closing {{.FieldName}}: %w", cerr) - } - } - {{- end}} - return err -} - -func (q *Queries) exec(ctx context.Context, stmt *sql.Stmt, query string, args ...interface{}) (sql.Result, error) { - switch { - case stmt != nil && q.tx != nil: - return q.tx.StmtContext(ctx, stmt).ExecContext(ctx, args...) - case stmt != nil: - return stmt.ExecContext(ctx, args...) - default: - return q.db.ExecContext(ctx, query, args...) - } -} - -func (q *Queries) query(ctx context.Context, stmt *sql.Stmt, query string, args ...interface{}) (*sql.Rows, error) { - switch { - case stmt != nil && q.tx != nil: - return q.tx.StmtContext(ctx, stmt).QueryContext(ctx, args...) - case stmt != nil: - return stmt.QueryContext(ctx, args...) - default: - return q.db.QueryContext(ctx, query, args...) - } -} - -func (q *Queries) queryRow(ctx context.Context, stmt *sql.Stmt, query string, args ...interface{}) (*sql.Row) { - switch { - case stmt != nil && q.tx != nil: - return q.tx.StmtContext(ctx, stmt).QueryRowContext(ctx, args...) - case stmt != nil: - return stmt.QueryRowContext(ctx, args...) - default: - return q.db.QueryRowContext(ctx, query, args...) - } -} -{{end}} - -type Queries struct { - {{- if not .EmitMethodsWithDBArgument}} - db DBTX - {{- end}} - - {{- if .EmitPreparedQueries}} - tx *sql.Tx - {{- range .GoQueries}} - {{.FieldName}} *sql.Stmt - {{- end}} - {{- end}} -} - -{{if not .EmitMethodsWithDBArgument}} -func (q *Queries) WithTx(tx *sql.Tx) *Queries { - return &Queries{ - db: tx, - {{- if .EmitPreparedQueries}} - tx: tx, - {{- range .GoQueries}} - {{.FieldName}}: q.{{.FieldName}}, - {{- end}} - {{- end}} - } -} -{{end}} -{{end}} diff --git a/internal/codegen/golang/templates/stdlib/interfaceCode.tmpl b/internal/codegen/golang/templates/stdlib/interfaceCode.tmpl deleted file mode 100644 index 3cbefe6df4..0000000000 --- a/internal/codegen/golang/templates/stdlib/interfaceCode.tmpl +++ /dev/null @@ -1,63 +0,0 @@ -{{define "interfaceCodeStd"}} - type Querier interface { - {{- $dbtxParam := .EmitMethodsWithDBArgument -}} - {{- range .GoQueries}} - {{- if and (eq .Cmd ":one") ($dbtxParam) }} - {{range .Comments}}//{{.}} - {{end -}} - {{.MethodName}}(ctx context.Context, db DBTX, {{.Arg.Pair}}) ({{.Ret.DefineType}}, error) - {{- else if eq .Cmd ":one"}} - {{range .Comments}}//{{.}} - {{end -}} - {{.MethodName}}(ctx context.Context, {{.Arg.Pair}}) ({{.Ret.DefineType}}, error) - {{- end}} - {{- if and (eq .Cmd ":many") ($dbtxParam) }} - {{range .Comments}}//{{.}} - {{end -}} - {{.MethodName}}(ctx context.Context, db DBTX, {{.Arg.Pair}}) ([]{{.Ret.DefineType}}, error) - {{- else if eq .Cmd ":many"}} - {{range .Comments}}//{{.}} - {{end -}} - {{.MethodName}}(ctx context.Context, {{.Arg.Pair}}) ([]{{.Ret.DefineType}}, error) - {{- end}} - {{- if and (eq .Cmd ":exec") ($dbtxParam) }} - {{range .Comments}}//{{.}} - {{end -}} - {{.MethodName}}(ctx context.Context, db DBTX, {{.Arg.Pair}}) error - {{- else if eq .Cmd ":exec"}} - {{range .Comments}}//{{.}} - {{end -}} - {{.MethodName}}(ctx context.Context, {{.Arg.Pair}}) error - {{- end}} - {{- if and (eq .Cmd ":execrows") ($dbtxParam) }} - {{range .Comments}}//{{.}} - {{end -}} - {{.MethodName}}(ctx context.Context, db DBTX, {{.Arg.Pair}}) (int64, error) - {{- else if eq .Cmd ":execrows"}} - {{range .Comments}}//{{.}} - {{end -}} - {{.MethodName}}(ctx context.Context, {{.Arg.Pair}}) (int64, error) - {{- end}} - {{- if and (eq .Cmd ":execlastid") ($dbtxParam) }} - {{range .Comments}}//{{.}} - {{end -}} - {{.MethodName}}(ctx context.Context, db DBTX, {{.Arg.Pair}}) (int64, error) - {{- else if eq .Cmd ":execlastid"}} - {{range .Comments}}//{{.}} - {{end -}} - {{.MethodName}}(ctx context.Context, {{.Arg.Pair}}) (int64, error) - {{- end}} - {{- if and (eq .Cmd ":execresult") ($dbtxParam) }} - {{range .Comments}}//{{.}} - {{end -}} - {{.MethodName}}(ctx context.Context, db DBTX, {{.Arg.Pair}}) (sql.Result, error) - {{- else if eq .Cmd ":execresult"}} - {{range .Comments}}//{{.}} - {{end -}} - {{.MethodName}}(ctx context.Context, {{.Arg.Pair}}) (sql.Result, error) - {{- end}} - {{- end}} - } - - var _ Querier = (*Queries)(nil) -{{end}} diff --git a/internal/codegen/golang/templates/stdlib/queryCode.tmpl b/internal/codegen/golang/templates/stdlib/queryCode.tmpl deleted file mode 100644 index 1e7f4e22a4..0000000000 --- a/internal/codegen/golang/templates/stdlib/queryCode.tmpl +++ /dev/null @@ -1,171 +0,0 @@ -{{define "queryCodeStd"}} -{{range .GoQueries}} -{{if $.OutputQuery .SourceName}} -const {{.ConstantName}} = {{$.Q}}-- name: {{.MethodName}} {{.Cmd}} -{{escape .SQL}} -{{$.Q}} - -{{if .Arg.EmitStruct}} -type {{.Arg.Type}} struct { {{- range .Arg.UniqueFields}} - {{.Name}} {{.Type}} {{if .Tag}}{{$.Q}}{{.Tag}}{{$.Q}}{{end}} - {{- end}} -} -{{end}} - -{{if .Ret.EmitStruct}} -type {{.Ret.Type}} struct { {{- range .Ret.Struct.Fields}} - {{.Name}} {{.Type}} {{if .Tag}}{{$.Q}}{{.Tag}}{{$.Q}}{{end}} - {{- end}} -} -{{end}} - -{{if eq .Cmd ":one"}} -{{range .Comments}}//{{.}} -{{end -}} -func (q *Queries) {{.MethodName}}(ctx context.Context, {{ dbarg }} {{.Arg.Pair}}) ({{.Ret.DefineType}}, error) { - {{- template "queryCodeStdExec" . }} - {{- if or (ne .Arg.Pair .Ret.Pair) (ne .Arg.DefineType .Ret.DefineType) }} - var {{.Ret.Name}} {{.Ret.Type}} - {{- end}} - err := row.Scan({{.Ret.Scan}}) - {{- if $.WrapErrors}} - if err != nil { - err = fmt.Errorf("query {{.MethodName}}: %w", err) - } - {{- end}} - return {{.Ret.ReturnName}}, err -} -{{end}} - -{{if eq .Cmd ":many"}} -{{range .Comments}}//{{.}} -{{end -}} -func (q *Queries) {{.MethodName}}(ctx context.Context, {{ dbarg }} {{.Arg.Pair}}) ([]{{.Ret.DefineType}}, error) { - {{- template "queryCodeStdExec" . }} - if err != nil { - return nil, {{if $.WrapErrors}}fmt.Errorf("query {{.MethodName}}: %w", err){{else}}err{{end}} - } - defer rows.Close() - {{- if $.EmitEmptySlices}} - items := []{{.Ret.DefineType}}{} - {{else}} - var items []{{.Ret.DefineType}} - {{end -}} - for rows.Next() { - var {{.Ret.Name}} {{.Ret.Type}} - if err := rows.Scan({{.Ret.Scan}}); err != nil { - return nil, {{if $.WrapErrors}}fmt.Errorf("query {{.MethodName}}: %w", err){{else}}err{{end}} - } - items = append(items, {{.Ret.ReturnName}}) - } - if err := rows.Close(); err != nil { - return nil, {{if $.WrapErrors}}fmt.Errorf("query {{.MethodName}}: %w", err){{else}}err{{end}} - } - if err := rows.Err(); err != nil { - return nil, {{if $.WrapErrors}}fmt.Errorf("query {{.MethodName}}: %w", err){{else}}err{{end}} - } - return items, nil -} -{{end}} - -{{if eq .Cmd ":exec"}} -{{range .Comments}}//{{.}} -{{end -}} -func (q *Queries) {{.MethodName}}(ctx context.Context, {{ dbarg }} {{.Arg.Pair}}) error { - {{- template "queryCodeStdExec" . }} - {{- if $.WrapErrors}} - if err != nil { - err = fmt.Errorf("query {{.MethodName}}: %w", err) - } - {{- end}} - return err -} -{{end}} - -{{if eq .Cmd ":execrows"}} -{{range .Comments}}//{{.}} -{{end -}} -func (q *Queries) {{.MethodName}}(ctx context.Context, {{ dbarg }} {{.Arg.Pair}}) (int64, error) { - {{- template "queryCodeStdExec" . }} - if err != nil { - return 0, {{if $.WrapErrors}}fmt.Errorf("query {{.MethodName}}: %w", err){{else}}err{{end}} - } - return result.RowsAffected() -} -{{end}} - -{{if eq .Cmd ":execlastid"}} -{{range .Comments}}//{{.}} -{{end -}} -func (q *Queries) {{.MethodName}}(ctx context.Context, {{ dbarg }} {{.Arg.Pair}}) (int64, error) { - {{- template "queryCodeStdExec" . }} - if err != nil { - return 0, {{if $.WrapErrors}}fmt.Errorf("query {{.MethodName}}: %w", err){{else}}err{{end}} - } - return result.LastInsertId() -} -{{end}} - -{{if eq .Cmd ":execresult"}} -{{range .Comments}}//{{.}} -{{end -}} -func (q *Queries) {{.MethodName}}(ctx context.Context, {{ dbarg }} {{.Arg.Pair}}) (sql.Result, error) { - {{- template "queryCodeStdExec" . }} - {{- if $.WrapErrors}} - if err != nil { - err = fmt.Errorf("query {{.MethodName}}: %w", err) - } - return result, err - {{- end}} -} -{{end}} - -{{end}} -{{end}} -{{end}} - -{{define "queryCodeStdExec"}} - {{- if .Arg.HasSqlcSlices }} - query := {{.ConstantName}} - var queryParams []interface{} - {{- if .Arg.Struct }} - {{- $arg := .Arg }} - {{- range .Arg.Struct.Fields }} - {{- if .HasSqlcSlice }} - if len({{$arg.VariableForField .}}) > 0 { - for _, v := range {{$arg.VariableForField .}} { - queryParams = append(queryParams, v) - } - query = strings.Replace(query, "/*SLICE:{{.Column.Name}}*/?", strings.Repeat(",?", len({{$arg.VariableForField .}}))[1:], 1) - } else { - query = strings.Replace(query, "/*SLICE:{{.Column.Name}}*/?", "NULL", 1) - } - {{- else }} - queryParams = append(queryParams, {{$arg.VariableForField .}}) - {{- end }} - {{- end }} - {{- else }} - {{- /* Single argument parameter to this goroutine (they are not packed - in a struct), because .Arg.HasSqlcSlices further up above was true, - this section is 100% a slice (impossible to get here otherwise). - */}} - if len({{.Arg.Name}}) > 0 { - for _, v := range {{.Arg.Name}} { - queryParams = append(queryParams, v) - } - query = strings.Replace(query, "/*SLICE:{{.Arg.Column.Name}}*/?", strings.Repeat(",?", len({{.Arg.Name}}))[1:], 1) - } else { - query = strings.Replace(query, "/*SLICE:{{.Arg.Column.Name}}*/?", "NULL", 1) - } - {{- end }} - {{- if emitPreparedQueries }} - {{ queryRetval . }} {{ queryMethod . }}(ctx, nil, query, queryParams...) - {{- else}} - {{ queryRetval . }} {{ queryMethod . }}(ctx, query, queryParams...) - {{- end -}} - {{- else if emitPreparedQueries }} - {{- queryRetval . }} {{ queryMethod . }}(ctx, q.{{.FieldName}}, {{.ConstantName}}, {{.Arg.Params}}) - {{- else}} - {{- queryRetval . }} {{ queryMethod . }}(ctx, {{.ConstantName}}, {{.Arg.Params}}) - {{- end -}} -{{end}} diff --git a/internal/codegen/golang/templates/template.tmpl b/internal/codegen/golang/templates/template.tmpl deleted file mode 100644 index afd50c01ac..0000000000 --- a/internal/codegen/golang/templates/template.tmpl +++ /dev/null @@ -1,254 +0,0 @@ -{{define "dbFile"}} -{{if .BuildTags}} -//go:build {{.BuildTags}} - -{{end}}// Code generated by sqlc. DO NOT EDIT. -{{if not .OmitSqlcVersion}}// versions: -// sqlc {{.SqlcVersion}} -{{end}} - -package {{.Package}} - -{{ if hasImports .SourceName }} -import ( - {{range imports .SourceName}} - {{range .}}{{.}} - {{end}} - {{end}} -) -{{end}} - -{{template "dbCode" . }} -{{end}} - -{{define "dbCode"}} - -{{if .SQLDriver.IsPGX }} - {{- template "dbCodeTemplatePgx" .}} -{{else}} - {{- template "dbCodeTemplateStd" .}} -{{end}} - -{{end}} - -{{define "interfaceFile"}} -{{if .BuildTags}} -//go:build {{.BuildTags}} - -{{end}}// Code generated by sqlc. DO NOT EDIT. -{{if not .OmitSqlcVersion}}// versions: -// sqlc {{.SqlcVersion}} -{{end}} - -package {{.Package}} - -{{ if hasImports .SourceName }} -import ( - {{range imports .SourceName}} - {{range .}}{{.}} - {{end}} - {{end}} -) -{{end}} - -{{template "interfaceCode" . }} -{{end}} - -{{define "interfaceCode"}} - {{if .SQLDriver.IsPGX }} - {{- template "interfaceCodePgx" .}} - {{else}} - {{- template "interfaceCodeStd" .}} - {{end}} -{{end}} - -{{define "modelsFile"}} -{{if .BuildTags}} -//go:build {{.BuildTags}} - -{{end}}// Code generated by sqlc. DO NOT EDIT. -{{if not .OmitSqlcVersion}}// versions: -// sqlc {{.SqlcVersion}} -{{end}} - -package {{.Package}} - -{{ if hasImports .SourceName }} -import ( - {{range imports .SourceName}} - {{range .}}{{.}} - {{end}} - {{end}} -) -{{end}} - -{{template "modelsCode" . }} -{{end}} - -{{define "modelsCode"}} -{{range .Enums}} -{{if .Comment}}{{comment .Comment}}{{end}} -type {{.Name}} string - -const ( - {{- range .Constants}} - {{.Name}} {{.Type}} = "{{.Value}}" - {{- end}} -) - -func (e *{{.Name}}) Scan(src interface{}) error { - switch s := src.(type) { - case []byte: - *e = {{.Name}}(s) - case string: - *e = {{.Name}}(s) - default: - return fmt.Errorf("unsupported scan type for {{.Name}}: %T", src) - } - return nil -} - -type Null{{.Name}} struct { - {{.Name}} {{.Name}} {{if .NameTag}}{{$.Q}}{{.NameTag}}{{$.Q}}{{end}} - Valid bool {{if .ValidTag}}{{$.Q}}{{.ValidTag}}{{$.Q}}{{end}} // Valid is true if {{.Name}} is not NULL -} - -// Scan implements the Scanner interface. -func (ns *Null{{.Name}}) Scan(value interface{}) error { - if value == nil { - ns.{{.Name}}, ns.Valid = "", false - return nil - } - ns.Valid = true - return ns.{{.Name}}.Scan(value) -} - -// Value implements the driver Valuer interface. -func (ns Null{{.Name}}) Value() (driver.Value, error) { - if !ns.Valid { - return nil, nil - } - return string(ns.{{.Name}}), nil -} - - -{{ if $.EmitEnumValidMethod }} -func (e {{.Name}}) Valid() bool { - switch e { - case {{ range $idx, $name := .Constants }}{{ if ne $idx 0 }},{{ "\n" }}{{ end }}{{ .Name }}{{ end }}: - return true - } - return false -} -{{ end }} - -{{ if $.EmitAllEnumValues }} -func All{{ .Name }}Values() []{{ .Name }} { - return []{{ .Name }}{ {{ range .Constants}}{{ "\n" }}{{ .Name }},{{ end }} - } -} -{{ end }} -{{end}} - -{{range .Structs}} -{{if .Comment}}{{comment .Comment}}{{end}} -type {{.Name}} struct { {{- range .Fields}} - {{- if .Comment}} - {{comment .Comment}}{{else}} - {{- end}} - {{.Name}} {{.Type}} {{if .Tag}}{{$.Q}}{{.Tag}}{{$.Q}}{{end}} - {{- end}} -} -{{end}} -{{end}} - -{{define "queryFile"}} -{{if .BuildTags}} -//go:build {{.BuildTags}} - -{{end}}// Code generated by sqlc. DO NOT EDIT. -{{if not .OmitSqlcVersion}}// versions: -// sqlc {{.SqlcVersion}} -{{end}}// source: {{.SourceName}} - -package {{.Package}} - -{{ if hasImports .SourceName }} -import ( - {{range imports .SourceName}} - {{range .}}{{.}} - {{end}} - {{end}} -) -{{end}} - -{{template "queryCode" . }} -{{end}} - -{{define "queryCode"}} -{{if .SQLDriver.IsPGX }} - {{- template "queryCodePgx" .}} -{{else}} - {{- template "queryCodeStd" .}} -{{end}} -{{end}} - -{{define "copyfromFile"}} -{{if .BuildTags}} -//go:build {{.BuildTags}} - -{{end}}// Code generated by sqlc. DO NOT EDIT. -{{if not .OmitSqlcVersion}}// versions: -// sqlc {{.SqlcVersion}} -{{end}}// source: {{.SourceName}} - -package {{.Package}} - -{{ if hasImports .SourceName }} -import ( - {{range imports .SourceName}} - {{range .}}{{.}} - {{end}} - {{end}} -) -{{end}} - -{{template "copyfromCode" . }} -{{end}} - -{{define "copyfromCode"}} -{{if .SQLDriver.IsPGX }} - {{- template "copyfromCodePgx" .}} -{{else if .SQLDriver.IsGoSQLDriverMySQL }} - {{- template "copyfromCodeGoSqlDriver" .}} -{{end}} -{{end}} - -{{define "batchFile"}} -{{if .BuildTags}} -//go:build {{.BuildTags}} - -{{end}}// Code generated by sqlc. DO NOT EDIT. -{{if not .OmitSqlcVersion}}// versions: -// sqlc {{.SqlcVersion}} -{{end}}// source: {{.SourceName}} - -package {{.Package}} - -{{ if hasImports .SourceName }} -import ( - {{range imports .SourceName}} - {{range .}}{{.}} - {{end}} - {{end}} -) -{{end}} - -{{template "batchCode" . }} -{{end}} - -{{define "batchCode"}} -{{if .SQLDriver.IsPGX }} - {{- template "batchCodePgx" .}} -{{end}} -{{end}} diff --git a/internal/poet/ast.go b/internal/poet/ast.go new file mode 100644 index 0000000000..409ea33ec2 --- /dev/null +++ b/internal/poet/ast.go @@ -0,0 +1,385 @@ +// Package poet provides Go code generation with custom AST nodes +// that properly support comment placement. +package poet + +import "strings" + +// File represents a Go source file. +type File struct { + BuildTags string + Comments []string // File-level comments + Package string + ImportGroups [][]Import // Groups separated by blank lines + Decls []Decl +} + +// Import represents an import statement. +type Import struct { + Alias string // Optional alias + Path string +} + +// Decl represents a declaration. +type Decl interface { + isDecl() +} + +// Raw is raw Go code (escape hatch). +type Raw struct { + Code string +} + +func (Raw) isDecl() {} + +// Const represents a const declaration. +type Const struct { + Comment string + Name string + Type string + Value string +} + +func (Const) isDecl() {} + +// ConstBlock represents a const block. +type ConstBlock struct { + Consts []Const +} + +func (ConstBlock) isDecl() {} + +// Var represents a var declaration. +type Var struct { + Comment string + Name string + Type string + Value string +} + +func (Var) isDecl() {} + +// VarBlock represents a var block. +type VarBlock struct { + Vars []Var +} + +func (VarBlock) isDecl() {} + +// TypeDef represents a type declaration. +type TypeDef struct { + Comment string + Name string + Type TypeExpr +} + +func (TypeDef) isDecl() {} + +// Func represents a function declaration. +type Func struct { + Comment string + Recv *Param // nil for non-methods + Name string + Params []Param + Results []Param + Stmts []Stmt +} + +func (Func) isDecl() {} + +// Param represents a function parameter or result. +type Param struct { + Name string + Type string + Pointer bool // If true, type is rendered as *Type +} + +// TypeExpr represents a type expression. +type TypeExpr interface { + isTypeExpr() +} + +// Struct represents a struct type. +type Struct struct { + Fields []Field +} + +func (Struct) isTypeExpr() {} + +// Field represents a struct field. +type Field struct { + Comment string // Leading comment (above the field) + Name string + Type string + Tag string + TrailingComment string // Trailing comment (on same line) +} + +// Interface represents an interface type. +type Interface struct { + Methods []Method +} + +func (Interface) isTypeExpr() {} + +// Method represents an interface method. +type Method struct { + Comment string + Name string + Params []Param + Results []Param +} + +// TypeName represents a type alias or named type. +type TypeName struct { + Name string +} + +func (TypeName) isTypeExpr() {} + +// Stmt represents a statement in a function body. +type Stmt interface { + isStmt() +} + +// RawStmt is raw Go code as a statement. +type RawStmt struct { + Code string +} + +func (RawStmt) isStmt() {} + +// Return represents a return statement. +type Return struct { + Values []string // Expressions to return +} + +func (Return) isStmt() {} + +// For represents a for loop. +type For struct { + Init string // e.g., "i := 0" + Cond string // e.g., "i < 10" + Post string // e.g., "i++" + Range string // If set, renders as "for Range {" (e.g., "_, v := range items") + Body []Stmt +} + +func (For) isStmt() {} + +// If represents an if statement. +type If struct { + Init string // Optional init statement (e.g., "err := foo()") + Cond string // Condition expression + Body []Stmt + Else []Stmt // Optional else body +} + +func (If) isStmt() {} + +// Switch represents a switch statement. +type Switch struct { + Init string // Optional init statement + Expr string // Expression to switch on (empty for type switch or bool switch) + Cases []Case +} + +func (Switch) isStmt() {} + +// Case represents a case clause in a switch statement. +type Case struct { + Values []string // Case values (empty for default case) + Body []Stmt +} + +// Defer represents a defer statement. +type Defer struct { + Call string // The function call to defer +} + +func (Defer) isStmt() {} + +// Assign represents an assignment statement. +type Assign struct { + Left []string // Left-hand side (variable names) + Op string // Assignment operator: "=", ":=", "+=", etc. + Right []string // Right-hand side (expressions) +} + +func (Assign) isStmt() {} + +// CallStmt represents a function call as a statement. +type CallStmt struct { + Call string // The function call expression +} + +func (CallStmt) isStmt() {} + +// VarDecl represents a variable declaration statement. +type VarDecl struct { + Name string // Variable name + Type string // Type (optional if Value is set) + Value string // Initial value (optional) +} + +func (VarDecl) isStmt() {} + +// GoStmt represents a go statement (goroutine). +type GoStmt struct { + Call string // The function call to run as a goroutine +} + +func (GoStmt) isStmt() {} + +// Continue represents a continue statement. +type Continue struct { + Label string // Optional label +} + +func (Continue) isStmt() {} + +// Break represents a break statement. +type Break struct { + Label string // Optional label +} + +func (Break) isStmt() {} + +// Expr is an interface for expression types that can be rendered to strings. +// These can be used in Return.Values, Assign.Right, etc. +type Expr interface { + Render() string +} + +// CallExpr represents a function or method call expression. +type CallExpr struct { + Func string // Function name or receiver.method + Args []string // Arguments +} + +func (c CallExpr) Render() string { + var b strings.Builder + b.WriteString(c.Func) + b.WriteString("(") + for i, arg := range c.Args { + if i > 0 { + b.WriteString(", ") + } + b.WriteString(arg) + } + b.WriteString(")") + return b.String() +} + +// StructLit represents a struct literal expression. +type StructLit struct { + Type string // Type name (e.g., "Queries") + Pointer bool // If true, prefix with & + Multiline bool // If true, always use multi-line format + Fields [][2]string // Field name-value pairs (use slice to preserve order) +} + +func (s StructLit) Render() string { + var b strings.Builder + if s.Pointer { + b.WriteString("&") + } + b.WriteString(s.Type) + b.WriteString("{") + if len(s.Fields) <= 2 && !s.Multiline { + // Compact format for small struct literals + for i, f := range s.Fields { + if i > 0 { + b.WriteString(", ") + } + b.WriteString(f[0]) + b.WriteString(": ") + b.WriteString(f[1]) + } + } else if len(s.Fields) > 0 { + // Multi-line format for larger struct literals or when explicitly requested + b.WriteString("\n") + for _, f := range s.Fields { + b.WriteString("\t\t") + b.WriteString(f[0]) + b.WriteString(": ") + b.WriteString(f[1]) + b.WriteString(",\n") + } + b.WriteString("\t") + } + b.WriteString("}") + return b.String() +} + +// SliceLit represents a slice literal expression. +type SliceLit struct { + Type string // Element type (e.g., "interface{}") + Multiline bool // If true, always use multi-line format + Values []string // Elements +} + +func (s SliceLit) Render() string { + var b strings.Builder + b.WriteString("[]") + b.WriteString(s.Type) + b.WriteString("{") + if len(s.Values) <= 3 && !s.Multiline { + // Compact format for small slice literals + for i, v := range s.Values { + if i > 0 { + b.WriteString(", ") + } + b.WriteString(v) + } + } else if len(s.Values) > 0 { + // Multi-line format for larger slice literals or when explicitly requested + b.WriteString("\n") + for _, v := range s.Values { + b.WriteString("\t\t") + b.WriteString(v) + b.WriteString(",\n") + } + b.WriteString("\t") + } + b.WriteString("}") + return b.String() +} + +// TypeCast represents a type conversion expression. +type TypeCast struct { + Type string // Target type + Value string // Value to convert +} + +func (t TypeCast) Render() string { + return t.Type + "(" + t.Value + ")" +} + +// FuncLit represents an anonymous function literal. +type FuncLit struct { + Params []Param + Results []Param + Body []Stmt + Indent string // Base indentation for body statements (default: "\t") +} + +// Note: FuncLit.Render() is implemented in render.go since it needs renderStmts + +// Selector represents a field or method selector expression (a.b.c). +type Selector struct { + Parts []string // e.g., ["r", "rows", "0", "Field"] for r.rows[0].Field +} + +func (s Selector) Render() string { + return strings.Join(s.Parts, ".") +} + +// Index represents an index or slice expression. +type Index struct { + Expr string // Base expression + Index string // Index value (or "start:end" for slice) +} + +func (i Index) Render() string { + return i.Expr + "[" + i.Index + "]" +} diff --git a/internal/poet/render.go b/internal/poet/render.go new file mode 100644 index 0000000000..b884d3d7bf --- /dev/null +++ b/internal/poet/render.go @@ -0,0 +1,540 @@ +package poet + +import ( + "go/format" + "strings" +) + +// Render converts a File to formatted Go source code. +func Render(f *File) ([]byte, error) { + var b strings.Builder + renderFile(&b, f) + return format.Source([]byte(b.String())) +} + +func renderFile(b *strings.Builder, f *File) { + // Build tags + if f.BuildTags != "" { + b.WriteString("//go:build ") + b.WriteString(f.BuildTags) + b.WriteString("\n\n") + } + + // File comments + for _, c := range f.Comments { + b.WriteString(c) + b.WriteString("\n") + } + + // Package + if len(f.Comments) > 0 { + b.WriteString("\n") + } + b.WriteString("package ") + b.WriteString(f.Package) + b.WriteString("\n") + + // Imports + hasImports := false + for _, group := range f.ImportGroups { + if len(group) > 0 { + hasImports = true + break + } + } + if hasImports { + b.WriteString("\nimport (\n") + first := true + for _, group := range f.ImportGroups { + if len(group) == 0 { + continue + } + if !first { + b.WriteString("\n") + } + first = false + for _, imp := range group { + b.WriteString("\t") + if imp.Alias != "" { + b.WriteString(imp.Alias) + b.WriteString(" ") + } + b.WriteString("\"") + b.WriteString(imp.Path) + b.WriteString("\"\n") + } + } + b.WriteString(")\n") + } + + // Declarations + for _, d := range f.Decls { + b.WriteString("\n") + renderDecl(b, d) + } +} + +func renderDecl(b *strings.Builder, d Decl) { + switch d := d.(type) { + case Raw: + b.WriteString(d.Code) + case Const: + renderConst(b, d, "") + case ConstBlock: + renderConstBlock(b, d) + case Var: + renderVar(b, d, "") + case VarBlock: + renderVarBlock(b, d) + case TypeDef: + renderTypeDef(b, d) + case Func: + renderFunc(b, d) + } +} + +func renderConst(b *strings.Builder, c Const, indent string) { + if c.Comment != "" { + writeComment(b, c.Comment, indent) + } + b.WriteString(indent) + if indent == "" { + b.WriteString("const ") + } + b.WriteString(c.Name) + if c.Type != "" { + b.WriteString(" ") + b.WriteString(c.Type) + } + if c.Value != "" { + b.WriteString(" = ") + b.WriteString(c.Value) + } + b.WriteString("\n") +} + +func renderConstBlock(b *strings.Builder, cb ConstBlock) { + b.WriteString("const (\n") + for _, c := range cb.Consts { + renderConst(b, c, "\t") + } + b.WriteString(")\n") +} + +func renderVar(b *strings.Builder, v Var, indent string) { + if v.Comment != "" { + writeComment(b, v.Comment, indent) + } + b.WriteString(indent) + if indent == "" { + b.WriteString("var ") + } + b.WriteString(v.Name) + if v.Type != "" { + b.WriteString(" ") + b.WriteString(v.Type) + } + if v.Value != "" { + b.WriteString(" = ") + b.WriteString(v.Value) + } + b.WriteString("\n") +} + +func renderVarBlock(b *strings.Builder, vb VarBlock) { + b.WriteString("var (\n") + for _, v := range vb.Vars { + renderVar(b, v, "\t") + } + b.WriteString(")\n") +} + +func renderTypeDef(b *strings.Builder, t TypeDef) { + if t.Comment != "" { + writeComment(b, t.Comment, "") + } + b.WriteString("type ") + b.WriteString(t.Name) + b.WriteString(" ") + renderTypeExpr(b, t.Type) + b.WriteString("\n") +} + +func renderTypeExpr(b *strings.Builder, t TypeExpr) { + switch t := t.(type) { + case Struct: + renderStruct(b, t) + case Interface: + renderInterface(b, t) + case TypeName: + b.WriteString(t.Name) + } +} + +func renderStruct(b *strings.Builder, s Struct) { + b.WriteString("struct {\n") + for _, f := range s.Fields { + if f.Comment != "" { + writeComment(b, f.Comment, "\t") + } + b.WriteString("\t") + b.WriteString(f.Name) + b.WriteString(" ") + b.WriteString(f.Type) + if f.Tag != "" { + b.WriteString(" `") + b.WriteString(f.Tag) + b.WriteString("`") + } + if f.TrailingComment != "" { + b.WriteString(" // ") + b.WriteString(f.TrailingComment) + } + b.WriteString("\n") + } + b.WriteString("}") +} + +func renderInterface(b *strings.Builder, iface Interface) { + b.WriteString("interface {\n") + for _, m := range iface.Methods { + if m.Comment != "" { + writeComment(b, m.Comment, "\t") + } + b.WriteString("\t") + b.WriteString(m.Name) + b.WriteString("(") + renderParams(b, m.Params) + b.WriteString(")") + if len(m.Results) > 0 { + b.WriteString(" ") + if len(m.Results) == 1 && m.Results[0].Name == "" { + b.WriteString(m.Results[0].Type) + } else { + b.WriteString("(") + renderParams(b, m.Results) + b.WriteString(")") + } + } + b.WriteString("\n") + } + b.WriteString("}") +} + +func renderFunc(b *strings.Builder, f Func) { + if f.Comment != "" { + writeComment(b, f.Comment, "") + } + b.WriteString("func ") + if f.Recv != nil { + b.WriteString("(") + b.WriteString(f.Recv.Name) + b.WriteString(" ") + if f.Recv.Pointer { + b.WriteString("*") + } + b.WriteString(f.Recv.Type) + b.WriteString(") ") + } + b.WriteString(f.Name) + b.WriteString("(") + renderParams(b, f.Params) + b.WriteString(")") + if len(f.Results) > 0 { + b.WriteString(" ") + if len(f.Results) == 1 && f.Results[0].Name == "" { + if f.Results[0].Pointer { + b.WriteString("*") + } + b.WriteString(f.Results[0].Type) + } else { + b.WriteString("(") + renderParams(b, f.Results) + b.WriteString(")") + } + } + b.WriteString(" {\n") + renderStmts(b, f.Stmts, "\t") + b.WriteString("}\n") +} + +func renderParams(b *strings.Builder, params []Param) { + for i, p := range params { + if i > 0 { + b.WriteString(", ") + } + if p.Name != "" { + b.WriteString(p.Name) + b.WriteString(" ") + } + if p.Pointer { + b.WriteString("*") + } + b.WriteString(p.Type) + } +} + +func writeComment(b *strings.Builder, comment, indent string) { + lines := strings.Split(comment, "\n") + for _, line := range lines { + b.WriteString(indent) + // If line already starts with //, write as-is + if strings.HasPrefix(line, "//") { + b.WriteString(line) + } else { + b.WriteString("// ") + b.WriteString(line) + } + b.WriteString("\n") + } +} + +func renderStmts(b *strings.Builder, stmts []Stmt, indent string) { + for _, s := range stmts { + renderStmt(b, s, indent) + } +} + +// RenderStmt renders a single statement to a string with the given indentation. +func RenderStmt(s Stmt, indent string) string { + var b strings.Builder + renderStmt(&b, s, indent) + return b.String() +} + +func renderStmt(b *strings.Builder, s Stmt, indent string) { + switch s := s.(type) { + case RawStmt: + b.WriteString(s.Code) + case Return: + renderReturn(b, s, indent) + case For: + renderFor(b, s, indent) + case If: + renderIf(b, s, indent) + case Switch: + renderSwitch(b, s, indent) + case Defer: + renderDefer(b, s, indent) + case Assign: + renderAssign(b, s, indent) + case CallStmt: + renderCallStmt(b, s, indent) + case VarDecl: + renderVarDecl(b, s, indent) + case GoStmt: + renderGoStmt(b, s, indent) + case Continue: + renderContinue(b, s, indent) + case Break: + renderBreak(b, s, indent) + } +} + +func renderReturn(b *strings.Builder, r Return, indent string) { + b.WriteString(indent) + b.WriteString("return") + if len(r.Values) > 0 { + b.WriteString(" ") + for i, v := range r.Values { + if i > 0 { + b.WriteString(", ") + } + b.WriteString(v) + } + } + b.WriteString("\n") +} + +func renderFor(b *strings.Builder, f For, indent string) { + b.WriteString(indent) + b.WriteString("for ") + if f.Range != "" { + b.WriteString(f.Range) + } else { + if f.Init != "" { + b.WriteString(f.Init) + } + b.WriteString("; ") + b.WriteString(f.Cond) + b.WriteString("; ") + if f.Post != "" { + b.WriteString(f.Post) + } + } + b.WriteString(" {\n") + renderStmts(b, f.Body, indent+"\t") + b.WriteString(indent) + b.WriteString("}\n") +} + +func renderIf(b *strings.Builder, i If, indent string) { + b.WriteString(indent) + b.WriteString("if ") + if i.Init != "" { + b.WriteString(i.Init) + b.WriteString("; ") + } + b.WriteString(i.Cond) + b.WriteString(" {\n") + renderStmts(b, i.Body, indent+"\t") + b.WriteString(indent) + b.WriteString("}") + if len(i.Else) > 0 { + b.WriteString(" else {\n") + renderStmts(b, i.Else, indent+"\t") + b.WriteString(indent) + b.WriteString("}") + } + b.WriteString("\n") +} + +func renderSwitch(b *strings.Builder, s Switch, indent string) { + b.WriteString(indent) + b.WriteString("switch ") + if s.Init != "" { + b.WriteString(s.Init) + b.WriteString("; ") + } + b.WriteString(s.Expr) + b.WriteString(" {\n") + for _, c := range s.Cases { + b.WriteString(indent) + if len(c.Values) == 0 { + b.WriteString("default:\n") + } else { + b.WriteString("case ") + if len(c.Values) == 1 { + b.WriteString(c.Values[0]) + } else { + // Multiple values: put each on its own line + for i, v := range c.Values { + if i > 0 { + b.WriteString(",\n") + b.WriteString(indent) + b.WriteString("\t") + } + b.WriteString(v) + } + } + b.WriteString(":\n") + } + renderStmts(b, c.Body, indent+"\t") + } + b.WriteString(indent) + b.WriteString("}\n") +} + +func renderDefer(b *strings.Builder, d Defer, indent string) { + b.WriteString(indent) + b.WriteString("defer ") + b.WriteString(d.Call) + b.WriteString("\n") +} + +func renderAssign(b *strings.Builder, a Assign, indent string) { + b.WriteString(indent) + for i, l := range a.Left { + if i > 0 { + b.WriteString(", ") + } + b.WriteString(l) + } + b.WriteString(" ") + b.WriteString(a.Op) + b.WriteString(" ") + for i, r := range a.Right { + if i > 0 { + b.WriteString(", ") + } + b.WriteString(r) + } + b.WriteString("\n") +} + +func renderCallStmt(b *strings.Builder, c CallStmt, indent string) { + b.WriteString(indent) + b.WriteString(c.Call) + b.WriteString("\n") +} + +func renderVarDecl(b *strings.Builder, v VarDecl, indent string) { + b.WriteString(indent) + b.WriteString("var ") + b.WriteString(v.Name) + if v.Type != "" { + b.WriteString(" ") + b.WriteString(v.Type) + } + if v.Value != "" { + b.WriteString(" = ") + b.WriteString(v.Value) + } + b.WriteString("\n") +} + +func renderGoStmt(b *strings.Builder, g GoStmt, indent string) { + b.WriteString(indent) + b.WriteString("go ") + b.WriteString(g.Call) + b.WriteString("\n") +} + +func renderContinue(b *strings.Builder, c Continue, indent string) { + b.WriteString(indent) + b.WriteString("continue") + if c.Label != "" { + b.WriteString(" ") + b.WriteString(c.Label) + } + b.WriteString("\n") +} + +func renderBreak(b *strings.Builder, br Break, indent string) { + b.WriteString(indent) + b.WriteString("break") + if br.Label != "" { + b.WriteString(" ") + b.WriteString(br.Label) + } + b.WriteString("\n") +} + +// RenderFuncLit renders a function literal to a string. +// This is used by FuncLit.Render() and can also be called directly. +func RenderFuncLit(f FuncLit) string { + var b strings.Builder + b.WriteString("func(") + renderParams(&b, f.Params) + b.WriteString(")") + if len(f.Results) > 0 { + b.WriteString(" ") + if len(f.Results) == 1 && f.Results[0].Name == "" { + if f.Results[0].Pointer { + b.WriteString("*") + } + b.WriteString(f.Results[0].Type) + } else { + b.WriteString("(") + renderParams(&b, f.Results) + b.WriteString(")") + } + } + b.WriteString(" {\n") + indent := f.Indent + if indent == "" { + indent = "\t" + } + renderStmts(&b, f.Body, indent) + // Write closing brace with one less tab than body content + if len(indent) > 0 { + b.WriteString(indent[:len(indent)-1]) + } + b.WriteString("}") + return b.String() +} + +// Render implements the Expr interface for FuncLit. +func (f FuncLit) Render() string { + return RenderFuncLit(f) +}