Skip to content

Commit 08be8c5

Browse files
committed
fix call tx
1 parent b51a47c commit 08be8c5

File tree

3 files changed

+29
-6
lines changed

3 files changed

+29
-6
lines changed

internal/xsql/conn.go

Lines changed: 15 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,8 @@ type conn struct {
5454

5555
dataOpts []options.ExecuteDataQueryOption
5656
scanOpts []options.ExecuteScanQueryOption
57+
58+
currentTx *tx
5759
}
5860

5961
var (
@@ -134,17 +136,22 @@ func (c *conn) BeginTx(ctx context.Context, txOptions driver.TxOptions) (driver.
134136
if err != nil {
135137
return nil, c.checkClosed(err)
136138
}
137-
return &tx{
139+
c.currentTx = &tx{
138140
conn: c,
139141
transaction: t,
140-
}, nil
142+
}
143+
return c.currentTx, nil
141144
}
142145

143146
func (c *conn) ExecContext(ctx context.Context, query string, args []driver.NamedValue) (driver.Result, error) {
144147
if c.isClosed() {
145148
return nil, errClosedConn
146149
}
147-
switch m := queryModeFromContext(ctx, c.defaultQueryMode); m {
150+
m := queryModeFromContext(ctx, c.defaultQueryMode)
151+
if c.currentTx != nil && m == DataQueryMode {
152+
return c.currentTx.ExecContext(ctx, query, args)
153+
}
154+
switch m {
148155
case DataQueryMode:
149156
_, res, err := c.session.Execute(ctx, txControl(ctx, c.defaultTxControl), query, toQueryParams(args))
150157
if err != nil {
@@ -175,7 +182,11 @@ func (c *conn) QueryContext(ctx context.Context, query string, args []driver.Nam
175182
if c.isClosed() {
176183
return nil, errClosedConn
177184
}
178-
switch m := queryModeFromContext(ctx, c.defaultQueryMode); m {
185+
m := queryModeFromContext(ctx, c.defaultQueryMode)
186+
if c.currentTx != nil && m == DataQueryMode {
187+
return c.currentTx.QueryContext(ctx, query, args)
188+
}
189+
switch m {
179190
case DataQueryMode:
180191
_, res, err := c.session.Execute(ctx, txControl(ctx, c.defaultTxControl), query, toQueryParams(args))
181192
if err != nil {

internal/xsql/tx.go

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -59,6 +59,9 @@ func (tx *tx) Commit() (err error) {
5959
if tx.conn.isClosed() {
6060
return errClosedConn
6161
}
62+
defer func() {
63+
tx.conn.currentTx = nil
64+
}()
6265
_, err = tx.transaction.CommitTx(context.Background())
6366
if err != nil {
6467
return tx.conn.checkClosed(err)
@@ -70,6 +73,9 @@ func (tx *tx) Rollback() (err error) {
7073
if tx.conn.isClosed() {
7174
return errClosedConn
7275
}
76+
defer func() {
77+
tx.conn.currentTx = nil
78+
}()
7379
err = tx.transaction.Rollback(context.Background())
7480
if err != nil {
7581
return tx.conn.checkClosed(err)

sql_e2e_test.go

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -123,8 +123,13 @@ func TestDatabaseSql(t *testing.T) {
123123
if err != nil {
124124
return fmt.Errorf("cannot upsert views: %w", err)
125125
}
126-
row = db.QueryRowContext(
127-
ydb.WithQueryMode(ctx, ydb.ScanQueryMode),
126+
return nil
127+
}, retry.Idempotent(true))
128+
if err != nil {
129+
t.Fatalf("begin tx failed: %v\n", err)
130+
}
131+
err = retry.DoTx(ctx, db, func(ctx context.Context, tx *sql.Tx) error {
132+
row := tx.QueryRowContext(ctx,
128133
render(
129134
querySelect,
130135
templateConfig{
@@ -135,6 +140,7 @@ func TestDatabaseSql(t *testing.T) {
135140
sql.Named("seasonID", uint64(1)),
136141
sql.Named("episodeID", uint64(1)),
137142
)
143+
var views sql.NullFloat64
138144
if err = row.Scan(&views); err != nil {
139145
return fmt.Errorf("cannot select current views: %w", err)
140146
}

0 commit comments

Comments
 (0)