Skip to content

Commit bf1e069

Browse files
authored
chore(spanner): handle commit retry protocol extension for mux rw (googleapis#11472)
* chore(spanner): handle commit retry protocol extension for mux rw * incorporate changes * fix tests
1 parent 586669a commit bf1e069

File tree

3 files changed

+163
-42
lines changed

3 files changed

+163
-42
lines changed

spanner/internal/testutil/inmem_spanner_server.go

Lines changed: 43 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -250,6 +250,7 @@ type SimulatedExecutionTime struct {
250250
MinimumExecutionTime time.Duration
251251
RandomExecutionTime time.Duration
252252
Errors []error
253+
Responses []interface{}
253254
// Keep error after execution. The error will continue to be returned until
254255
// it is cleared.
255256
KeepError bool
@@ -678,47 +679,57 @@ func (s *inMemSpannerServer) getStatementResult(sql string) (*StatementResult, e
678679
return result, nil
679680
}
680681

681-
func (s *inMemSpannerServer) simulateExecutionTime(method string, req interface{}) error {
682+
func (s *inMemSpannerServer) simulateExecutionTime(method string, req interface{}) (interface{}, error) {
682683
s.mu.Lock()
684+
defer s.mu.Unlock()
685+
686+
// Check if the server is stopped
683687
if s.stopped {
684-
s.mu.Unlock()
685-
return gstatus.Error(codes.Unavailable, "server has been stopped")
688+
return nil, gstatus.Error(codes.Unavailable, "server has been stopped")
686689
}
690+
691+
// Send the request to the receivedRequests channel
687692
s.receivedRequests <- req
688-
s.mu.Unlock()
689-
s.ready()
690-
s.mu.Lock()
693+
694+
// Check for a simulated error
691695
if s.err != nil {
692696
err := s.err
693697
s.err = nil
694-
s.mu.Unlock()
695-
return err
698+
return nil, err
696699
}
700+
701+
// Check for a simulated execution time
697702
executionTime, ok := s.executionTimes[method]
698-
s.mu.Unlock()
699703
if ok {
700704
var randTime int64
701705
if executionTime.RandomExecutionTime > 0 {
702706
randTime = rand.Int63n(int64(executionTime.RandomExecutionTime))
703707
}
704708
totalExecutionTime := time.Duration(int64(executionTime.MinimumExecutionTime) + randTime)
705709
<-time.After(totalExecutionTime)
706-
s.mu.Lock()
710+
711+
// Check for errors in the execution time
707712
if len(executionTime.Errors) > 0 {
708713
err := executionTime.Errors[0]
709714
if !executionTime.KeepError {
710715
executionTime.Errors = executionTime.Errors[1:]
711716
}
712-
s.mu.Unlock()
713-
return err
717+
return nil, err
718+
}
719+
720+
// Check for responses in the execution time
721+
if len(executionTime.Responses) > 0 {
722+
response := executionTime.Responses[0]
723+
executionTime.Responses = executionTime.Responses[1:]
724+
return response, nil
714725
}
715-
s.mu.Unlock()
716726
}
717-
return nil
727+
728+
return nil, nil
718729
}
719730

720731
func (s *inMemSpannerServer) CreateSession(ctx context.Context, req *spannerpb.CreateSessionRequest) (*spannerpb.Session, error) {
721-
if err := s.simulateExecutionTime(MethodCreateSession, req); err != nil {
732+
if _, err := s.simulateExecutionTime(MethodCreateSession, req); err != nil {
722733
return nil, err
723734
}
724735
if req.Database == "" {
@@ -750,7 +761,7 @@ func (s *inMemSpannerServer) CreateSession(ctx context.Context, req *spannerpb.C
750761
}
751762

752763
func (s *inMemSpannerServer) BatchCreateSessions(ctx context.Context, req *spannerpb.BatchCreateSessionsRequest) (*spannerpb.BatchCreateSessionsResponse, error) {
753-
if err := s.simulateExecutionTime(MethodBatchCreateSession, req); err != nil {
764+
if _, err := s.simulateExecutionTime(MethodBatchCreateSession, req); err != nil {
754765
return nil, err
755766
}
756767
if req.Database == "" {
@@ -792,7 +803,7 @@ func (s *inMemSpannerServer) BatchCreateSessions(ctx context.Context, req *spann
792803
}
793804

794805
func (s *inMemSpannerServer) GetSession(ctx context.Context, req *spannerpb.GetSessionRequest) (*spannerpb.Session, error) {
795-
if err := s.simulateExecutionTime(MethodGetSession, req); err != nil {
806+
if _, err := s.simulateExecutionTime(MethodGetSession, req); err != nil {
796807
return nil, err
797808
}
798809
if req.Name == "" {
@@ -833,7 +844,7 @@ func (s *inMemSpannerServer) ListSessions(ctx context.Context, req *spannerpb.Li
833844
}
834845

835846
func (s *inMemSpannerServer) DeleteSession(ctx context.Context, req *spannerpb.DeleteSessionRequest) (*emptypb.Empty, error) {
836-
if err := s.simulateExecutionTime(MethodDeleteSession, req); err != nil {
847+
if _, err := s.simulateExecutionTime(MethodDeleteSession, req); err != nil {
837848
return nil, err
838849
}
839850
if req.Name == "" {
@@ -850,7 +861,7 @@ func (s *inMemSpannerServer) DeleteSession(ctx context.Context, req *spannerpb.D
850861
}
851862

852863
func (s *inMemSpannerServer) ExecuteSql(ctx context.Context, req *spannerpb.ExecuteSqlRequest) (*spannerpb.ResultSet, error) {
853-
if err := s.simulateExecutionTime(MethodExecuteSql, req); err != nil {
864+
if _, err := s.simulateExecutionTime(MethodExecuteSql, req); err != nil {
854865
return nil, err
855866
}
856867
if req.Sql == "SELECT 1" {
@@ -905,7 +916,7 @@ func (s *inMemSpannerServer) ExecuteSql(ctx context.Context, req *spannerpb.Exec
905916
}
906917

907918
func (s *inMemSpannerServer) ExecuteStreamingSql(req *spannerpb.ExecuteSqlRequest, stream spannerpb.Spanner_ExecuteStreamingSqlServer) error {
908-
if err := s.simulateExecutionTime(MethodExecuteStreamingSql, req); err != nil {
919+
if _, err := s.simulateExecutionTime(MethodExecuteStreamingSql, req); err != nil {
909920
return err
910921
}
911922
return s.executeStreamingSQL(req, stream)
@@ -984,7 +995,7 @@ func (s *inMemSpannerServer) executeStreamingSQL(req *spannerpb.ExecuteSqlReques
984995
}
985996

986997
func (s *inMemSpannerServer) ExecuteBatchDml(ctx context.Context, req *spannerpb.ExecuteBatchDmlRequest) (*spannerpb.ExecuteBatchDmlResponse, error) {
987-
if err := s.simulateExecutionTime(MethodExecuteBatchDml, req); err != nil {
998+
if _, err := s.simulateExecutionTime(MethodExecuteBatchDml, req); err != nil {
988999
return nil, err
9891000
}
9901001
if req.Session == "" {
@@ -1047,7 +1058,7 @@ func (s *inMemSpannerServer) Read(ctx context.Context, req *spannerpb.ReadReques
10471058
}
10481059

10491060
func (s *inMemSpannerServer) StreamingRead(req *spannerpb.ReadRequest, stream spannerpb.Spanner_StreamingReadServer) error {
1050-
if err := s.simulateExecutionTime(MethodStreamingRead, req); err != nil {
1061+
if _, err := s.simulateExecutionTime(MethodStreamingRead, req); err != nil {
10511062
return err
10521063
}
10531064
sqlReq := &spannerpb.ExecuteSqlRequest{
@@ -1066,7 +1077,7 @@ func (s *inMemSpannerServer) StreamingRead(req *spannerpb.ReadRequest, stream sp
10661077
}
10671078

10681079
func (s *inMemSpannerServer) BeginTransaction(ctx context.Context, req *spannerpb.BeginTransactionRequest) (*spannerpb.Transaction, error) {
1069-
if err := s.simulateExecutionTime(MethodBeginTransaction, req); err != nil {
1080+
if _, err := s.simulateExecutionTime(MethodBeginTransaction, req); err != nil {
10701081
return nil, err
10711082
}
10721083
if req.Session == "" {
@@ -1085,7 +1096,8 @@ func (s *inMemSpannerServer) BeginTransaction(ctx context.Context, req *spannerp
10851096
}
10861097

10871098
func (s *inMemSpannerServer) Commit(ctx context.Context, req *spannerpb.CommitRequest) (*spannerpb.CommitResponse, error) {
1088-
if err := s.simulateExecutionTime(MethodCommitTransaction, req); err != nil {
1099+
mockResponse, err := s.simulateExecutionTime(MethodCommitTransaction, req)
1100+
if err != nil {
10891101
return nil, err
10901102
}
10911103
if req.Session == "" {
@@ -1107,8 +1119,11 @@ func (s *inMemSpannerServer) Commit(ctx context.Context, req *spannerpb.CommitRe
11071119
} else {
11081120
return nil, gstatus.Error(codes.InvalidArgument, "Missing transaction in commit request")
11091121
}
1110-
s.removeTransaction(tx)
1111-
resp := &spannerpb.CommitResponse{CommitTimestamp: getCurrentTimestamp()}
1122+
resp, ok := mockResponse.(*spannerpb.CommitResponse)
1123+
if !ok {
1124+
resp = &spannerpb.CommitResponse{CommitTimestamp: getCurrentTimestamp()}
1125+
s.removeTransaction(tx)
1126+
}
11121127
if req.ReturnCommitStats {
11131128
resp.CommitStats = &spannerpb.CommitResponse_CommitStats{
11141129
MutationCount: int64(1),
@@ -1142,7 +1157,7 @@ func (s *inMemSpannerServer) Rollback(ctx context.Context, req *spannerpb.Rollba
11421157
}
11431158

11441159
func (s *inMemSpannerServer) PartitionQuery(ctx context.Context, req *spannerpb.PartitionQueryRequest) (*spannerpb.PartitionResponse, error) {
1145-
if err := s.simulateExecutionTime(MethodPartitionQuery, req); err != nil {
1160+
if _, err := s.simulateExecutionTime(MethodPartitionQuery, req); err != nil {
11461161
return nil, err
11471162
}
11481163
s.mu.Lock()
@@ -1214,7 +1229,7 @@ func DecodeResumeToken(t []byte) (uint64, error) {
12141229
}
12151230

12161231
func (s *inMemSpannerServer) BatchWrite(req *spannerpb.BatchWriteRequest, stream spannerpb.Spanner_BatchWriteServer) error {
1217-
if err := s.simulateExecutionTime(MethodBatchWrite, req); err != nil {
1232+
if _, err := s.simulateExecutionTime(MethodBatchWrite, req); err != nil {
12181233
return err
12191234
}
12201235
return s.batchWrite(req, stream)

spanner/transaction.go

Lines changed: 28 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -1684,7 +1684,6 @@ func (t *ReadWriteTransaction) commit(ctx context.Context, options CommitOptions
16841684
}
16851685
t.state = txClosed // No further operations after commit.
16861686
close(t.txReadyOrClosed)
1687-
precommitToken := t.precommitToken
16881687
t.mu.Unlock()
16891688
if err != nil {
16901689
return resp, err
@@ -1703,17 +1702,32 @@ func (t *ReadWriteTransaction) commit(ctx context.Context, options CommitOptions
17031702
if options.MaxCommitDelay != nil {
17041703
maxCommitDelay = durationpb.New(*(options.MaxCommitDelay))
17051704
}
1706-
res, e := client.Commit(contextWithOutgoingMetadata(ctx, t.sh.getMetadata(), t.disableRouteToLeader), &sppb.CommitRequest{
1707-
Session: sid,
1708-
Transaction: &sppb.CommitRequest_TransactionId{
1709-
TransactionId: t.tx,
1710-
},
1711-
PrecommitToken: precommitToken,
1712-
RequestOptions: createRequestOptions(t.txOpts.CommitPriority, "", t.txOpts.TransactionTag),
1713-
Mutations: mutationProtos,
1714-
ReturnCommitStats: options.ReturnCommitStats,
1715-
MaxCommitDelay: maxCommitDelay,
1716-
}, gax.WithGRPCOptions(grpc.Header(&md)))
1705+
performCommit := func(includeMutations bool) (*sppb.CommitResponse, error) {
1706+
req := &sppb.CommitRequest{
1707+
Session: sid,
1708+
Transaction: &sppb.CommitRequest_TransactionId{
1709+
TransactionId: t.tx,
1710+
},
1711+
PrecommitToken: t.precommitToken,
1712+
RequestOptions: createRequestOptions(t.txOpts.CommitPriority, "", t.txOpts.TransactionTag),
1713+
ReturnCommitStats: options.ReturnCommitStats,
1714+
MaxCommitDelay: maxCommitDelay,
1715+
}
1716+
if includeMutations {
1717+
req.Mutations = mutationProtos
1718+
}
1719+
return client.Commit(contextWithOutgoingMetadata(ctx, t.sh.getMetadata(), t.disableRouteToLeader), req, gax.WithGRPCOptions(grpc.Header(&md)))
1720+
}
1721+
// Initial commit attempt with mutations
1722+
res, err := performCommit(true)
1723+
if err != nil {
1724+
return resp, t.txReadOnly.updateTxState(toSpannerErrorWithCommitInfo(err, true))
1725+
}
1726+
// Retry if MultiplexedSessionRetry is present, without mutations
1727+
if res.GetMultiplexedSessionRetry() != nil {
1728+
t.updatePrecommitToken(res.GetPrecommitToken())
1729+
res, err = performCommit(false)
1730+
}
17171731
if getGFELatencyMetricsFlag() && md != nil && t.ct != nil {
17181732
if err := createContextAndCaptureGFELatencyMetrics(ctx, t.ct, md, "commit"); err != nil {
17191733
trace.TracePrintf(ctx, nil, "Error in recording GFE Latency. Try disabling and rerunning. Error: %v", err)
@@ -1722,8 +1736,8 @@ func (t *ReadWriteTransaction) commit(ctx context.Context, options CommitOptions
17221736
if metricErr := recordGFELatencyMetricsOT(ctx, md, "commit", t.otConfig); metricErr != nil {
17231737
trace.TracePrintf(ctx, nil, "Error in recording GFE Latency through OpenTelemetry. Error: %v", metricErr)
17241738
}
1725-
if e != nil {
1726-
return resp, t.txReadOnly.updateTxState(toSpannerErrorWithCommitInfo(e, true))
1739+
if err != nil {
1740+
return resp, t.txReadOnly.updateTxState(toSpannerErrorWithCommitInfo(err, true))
17271741
}
17281742
if tstamp := res.GetCommitTimestamp(); tstamp != nil {
17291743
resp.CommitTs = time.Unix(tstamp.Seconds, int64(tstamp.Nanos))

spanner/transaction_test.go

Lines changed: 92 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -504,6 +504,98 @@ func TestReadWriteTransaction_PrecommitToken(t *testing.T) {
504504
}
505505
}
506506

507+
func TestCommitWithMultiplexedSessionRetry(t *testing.T) {
508+
ctx := context.Background()
509+
server, client, teardown := setupMockedTestServerWithConfig(t, ClientConfig{
510+
DisableNativeMetrics: true,
511+
SessionPoolConfig: SessionPoolConfig{
512+
MinOpened: 1,
513+
MaxOpened: 1,
514+
enableMultiplexSession: true,
515+
enableMultiplexedSessionForRW: true,
516+
},
517+
})
518+
defer teardown()
519+
520+
// newCommitResponseWithPrecommitToken creates a simulated response with a PrecommitToken
521+
newCommitResponseWithPrecommitToken := func() *sppb.CommitResponse {
522+
precommitToken := &sppb.MultiplexedSessionPrecommitToken{
523+
PrecommitToken: []byte("commit-retry-precommit-token"),
524+
SeqNum: 4,
525+
}
526+
527+
// Create a CommitResponse with the PrecommitToken
528+
return &sppb.CommitResponse{
529+
MultiplexedSessionRetry: &sppb.CommitResponse_PrecommitToken{PrecommitToken: precommitToken},
530+
}
531+
}
532+
533+
// Simulate a commit response with a MultiplexedSessionRetry
534+
server.TestSpanner.PutExecutionTime(MethodCommitTransaction,
535+
SimulatedExecutionTime{
536+
Responses: []interface{}{newCommitResponseWithPrecommitToken()},
537+
})
538+
539+
_, err := client.ReadWriteTransaction(ctx, func(ctx context.Context, tx *ReadWriteTransaction) error {
540+
ms := []*Mutation{
541+
Insert("t_foo", []string{"col1", "col2"}, []interface{}{int64(1), int64(2)}),
542+
Update("t_foo", []string{"col1", "col2"}, []interface{}{"one", []byte(nil)}),
543+
}
544+
if err := tx.BufferWrite(ms); err != nil {
545+
return err
546+
}
547+
548+
iter := tx.Query(ctx, NewStatement(SelectSingerIDAlbumIDAlbumTitleFromAlbums))
549+
defer iter.Stop()
550+
for {
551+
_, err := iter.Next()
552+
if err == iterator.Done {
553+
break
554+
}
555+
if err != nil {
556+
return err
557+
}
558+
}
559+
return nil
560+
})
561+
if err != nil {
562+
t.Fatalf("Commit failed: %v", err)
563+
}
564+
565+
// Verify that the commit was retried
566+
requests := drainRequestsFromServer(server.TestSpanner)
567+
commitCount := 0
568+
for _, req := range requests {
569+
if commitReq, ok := req.(*sppb.CommitRequest); ok {
570+
if !strings.Contains(commitReq.GetSession(), "multiplexed") {
571+
t.Errorf("Expected session to be multiplexed")
572+
}
573+
commitCount++
574+
if commitCount == 1 {
575+
// Validate that the first commit had mutations set
576+
if len(commitReq.Mutations) == 0 {
577+
t.Fatalf("Expected first commit to have mutations set")
578+
}
579+
if commitReq.PrecommitToken == nil || !strings.Contains(string(commitReq.PrecommitToken.PrecommitToken), "ResultSetPrecommitToken") {
580+
t.Fatalf("Expected first commit to have precommit token 'ResultSetPrecommitToken', got: %v", commitReq.PrecommitToken)
581+
}
582+
} else if commitCount == 2 {
583+
// Validate that the second commit attempt had mutations un-set
584+
if len(commitReq.Mutations) != 0 {
585+
t.Fatalf("Expected second commit to have no mutations set")
586+
}
587+
// Validate that the second commit had the precommit token set
588+
if commitReq.PrecommitToken == nil || string(commitReq.PrecommitToken.PrecommitToken) != "commit-retry-precommit-token" {
589+
t.Fatalf("Expected second commit to have precommit token 'commit-retry-precommit-token', got: %v", commitReq.PrecommitToken)
590+
}
591+
}
592+
}
593+
}
594+
if commitCount != 2 {
595+
t.Fatalf("Expected 2 commit attempts, got %d", commitCount)
596+
}
597+
}
598+
507599
func TestMutationOnlyCaseAborted(t *testing.T) {
508600
t.Parallel()
509601
ctx := context.Background()

0 commit comments

Comments
 (0)