Skip to content

Commit 011b148

Browse files
authored
fix(datastore): resolve data race on transaction state (googleapis#12912)
Refactors the state locking for BeginLater transactions. This change: - Removes the racy stateLockDeferUnlock function. - Ensures all reads and writes to t.state and t.id are protected by t.stateLock. - Fixes a related logic bug where a transaction's ID was ignored if the first read operation (like Get or Run) returned an ErrNoSuchEntity. Fixes: googleapis#11038
1 parent a18bc21 commit 011b148

File tree

2 files changed

+36
-57
lines changed

2 files changed

+36
-57
lines changed

datastore/query.go

Lines changed: 14 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -945,9 +945,6 @@ func (c *Client) RunAggregationQueryWithOptions(ctx context.Context, aq *Aggrega
945945

946946
// Parse the read options.
947947
txn := aq.query.trans
948-
if txn != nil {
949-
defer txn.stateLockDeferUnlock()()
950-
}
951948

952949
req.ReadOptions, err = parseQueryReadOptions(aq.query.eventual, txn, c.readSettings)
953950
if err != nil {
@@ -959,8 +956,13 @@ func (c *Client) RunAggregationQueryWithOptions(ctx context.Context, aq *Aggrega
959956
return ar, err
960957
}
961958

962-
if txn != nil && txn.state == transactionStateNotStarted {
963-
txn.setToInProgress(resp.Transaction)
959+
if txn != nil && resp.Transaction != nil {
960+
txn.stateLock.Lock()
961+
if txn.state == transactionStateNotStarted {
962+
txn.id = resp.Transaction
963+
txn.state = transactionStateInProgress
964+
}
965+
txn.stateLock.Unlock()
964966
}
965967

966968
if req.ExplainOptions == nil || req.ExplainOptions.Analyze {
@@ -1128,9 +1130,6 @@ func (t *Iterator) nextBatch() error {
11281130
}
11291131

11301132
txn := t.trans
1131-
if txn != nil {
1132-
defer txn.stateLockDeferUnlock()()
1133-
}
11341133

11351134
var err error
11361135
t.req.ReadOptions, err = parseQueryReadOptions(t.eventual, txn, t.client.readSettings)
@@ -1144,8 +1143,13 @@ func (t *Iterator) nextBatch() error {
11441143
return err
11451144
}
11461145

1147-
if txn != nil && txn.state == transactionStateNotStarted {
1148-
txn.setToInProgress(resp.Transaction)
1146+
if txn != nil && resp.Transaction != nil {
1147+
txn.stateLock.Lock()
1148+
if txn.state == transactionStateNotStarted {
1149+
txn.id = resp.Transaction
1150+
txn.state = transactionStateInProgress
1151+
}
1152+
txn.stateLock.Unlock()
11491153
}
11501154

11511155
if t.req.ExplainOptions != nil && !t.req.ExplainOptions.Analyze {

datastore/transaction.go

Lines changed: 22 additions & 47 deletions
Original file line numberDiff line numberDiff line change
@@ -229,57 +229,22 @@ func (t *Transaction) beginTransaction() (txnID []byte, err error) {
229229

230230
// beginLaterTransaction makes BeginTransaction rpc if transaction has not yet started
231231
func (t *Transaction) beginLaterTransaction() (err error) {
232-
if t.state != transactionStateNotStarted {
233-
return nil
234-
}
235-
236-
// Obtain state lock since the state needs to be updated
237-
// after transaction has started
238232
t.stateLock.Lock()
239-
defer t.stateLock.Unlock()
240233
if t.state != transactionStateNotStarted {
234+
t.stateLock.Unlock()
241235
return nil
242236
}
243237

244238
txnID, err := t.beginTransaction()
245239
if err != nil {
246-
return err
247-
}
248-
249-
t.setToInProgress(txnID)
250-
return nil
251-
}
252-
253-
// Acquires state lock if transaction has not started. No-op otherwise
254-
// The returned function unlocks the state if it was locked.
255-
//
256-
// Usage:
257-
//
258-
// func (t *Transaction) someFunction() {
259-
// ...
260-
// if t != nil {
261-
// defer t.stateLockDeferUnlock()()
262-
// }
263-
// ....
264-
// }
265-
//
266-
// This ensures that state is locked before any of the following lines are exexcuted
267-
// The lock will be released after 'someFunction' ends
268-
func (t *Transaction) stateLockDeferUnlock() func() {
269-
if t.state == transactionStateNotStarted {
270-
t.stateLock.Lock()
271-
// Check whether state changed while waiting to acquire lock
272-
if t.state == transactionStateNotStarted {
273-
return func() { t.stateLock.Unlock() }
274-
}
275240
t.stateLock.Unlock()
241+
return err
276242
}
277-
return func() {}
278-
}
279243

280-
func (t *Transaction) setToInProgress(id []byte) {
281-
t.id = id
244+
t.id = txnID
282245
t.state = transactionStateInProgress
246+
t.stateLock.Unlock()
247+
return nil
283248
}
284249

285250
// backoffBeforeRetry returns:
@@ -332,7 +297,8 @@ func (c *Client) newTransaction(ctx context.Context, s *transactionSettings) (_
332297
if err != nil {
333298
return nil, err
334299
}
335-
t.setToInProgress(txnID)
300+
t.id = txnID
301+
t.state = transactionStateInProgress
336302
}
337303

338304
return t, nil
@@ -449,9 +415,12 @@ func (t *Transaction) Commit() (c *Commit, err error) {
449415
t.ctx = trace.StartSpan(t.ctx, "cloud.google.com/go/datastore.Transaction.Commit")
450416
defer func() { trace.EndSpan(t.ctx, err) }()
451417

418+
t.stateLock.Lock()
452419
if t.state == transactionStateExpired {
420+
t.stateLock.Unlock()
453421
return nil, errExpiredTransaction
454422
}
423+
t.stateLock.Unlock()
455424

456425
err = t.beginLaterTransaction()
457426
if err != nil {
@@ -473,7 +442,9 @@ func (t *Transaction) Commit() (c *Commit, err error) {
473442
return nil, err
474443
}
475444

445+
t.stateLock.Lock()
476446
t.state = transactionStateExpired
447+
t.stateLock.Unlock()
477448

478449
c = &Commit{}
479450
// Copy any newly minted keys into the returned keys.
@@ -548,6 +519,9 @@ func (t *Transaction) parseReadOptions() (*pb.ReadOptions, error) {
548519
return nil, errTxnClientReadTime
549520
}
550521

522+
t.stateLock.Lock()
523+
defer t.stateLock.Unlock()
524+
551525
var opts *pb.ReadOptions
552526
switch t.state {
553527
case transactionStateExpired:
@@ -571,19 +545,20 @@ func (t *Transaction) get(spanName string, keys []*Key, dst interface{}) (err er
571545
t.ctx = trace.StartSpan(t.ctx, spanName)
572546
defer func() { trace.EndSpan(t.ctx, err) }()
573547

574-
if t != nil {
575-
defer t.stateLockDeferUnlock()()
576-
}
577-
578548
opts, err := t.parseReadOptions()
579549
if err != nil {
580550
return err
581551
}
582552

583553
txnID, err := t.client.get(t.ctx, keys, dst, opts)
584554

585-
if txnID != nil && err == nil {
586-
t.setToInProgress(txnID)
555+
if txnID != nil {
556+
t.stateLock.Lock()
557+
if t.state == transactionStateNotStarted {
558+
t.id = txnID
559+
t.state = transactionStateInProgress
560+
}
561+
t.stateLock.Unlock()
587562
}
588563
return t.client.processFieldMismatchError(err)
589564
}

0 commit comments

Comments
 (0)