diff --git a/examples/authors/sqlc.yaml b/examples/authors/sqlc.yaml index 143cb608b0..c06607c6fc 100644 --- a/examples/authors/sqlc.yaml +++ b/examples/authors/sqlc.yaml @@ -52,6 +52,7 @@ sql: package: authors out: ydb emit_json_tags: true + ydb_retry_idempotent: true rules: diff --git a/examples/authors/ydb/db.go b/examples/authors/ydb/db.go index e2b0a86b13..86681a3f78 100644 --- a/examples/authors/ydb/db.go +++ b/examples/authors/ydb/db.go @@ -7,8 +7,12 @@ package authors import ( "context" "database/sql" + + "github.com/ydb-platform/ydb-go-sdk/v3/retry" ) +type Retrier func(ctx context.Context, op func(ctx context.Context, db DBTX) error) error + type DBTX interface { ExecContext(context.Context, string, ...interface{}) (sql.Result, error) PrepareContext(context.Context, string) (*sql.Stmt, error) @@ -16,16 +20,55 @@ type DBTX interface { QueryRowContext(context.Context, string, ...interface{}) *sql.Row } -func New(db DBTX) *Queries { - return &Queries{db: db} +func New(db *sql.DB) *Queries { + return &Queries{ + db: db, + retrier: func(ctx context.Context, op func(ctx context.Context, db DBTX) error) error { + return retry.Do(ctx, db, func(ctx context.Context, conn *sql.Conn) error { + return op(ctx, conn) + }, retry.WithIdempotent(true)) + }, + } +} + +func (q *Queries) WithRetryOptions( /* opts ...retry.Option */ ) *Queries { + q.retrier = func(ctx context.Context, op func(ctx context.Context, db DBTX) error) error { + return retry.Do(ctx, q.db, func(ctx context.Context, conn *sql.Conn) error { + return op(ctx, conn) + }, retry.WithIdempotent(true) /* , opts... */) + } + return q } type Queries struct { - db DBTX + retrier Retrier + db *sql.DB +} + +func NewTx(db *sql.DB) *Queries { + return &Queries{ + db: db, + retrier: func(ctx context.Context, op func(ctx context.Context, db DBTX) error) error { + return retry.DoTx(ctx, db, func(ctx context.Context, tx *sql.Tx) error { + return op(ctx, tx) + }, retry.WithIdempotent(true)) + }, + } +} + +func (q *Queries) WithTxRetryOptions( /* opts ...retry.Option */ ) *Queries { + q.retrier = func(ctx context.Context, op func(ctx context.Context, db DBTX) error) error { + return retry.DoTx(ctx, q.db, func(ctx context.Context, tx *sql.Tx) error { + return op(ctx, tx) + }, retry.WithIdempotent(true) /* , opts... */) + } + return q } func (q *Queries) WithTx(tx *sql.Tx) *Queries { return &Queries{ - db: tx, + retrier: func(ctx context.Context, op func(ctx context.Context, db DBTX) error) error { + return op(ctx, tx) + }, } } diff --git a/examples/authors/ydb/query.sql b/examples/authors/ydb/query.sql index bf672042c5..84f51e79d3 100644 --- a/examples/authors/ydb/query.sql +++ b/examples/authors/ydb/query.sql @@ -13,10 +13,10 @@ WHERE name = $p0; SELECT * FROM authors WHERE bio IS NULL; --- name: Count :one +-- name: CountAuthors :one SELECT COUNT(*) FROM authors; --- name: COALESCE :many +-- name: Coalesce :many SELECT id, name, COALESCE(bio, 'Null value!') FROM authors; -- name: CreateOrUpdateAuthor :execresult diff --git a/examples/authors/ydb/query.sql.go b/examples/authors/ydb/query.sql.go index 45d86c96fd..71287c4713 100644 --- a/examples/authors/ydb/query.sql.go +++ b/examples/authors/ydb/query.sql.go @@ -10,47 +10,56 @@ import ( "database/sql" ) -const cOALESCE = `-- name: COALESCE :many +const coalesce = `-- name: Coalesce :many SELECT id, name, COALESCE(bio, 'Null value!') FROM authors ` -type COALESCERow struct { +type CoalesceRow struct { ID uint64 `json:"id"` Name string `json:"name"` Bio string `json:"bio"` } -func (q *Queries) COALESCE(ctx context.Context) ([]COALESCERow, error) { - rows, err := q.db.QueryContext(ctx, cOALESCE) - if err != nil { - return nil, err - } - defer rows.Close() - var items []COALESCERow - for rows.Next() { - var i COALESCERow - if err := rows.Scan(&i.ID, &i.Name, &i.Bio); err != nil { - return nil, err +func (q *Queries) Coalesce(ctx context.Context) ([]CoalesceRow, error) { + var items []CoalesceRow + + err := q.retrier(ctx, func(ctx context.Context, db DBTX) error { + items = nil + rows, err := db.QueryContext(ctx, coalesce) + if err != nil { + return err } - items = append(items, i) - } - if err := rows.Close(); err != nil { - return nil, err - } - if err := rows.Err(); err != nil { + defer rows.Close() + for rows.Next() { + var i CoalesceRow + if err := rows.Scan(&i.ID, &i.Name, &i.Bio); err != nil { + return err + } + items = append(items, i) + } + if err := rows.Close(); err != nil { + return err + } + + return rows.Err() + }) + if err != nil { return nil, err } return items, nil } -const count = `-- name: Count :one +const countAuthors = `-- name: CountAuthors :one SELECT COUNT(*) FROM authors ` -func (q *Queries) Count(ctx context.Context) (uint64, error) { - row := q.db.QueryRowContext(ctx, count) +func (q *Queries) CountAuthors(ctx context.Context) (uint64, error) { var count uint64 - err := row.Scan(&count) + err := q.retrier(ctx, func(ctx context.Context, db DBTX) error { + row := db.QueryRowContext(ctx, countAuthors) + err := row.Scan(&count) + return err + }) return count, err } @@ -65,7 +74,15 @@ type CreateOrUpdateAuthorParams struct { } func (q *Queries) CreateOrUpdateAuthor(ctx context.Context, arg CreateOrUpdateAuthorParams) (sql.Result, error) { - return q.db.ExecContext(ctx, createOrUpdateAuthor, arg.P0, arg.P1, arg.P2) + var sqlResult sql.Result + + err := q.retrier(ctx, func(ctx context.Context, db DBTX) error { + result, err := db.ExecContext(ctx, createOrUpdateAuthor, arg.P0, arg.P1, arg.P2) + + sqlResult = result + return err + }) + return sqlResult, err } const createOrUpdateAuthorReturningBio = `-- name: CreateOrUpdateAuthorReturningBio :one @@ -79,9 +96,12 @@ type CreateOrUpdateAuthorReturningBioParams struct { } func (q *Queries) CreateOrUpdateAuthorReturningBio(ctx context.Context, arg CreateOrUpdateAuthorReturningBioParams) (*string, error) { - row := q.db.QueryRowContext(ctx, createOrUpdateAuthorReturningBio, arg.P0, arg.P1, arg.P2) var bio *string - err := row.Scan(&bio) + err := q.retrier(ctx, func(ctx context.Context, db DBTX) error { + row := db.QueryRowContext(ctx, createOrUpdateAuthorReturningBio, arg.P0, arg.P1, arg.P2) + err := row.Scan(&bio) + return err + }) return bio, err } @@ -90,7 +110,10 @@ DELETE FROM authors WHERE id = $p0 ` func (q *Queries) DeleteAuthor(ctx context.Context, p0 uint64) error { - _, err := q.db.ExecContext(ctx, deleteAuthor, p0) + err := q.retrier(ctx, func(ctx context.Context, db DBTX) error { + _, err := db.ExecContext(ctx, deleteAuthor, p0) + return err + }) return err } @@ -100,9 +123,12 @@ WHERE id = $p0 ` func (q *Queries) GetAuthor(ctx context.Context, p0 uint64) (Author, error) { - row := q.db.QueryRowContext(ctx, getAuthor, p0) var i Author - err := row.Scan(&i.ID, &i.Name, &i.Bio) + err := q.retrier(ctx, func(ctx context.Context, db DBTX) error { + row := db.QueryRowContext(ctx, getAuthor, p0) + err := row.Scan(&i.ID, &i.Name, &i.Bio) + return err + }) return i, err } @@ -112,23 +138,29 @@ WHERE name = $p0 ` func (q *Queries) GetAuthorsByName(ctx context.Context, p0 string) ([]Author, error) { - rows, err := q.db.QueryContext(ctx, getAuthorsByName, p0) - if err != nil { - return nil, err - } - defer rows.Close() var items []Author - for rows.Next() { - var i Author - if err := rows.Scan(&i.ID, &i.Name, &i.Bio); err != nil { - return nil, err + + err := q.retrier(ctx, func(ctx context.Context, db DBTX) error { + items = nil + rows, err := db.QueryContext(ctx, getAuthorsByName, p0) + if err != nil { + return err } - items = append(items, i) - } - if err := rows.Close(); err != nil { - return nil, err - } - if err := rows.Err(); err != nil { + defer rows.Close() + for rows.Next() { + var i Author + if err := rows.Scan(&i.ID, &i.Name, &i.Bio); err != nil { + return err + } + items = append(items, i) + } + if err := rows.Close(); err != nil { + return err + } + + return rows.Err() + }) + if err != nil { return nil, err } return items, nil @@ -139,23 +171,29 @@ SELECT id, name, bio FROM authors ` func (q *Queries) ListAuthors(ctx context.Context) ([]Author, error) { - rows, err := q.db.QueryContext(ctx, listAuthors) - if err != nil { - return nil, err - } - defer rows.Close() var items []Author - for rows.Next() { - var i Author - if err := rows.Scan(&i.ID, &i.Name, &i.Bio); err != nil { - return nil, err + + err := q.retrier(ctx, func(ctx context.Context, db DBTX) error { + items = nil + rows, err := db.QueryContext(ctx, listAuthors) + if err != nil { + return err } - items = append(items, i) - } - if err := rows.Close(); err != nil { - return nil, err - } - if err := rows.Err(); err != nil { + defer rows.Close() + for rows.Next() { + var i Author + if err := rows.Scan(&i.ID, &i.Name, &i.Bio); err != nil { + return err + } + items = append(items, i) + } + if err := rows.Close(); err != nil { + return err + } + + return rows.Err() + }) + if err != nil { return nil, err } return items, nil @@ -167,23 +205,29 @@ WHERE bio IS NULL ` func (q *Queries) ListAuthorsWithNullBio(ctx context.Context) ([]Author, error) { - rows, err := q.db.QueryContext(ctx, listAuthorsWithNullBio) - if err != nil { - return nil, err - } - defer rows.Close() var items []Author - for rows.Next() { - var i Author - if err := rows.Scan(&i.ID, &i.Name, &i.Bio); err != nil { - return nil, err + + err := q.retrier(ctx, func(ctx context.Context, db DBTX) error { + items = nil + rows, err := db.QueryContext(ctx, listAuthorsWithNullBio) + if err != nil { + return err } - items = append(items, i) - } - if err := rows.Close(); err != nil { - return nil, err - } - if err := rows.Err(); err != nil { + defer rows.Close() + for rows.Next() { + var i Author + if err := rows.Scan(&i.ID, &i.Name, &i.Bio); err != nil { + return err + } + items = append(items, i) + } + if err := rows.Close(); err != nil { + return err + } + + return rows.Err() + }) + if err != nil { return nil, err } return items, nil @@ -200,8 +244,11 @@ type UpdateAuthorByIDParams struct { } func (q *Queries) UpdateAuthorByID(ctx context.Context, arg UpdateAuthorByIDParams) (Author, error) { - row := q.db.QueryRowContext(ctx, updateAuthorByID, arg.P0, arg.P1, arg.P2) var i Author - err := row.Scan(&i.ID, &i.Name, &i.Bio) + err := q.retrier(ctx, func(ctx context.Context, db DBTX) error { + row := db.QueryRowContext(ctx, updateAuthorByID, arg.P0, arg.P1, arg.P2) + err := row.Scan(&i.ID, &i.Name, &i.Bio) + return err + }) return i, err } diff --git a/go.mod b/go.mod index bd44f45665..d7677e293c 100644 --- a/go.mod +++ b/go.mod @@ -24,11 +24,11 @@ require ( github.com/tetratelabs/wazero v1.9.0 github.com/wasilibs/go-pgquery v0.0.0-20250409022910-10ac41983c07 github.com/xeipuuv/gojsonschema v1.2.0 - github.com/ydb-platform/ydb-go-sdk/v3 v3.108.0 + github.com/ydb-platform/ydb-go-sdk/v3 v3.113.5 github.com/ydb-platform/yql-parsers v0.0.0-20250309001738-7d693911f333 golang.org/x/sync v0.16.0 google.golang.org/grpc v1.74.2 - google.golang.org/protobuf v1.36.6 + google.golang.org/protobuf v1.36.7 gopkg.in/yaml.v3 v3.0.1 modernc.org/sqlite v1.38.2 ) @@ -48,7 +48,7 @@ require ( github.com/jackc/pgservicefile v0.0.0-20240606120523-5a60cdf6a761 // indirect github.com/jackc/pgtype v1.14.0 // indirect github.com/jackc/puddle/v2 v2.2.2 // indirect - github.com/jonboulle/clockwork v0.3.0 // indirect + github.com/jonboulle/clockwork v0.5.0 // indirect github.com/mattn/go-isatty v0.0.20 // indirect github.com/ncruces/go-strftime v0.1.9 // indirect github.com/pingcap/errors v0.11.5-0.20240311024730-e056997136bb // indirect diff --git a/go.sum b/go.sum index 39baee97e9..06f0a4e4d7 100644 --- a/go.sum +++ b/go.sum @@ -144,8 +144,8 @@ github.com/jackc/puddle/v2 v2.2.2 h1:PR8nw+E/1w0GLuRFSmiioY6UooMp6KJv0/61nB7icHo github.com/jackc/puddle/v2 v2.2.2/go.mod h1:vriiEXHvEE654aYKXXjOvZM39qJ0q+azkZFrfEOc3H4= github.com/jinzhu/inflection v1.0.0 h1:K317FqzuhWc8YvSVlFMCCUb36O/S9MCKRDI7QkRKD/E= github.com/jinzhu/inflection v1.0.0/go.mod h1:h+uFLlag+Qp1Va5pdKtLDYj+kHp5pxUVkryuEj+Srlc= -github.com/jonboulle/clockwork v0.3.0 h1:9BSCMi8C+0qdApAp4auwX0RkLGUjs956h0EkuQymUhg= -github.com/jonboulle/clockwork v0.3.0/go.mod h1:Pkfl5aHPm1nk2H9h0bjmnJD/BcgbGXUBGnn1kMkgxc8= +github.com/jonboulle/clockwork v0.5.0 h1:Hyh9A8u51kptdkR+cqRpT1EebBwTn1oK9YfGYbdFz6I= +github.com/jonboulle/clockwork v0.5.0/go.mod h1:3mZlmanh0g2NDKO5TWZVJAfofYk64M7XN3SzBPjZF60= github.com/kisielk/gotool v1.0.0/go.mod h1:XhKaO+MFFWcvkIS/tQcRk01m1F5IRFswLeQ+oQHNcck= github.com/konsorten/go-windows-terminal-sequences v1.0.1/go.mod h1:T0+1ngSBFLxvqU3pZ+m/2kptfBszLMUkC4ZK/EgS/cQ= github.com/konsorten/go-windows-terminal-sequences v1.0.2/go.mod h1:T0+1ngSBFLxvqU3pZ+m/2kptfBszLMUkC4ZK/EgS/cQ= @@ -240,8 +240,8 @@ github.com/xeipuuv/gojsonschema v1.2.0 h1:LhYJRs+L4fBtjZUfuSZIKGeVu0QRy8e5Xi7D17 github.com/xeipuuv/gojsonschema v1.2.0/go.mod h1:anYRn/JVcOK2ZgGU+IjEV4nwlhoK5sQluxsYJ78Id3Y= github.com/ydb-platform/ydb-go-genproto v0.0.0-20241112172322-ea1f63298f77 h1:LY6cI8cP4B9rrpTleZk95+08kl2gF4rixG7+V/dwL6Q= github.com/ydb-platform/ydb-go-genproto v0.0.0-20241112172322-ea1f63298f77/go.mod h1:Er+FePu1dNUieD+XTMDduGpQuCPssK5Q4BjF+IIXJ3I= -github.com/ydb-platform/ydb-go-sdk/v3 v3.108.0 h1:TwWSp3gRMcja/hRpOofncLvgxAXCmzpz5cGtmdaoITw= -github.com/ydb-platform/ydb-go-sdk/v3 v3.108.0/go.mod h1:l5sSv153E18VvYcsmr51hok9Sjc16tEC8AXGbwrk+ho= +github.com/ydb-platform/ydb-go-sdk/v3 v3.113.5 h1:olAAZfpMnFYChJNgZJ16G4jqoelRNx7Kx4tW50XcMv0= +github.com/ydb-platform/ydb-go-sdk/v3 v3.113.5/go.mod h1:Pp1w2xxUoLQ3NCNAwV7pvDq0TVQOdtAqs+ZiC+i8r14= github.com/ydb-platform/yql-parsers v0.0.0-20250309001738-7d693911f333 h1:KFtJwlPdOxWjCKXX0jFJ8k1FlbqbRbUW3k/kYSZX7SA= github.com/ydb-platform/yql-parsers v0.0.0-20250309001738-7d693911f333/go.mod h1:vrPJPS8cdPSV568YcXhB4bUwhyV8bmWKqmQ5c5Xi99o= github.com/zenazn/goji v0.9.0/go.mod h1:7S9M489iMyHBNxwZnk9/EHS098H4/F6TATF2mIxtB1Q= @@ -409,8 +409,8 @@ google.golang.org/protobuf v1.26.0/go.mod h1:9q0QmTI4eRPtz6boOQmLYwt+qCgq0jsYwAQ google.golang.org/protobuf v1.27.1/go.mod h1:9q0QmTI4eRPtz6boOQmLYwt+qCgq0jsYwAQnmE0givc= google.golang.org/protobuf v1.28.0/go.mod h1:HV8QOd/L58Z+nl8r43ehVNZIU/HEI6OcFqwMG9pJV4I= google.golang.org/protobuf v1.31.0/go.mod h1:HV8QOd/L58Z+nl8r43ehVNZIU/HEI6OcFqwMG9pJV4I= -google.golang.org/protobuf v1.36.6 h1:z1NpPI8ku2WgiWnf+t9wTPsn6eP1L7ksHUlkfLvd9xY= -google.golang.org/protobuf v1.36.6/go.mod h1:jduwjTPXsFjZGTmRluh+L6NjiWu7pchiJ2/5YcXBHnY= +google.golang.org/protobuf v1.36.7 h1:IgrO7UwFQGJdRNXH/sQux4R1Dj1WAKcLElzeeRaXV2A= +google.golang.org/protobuf v1.36.7/go.mod h1:jduwjTPXsFjZGTmRluh+L6NjiWu7pchiJ2/5YcXBHnY= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= gopkg.in/check.v1 v1.0.0-20180628173108-788fd7840127/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c h1:Hei/4ADfdWqJk1ZMxUNpqntNwaWcugrBjAiHlqqRiVk= diff --git a/internal/codegen/golang/gen.go b/internal/codegen/golang/gen.go index ac91cc537f..9fdce30dda 100644 --- a/internal/codegen/golang/gen.go +++ b/internal/codegen/golang/gen.go @@ -42,6 +42,8 @@ type tmplCtx struct { OmitSqlcVersion bool BuildTags string WrapErrors bool + EnableYDBRetry bool + YDBRetryIdempotent bool } func (t *tmplCtx) OutputQuery(sourceName string) bool { @@ -50,6 +52,9 @@ func (t *tmplCtx) OutputQuery(sourceName string) bool { func (t *tmplCtx) codegenDbarg() string { if t.EmitMethodsWithDBArgument { + if t.EnableYDBRetry { + return "db *sql.DB, " + } return "db DBTX, " } return "" @@ -63,7 +68,7 @@ func (t *tmplCtx) codegenEmitPreparedQueries() bool { func (t *tmplCtx) codegenQueryMethod(q Query) string { db := "q.db" - if t.EmitMethodsWithDBArgument { + if t.EmitMethodsWithDBArgument || t.EnableYDBRetry { db = "db" } @@ -99,7 +104,7 @@ func (t *tmplCtx) codegenQueryRetval(q Query) (string, error) { case ":execrows", ":execlastid": return "result, err :=", nil case ":execresult": - if t.WrapErrors { + if t.WrapErrors || t.EnableYDBRetry { return "result, err :=", nil } return "return", nil @@ -192,6 +197,8 @@ func generate(req *plugin.GenerateRequest, options *opts.Options, enums []Enum, BuildTags: options.BuildTags, OmitSqlcVersion: options.OmitSqlcVersion, WrapErrors: options.WrapErrors, + EnableYDBRetry: options.EnableYDBRetry, + YDBRetryIdempotent: options.YDBRetryIdempotent, } if tctx.UsesCopyFrom && !tctx.SQLDriver.IsPGX() && options.SqlDriver != opts.SQLDriverGoSQLDriverMySQL { diff --git a/internal/codegen/golang/imports.go b/internal/codegen/golang/imports.go index ccca4f603c..9d12e7ee73 100644 --- a/internal/codegen/golang/imports.go +++ b/internal/codegen/golang/imports.go @@ -138,6 +138,10 @@ func (i *importer) dbImports() fileImports { std = append(std, ImportSpec{Path: "fmt"}) } } + + if i.Options.EnableYDBRetry { + pkg = append(pkg, ImportSpec{Path: "github.com/ydb-platform/ydb-go-sdk/v3/retry"}) + } sort.Slice(std, func(i, j int) bool { return std[i].Path < std[j].Path }) sort.Slice(pkg, func(i, j int) bool { return pkg[i].Path < pkg[j].Path }) diff --git a/internal/codegen/golang/opts/options.go b/internal/codegen/golang/opts/options.go index 0d5d51c2dd..da00fd67c6 100644 --- a/internal/codegen/golang/opts/options.go +++ b/internal/codegen/golang/opts/options.go @@ -46,6 +46,10 @@ type Options struct { BuildTags string `json:"build_tags,omitempty" yaml:"build_tags"` Initialisms *[]string `json:"initialisms,omitempty" yaml:"initialisms"` + // YDB specific + EnableYDBRetry bool `json:"enable_ydb_retry,omitempty" yaml:"enable_ydb_retry"` + YDBRetryIdempotent bool `json:"ydb_retry_idempotent,omitempty" yaml:"ydb_retry_idempotent"` + InitialismsMap map[string]struct{} `json:"-" yaml:"-"` } @@ -72,6 +76,9 @@ func Parse(req *plugin.GenerateRequest) (*Options, error) { } maps.Copy(options.Rename, global.Rename) } + if req.Settings.Engine == "ydb" && (options.SqlPackage == "" || options.SqlPackage == "database/sql") { + options.EnableYDBRetry = true + } return options, nil } @@ -151,6 +158,8 @@ func ValidateOpts(opts *Options) error { if *opts.QueryParameterLimit < 0 { return fmt.Errorf("invalid options: query parameter limit must not be negative") } - + if !opts.EnableYDBRetry && opts.YDBRetryIdempotent { + return fmt.Errorf("invalid options: ydb_retry_idempotent requires enable_ydb_retry to be set") + } return nil } diff --git a/internal/codegen/golang/templates/stdlib/dbCode.tmpl b/internal/codegen/golang/templates/stdlib/dbCode.tmpl index 7433d522f6..eb76ec6cba 100644 --- a/internal/codegen/golang/templates/stdlib/dbCode.tmpl +++ b/internal/codegen/golang/templates/stdlib/dbCode.tmpl @@ -1,4 +1,9 @@ {{define "dbCodeTemplateStd"}} + +{{- if $.EnableYDBRetry }} +type Retrier func(ctx context.Context, op func(ctx context.Context, db DBTX) error) error +{{- end }} + type DBTX interface { ExecContext(context.Context, string, ...interface{}) (sql.Result, error) PrepareContext(context.Context, string) (*sql.Stmt, error) @@ -6,6 +11,30 @@ type DBTX interface { QueryRowContext(context.Context, string, ...interface{}) *sql.Row } +{{ if $.EnableYDBRetry }} +func New(db *sql.DB) *Queries { + return &Queries{ + {{- if not .EmitMethodsWithDBArgument}} + db: db, + {{- end}} + retrier: func(ctx context.Context, op func(ctx context.Context, db DBTX) error) error { + return retry.Do(ctx, db, func(ctx context.Context, conn *sql.Conn) error { + return op(ctx, conn) + } + {{- if $.YDBRetryIdempotent }}, retry.WithIdempotent(true) {{- end }}) + }, + } +} + +func (q *Queries) WithRetryOptions( /* opts ...retry.Option */ ) *Queries { + q.retrier = func(ctx context.Context, op func(ctx context.Context, db DBTX) error) error { + return retry.Do(ctx, q.db, func(ctx context.Context, conn *sql.Conn) error { + return op(ctx, conn) + } {{- if $.YDBRetryIdempotent }}, retry.WithIdempotent(true){{- end }} /* , opts... */) + } + return q + +{{- else -}} {{ if .EmitMethodsWithDBArgument}} func New() *Queries { return &Queries{} @@ -13,11 +42,26 @@ func New() *Queries { func New(db DBTX) *Queries { return &Queries{db: db} {{- end}} +{{- end}} } {{if .EmitPreparedQueries}} +{{- if $.EnableYDBRetry }} +func Prepare(ctx context.Context, db *sql.DB) (*Queries, error) { + q := Queries{ + {{- if not .EmitMethodsWithDBArgument}} + db: db, + {{- end}} + retrier: func(ctx context.Context, op func(ctx context.Context, db DBTX) error) error { + return retry.Do(ctx, db, func(ctx context.Context, conn *sql.Conn) error { + return op(ctx, conn) + }{{- if $.YDBRetryIdempotent }}, retry.WithIdempotent(true){{- end }}) + }, + } +{{- else }} func Prepare(ctx context.Context, db DBTX) (*Queries, error) { q := Queries{db: db} + {{- end }} var err error {{- if eq (len .GoQueries) 0 }} _ = err @@ -43,6 +87,19 @@ func (q *Queries) Close() error { } func (q *Queries) exec(ctx context.Context, stmt *sql.Stmt, query string, args ...interface{}) (sql.Result, error) { + {{- if $.EnableYDBRetry }} + var result sql.Result + err := q.retrier(ctx, func(ctx context.Context, db DBTX) error { + var err error + if stmt != nil { + result, err = stmt.ExecContext(ctx, args...) + } else { + result, err = db.ExecContext(ctx, query, args...) + } + return err + }) + return result, err + {{- else }} switch { case stmt != nil && q.tx != nil: return q.tx.StmtContext(ctx, stmt).ExecContext(ctx, args...) @@ -51,9 +108,23 @@ func (q *Queries) exec(ctx context.Context, stmt *sql.Stmt, query string, args . default: return q.db.ExecContext(ctx, query, args...) } + {{- end}} } func (q *Queries) query(ctx context.Context, stmt *sql.Stmt, query string, args ...interface{}) (*sql.Rows, error) { + {{- if $.EnableYDBRetry }} + var rows *sql.Rows + err := q.retrier(ctx, func(ctx context.Context, db DBTX) error { + var err error + if stmt != nil { + rows, err = stmt.QueryContext(ctx, args...) + } else { + rows, err = db.QueryContext(ctx, query, args...) + } + return err + }) + return rows, err + {{- else }} switch { case stmt != nil && q.tx != nil: return q.tx.StmtContext(ctx, stmt).QueryContext(ctx, args...) @@ -62,9 +133,22 @@ func (q *Queries) query(ctx context.Context, stmt *sql.Stmt, query string, args default: return q.db.QueryContext(ctx, query, args...) } + {{- end}} } func (q *Queries) queryRow(ctx context.Context, stmt *sql.Stmt, query string, args ...interface{}) (*sql.Row) { + {{- if $.EnableYDBRetry }} + var row *sql.Row + _ = q.retrier(ctx, func(ctx context.Context, db DBTX) error { + if stmt != nil { + row = stmt.QueryRowContext(ctx, args...) + } else { + row = db.QueryRowContext(ctx, query, args...) + } + return nil + }) + return row + {{- else }} switch { case stmt != nil && q.tx != nil: return q.tx.StmtContext(ctx, stmt).QueryRowContext(ctx, args...) @@ -73,22 +157,69 @@ func (q *Queries) queryRow(ctx context.Context, stmt *sql.Stmt, query string, ar default: return q.db.QueryRowContext(ctx, query, args...) } + {{- end}} } {{end}} type Queries struct { + {{- if $.EnableYDBRetry }} + retrier Retrier + {{- if not .EmitMethodsWithDBArgument}} + db *sql.DB + {{- end}} + {{- else }} {{- if not .EmitMethodsWithDBArgument}} db DBTX {{- end}} + {{- end}} {{- if .EmitPreparedQueries}} + {{- if not $.EnableYDBRetry }} tx *sql.Tx + {{- end}} {{- range .GoQueries}} {{.FieldName}} *sql.Stmt {{- end}} {{- end}} } +{{ if $.EnableYDBRetry}} + +func NewTx(db *sql.DB) *Queries { + return &Queries{ + {{- if not .EmitMethodsWithDBArgument}} + db: db, + {{- end}} + retrier: func(ctx context.Context, op func(ctx context.Context, db DBTX) error) error { + return retry.DoTx(ctx, db, func(ctx context.Context, tx *sql.Tx) error { + return op(ctx, tx) + }{{- if $.YDBRetryIdempotent }}, retry.WithIdempotent(true){{- end }}) + }, + } +} + +func (q *Queries) WithTxRetryOptions( /* opts ...retry.Option */ ) *Queries { + q.retrier = func(ctx context.Context, op func(ctx context.Context, db DBTX) error) error { + return retry.DoTx(ctx, q.db, func(ctx context.Context, tx *sql.Tx) error { + return op(ctx, tx) + } {{- if $.YDBRetryIdempotent }}, retry.WithIdempotent(true){{- end }} /* , opts... */) + } + return q +} + +func (q *Queries) WithTx(tx *sql.Tx) *Queries { + return &Queries{ + {{- if .EmitPreparedQueries}} + {{- range .GoQueries}} + {{.FieldName}}: q.{{.FieldName}}, + {{- end}} + {{- end}} + retrier: func(ctx context.Context, op func(ctx context.Context, db DBTX) error) error { + return op(ctx, tx) + }, + } +} +{{- else -}} {{if not .EmitMethodsWithDBArgument}} func (q *Queries) WithTx(tx *sql.Tx) *Queries { return &Queries{ @@ -101,5 +232,6 @@ func (q *Queries) WithTx(tx *sql.Tx) *Queries { {{- end}} } } +{{- end}} {{end}} {{end}} diff --git a/internal/codegen/golang/templates/stdlib/queryCode.tmpl b/internal/codegen/golang/templates/stdlib/queryCode.tmpl index 1e7f4e22a4..5c10090335 100644 --- a/internal/codegen/golang/templates/stdlib/queryCode.tmpl +++ b/internal/codegen/golang/templates/stdlib/queryCode.tmpl @@ -23,17 +23,34 @@ type {{.Ret.Type}} struct { {{- range .Ret.Struct.Fields}} {{range .Comments}}//{{.}} {{end -}} func (q *Queries) {{.MethodName}}(ctx context.Context, {{ dbarg }} {{.Arg.Pair}}) ({{.Ret.DefineType}}, error) { + {{- if $.EnableYDBRetry }} + {{- if or (ne .Arg.Pair .Ret.Pair) (ne .Arg.DefineType .Ret.DefineType) }} + var {{.Ret.Name}} {{.Ret.Type}} + {{- end}} + err := q.retrier(ctx, func(ctx context.Context, db DBTX) error { + {{- template "queryCodeStdExec" . }} + err := row.Scan({{.Ret.Scan}}) + return err + }) + {{- if $.WrapErrors}} + if err != nil { + err = fmt.Errorf("query {{.MethodName}}: %w", err) + } + {{- end}} + return {{.Ret.ReturnName}}, err + {{- else}} {{- template "queryCodeStdExec" . }} - {{- if or (ne .Arg.Pair .Ret.Pair) (ne .Arg.DefineType .Ret.DefineType) }} - var {{.Ret.Name}} {{.Ret.Type}} - {{- end}} - err := row.Scan({{.Ret.Scan}}) - {{- if $.WrapErrors}} - if err != nil { - err = fmt.Errorf("query {{.MethodName}}: %w", err) - } - {{- end}} - return {{.Ret.ReturnName}}, err + {{- if or (ne .Arg.Pair .Ret.Pair) (ne .Arg.DefineType .Ret.DefineType) }} + var {{.Ret.Name}} {{.Ret.Type}} + {{- end}} + err := row.Scan({{.Ret.Scan}}) + {{- if $.WrapErrors}} + if err != nil { + err = fmt.Errorf("query {{.MethodName}}: %w", err) + } + {{- end}} + return {{.Ret.ReturnName}}, err + {{- end }} } {{end}} @@ -41,6 +58,39 @@ func (q *Queries) {{.MethodName}}(ctx context.Context, {{ dbarg }} {{.Arg.Pair}} {{range .Comments}}//{{.}} {{end -}} func (q *Queries) {{.MethodName}}(ctx context.Context, {{ dbarg }} {{.Arg.Pair}}) ([]{{.Ret.DefineType}}, error) { + {{- if $.EnableYDBRetry }} + var items []{{.Ret.DefineType}} + + err := q.retrier(ctx, func(ctx context.Context, db DBTX) error { + {{- if $.EmitEmptySlices}} + items = []{{.Ret.DefineType}}{} + {{else}} + items = nil + {{end -}} + + {{- template "queryCodeStdExec" . }} + if err != nil { + return err + } + defer rows.Close() + for rows.Next() { + var {{.Ret.Name}} {{.Ret.Type}} + if err := rows.Scan({{.Ret.Scan}}); err != nil { + return err + } + items = append(items, {{.Ret.ReturnName}}) + } + if err := rows.Close(); err != nil { + return err + } + + return rows.Err() + }) + if err != nil { + return nil, {{if $.WrapErrors}}fmt.Errorf("query {{.MethodName}}: %w", err){{else}}err{{end}} + } + return items, nil + {{- else}} {{- template "queryCodeStdExec" . }} if err != nil { return nil, {{if $.WrapErrors}}fmt.Errorf("query {{.MethodName}}: %w", err){{else}}err{{end}} @@ -65,6 +115,7 @@ func (q *Queries) {{.MethodName}}(ctx context.Context, {{ dbarg }} {{.Arg.Pair}} return nil, {{if $.WrapErrors}}fmt.Errorf("query {{.MethodName}}: %w", err){{else}}err{{end}} } return items, nil + {{- end }} } {{end}} @@ -72,6 +123,18 @@ func (q *Queries) {{.MethodName}}(ctx context.Context, {{ dbarg }} {{.Arg.Pair}} {{range .Comments}}//{{.}} {{end -}} func (q *Queries) {{.MethodName}}(ctx context.Context, {{ dbarg }} {{.Arg.Pair}}) error { + {{- if $.EnableYDBRetry }} + err := q.retrier(ctx, func(ctx context.Context, db DBTX) error { + {{- template "queryCodeStdExec" . }} + return err + }) + {{- if $.WrapErrors}} + if err != nil { + err = fmt.Errorf("query {{.MethodName}}: %w", err) + } + {{- end}} + return err + {{- else}} {{- template "queryCodeStdExec" . }} {{- if $.WrapErrors}} if err != nil { @@ -79,6 +142,7 @@ func (q *Queries) {{.MethodName}}(ctx context.Context, {{ dbarg }} {{.Arg.Pair}} } {{- end}} return err + {{- end}} } {{end}} @@ -86,11 +150,31 @@ func (q *Queries) {{.MethodName}}(ctx context.Context, {{ dbarg }} {{.Arg.Pair}} {{range .Comments}}//{{.}} {{end -}} func (q *Queries) {{.MethodName}}(ctx context.Context, {{ dbarg }} {{.Arg.Pair}}) (int64, error) { + {{- if $.EnableYDBRetry }} + var rowsAffected int64 + + err := q.retrier(ctx, func(ctx context.Context, db DBTX) error { + {{- template "queryCodeStdExec" . }} + if err != nil { + return err + } + + rowsAffected, err = result.RowsAffected() + return err + }) + + if err != nil { + return 0, {{if $.WrapErrors}}fmt.Errorf("query {{.MethodName}}: %w", err){{else}}err{{end}} + } + + return rowsAffected, nil + {{- else}} {{- template "queryCodeStdExec" . }} if err != nil { return 0, {{if $.WrapErrors}}fmt.Errorf("query {{.MethodName}}: %w", err){{else}}err{{end}} } return result.RowsAffected() + {{- end}} } {{end}} @@ -98,11 +182,31 @@ func (q *Queries) {{.MethodName}}(ctx context.Context, {{ dbarg }} {{.Arg.Pair}} {{range .Comments}}//{{.}} {{end -}} func (q *Queries) {{.MethodName}}(ctx context.Context, {{ dbarg }} {{.Arg.Pair}}) (int64, error) { + {{- if $.EnableYDBRetry }} + var lastID int64 + + err := q.retrier(ctx, func(ctx context.Context, db DBTX) error { + {{- template "queryCodeStdExec" . }} + if err != nil { + return err + } + + lastID, err = result.LastInsertId() + return err + }) + + if err != nil { + return 0, {{if $.WrapErrors}}fmt.Errorf("query {{.MethodName}}: %w", err){{else}}err{{end}} + } + + return lastID, nil + {{- else}} {{- template "queryCodeStdExec" . }} if err != nil { return 0, {{if $.WrapErrors}}fmt.Errorf("query {{.MethodName}}: %w", err){{else}}err{{end}} } return result.LastInsertId() + {{- end}} } {{end}} @@ -110,6 +214,23 @@ func (q *Queries) {{.MethodName}}(ctx context.Context, {{ dbarg }} {{.Arg.Pair}} {{range .Comments}}//{{.}} {{end -}} func (q *Queries) {{.MethodName}}(ctx context.Context, {{ dbarg }} {{.Arg.Pair}}) (sql.Result, error) { + {{- if $.EnableYDBRetry }} + var sqlResult sql.Result + + err := q.retrier(ctx, func(ctx context.Context, db DBTX) error { + {{- template "queryCodeStdExec" . }} + + sqlResult = result + return err + }) + + {{- if $.WrapErrors}} + if err != nil { + return nil, fmt.Errorf("query {{.MethodName}}: %w", err) + } + {{- end}} + return sqlResult, err + {{- else}} {{- template "queryCodeStdExec" . }} {{- if $.WrapErrors}} if err != nil { @@ -117,6 +238,7 @@ func (q *Queries) {{.MethodName}}(ctx context.Context, {{ dbarg }} {{.Arg.Pair}} } return result, err {{- end}} + {{- end}} } {{end}} diff --git a/internal/sqltest/local/ydb.go b/internal/sqltest/local/ydb.go index 8703b170b5..850f48a716 100644 --- a/internal/sqltest/local/ydb.go +++ b/internal/sqltest/local/ydb.go @@ -14,6 +14,7 @@ import ( migrate "github.com/sqlc-dev/sqlc/internal/migrations" "github.com/sqlc-dev/sqlc/internal/sql/sqlpath" "github.com/ydb-platform/ydb-go-sdk/v3" + "github.com/ydb-platform/ydb-go-sdk/v3/retry" ) func init() { @@ -108,9 +109,15 @@ func link_YDB(t *testing.T, migrations []string, rw bool) TestYDB { schemeCtx := ydb.WithQueryMode(ctx, ydb.SchemeQueryMode) for _, stmt := range seed { - _, err := db.ExecContext(schemeCtx, stmt) + err := retry.Do(schemeCtx, db, func(ctx context.Context, conn *sql.Conn) error { + _, err := conn.ExecContext(ctx, stmt) + return err + }, + retry.WithIdempotent(true), + retry.WithLabel("ApplySchemaMigration"), + ) if err != nil { - t.Fatalf("failed to apply migration: %s\nSQL: %s", err, stmt) + t.Fatalf("failed to apply migration with retry: %v\nSQL: %s", err, stmt) } } return TestYDB{DB: db, Prefix: prefix}