@@ -3,6 +3,7 @@ package table
33import (
44 "context"
55 "fmt"
6+ "sync/atomic"
67
78 "github.com/ydb-platform/ydb-go-genproto/protos/Ydb_Table"
89
@@ -16,35 +17,47 @@ import (
1617 "github.com/ydb-platform/ydb-go-sdk/v3/trace"
1718)
1819
19- //nolint:gofumpt
20- //nolint:nolintlint
2120var (
22- // errAlreadyCommited returns if transaction Commit called twice
23- errAlreadyCommited = xerrors .Wrap (fmt .Errorf ("already committed" ))
21+ errTxInvalidatedWithCommit = xerrors .Wrap (fmt .Errorf ("transaction invalidated from WithCommit() call" ))
22+ errTxAlreadyCommitted = xerrors .Wrap (fmt .Errorf ("transaction already committed" ))
23+ errTxRollbackedEarly = xerrors .Wrap (fmt .Errorf ("transaction rollbacked early" ))
2424)
2525
26- type transaction struct {
27- id string
28- s * session
29- c * table.TransactionControl
26+ type txState int32
27+
28+ const (
29+ txStateInitialized = iota
30+ txStateInvalidatedWithCommit
31+ txStateCommitted
32+ txStateRollbacked
33+ )
3034
31- committed bool
35+ type transaction struct {
36+ id string
37+ s * session
38+ control * table.TransactionControl
39+ state txState
3240}
3341
3442func (tx * transaction ) ID () string {
3543 return tx .id
3644}
3745
38- func (tx * transaction ) IsNil () bool {
39- return tx == nil
40- }
41-
4246// Execute executes query represented by text within transaction tx.
4347func (tx * transaction ) Execute (
4448 ctx context.Context ,
4549 query string , params * table.QueryParameters ,
4650 opts ... options.ExecuteDataQueryOption ,
4751) (r result.Result , err error ) {
52+ switch txState (atomic .LoadInt32 ((* int32 )(& tx .state ))) {
53+ case txStateInvalidatedWithCommit :
54+ return nil , xerrors .WithStackTrace (errTxInvalidatedWithCommit )
55+ case txStateCommitted :
56+ return nil , xerrors .WithStackTrace (errTxAlreadyCommitted )
57+ case txStateRollbacked :
58+ return nil , xerrors .WithStackTrace (errTxRollbackedEarly )
59+ default :
60+ }
4861 var (
4962 a = allocator .New ()
5063 q = queryFromText (query )
@@ -72,7 +85,7 @@ func (tx *transaction) Execute(
7285 defer func () {
7386 onDone (r , err )
7487 }()
75- _ , r , err = tx .s .Execute (ctx , tx .txc () , query , params , opts ... )
88+ _ , r , err = tx .s .Execute (ctx , tx .control , query , params , opts ... )
7689 if err != nil {
7790 return nil , xerrors .WithStackTrace (err )
7891 }
@@ -86,6 +99,15 @@ func (tx *transaction) ExecuteStatement(
8699 stmt table.Statement , params * table.QueryParameters ,
87100 opts ... options.ExecuteDataQueryOption ,
88101) (r result.Result , err error ) {
102+ switch txState (atomic .LoadInt32 ((* int32 )(& tx .state ))) {
103+ case txStateInvalidatedWithCommit :
104+ return nil , xerrors .WithStackTrace (errTxInvalidatedWithCommit )
105+ case txStateCommitted :
106+ return nil , xerrors .WithStackTrace (errTxAlreadyCommitted )
107+ case txStateRollbacked :
108+ return nil , xerrors .WithStackTrace (errTxRollbackedEarly )
109+ default :
110+ }
89111 if params == nil {
90112 params = table .NewQueryParameters ()
91113 }
@@ -110,7 +132,7 @@ func (tx *transaction) ExecuteStatement(
110132 onDone (r , err )
111133 }()
112134
113- _ , r , err = stmt .Execute (ctx , tx .txc () , params , opts ... )
135+ _ , r , err = stmt .Execute (ctx , tx .control , params , opts ... )
114136 if err != nil {
115137 return nil , xerrors .WithStackTrace (err )
116138 }
@@ -119,11 +141,14 @@ func (tx *transaction) ExecuteStatement(
119141}
120142
121143func (tx * transaction ) WithCommit () table.TransactionActor {
144+ defer func () {
145+ atomic .StoreInt32 ((* int32 )(& tx .state ), txStateInvalidatedWithCommit )
146+ }()
122147 return & transaction {
123- id : tx .id ,
124- s : tx .s ,
125- c : table .TxControl (table .WithTxID (tx .id ), table .CommitTx ()),
126- committed : tx .committed ,
148+ id : tx .id ,
149+ s : tx .s ,
150+ control : table .TxControl (table .WithTxID (tx .id ), table .CommitTx ()),
151+ state : txState ( atomic . LoadInt32 (( * int32 )( & tx .state ))) ,
127152 }
128153}
129154
@@ -132,14 +157,20 @@ func (tx *transaction) CommitTx(
132157 ctx context.Context ,
133158 opts ... options.CommitTransactionOption ,
134159) (r result.Result , err error ) {
135- if tx .committed {
136- return nil , xerrors .WithStackTrace (errAlreadyCommited )
160+ switch txState (atomic .LoadInt32 ((* int32 )(& tx .state ))) {
161+ case txStateInvalidatedWithCommit :
162+ return nil , xerrors .WithStackTrace (errTxInvalidatedWithCommit )
163+ case txStateCommitted :
164+ return nil , xerrors .WithStackTrace (errTxAlreadyCommitted )
165+ case txStateRollbacked :
166+ return nil , xerrors .WithStackTrace (errTxRollbackedEarly )
167+ default :
168+ defer func () {
169+ if err == nil {
170+ atomic .StoreInt32 ((* int32 )(& tx .state ), txStateCommitted )
171+ }
172+ }()
137173 }
138- defer func () {
139- if err == nil {
140- tx .committed = true
141- }
142- }()
143174 onDone := trace .TableOnSessionTransactionCommit (
144175 tx .s .config .Trace (),
145176 & ctx ,
@@ -185,8 +216,19 @@ func (tx *transaction) CommitTx(
185216
186217// Rollback performs a rollback of the specified active transaction.
187218func (tx * transaction ) Rollback (ctx context.Context ) (err error ) {
188- if tx .committed {
189- return nil
219+ switch txState (atomic .LoadInt32 ((* int32 )(& tx .state ))) {
220+ case txStateInvalidatedWithCommit :
221+ return xerrors .WithStackTrace (errTxInvalidatedWithCommit )
222+ case txStateCommitted :
223+ return xerrors .WithStackTrace (errTxAlreadyCommitted )
224+ case txStateRollbacked :
225+ return xerrors .WithStackTrace (errTxRollbackedEarly )
226+ default :
227+ defer func () {
228+ if err == nil {
229+ atomic .StoreInt32 ((* int32 )(& tx .state ), txStateRollbacked )
230+ }
231+ }()
190232 }
191233 onDone := trace .TableOnSessionTransactionRollback (
192234 tx .s .config .Trace (),
@@ -212,10 +254,3 @@ func (tx *transaction) Rollback(ctx context.Context) (err error) {
212254 )
213255 return xerrors .WithStackTrace (err )
214256}
215-
216- func (tx * transaction ) txc () * table.TransactionControl {
217- if tx .c == nil {
218- tx .c = table .TxControl (table .WithTx (tx ))
219- }
220- return tx .c
221- }
0 commit comments