@@ -22,6 +22,7 @@ import (
2222 "net/http"
2323 "net/http/httptest"
2424 "strings"
25+ "sync"
2526 "testing"
2627 "time"
2728
@@ -105,9 +106,9 @@ func TestError_Temporary(t *testing.T) {
105106
106107func TestToken_MetadataString (t * testing.T ) {
107108 cases := []struct {
108- name string
109+ name string
109110 metadata map [string ]interface {}
110- want string
111+ want string
111112 }{
112113 {
113114 name : "nil metadata" ,
@@ -142,10 +143,10 @@ func TestToken_isValidWithEarlyExpiry(t *testing.T) {
142143 defer func () { timeNow = time .Now }()
143144
144145 cases := []struct {
145- name string
146- tok * Token
146+ name string
147+ tok * Token
147148 expiry time.Duration
148- want bool
149+ want bool
149150 }{
150151 {name : "4 minutes" , tok : & Token {Expiry : now .Add (4 * 60 * time .Second )}, expiry : defaultExpiryDelta , want : true },
151152 {name : "3 minutes and 45 seconds" , tok : & Token {Expiry : now .Add (defaultExpiryDelta )}, expiry : defaultExpiryDelta , want : true },
@@ -169,12 +170,12 @@ func TestError_Error(t *testing.T) {
169170 tests := []struct {
170171 name string
171172
172- Response * http.Response
173- Body []byte
174- Err error
175- code string
173+ Response * http.Response
174+ Body []byte
175+ Err error
176+ code string
176177 description string
177- uri string
178+ uri string
178179
179180 want string
180181 }{
@@ -187,22 +188,22 @@ func TestError_Error(t *testing.T) {
187188 want : "auth: cannot fetch token: 418\n Response: I'm a teapot" ,
188189 },
189190 {
190- name : "from query" ,
191- code : fmt .Sprint (http .StatusTeapot ),
191+ name : "from query" ,
192+ code : fmt .Sprint (http .StatusTeapot ),
192193 description : "I'm a teapot" ,
193- uri : "somewhere" ,
194- want : "auth: \" 418\" \" I'm a teapot\" \" somewhere\" " ,
194+ uri : "somewhere" ,
195+ want : "auth: \" 418\" \" I'm a teapot\" \" somewhere\" " ,
195196 },
196197 }
197198 for _ , tt := range tests {
198199 t .Run (tt .name , func (t * testing.T ) {
199200 r := & Error {
200- Response : tt .Response ,
201- Body : tt .Body ,
202- Err : tt .Err ,
203- code : tt .code ,
201+ Response : tt .Response ,
202+ Body : tt .Body ,
203+ Err : tt .Err ,
204+ code : tt .code ,
204205 description : tt .description ,
205- uri : tt .uri ,
206+ uri : tt .uri ,
206207 }
207208 if got := r .Error (); got != tt .want {
208209 t .Errorf ("Error.Error() = %v, want %v" , got , tt .want )
@@ -224,9 +225,9 @@ func TestNew2LOTokenProvider_JSONResponse(t *testing.T) {
224225 defer ts .Close ()
225226
226227 opts := & Options2LO {
227- Email : "aaa@example.com" ,
228+ Email : "aaa@example.com" ,
228229 PrivateKey : fakePrivateKey ,
229- TokenURL : ts .URL ,
230+ TokenURL : ts .URL ,
230231 }
231232 tp , err := New2LOTokenProvider (opts )
232233 if err != nil {
@@ -262,9 +263,9 @@ func TestNew2LOTokenProvider_BadResponse(t *testing.T) {
262263 defer ts .Close ()
263264
264265 opts := & Options2LO {
265- Email : "aaa@example.com" ,
266+ Email : "aaa@example.com" ,
266267 PrivateKey : fakePrivateKey ,
267- TokenURL : ts .URL ,
268+ TokenURL : ts .URL ,
268269 }
269270 tp , err := New2LOTokenProvider (opts )
270271 if err != nil {
@@ -299,9 +300,9 @@ func TestNew2LOTokenProvider_BadResponseType(t *testing.T) {
299300 }))
300301 defer ts .Close ()
301302 opts := & Options2LO {
302- Email : "aaa@example.com" ,
303+ Email : "aaa@example.com" ,
303304 PrivateKey : fakePrivateKey ,
304- TokenURL : ts .URL ,
305+ TokenURL : ts .URL ,
305306 }
306307 tp , err := New2LOTokenProvider (opts )
307308 if err != nil {
@@ -333,10 +334,10 @@ func TestNew2LOTokenProvider_Assertion(t *testing.T) {
333334 defer ts .Close ()
334335
335336 opts := & Options2LO {
336- Email : "aaa@example.com" ,
337- PrivateKey : fakePrivateKey ,
337+ Email : "aaa@example.com" ,
338+ PrivateKey : fakePrivateKey ,
338339 PrivateKeyID : "ABCDEFGHIJKLMNOPQRSTUVWXYZ" ,
339- TokenURL : ts .URL ,
340+ TokenURL : ts .URL ,
340341 }
341342
342343 tp , err := New2LOTokenProvider (opts )
@@ -364,8 +365,8 @@ func TestNew2LOTokenProvider_Assertion(t *testing.T) {
364365
365366 want := jwt.Header {
366367 Algorithm : "RS256" ,
367- Type : "JWT" ,
368- KeyID : "ABCDEFGHIJKLMNOPQRSTUVWXYZ" ,
368+ Type : "JWT" ,
369+ KeyID : "ABCDEFGHIJKLMNOPQRSTUVWXYZ" ,
369370 }
370371 if got != want {
371372 t .Errorf ("access token header = %q; want %q" , got , want )
@@ -390,23 +391,23 @@ func TestNew2LOTokenProvider_AssertionPayload(t *testing.T) {
390391
391392 for _ , opts := range []* Options2LO {
392393 {
393- Email : "aaa1@example.com" ,
394- PrivateKey : fakePrivateKey ,
394+ Email : "aaa1@example.com" ,
395+ PrivateKey : fakePrivateKey ,
395396 PrivateKeyID : "ABCDEFGHIJKLMNOPQRSTUVWXYZ" ,
396- TokenURL : ts .URL ,
397+ TokenURL : ts .URL ,
397398 },
398399 {
399- Email : "aaa2@example.com" ,
400- PrivateKey : fakePrivateKey ,
400+ Email : "aaa2@example.com" ,
401+ PrivateKey : fakePrivateKey ,
401402 PrivateKeyID : "ABCDEFGHIJKLMNOPQRSTUVWXYZ" ,
402- TokenURL : ts .URL ,
403- Audience : "https://example.com" ,
403+ TokenURL : ts .URL ,
404+ Audience : "https://example.com" ,
404405 },
405406 {
406- Email : "aaa2@example.com" ,
407- PrivateKey : fakePrivateKey ,
407+ Email : "aaa2@example.com" ,
408+ PrivateKey : fakePrivateKey ,
408409 PrivateKeyID : "ABCDEFGHIJKLMNOPQRSTUVWXYZ" ,
409- TokenURL : ts .URL ,
410+ TokenURL : ts .URL ,
410411 PrivateClaims : map [string ]interface {}{
411412 "private0" : "claim0" ,
412413 "private1" : "claim1" ,
@@ -478,9 +479,9 @@ func TestNew2LOTokenProvider_TokenError(t *testing.T) {
478479 defer ts .Close ()
479480
480481 opts := & Options2LO {
481- Email : "aaa@example.com" ,
482+ Email : "aaa@example.com" ,
482483 PrivateKey : fakePrivateKey ,
483- TokenURL : ts .URL ,
484+ TokenURL : ts .URL ,
484485 }
485486
486487 tp , err := New2LOTokenProvider (opts )
@@ -513,20 +514,20 @@ func TestNew2LOTokenProvider_Validate(t *testing.T) {
513514 name : "missing email" ,
514515 opts : & Options2LO {
515516 PrivateKey : []byte ("key" ),
516- TokenURL : "url" ,
517+ TokenURL : "url" ,
517518 },
518519 },
519520 {
520521 name : "missing key" ,
521522 opts : & Options2LO {
522- Email : "email" ,
523+ Email : "email" ,
523524 TokenURL : "url" ,
524525 },
525526 },
526527 {
527528 name : "missing URL" ,
528529 opts : & Options2LO {
529- Email : "email" ,
530+ Email : "email" ,
530531 PrivateKey : []byte ("key" ),
531532 },
532533 },
@@ -615,25 +616,25 @@ func TestComputeTokenProvider_NonBlockingRefresh(t *testing.T) {
615616
616617func TestComputeTokenProvider_BlockingRefresh (t * testing.T ) {
617618 tests := []struct {
618- name string
619+ name string
619620 disableAutoRefresh bool
620- want1 string
621- want2 string
622- wantState2 tokenState
621+ want1 string
622+ want2 string
623+ wantState2 tokenState
623624 }{
624625 {
625- name : "disableAutoRefresh" ,
626+ name : "disableAutoRefresh" ,
626627 disableAutoRefresh : true ,
627- want1 : "1" ,
628- want2 : "1" ,
628+ want1 : "1" ,
629+ want2 : "1" ,
629630 // Because token "count" does not increase, it will always be stale.
630631 wantState2 : stale ,
631632 },
632633 {
633- name : "autoRefresh" ,
634+ name : "autoRefresh" ,
634635 disableAutoRefresh : false ,
635- want1 : "1" ,
636- want2 : "2" ,
636+ want1 : "1" ,
637+ want2 : "2" ,
637638 // As token "count" increases to 2, it transitions to fresh.
638639 wantState2 : fresh ,
639640 },
@@ -646,7 +647,7 @@ func TestComputeTokenProvider_BlockingRefresh(t *testing.T) {
646647 defer func () { timeNow = time .Now }()
647648 tp := NewCachedTokenProvider (& countingTestProvider {count : 1 }, & CachedTokenProviderOptions {
648649 DisableAsyncRefresh : true ,
649- DisableAutoRefresh : tt .disableAutoRefresh ,
650+ DisableAutoRefresh : tt .disableAutoRefresh ,
650651 // EarlyTokenRefresh ensures that token with early expiry just less than 2 seconds before now is already stale.
651652 ExpireEarly : 1990 * time .Millisecond ,
652653 })
@@ -680,3 +681,97 @@ func TestComputeTokenProvider_BlockingRefresh(t *testing.T) {
680681 })
681682 }
682683}
684+
685+ type controllableTokenProvider struct {
686+ mu sync.Mutex
687+ count int
688+ tok * Token
689+ err error
690+ block chan struct {}
691+ }
692+
693+ func (p * controllableTokenProvider ) Token (ctx context.Context ) (* Token , error ) {
694+ if ch := p .getBlockChan (); ch != nil {
695+ <- ch
696+ }
697+ p .mu .Lock ()
698+ defer p .mu .Unlock ()
699+ p .count ++
700+ return p .tok , p .err
701+ }
702+
703+ func (p * controllableTokenProvider ) getBlockChan () chan struct {} {
704+ p .mu .Lock ()
705+ defer p .mu .Unlock ()
706+ return p .block
707+ }
708+
709+ func (p * controllableTokenProvider ) setBlockChan (ch chan struct {}) {
710+ p .mu .Lock ()
711+ defer p .mu .Unlock ()
712+ p .block = ch
713+ }
714+
715+ func (p * controllableTokenProvider ) getCount () int {
716+ p .mu .Lock ()
717+ defer p .mu .Unlock ()
718+ return p .count
719+ }
720+
721+ func TestCachedTokenProvider_TokenAsyncRace (t * testing.T ) {
722+ now := time .Now ()
723+ timeNow = func () time.Time { return now }
724+ defer func () { timeNow = time .Now }()
725+
726+ tp := & controllableTokenProvider {}
727+ ctp := NewCachedTokenProvider (tp , & CachedTokenProviderOptions {
728+ ExpireEarly : 2 * time .Second ,
729+ }).(* cachedTokenProvider )
730+
731+ // 1. Cache a stale token.
732+ tp .tok = & Token {Value : "initial" , Expiry : now .Add (1 * time .Second )}
733+ if _ , err := ctp .Token (context .Background ()); err != nil {
734+ t .Fatalf ("initial Token() failed: %v" , err )
735+ }
736+ if got , want := tp .getCount (), 1 ; got != want {
737+ t .Fatalf ("tp.count = %d; want %d" , got , want )
738+ }
739+ if got , want := ctp .tokenState (), stale ; got != want {
740+ t .Fatalf ("tokenState = %v; want %v" , got , want )
741+ }
742+
743+ // 2. Setup for refresh.
744+ tp .setBlockChan (make (chan struct {}))
745+ tp .tok = & Token {Value : "refreshed" , Expiry : now .Add (1 * time .Hour )}
746+
747+ // 3. Concurrently call Token to trigger async refresh.
748+ var wg sync.WaitGroup
749+ for i := 0 ; i < 10 ; i ++ {
750+ wg .Add (1 )
751+ go func () {
752+ defer wg .Done ()
753+ ctp .Token (context .Background ())
754+ }()
755+ }
756+
757+ // 4. Unblock refresh and wait for all goroutines to finish.
758+ time .Sleep (100 * time .Millisecond ) // give time for goroutines to run
759+ close (tp .getBlockChan ())
760+ wg .Wait ()
761+ time .Sleep (100 * time .Millisecond ) // give time for async refresh to complete
762+
763+ // 5. Check results.
764+ if got , want := tp .getCount (), 2 ; got != want {
765+ t .Errorf ("tp.count = %d; want %d" , got , want )
766+ }
767+ if got , want := ctp .tokenState (), fresh ; got != want {
768+ t .Errorf ("tokenState = %v; want %v" , got , want )
769+ }
770+ tok , err := ctp .Token (context .Background ())
771+ if err != nil {
772+ t .Fatal (err )
773+ }
774+ if got , want := tok .Value , "refreshed" ; got != want {
775+ t .Errorf ("tok.Value = %q; want %q" , got , want )
776+ }
777+ }
0 commit comments