@@ -140,18 +140,23 @@ func TestRunTransaction(t *testing.T) {
140140 srv .reset ()
141141 srv .addRPC (beginReq , beginRes )
142142 srv .addRPC (commitReq , status .Errorf (codes .Aborted , "" ))
143- srv .addRPC (
144- & pb.BeginTransactionRequest {
145- Database : db ,
146- Options : & pb.TransactionOptions {
147- Mode : & pb.TransactionOptions_ReadWrite_ {
148- ReadWrite : & pb.TransactionOptions_ReadWrite {RetryTransaction : tid },
149- },
143+ rollbackReq := & pb.RollbackRequest {Database : db , Transaction : tid }
144+ srv .addRPC (rollbackReq , & emptypb.Empty {})
145+ tid2 := []byte {2 } // Use a new transaction ID for the response to the retry BeginTransaction
146+ beginReqRetry := & pb.BeginTransactionRequest {
147+ Database : db ,
148+ Options : & pb.TransactionOptions {
149+ Mode : & pb.TransactionOptions_ReadWrite_ {
150+ ReadWrite : & pb.TransactionOptions_ReadWrite {RetryTransaction : tid }, // Retries with the previous transaction's ID
150151 },
151152 },
152- beginRes ,
153- )
154- srv .addRPC (commitReq , & pb.CommitResponse {CommitTime : aTimestamp })
153+ }
154+ beginRes2 := & pb.BeginTransactionResponse {Transaction : tid2 } // New transaction ID for the successful attempt
155+ srv .addRPC (beginReqRetry , beginRes2 )
156+
157+ // Attempt 2: Commit succeeds with the new transaction ID (tid2).
158+ commitReq2 := & pb.CommitRequest {Database : db , Transaction : tid2 }
159+ srv .addRPC (commitReq2 , & pb.CommitResponse {CommitTime : aTimestamp })
155160 err = c .RunTransaction (ctx , func (_ context.Context , tx * Transaction ) error { return nil })
156161 if err != nil {
157162 t .Fatal (err )
@@ -167,7 +172,17 @@ func TestTransactionErrors(t *testing.T) {
167172 var (
168173 tid = []byte {1 }
169174 unknownErr = status .Errorf (codes .Unknown , "so sad" )
170- beginReq = & pb.BeginTransactionRequest {
175+ abortedErr = status .Errorf (codes .Aborted , "not so sad.retryable" )
176+
177+ beginRetryReq = & pb.BeginTransactionRequest {
178+ Database : db ,
179+ Options : & pb.TransactionOptions {
180+ Mode : & pb.TransactionOptions_ReadWrite_ {
181+ ReadWrite : & pb.TransactionOptions_ReadWrite {RetryTransaction : tid },
182+ },
183+ },
184+ }
185+ beginReq = & pb.BeginTransactionRequest {
171186 Database : db ,
172187 }
173188 beginRes = & pb.BeginTransactionResponse {Transaction : tid }
@@ -180,8 +195,19 @@ func TestTransactionErrors(t *testing.T) {
180195 Documents : []string {db + "/documents/C/a" },
181196 ConsistencySelector : & pb.BatchGetDocumentsRequest_Transaction {Transaction : tid },
182197 }
198+ getRes = []interface {}{
199+ & pb.BatchGetDocumentsResponse {
200+ Result : & pb.BatchGetDocumentsResponse_Found {Found : & pb.Document {
201+ Name : "projects/projectID/databases/(default)/documents/C/a" ,
202+ CreateTime : aTimestamp ,
203+ UpdateTime : aTimestamp2 ,
204+ }},
205+ ReadTime : aTimestamp2 ,
206+ },
207+ }
183208 rollbackReq = & pb.RollbackRequest {Database : db , Transaction : tid }
184209 commitReq = & pb.CommitRequest {Database : db , Transaction : tid }
210+ commitRes = & pb.CommitResponse {CommitTime : aTimestamp }
185211 )
186212
187213 t .Run ("BeginTransaction has a permanent error" , func (t * testing.T ) {
@@ -239,6 +265,8 @@ func TestTransactionErrors(t *testing.T) {
239265 },
240266 })
241267 srv .addRPC (commitReq , unknownErr )
268+ srv .addRPC (rollbackReq , & emptypb.Empty {})
269+
242270 err := c .RunTransaction (ctx , get )
243271 if status .Code (err ) != codes .Unknown {
244272 t .Errorf ("got <%v>, want Unknown" , err )
@@ -356,22 +384,38 @@ func TestTransactionErrors(t *testing.T) {
356384 })
357385
358386 t .Run ("Too many retries" , func (t * testing.T ) {
387+ // Use tid = 1 for the first attempt.
388+ // Use tid = 2 for the second attempt.
389+ tid1 := []byte {1 }
390+ tid2 := []byte {2 }
391+ beginRes2 := & pb.BeginTransactionResponse {Transaction : tid2 }
392+
359393 srv .reset ()
360- srv .addRPC (beginReq , beginRes )
361- srv .addRPC (commitReq , status .Errorf (codes .Aborted , "" ))
362- srv .addRPC (
363- & pb.BeginTransactionRequest {
364- Database : db ,
365- Options : & pb.TransactionOptions {
366- Mode : & pb.TransactionOptions_ReadWrite_ {
367- ReadWrite : & pb.TransactionOptions_ReadWrite {RetryTransaction : tid },
368- },
394+
395+ // Attempt 1 (Fails)
396+ srv .addRPC (beginReq , beginRes ) // 1. BeginTransaction (tid1)
397+ srv .addRPC (commitReq , status .Errorf (codes .Aborted , "" )) // 2. Commit (tid1) fails (Aborted)
398+ srv .addRPC (rollbackReq , & emptypb.Empty {}) // 3. Rollback (tid1)
399+
400+ // Attempt 2 (Fails)
401+ beginReqRetry := & pb.BeginTransactionRequest {
402+ Database : db ,
403+ Options : & pb.TransactionOptions {
404+ Mode : & pb.TransactionOptions_ReadWrite_ {
405+ ReadWrite : & pb.TransactionOptions_ReadWrite {RetryTransaction : tid1 }, // Retries with previous ID (tid1)
369406 },
370407 },
371- beginRes ,
372- )
373- srv .addRPC (commitReq , status .Errorf (codes .Aborted , "" ))
374- srv .addRPC (rollbackReq , & emptypb.Empty {})
408+ }
409+ // The retry BeginTransaction should return a new ID (tid2)
410+ srv .addRPC (beginReqRetry , beginRes2 ) // 4. BeginTransaction (tid2)
411+
412+ commitReq2 := & pb.CommitRequest {Database : db , Transaction : tid2 } // New commit request with tid2
413+ srv .addRPC (commitReq2 , status .Errorf (codes .Aborted , "" )) // 5. Commit (tid2) fails (Aborted)
414+
415+ // Final Rollback on Aborted error when MaxAttempts is reached
416+ rollbackReq2 := & pb.RollbackRequest {Database : db , Transaction : tid2 }
417+ srv .addRPC (rollbackReq2 , & emptypb.Empty {}) // 6. Rollback (tid2)
418+
375419 err := c .RunTransaction (ctx , func (context.Context , * Transaction ) error { return nil },
376420 MaxAttempts (2 ))
377421 if status .Code (err ) != codes .Aborted {
@@ -397,6 +441,22 @@ func TestTransactionErrors(t *testing.T) {
397441 }
398442 })
399443
444+ t .Run ("Get has a retryable error" , func (t * testing.T ) {
445+ srv .reset ()
446+ srv .addRPC (beginReq , beginRes )
447+ srv .addRPC (getReq , abortedErr )
448+ srv .addRPC (rollbackReq , & emptypb.Empty {})
449+ srv .addRPC (beginRetryReq , beginRes )
450+ srv .addRPC (getReq , getRes )
451+ srv .addRPC (commitReq , commitRes )
452+ err := c .RunTransaction (ctx , get )
453+ if err != nil {
454+ t .Errorf ("got <%v>, want nil" , err )
455+ }
456+ if ! srv .isEmpty () {
457+ t .Errorf ("Expected %+v requests but not received. srv.reqItems: %+v" , len (srv .reqItems ), srv .reqItems )
458+ }
459+ })
400460}
401461
402462func TestTransactionGetAll (t * testing.T ) {
@@ -439,6 +499,7 @@ func TestRunTransaction_Retries(t *testing.T) {
439499 const db = "projects/projectID/databases/(default)"
440500 tid := []byte {1 }
441501
502+ // Attempt 1: Begin
442503 srv .addRPC (
443504 & pb.BeginTransactionRequest {Database : db },
444505 & pb.BeginTransactionResponse {Transaction : tid },
@@ -455,6 +516,7 @@ func TestRunTransaction_Retries(t *testing.T) {
455516 Fields : map [string ]* pb.Value {"count" : intval (7 )},
456517 }
457518
519+ // Attempt 1: Commit (Fails)
458520 srv .addRPC (
459521 & pb.CommitRequest {
460522 Database : db ,
@@ -470,6 +532,10 @@ func TestRunTransaction_Retries(t *testing.T) {
470532 status .Errorf (codes .Aborted , "something failed! please retry me!" ),
471533 )
472534
535+ rollbackReq := & pb.RollbackRequest {Database : db , Transaction : tid }
536+ srv .addRPC (rollbackReq , & emptypb.Empty {})
537+
538+ // Attempt 2: Begin (Retry)
473539 srv .addRPC (
474540 & pb.BeginTransactionRequest {
475541 Database : db ,
0 commit comments