Skip to content

Commit b7124db

Browse files
authored
fix(firestore): Correct ReadWrite transaction retries (googleapis#12893)
Fixes: - [b/443038452](b/443038452) - [b/291261375](b/291261375) - [b/248162068](b/248162068) Retry logic has been changed as per [go/transaction-retry-matrix](http://go/transaction-retry-matrix) and [go/transaction-retries](go/transaction-retries)
1 parent 5d2550c commit b7124db

File tree

2 files changed

+131
-50
lines changed

2 files changed

+131
-50
lines changed

firestore/transaction.go

Lines changed: 41 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -102,6 +102,21 @@ var (
102102

103103
type transactionInProgressKey struct{}
104104

105+
// isAborted checks if an error from a transaction operation
106+
// indicates that the entire transaction should be retried.
107+
func isAborted(err error) bool {
108+
s, ok := status.FromError(err)
109+
if !ok {
110+
return false // Not a gRPC error
111+
}
112+
switch s.Code() {
113+
case codes.Aborted:
114+
return true
115+
default:
116+
return false
117+
}
118+
}
119+
105120
// RunTransaction runs f in a transaction. f should use the transaction it is given
106121
// for all Firestore operations. For any operation requiring a context, f should use
107122
// the context it is passed, not the first argument to RunTransaction.
@@ -162,41 +177,42 @@ func (c *Client) RunTransaction(ctx context.Context, f func(context.Context, *Tr
162177
return err
163178
}
164179
t.id = res.Transaction
180+
165181
err = f(context.WithValue(ctx, transactionInProgressKey{}, 1), t)
166182
// Read after write can only be checked client-side, so we make sure to check
167183
// even if the user does not.
168184
if err == nil && t.readAfterWrite {
169185
err = errReadAfterWrite
170186
}
171-
if err != nil {
172-
t.rollback()
173-
// Prefer f's returned error to rollback error.
174-
return err
175-
}
176-
t.ctx = trace.StartSpan(t.ctx, "cloud.google.com/go/firestore.Client.Commit")
177-
commitResponse, err = t.c.c.Commit(t.ctx, &pb.CommitRequest{
178-
Database: t.c.path(),
179-
Writes: t.writes,
180-
Transaction: t.id,
181-
})
182-
trace.EndSpan(t.ctx, err)
183187

184-
// on success, handle the commit response
185188
if err == nil {
186-
for _, opt := range opts {
187-
opt.handleCommitResponse(commitResponse)
189+
t.ctx = trace.StartSpan(t.ctx, "cloud.google.com/go/firestore.Client.Commit")
190+
commitResponse, err = t.c.c.Commit(t.ctx, &pb.CommitRequest{
191+
Database: t.c.path(),
192+
Writes: t.writes,
193+
Transaction: t.id,
194+
})
195+
trace.EndSpan(t.ctx, err)
196+
197+
// on success, handle the commit response
198+
if err == nil {
199+
for _, opt := range opts {
200+
opt.handleCommitResponse(commitResponse)
201+
}
202+
return nil
188203
}
189204
}
190205

191-
// If a read-write transaction returns Aborted, retry.
192-
// On success or other failures, return here.
193-
if t.readOnly || status.Code(err) != codes.Aborted {
194-
// According to the Firestore team, we should not roll back here
195-
// if err != nil. But spanner does.
196-
// See https://code.googlesource.com/gocloud/+/master/spanner/transaction.go#740.
206+
// At this point, `err` is non-nil. It came from `f` or `Commit`.
207+
t.rollback()
208+
209+
// If not a retryable error, or if read-only, return now.
210+
// (We've already rolled back).
211+
if t.readOnly || !isAborted(err) {
197212
return err
198213
}
199214

215+
// It's a retryable error, so continue to backoff and retry logic.
200216
if txOpts == nil {
201217
// txOpts can only be nil if is the first retry of a read-write transaction.
202218
// (It is only set here and in the body of "if t.readOnly" above.)
@@ -207,6 +223,9 @@ func (c *Client) RunTransaction(ctx context.Context, f func(context.Context, *Tr
207223
ReadWrite: &pb.TransactionOptions_ReadWrite{RetryTransaction: t.id},
208224
},
209225
}
226+
} else if rw := txOpts.GetReadWrite(); rw != nil {
227+
// Update transaction ID for read-write retries.
228+
rw.RetryTransaction = t.id
210229
}
211230
// Use exponential backoff to avoid contention with other running
212231
// transactions.
@@ -218,11 +237,7 @@ func (c *Client) RunTransaction(ctx context.Context, f func(context.Context, *Tr
218237
// Reset state for the next attempt.
219238
t.writes = nil
220239
}
221-
// If we run out of retries, return the last error we saw (which should
222-
// be the Aborted from Commit, or a context error).
223-
if err != nil {
224-
t.rollback()
225-
}
240+
226241
return err
227242
}
228243

firestore/transaction_test.go

Lines changed: 90 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -140,18 +140,23 @@ func TestRunTransaction(t *testing.T) {
140140
srv.reset()
141141
srv.addRPC(beginReq, beginRes)
142142
srv.addRPC(commitReq, status.Errorf(codes.Aborted, ""))
143-
srv.addRPC(
144-
&pb.BeginTransactionRequest{
145-
Database: db,
146-
Options: &pb.TransactionOptions{
147-
Mode: &pb.TransactionOptions_ReadWrite_{
148-
ReadWrite: &pb.TransactionOptions_ReadWrite{RetryTransaction: tid},
149-
},
143+
rollbackReq := &pb.RollbackRequest{Database: db, Transaction: tid}
144+
srv.addRPC(rollbackReq, &emptypb.Empty{})
145+
tid2 := []byte{2} // Use a new transaction ID for the response to the retry BeginTransaction
146+
beginReqRetry := &pb.BeginTransactionRequest{
147+
Database: db,
148+
Options: &pb.TransactionOptions{
149+
Mode: &pb.TransactionOptions_ReadWrite_{
150+
ReadWrite: &pb.TransactionOptions_ReadWrite{RetryTransaction: tid}, // Retries with the previous transaction's ID
150151
},
151152
},
152-
beginRes,
153-
)
154-
srv.addRPC(commitReq, &pb.CommitResponse{CommitTime: aTimestamp})
153+
}
154+
beginRes2 := &pb.BeginTransactionResponse{Transaction: tid2} // New transaction ID for the successful attempt
155+
srv.addRPC(beginReqRetry, beginRes2)
156+
157+
// Attempt 2: Commit succeeds with the new transaction ID (tid2).
158+
commitReq2 := &pb.CommitRequest{Database: db, Transaction: tid2}
159+
srv.addRPC(commitReq2, &pb.CommitResponse{CommitTime: aTimestamp})
155160
err = c.RunTransaction(ctx, func(_ context.Context, tx *Transaction) error { return nil })
156161
if err != nil {
157162
t.Fatal(err)
@@ -167,7 +172,17 @@ func TestTransactionErrors(t *testing.T) {
167172
var (
168173
tid = []byte{1}
169174
unknownErr = status.Errorf(codes.Unknown, "so sad")
170-
beginReq = &pb.BeginTransactionRequest{
175+
abortedErr = status.Errorf(codes.Aborted, "not so sad.retryable")
176+
177+
beginRetryReq = &pb.BeginTransactionRequest{
178+
Database: db,
179+
Options: &pb.TransactionOptions{
180+
Mode: &pb.TransactionOptions_ReadWrite_{
181+
ReadWrite: &pb.TransactionOptions_ReadWrite{RetryTransaction: tid},
182+
},
183+
},
184+
}
185+
beginReq = &pb.BeginTransactionRequest{
171186
Database: db,
172187
}
173188
beginRes = &pb.BeginTransactionResponse{Transaction: tid}
@@ -180,8 +195,19 @@ func TestTransactionErrors(t *testing.T) {
180195
Documents: []string{db + "/documents/C/a"},
181196
ConsistencySelector: &pb.BatchGetDocumentsRequest_Transaction{Transaction: tid},
182197
}
198+
getRes = []interface{}{
199+
&pb.BatchGetDocumentsResponse{
200+
Result: &pb.BatchGetDocumentsResponse_Found{Found: &pb.Document{
201+
Name: "projects/projectID/databases/(default)/documents/C/a",
202+
CreateTime: aTimestamp,
203+
UpdateTime: aTimestamp2,
204+
}},
205+
ReadTime: aTimestamp2,
206+
},
207+
}
183208
rollbackReq = &pb.RollbackRequest{Database: db, Transaction: tid}
184209
commitReq = &pb.CommitRequest{Database: db, Transaction: tid}
210+
commitRes = &pb.CommitResponse{CommitTime: aTimestamp}
185211
)
186212

187213
t.Run("BeginTransaction has a permanent error", func(t *testing.T) {
@@ -239,6 +265,8 @@ func TestTransactionErrors(t *testing.T) {
239265
},
240266
})
241267
srv.addRPC(commitReq, unknownErr)
268+
srv.addRPC(rollbackReq, &emptypb.Empty{})
269+
242270
err := c.RunTransaction(ctx, get)
243271
if status.Code(err) != codes.Unknown {
244272
t.Errorf("got <%v>, want Unknown", err)
@@ -356,22 +384,38 @@ func TestTransactionErrors(t *testing.T) {
356384
})
357385

358386
t.Run("Too many retries", func(t *testing.T) {
387+
// Use tid = 1 for the first attempt.
388+
// Use tid = 2 for the second attempt.
389+
tid1 := []byte{1}
390+
tid2 := []byte{2}
391+
beginRes2 := &pb.BeginTransactionResponse{Transaction: tid2}
392+
359393
srv.reset()
360-
srv.addRPC(beginReq, beginRes)
361-
srv.addRPC(commitReq, status.Errorf(codes.Aborted, ""))
362-
srv.addRPC(
363-
&pb.BeginTransactionRequest{
364-
Database: db,
365-
Options: &pb.TransactionOptions{
366-
Mode: &pb.TransactionOptions_ReadWrite_{
367-
ReadWrite: &pb.TransactionOptions_ReadWrite{RetryTransaction: tid},
368-
},
394+
395+
// Attempt 1 (Fails)
396+
srv.addRPC(beginReq, beginRes) // 1. BeginTransaction (tid1)
397+
srv.addRPC(commitReq, status.Errorf(codes.Aborted, "")) // 2. Commit (tid1) fails (Aborted)
398+
srv.addRPC(rollbackReq, &emptypb.Empty{}) // 3. Rollback (tid1)
399+
400+
// Attempt 2 (Fails)
401+
beginReqRetry := &pb.BeginTransactionRequest{
402+
Database: db,
403+
Options: &pb.TransactionOptions{
404+
Mode: &pb.TransactionOptions_ReadWrite_{
405+
ReadWrite: &pb.TransactionOptions_ReadWrite{RetryTransaction: tid1}, // Retries with previous ID (tid1)
369406
},
370407
},
371-
beginRes,
372-
)
373-
srv.addRPC(commitReq, status.Errorf(codes.Aborted, ""))
374-
srv.addRPC(rollbackReq, &emptypb.Empty{})
408+
}
409+
// The retry BeginTransaction should return a new ID (tid2)
410+
srv.addRPC(beginReqRetry, beginRes2) // 4. BeginTransaction (tid2)
411+
412+
commitReq2 := &pb.CommitRequest{Database: db, Transaction: tid2} // New commit request with tid2
413+
srv.addRPC(commitReq2, status.Errorf(codes.Aborted, "")) // 5. Commit (tid2) fails (Aborted)
414+
415+
// Final Rollback on Aborted error when MaxAttempts is reached
416+
rollbackReq2 := &pb.RollbackRequest{Database: db, Transaction: tid2}
417+
srv.addRPC(rollbackReq2, &emptypb.Empty{}) // 6. Rollback (tid2)
418+
375419
err := c.RunTransaction(ctx, func(context.Context, *Transaction) error { return nil },
376420
MaxAttempts(2))
377421
if status.Code(err) != codes.Aborted {
@@ -397,6 +441,22 @@ func TestTransactionErrors(t *testing.T) {
397441
}
398442
})
399443

444+
t.Run("Get has a retryable error", func(t *testing.T) {
445+
srv.reset()
446+
srv.addRPC(beginReq, beginRes)
447+
srv.addRPC(getReq, abortedErr)
448+
srv.addRPC(rollbackReq, &emptypb.Empty{})
449+
srv.addRPC(beginRetryReq, beginRes)
450+
srv.addRPC(getReq, getRes)
451+
srv.addRPC(commitReq, commitRes)
452+
err := c.RunTransaction(ctx, get)
453+
if err != nil {
454+
t.Errorf("got <%v>, want nil", err)
455+
}
456+
if !srv.isEmpty() {
457+
t.Errorf("Expected %+v requests but not received. srv.reqItems: %+v", len(srv.reqItems), srv.reqItems)
458+
}
459+
})
400460
}
401461

402462
func TestTransactionGetAll(t *testing.T) {
@@ -439,6 +499,7 @@ func TestRunTransaction_Retries(t *testing.T) {
439499
const db = "projects/projectID/databases/(default)"
440500
tid := []byte{1}
441501

502+
// Attempt 1: Begin
442503
srv.addRPC(
443504
&pb.BeginTransactionRequest{Database: db},
444505
&pb.BeginTransactionResponse{Transaction: tid},
@@ -455,6 +516,7 @@ func TestRunTransaction_Retries(t *testing.T) {
455516
Fields: map[string]*pb.Value{"count": intval(7)},
456517
}
457518

519+
// Attempt 1: Commit (Fails)
458520
srv.addRPC(
459521
&pb.CommitRequest{
460522
Database: db,
@@ -470,6 +532,10 @@ func TestRunTransaction_Retries(t *testing.T) {
470532
status.Errorf(codes.Aborted, "something failed! please retry me!"),
471533
)
472534

535+
rollbackReq := &pb.RollbackRequest{Database: db, Transaction: tid}
536+
srv.addRPC(rollbackReq, &emptypb.Empty{})
537+
538+
// Attempt 2: Begin (Retry)
473539
srv.addRPC(
474540
&pb.BeginTransactionRequest{
475541
Database: db,

0 commit comments

Comments
 (0)