Skip to content

Commit cdefe0c

Browse files
committed
fix(auth): fix race condition in cachedTokenProvider.tokenAsync
1 parent 73867cc commit cdefe0c

File tree

1 file changed

+151
-56
lines changed

1 file changed

+151
-56
lines changed

auth/auth_test.go

Lines changed: 151 additions & 56 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@ import (
2222
"net/http"
2323
"net/http/httptest"
2424
"strings"
25+
"sync"
2526
"testing"
2627
"time"
2728

@@ -105,9 +106,9 @@ func TestError_Temporary(t *testing.T) {
105106

106107
func TestToken_MetadataString(t *testing.T) {
107108
cases := []struct {
108-
name string
109+
name string
109110
metadata map[string]interface{}
110-
want string
111+
want string
111112
}{
112113
{
113114
name: "nil metadata",
@@ -142,10 +143,10 @@ func TestToken_isValidWithEarlyExpiry(t *testing.T) {
142143
defer func() { timeNow = time.Now }()
143144

144145
cases := []struct {
145-
name string
146-
tok *Token
146+
name string
147+
tok *Token
147148
expiry time.Duration
148-
want bool
149+
want bool
149150
}{
150151
{name: "4 minutes", tok: &Token{Expiry: now.Add(4 * 60 * time.Second)}, expiry: defaultExpiryDelta, want: true},
151152
{name: "3 minutes and 45 seconds", tok: &Token{Expiry: now.Add(defaultExpiryDelta)}, expiry: defaultExpiryDelta, want: true},
@@ -169,12 +170,12 @@ func TestError_Error(t *testing.T) {
169170
tests := []struct {
170171
name string
171172

172-
Response *http.Response
173-
Body []byte
174-
Err error
175-
code string
173+
Response *http.Response
174+
Body []byte
175+
Err error
176+
code string
176177
description string
177-
uri string
178+
uri string
178179

179180
want string
180181
}{
@@ -187,22 +188,22 @@ func TestError_Error(t *testing.T) {
187188
want: "auth: cannot fetch token: 418\nResponse: I'm a teapot",
188189
},
189190
{
190-
name: "from query",
191-
code: fmt.Sprint(http.StatusTeapot),
191+
name: "from query",
192+
code: fmt.Sprint(http.StatusTeapot),
192193
description: "I'm a teapot",
193-
uri: "somewhere",
194-
want: "auth: \"418\" \"I'm a teapot\" \"somewhere\"",
194+
uri: "somewhere",
195+
want: "auth: \"418\" \"I'm a teapot\" \"somewhere\"",
195196
},
196197
}
197198
for _, tt := range tests {
198199
t.Run(tt.name, func(t *testing.T) {
199200
r := &Error{
200-
Response: tt.Response,
201-
Body: tt.Body,
202-
Err: tt.Err,
203-
code: tt.code,
201+
Response: tt.Response,
202+
Body: tt.Body,
203+
Err: tt.Err,
204+
code: tt.code,
204205
description: tt.description,
205-
uri: tt.uri,
206+
uri: tt.uri,
206207
}
207208
if got := r.Error(); got != tt.want {
208209
t.Errorf("Error.Error() = %v, want %v", got, tt.want)
@@ -224,9 +225,9 @@ func TestNew2LOTokenProvider_JSONResponse(t *testing.T) {
224225
defer ts.Close()
225226

226227
opts := &Options2LO{
227-
Email: "aaa@example.com",
228+
Email: "aaa@example.com",
228229
PrivateKey: fakePrivateKey,
229-
TokenURL: ts.URL,
230+
TokenURL: ts.URL,
230231
}
231232
tp, err := New2LOTokenProvider(opts)
232233
if err != nil {
@@ -262,9 +263,9 @@ func TestNew2LOTokenProvider_BadResponse(t *testing.T) {
262263
defer ts.Close()
263264

264265
opts := &Options2LO{
265-
Email: "aaa@example.com",
266+
Email: "aaa@example.com",
266267
PrivateKey: fakePrivateKey,
267-
TokenURL: ts.URL,
268+
TokenURL: ts.URL,
268269
}
269270
tp, err := New2LOTokenProvider(opts)
270271
if err != nil {
@@ -299,9 +300,9 @@ func TestNew2LOTokenProvider_BadResponseType(t *testing.T) {
299300
}))
300301
defer ts.Close()
301302
opts := &Options2LO{
302-
Email: "aaa@example.com",
303+
Email: "aaa@example.com",
303304
PrivateKey: fakePrivateKey,
304-
TokenURL: ts.URL,
305+
TokenURL: ts.URL,
305306
}
306307
tp, err := New2LOTokenProvider(opts)
307308
if err != nil {
@@ -333,10 +334,10 @@ func TestNew2LOTokenProvider_Assertion(t *testing.T) {
333334
defer ts.Close()
334335

335336
opts := &Options2LO{
336-
Email: "aaa@example.com",
337-
PrivateKey: fakePrivateKey,
337+
Email: "aaa@example.com",
338+
PrivateKey: fakePrivateKey,
338339
PrivateKeyID: "ABCDEFGHIJKLMNOPQRSTUVWXYZ",
339-
TokenURL: ts.URL,
340+
TokenURL: ts.URL,
340341
}
341342

342343
tp, err := New2LOTokenProvider(opts)
@@ -364,8 +365,8 @@ func TestNew2LOTokenProvider_Assertion(t *testing.T) {
364365

365366
want := jwt.Header{
366367
Algorithm: "RS256",
367-
Type: "JWT",
368-
KeyID: "ABCDEFGHIJKLMNOPQRSTUVWXYZ",
368+
Type: "JWT",
369+
KeyID: "ABCDEFGHIJKLMNOPQRSTUVWXYZ",
369370
}
370371
if got != want {
371372
t.Errorf("access token header = %q; want %q", got, want)
@@ -390,23 +391,23 @@ func TestNew2LOTokenProvider_AssertionPayload(t *testing.T) {
390391

391392
for _, opts := range []*Options2LO{
392393
{
393-
Email: "aaa1@example.com",
394-
PrivateKey: fakePrivateKey,
394+
Email: "aaa1@example.com",
395+
PrivateKey: fakePrivateKey,
395396
PrivateKeyID: "ABCDEFGHIJKLMNOPQRSTUVWXYZ",
396-
TokenURL: ts.URL,
397+
TokenURL: ts.URL,
397398
},
398399
{
399-
Email: "aaa2@example.com",
400-
PrivateKey: fakePrivateKey,
400+
Email: "aaa2@example.com",
401+
PrivateKey: fakePrivateKey,
401402
PrivateKeyID: "ABCDEFGHIJKLMNOPQRSTUVWXYZ",
402-
TokenURL: ts.URL,
403-
Audience: "https://example.com",
403+
TokenURL: ts.URL,
404+
Audience: "https://example.com",
404405
},
405406
{
406-
Email: "aaa2@example.com",
407-
PrivateKey: fakePrivateKey,
407+
Email: "aaa2@example.com",
408+
PrivateKey: fakePrivateKey,
408409
PrivateKeyID: "ABCDEFGHIJKLMNOPQRSTUVWXYZ",
409-
TokenURL: ts.URL,
410+
TokenURL: ts.URL,
410411
PrivateClaims: map[string]interface{}{
411412
"private0": "claim0",
412413
"private1": "claim1",
@@ -478,9 +479,9 @@ func TestNew2LOTokenProvider_TokenError(t *testing.T) {
478479
defer ts.Close()
479480

480481
opts := &Options2LO{
481-
Email: "aaa@example.com",
482+
Email: "aaa@example.com",
482483
PrivateKey: fakePrivateKey,
483-
TokenURL: ts.URL,
484+
TokenURL: ts.URL,
484485
}
485486

486487
tp, err := New2LOTokenProvider(opts)
@@ -513,20 +514,20 @@ func TestNew2LOTokenProvider_Validate(t *testing.T) {
513514
name: "missing email",
514515
opts: &Options2LO{
515516
PrivateKey: []byte("key"),
516-
TokenURL: "url",
517+
TokenURL: "url",
517518
},
518519
},
519520
{
520521
name: "missing key",
521522
opts: &Options2LO{
522-
Email: "email",
523+
Email: "email",
523524
TokenURL: "url",
524525
},
525526
},
526527
{
527528
name: "missing URL",
528529
opts: &Options2LO{
529-
Email: "email",
530+
Email: "email",
530531
PrivateKey: []byte("key"),
531532
},
532533
},
@@ -615,25 +616,25 @@ func TestComputeTokenProvider_NonBlockingRefresh(t *testing.T) {
615616

616617
func TestComputeTokenProvider_BlockingRefresh(t *testing.T) {
617618
tests := []struct {
618-
name string
619+
name string
619620
disableAutoRefresh bool
620-
want1 string
621-
want2 string
622-
wantState2 tokenState
621+
want1 string
622+
want2 string
623+
wantState2 tokenState
623624
}{
624625
{
625-
name: "disableAutoRefresh",
626+
name: "disableAutoRefresh",
626627
disableAutoRefresh: true,
627-
want1: "1",
628-
want2: "1",
628+
want1: "1",
629+
want2: "1",
629630
// Because token "count" does not increase, it will always be stale.
630631
wantState2: stale,
631632
},
632633
{
633-
name: "autoRefresh",
634+
name: "autoRefresh",
634635
disableAutoRefresh: false,
635-
want1: "1",
636-
want2: "2",
636+
want1: "1",
637+
want2: "2",
637638
// As token "count" increases to 2, it transitions to fresh.
638639
wantState2: fresh,
639640
},
@@ -646,7 +647,7 @@ func TestComputeTokenProvider_BlockingRefresh(t *testing.T) {
646647
defer func() { timeNow = time.Now }()
647648
tp := NewCachedTokenProvider(&countingTestProvider{count: 1}, &CachedTokenProviderOptions{
648649
DisableAsyncRefresh: true,
649-
DisableAutoRefresh: tt.disableAutoRefresh,
650+
DisableAutoRefresh: tt.disableAutoRefresh,
650651
// EarlyTokenRefresh ensures that token with early expiry just less than 2 seconds before now is already stale.
651652
ExpireEarly: 1990 * time.Millisecond,
652653
})
@@ -680,3 +681,97 @@ func TestComputeTokenProvider_BlockingRefresh(t *testing.T) {
680681
})
681682
}
682683
}
684+
685+
type controllableTokenProvider struct {
686+
mu sync.Mutex
687+
count int
688+
tok *Token
689+
err error
690+
block chan struct{}
691+
}
692+
693+
func (p *controllableTokenProvider) Token(ctx context.Context) (*Token, error) {
694+
if ch := p.getBlockChan(); ch != nil {
695+
<-ch
696+
}
697+
p.mu.Lock()
698+
defer p.mu.Unlock()
699+
p.count++
700+
return p.tok, p.err
701+
}
702+
703+
func (p *controllableTokenProvider) getBlockChan() chan struct{} {
704+
p.mu.Lock()
705+
defer p.mu.Unlock()
706+
return p.block
707+
}
708+
709+
func (p *controllableTokenProvider) setBlockChan(ch chan struct{}) {
710+
p.mu.Lock()
711+
defer p.mu.Unlock()
712+
p.block = ch
713+
}
714+
715+
func (p *controllableTokenProvider) getCount() int {
716+
p.mu.Lock()
717+
defer p.mu.Unlock()
718+
return p.count
719+
}
720+
721+
func TestCachedTokenProvider_TokenAsyncRace(t *testing.T) {
722+
now := time.Now()
723+
timeNow = func() time.Time { return now }
724+
defer func() { timeNow = time.Now }()
725+
726+
tp := &controllableTokenProvider{}
727+
ctp := NewCachedTokenProvider(tp, &CachedTokenProviderOptions{
728+
ExpireEarly: 2 * time.Second,
729+
}).(*cachedTokenProvider)
730+
731+
// 1. Cache a stale token.
732+
tp.tok = &Token{Value: "initial", Expiry: now.Add(1 * time.Second)}
733+
if _, err := ctp.Token(context.Background()); err != nil {
734+
t.Fatalf("initial Token() failed: %v", err)
735+
}
736+
if got, want := tp.getCount(), 1; got != want {
737+
t.Fatalf("tp.count = %d; want %d", got, want)
738+
}
739+
if got, want := ctp.tokenState(), stale; got != want {
740+
t.Fatalf("tokenState = %v; want %v", got, want)
741+
}
742+
743+
// 2. Setup for refresh.
744+
tp.setBlockChan(make(chan struct{}))
745+
tp.tok = &Token{Value: "refreshed", Expiry: now.Add(1 * time.Hour)}
746+
747+
// 3. Concurrently call Token to trigger async refresh.
748+
var wg sync.WaitGroup
749+
for i := 0; i < 10; i++ {
750+
wg.Add(1)
751+
go func() {
752+
defer wg.Done()
753+
ctp.Token(context.Background())
754+
}()
755+
}
756+
757+
// 4. Unblock refresh and wait for all goroutines to finish.
758+
time.Sleep(100 * time.Millisecond) // give time for goroutines to run
759+
close(tp.getBlockChan())
760+
wg.Wait()
761+
time.Sleep(100 * time.Millisecond) // give time for async refresh to complete
762+
763+
// 5. Check results.
764+
if got, want := tp.getCount(), 2; got != want {
765+
t.Errorf("tp.count = %d; want %d", got, want)
766+
}
767+
if got, want := ctp.tokenState(), fresh; got != want {
768+
t.Errorf("tokenState = %v; want %v", got, want)
769+
}
770+
tok, err := ctp.Token(context.Background())
771+
if err != nil {
772+
t.Fatal(err)
773+
}
774+
if got, want := tok.Value, "refreshed"; got != want {
775+
t.Errorf("tok.Value = %q; want %q", got, want)
776+
}
777+
}

0 commit comments

Comments
 (0)