Skip to content

Commit edd04c8

Browse files
committed
nop Rollback for committed transactions
1 parent dff24d3 commit edd04c8

File tree

3 files changed

+155
-0
lines changed

3 files changed

+155
-0
lines changed

internal/errors/defaults.go

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,4 +10,7 @@ var (
1010

1111
// ErrNilConnection is returned when use nil preferred connection
1212
ErrNilConnection = errors.New("nil connection")
13+
14+
// ErrAlreadyCommited returns if transaction Commit called twice
15+
ErrAlreadyCommited = errors.New("already committed")
1316
)

internal/table/session.go

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -841,6 +841,8 @@ type Transaction struct {
841841
id string
842842
s *session
843843
c *table.TransactionControl
844+
845+
committed bool
844846
}
845847

846848
func (tx *Transaction) ID() string {
@@ -873,6 +875,14 @@ func (tx *Transaction) ExecuteStatement(
873875

874876
// CommitTx commits specified active transaction.
875877
func (tx *Transaction) CommitTx(ctx context.Context, opts ...options.CommitTransactionOption) (r result.Result, err error) {
878+
if tx.committed {
879+
return nil, errors.ErrAlreadyCommited
880+
}
881+
defer func() {
882+
if err == nil {
883+
tx.committed = true
884+
}
885+
}()
876886
onDone := trace.TableOnSessionTransactionCommit(tx.s.trace, &ctx, tx.s, tx)
877887
defer func() {
878888
onDone(err)
@@ -904,6 +914,9 @@ func (tx *Transaction) CommitTx(ctx context.Context, opts ...options.CommitTrans
904914

905915
// Rollback performs a rollback of the specified active transaction.
906916
func (tx *Transaction) Rollback(ctx context.Context) (err error) {
917+
if tx.committed {
918+
return nil
919+
}
907920
onDone := trace.TableOnSessionTransactionRollback(tx.s.trace, &ctx, tx.s, tx)
908921
defer func() {
909922
onDone(err)

internal/table/session_test.go

Lines changed: 139 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,9 +8,11 @@ import (
88
"time"
99

1010
"google.golang.org/protobuf/proto"
11+
"google.golang.org/protobuf/types/known/anypb"
1112

1213
"github.com/ydb-platform/ydb-go-genproto/Ydb_Table_V1"
1314
"github.com/ydb-platform/ydb-go-genproto/protos/Ydb"
15+
"github.com/ydb-platform/ydb-go-genproto/protos/Ydb_Operations"
1416
"github.com/ydb-platform/ydb-go-genproto/protos/Ydb_Scheme"
1517
"github.com/ydb-platform/ydb-go-genproto/protos/Ydb_Table"
1618

@@ -471,3 +473,140 @@ func TestQueryCachePolicyKeepInCache(t *testing.T) {
471473
})
472474
}
473475
}
476+
477+
func TestTxSkipRollbackForCommitted(t *testing.T) {
478+
var (
479+
begin = 0
480+
commit = 0
481+
rollback = 0
482+
)
483+
b := StubBuilder{
484+
T: t,
485+
Cluster: testutil.NewCluster(
486+
testutil.WithInvokeHandlers(
487+
testutil.InvokeHandlers{
488+
// nolint:unparam
489+
testutil.TableBeginTransaction: func(request interface{}) (proto.Message, error) {
490+
_, ok := request.(*Ydb_Table.BeginTransactionRequest)
491+
if !ok {
492+
t.Fatalf("cannot cast request '%T' to *Ydb_Table.BeginTransactionRequest", request)
493+
}
494+
result, err := anypb.New(
495+
&Ydb_Table.BeginTransactionResult{
496+
TxMeta: &Ydb_Table.TransactionMeta{
497+
Id: "",
498+
},
499+
},
500+
)
501+
if err != nil {
502+
return nil, err
503+
}
504+
begin++
505+
return &Ydb_Table.BeginTransactionResponse{
506+
Operation: &Ydb_Operations.Operation{
507+
Ready: true,
508+
Status: Ydb.StatusIds_SUCCESS,
509+
Result: result,
510+
},
511+
}, nil
512+
},
513+
// nolint:unparam
514+
testutil.TableCommitTransaction: func(request interface{}) (proto.Message, error) {
515+
_, ok := request.(*Ydb_Table.CommitTransactionRequest)
516+
if !ok {
517+
t.Fatalf("cannot cast request '%T' to *Ydb_Table.CommitTransactionRequest", request)
518+
}
519+
result, err := anypb.New(
520+
&Ydb_Table.CommitTransactionResult{},
521+
)
522+
if err != nil {
523+
return nil, err
524+
}
525+
commit++
526+
return &Ydb_Table.CommitTransactionResponse{
527+
Operation: &Ydb_Operations.Operation{
528+
Ready: true,
529+
Status: Ydb.StatusIds_SUCCESS,
530+
Result: result,
531+
},
532+
}, nil
533+
},
534+
// nolint:unparam
535+
testutil.TableRollbackTransaction: func(request interface{}) (proto.Message, error) {
536+
_, ok := request.(*Ydb_Table.RollbackTransactionRequest)
537+
if !ok {
538+
t.Fatalf("cannot cast request '%T' to *Ydb_Table.RollbackTransactionRequest", request)
539+
}
540+
rollback++
541+
return &Ydb_Table.RollbackTransactionResponse{
542+
Operation: &Ydb_Operations.Operation{
543+
Ready: true,
544+
Status: Ydb.StatusIds_SUCCESS,
545+
},
546+
}, nil
547+
},
548+
// nolint:unparam
549+
testutil.TableCreateSession: func(request interface{}) (result proto.Message, err error) {
550+
return &Ydb_Table.CreateSessionResult{
551+
SessionId: testutil.SessionID(),
552+
}, nil
553+
},
554+
},
555+
),
556+
),
557+
}
558+
s, err := b.createSession(context.Background())
559+
if err != nil {
560+
t.Fatal(err)
561+
}
562+
{
563+
x, err := s.BeginTransaction(context.Background(), table.TxSettings())
564+
if err != nil {
565+
t.Fatal(err)
566+
}
567+
if begin != 1 {
568+
t.Fatalf("unexpected begin: %d", begin)
569+
}
570+
_, err = x.CommitTx(context.Background())
571+
if err != nil {
572+
t.Fatal(err)
573+
}
574+
if commit != 1 {
575+
t.Fatalf("unexpected commit: %d", begin)
576+
}
577+
_, _ = x.CommitTx(context.Background())
578+
if commit != 1 {
579+
t.Fatalf("unexpected commit: %d", begin)
580+
}
581+
err = x.Rollback(context.Background())
582+
if err != nil {
583+
t.Fatal(err)
584+
}
585+
if rollback != 0 {
586+
t.Fatalf("unexpected rollback: %d", begin)
587+
}
588+
}
589+
{
590+
x, err := s.BeginTransaction(context.Background(), table.TxSettings())
591+
if err != nil {
592+
t.Fatal(err)
593+
}
594+
if begin != 2 {
595+
t.Fatalf("unexpected begin: %d", begin)
596+
}
597+
err = x.Rollback(context.Background())
598+
if err != nil {
599+
t.Fatal(err)
600+
}
601+
if rollback != 1 {
602+
t.Fatalf("unexpected rollback: %d", begin)
603+
}
604+
_, err = x.CommitTx(context.Background())
605+
if err != nil {
606+
t.Fatal(err)
607+
}
608+
if commit != 2 {
609+
t.Fatalf("unexpected commit: %d", begin)
610+
}
611+
}
612+
}

0 commit comments

Comments
 (0)