Skip to content

Commit 73867cc

Browse files
authored
fix(auth): fix race condition in cachedTokenProvider.tokenAsync (googleapis#12586)
1 parent f464d65 commit 73867cc

File tree

2 files changed

+125
-3
lines changed

2 files changed

+125
-3
lines changed

auth/auth.go

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -362,9 +362,6 @@ func (c *cachedTokenProvider) tokenState() tokenState {
362362
// blocking call to Token should likely return the same error on the main goroutine.
363363
func (c *cachedTokenProvider) tokenAsync(ctx context.Context) {
364364
fn := func() {
365-
c.mu.Lock()
366-
c.isRefreshRunning = true
367-
c.mu.Unlock()
368365
t, err := c.tp.Token(ctx)
369366
c.mu.Lock()
370367
defer c.mu.Unlock()
@@ -380,6 +377,7 @@ func (c *cachedTokenProvider) tokenAsync(ctx context.Context) {
380377
c.mu.Lock()
381378
defer c.mu.Unlock()
382379
if !c.isRefreshRunning && !c.isRefreshErr {
380+
c.isRefreshRunning = true
383381
go fn()
384382
}
385383
}

auth/auth_token_async_test.go

Lines changed: 124 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,124 @@
1+
// Copyright 2025 Google LLC
2+
//
3+
// Licensed under the Apache License, Version 2.0 (the "License");
4+
// you may not use this file except in compliance with the License.
5+
// You may obtain a copy of the License at
6+
//
7+
// http://www.apache.org/licenses/LICENSE-2.0
8+
//
9+
// Unless required by applicable law or agreed to in writing, software
10+
// distributed under the License is distributed on an "AS IS" BASIS,
11+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
// See the License for the specific language governing permissions and
13+
// limitations under the License.
14+
15+
package auth
16+
17+
import (
18+
"context"
19+
"fmt"
20+
"runtime"
21+
"sync"
22+
"testing"
23+
"time"
24+
)
25+
26+
type controllableTokenProvider struct {
27+
mu sync.Mutex
28+
count int
29+
tok *Token
30+
err error
31+
block chan struct{}
32+
}
33+
34+
func (p *controllableTokenProvider) Token(ctx context.Context) (*Token, error) {
35+
if ch := p.getBlockChan(); ch != nil {
36+
<-ch
37+
}
38+
p.mu.Lock()
39+
defer p.mu.Unlock()
40+
p.count++
41+
return p.tok, p.err
42+
}
43+
44+
func (p *controllableTokenProvider) getBlockChan() chan struct{} {
45+
p.mu.Lock()
46+
defer p.mu.Unlock()
47+
return p.block
48+
}
49+
50+
func (p *controllableTokenProvider) setBlockChan(ch chan struct{}) {
51+
p.mu.Lock()
52+
defer p.mu.Unlock()
53+
p.block = ch
54+
}
55+
56+
func (p *controllableTokenProvider) getCount() int {
57+
p.mu.Lock()
58+
defer p.mu.Unlock()
59+
return p.count
60+
}
61+
62+
func TestCachedTokenProvider_TokenAsyncRace(t *testing.T) {
63+
for i := 0; i < 10; i++ {
64+
t.Run(fmt.Sprintf("attempt-%d", i), func(t *testing.T) {
65+
now := time.Now()
66+
timeNow = func() time.Time { return now }
67+
defer func() { timeNow = time.Now }()
68+
69+
tp := &controllableTokenProvider{}
70+
ctp := NewCachedTokenProvider(tp, &CachedTokenProviderOptions{
71+
ExpireEarly: 2 * time.Second,
72+
}).(*cachedTokenProvider)
73+
74+
// 1. Cache a stale token.
75+
tp.tok = &Token{Value: "initial", Expiry: now.Add(1 * time.Second)}
76+
if _, err := ctp.Token(context.Background()); err != nil {
77+
t.Fatalf("initial Token() failed: %v", err)
78+
}
79+
if got, want := tp.getCount(), 1; got != want {
80+
t.Fatalf("tp.count = %d; want %d", got, want)
81+
}
82+
if got, want := ctp.tokenState(), stale; got != want {
83+
t.Fatalf("tokenState = %v; want %v", got, want)
84+
}
85+
86+
// 2. Setup for refresh.
87+
tp.setBlockChan(make(chan struct{}))
88+
tp.tok = &Token{Value: "refreshed", Expiry: now.Add(1 * time.Hour)}
89+
90+
// 3. Concurrently call Token to trigger async refresh.
91+
var wg sync.WaitGroup
92+
numGoroutines := 20 * (i + 1)
93+
wg.Add(numGoroutines)
94+
for i := 0; i < numGoroutines; i++ {
95+
go func() {
96+
defer wg.Done()
97+
runtime.Gosched()
98+
ctp.Token(context.Background())
99+
}()
100+
}
101+
102+
// 4. Unblock refresh and wait for all goroutines to finish.
103+
time.Sleep(100 * time.Millisecond) // give time for goroutines to run
104+
close(tp.getBlockChan())
105+
wg.Wait()
106+
time.Sleep(100 * time.Millisecond) // give time for async refresh to complete
107+
108+
// 5. Check results.
109+
if got, want := tp.getCount(), 2; got != want {
110+
t.Errorf("tp.count = %d; want %d. This indicates a race condition where multiple refreshes were triggered.", got, want)
111+
}
112+
if got, want := ctp.tokenState(), fresh; got != want {
113+
t.Errorf("tokenState = %v; want %v", got, want)
114+
}
115+
tok, err := ctp.Token(context.Background())
116+
if err != nil {
117+
t.Fatal(err)
118+
}
119+
if got, want := tok.Value, "refreshed"; got != want {
120+
t.Errorf("tok.Value = %q; want %q", got, want)
121+
}
122+
})
123+
}
124+
}

0 commit comments

Comments
 (0)