Skip to content

Commit cc1d1de

Browse files
committed
fix
1 parent 6641717 commit cc1d1de

File tree

2 files changed

+78
-41
lines changed

2 files changed

+78
-41
lines changed

internal/table/session.go

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -680,9 +680,10 @@ func (s *session) executeQueryResult(res *Ydb_Table.ExecuteQueryResult, txContro
680680
table.Transaction, result.Result, error,
681681
) {
682682
t := &transaction{
683-
id: res.GetTxMeta().GetId(),
684-
s: s,
685-
c: txControl,
683+
id: res.GetTxMeta().GetId(),
684+
state: txStateInitialized,
685+
s: s,
686+
control: txControl,
686687
}
687688
r := scanner.NewUnary(
688689
res.GetResultSets(),
@@ -1084,8 +1085,9 @@ func (s *session) BeginTransaction(
10841085
return
10851086
}
10861087
return &transaction{
1087-
id: result.GetTxMeta().GetId(),
1088-
s: s,
1089-
c: table.TxControl(table.WithTxID(result.GetTxMeta().GetId())),
1088+
id: result.GetTxMeta().GetId(),
1089+
state: txStateInitialized,
1090+
s: s,
1091+
control: table.TxControl(table.WithTxID(result.GetTxMeta().GetId())),
10901092
}, nil
10911093
}

internal/table/transaction.go

Lines changed: 70 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@ package table
33
import (
44
"context"
55
"fmt"
6+
"sync/atomic"
67

78
"github.com/ydb-platform/ydb-go-genproto/protos/Ydb_Table"
89

@@ -16,35 +17,47 @@ import (
1617
"github.com/ydb-platform/ydb-go-sdk/v3/trace"
1718
)
1819

19-
//nolint:gofumpt
20-
//nolint:nolintlint
2120
var (
22-
// errAlreadyCommited returns if transaction Commit called twice
23-
errAlreadyCommited = xerrors.Wrap(fmt.Errorf("already committed"))
21+
errTxInvalidatedWithCommit = xerrors.Wrap(fmt.Errorf("transaction invalidated from WithCommit() call"))
22+
errTxAlreadyCommitted = xerrors.Wrap(fmt.Errorf("transaction already committed"))
23+
errTxRollbackedEarly = xerrors.Wrap(fmt.Errorf("transaction rollbacked early"))
2424
)
2525

26-
type transaction struct {
27-
id string
28-
s *session
29-
c *table.TransactionControl
26+
type txState int32
27+
28+
const (
29+
txStateInitialized = iota
30+
txStateInvalidatedWithCommit
31+
txStateCommitted
32+
txStateRollbacked
33+
)
3034

31-
committed bool
35+
type transaction struct {
36+
id string
37+
s *session
38+
control *table.TransactionControl
39+
state txState
3240
}
3341

3442
func (tx *transaction) ID() string {
3543
return tx.id
3644
}
3745

38-
func (tx *transaction) IsNil() bool {
39-
return tx == nil
40-
}
41-
4246
// Execute executes query represented by text within transaction tx.
4347
func (tx *transaction) Execute(
4448
ctx context.Context,
4549
query string, params *table.QueryParameters,
4650
opts ...options.ExecuteDataQueryOption,
4751
) (r result.Result, err error) {
52+
switch txState(atomic.LoadInt32((*int32)(&tx.state))) {
53+
case txStateInvalidatedWithCommit:
54+
return nil, xerrors.WithStackTrace(errTxInvalidatedWithCommit)
55+
case txStateCommitted:
56+
return nil, xerrors.WithStackTrace(errTxAlreadyCommitted)
57+
case txStateRollbacked:
58+
return nil, xerrors.WithStackTrace(errTxRollbackedEarly)
59+
default:
60+
}
4861
var (
4962
a = allocator.New()
5063
q = queryFromText(query)
@@ -72,7 +85,7 @@ func (tx *transaction) Execute(
7285
defer func() {
7386
onDone(r, err)
7487
}()
75-
_, r, err = tx.s.Execute(ctx, tx.txc(), query, params, opts...)
88+
_, r, err = tx.s.Execute(ctx, tx.control, query, params, opts...)
7689
if err != nil {
7790
return nil, xerrors.WithStackTrace(err)
7891
}
@@ -86,6 +99,15 @@ func (tx *transaction) ExecuteStatement(
8699
stmt table.Statement, params *table.QueryParameters,
87100
opts ...options.ExecuteDataQueryOption,
88101
) (r result.Result, err error) {
102+
switch txState(atomic.LoadInt32((*int32)(&tx.state))) {
103+
case txStateInvalidatedWithCommit:
104+
return nil, xerrors.WithStackTrace(errTxInvalidatedWithCommit)
105+
case txStateCommitted:
106+
return nil, xerrors.WithStackTrace(errTxAlreadyCommitted)
107+
case txStateRollbacked:
108+
return nil, xerrors.WithStackTrace(errTxRollbackedEarly)
109+
default:
110+
}
89111
if params == nil {
90112
params = table.NewQueryParameters()
91113
}
@@ -110,7 +132,7 @@ func (tx *transaction) ExecuteStatement(
110132
onDone(r, err)
111133
}()
112134

113-
_, r, err = stmt.Execute(ctx, tx.txc(), params, opts...)
135+
_, r, err = stmt.Execute(ctx, tx.control, params, opts...)
114136
if err != nil {
115137
return nil, xerrors.WithStackTrace(err)
116138
}
@@ -119,11 +141,14 @@ func (tx *transaction) ExecuteStatement(
119141
}
120142

121143
func (tx *transaction) WithCommit() table.TransactionActor {
144+
defer func() {
145+
atomic.StoreInt32((*int32)(&tx.state), txStateInvalidatedWithCommit)
146+
}()
122147
return &transaction{
123-
id: tx.id,
124-
s: tx.s,
125-
c: table.TxControl(table.WithTxID(tx.id), table.CommitTx()),
126-
committed: tx.committed,
148+
id: tx.id,
149+
s: tx.s,
150+
control: table.TxControl(table.WithTxID(tx.id), table.CommitTx()),
151+
state: txState(atomic.LoadInt32((*int32)(&tx.state))),
127152
}
128153
}
129154

@@ -132,14 +157,20 @@ func (tx *transaction) CommitTx(
132157
ctx context.Context,
133158
opts ...options.CommitTransactionOption,
134159
) (r result.Result, err error) {
135-
if tx.committed {
136-
return nil, xerrors.WithStackTrace(errAlreadyCommited)
160+
switch txState(atomic.LoadInt32((*int32)(&tx.state))) {
161+
case txStateInvalidatedWithCommit:
162+
return nil, xerrors.WithStackTrace(errTxInvalidatedWithCommit)
163+
case txStateCommitted:
164+
return nil, xerrors.WithStackTrace(errTxAlreadyCommitted)
165+
case txStateRollbacked:
166+
return nil, xerrors.WithStackTrace(errTxRollbackedEarly)
167+
default:
168+
defer func() {
169+
if err == nil {
170+
atomic.StoreInt32((*int32)(&tx.state), txStateCommitted)
171+
}
172+
}()
137173
}
138-
defer func() {
139-
if err == nil {
140-
tx.committed = true
141-
}
142-
}()
143174
onDone := trace.TableOnSessionTransactionCommit(
144175
tx.s.config.Trace(),
145176
&ctx,
@@ -185,8 +216,19 @@ func (tx *transaction) CommitTx(
185216

186217
// Rollback performs a rollback of the specified active transaction.
187218
func (tx *transaction) Rollback(ctx context.Context) (err error) {
188-
if tx.committed {
189-
return nil
219+
switch txState(atomic.LoadInt32((*int32)(&tx.state))) {
220+
case txStateInvalidatedWithCommit:
221+
return xerrors.WithStackTrace(errTxInvalidatedWithCommit)
222+
case txStateCommitted:
223+
return xerrors.WithStackTrace(errTxAlreadyCommitted)
224+
case txStateRollbacked:
225+
return xerrors.WithStackTrace(errTxRollbackedEarly)
226+
default:
227+
defer func() {
228+
if err == nil {
229+
atomic.StoreInt32((*int32)(&tx.state), txStateRollbacked)
230+
}
231+
}()
190232
}
191233
onDone := trace.TableOnSessionTransactionRollback(
192234
tx.s.config.Trace(),
@@ -212,10 +254,3 @@ func (tx *transaction) Rollback(ctx context.Context) (err error) {
212254
)
213255
return xerrors.WithStackTrace(err)
214256
}
215-
216-
func (tx *transaction) txc() *table.TransactionControl {
217-
if tx.c == nil {
218-
tx.c = table.TxControl(table.WithTx(tx))
219-
}
220-
return tx.c
221-
}

0 commit comments

Comments
 (0)