@@ -18,6 +18,8 @@ package azure
18
18
19
19
import (
20
20
"encoding/json"
21
+ "errors"
22
+ "net/http"
21
23
"strconv"
22
24
"strings"
23
25
"sync"
@@ -172,10 +174,7 @@ func TestAzureTokenSource(t *testing.T) {
172
174
for i , configMode := range configModes {
173
175
t .Run ("validate token against cache" , func (t * testing.T ) {
174
176
fakeAccessToken := "fake token 1"
175
- fakeSource := fakeTokenSource {
176
- accessToken : fakeAccessToken ,
177
- expiresOn : strconv .FormatInt (time .Now ().Add (3600 * time .Second ).Unix (), 10 ),
178
- }
177
+ fakeSource := fakeTokenSource {token : newFakeAzureToken (fakeAccessToken , time .Now ().Add (3600 * time .Second ))}
179
178
cfg := make (map [string ]string )
180
179
persiter := & fakePersister {cache : make (map [string ]string )}
181
180
tokenCache := newAzureTokenCache ()
@@ -210,7 +209,7 @@ func TestAzureTokenSource(t *testing.T) {
210
209
}
211
210
}
212
211
213
- fakeSource .accessToken = "fake token 2"
212
+ fakeSource .token = newFakeAzureToken ( "fake token 2" , time . Now (). Add ( 3600 * time . Second ))
214
213
token , err = tokenSource .Token ()
215
214
if err != nil {
216
215
t .Errorf ("failed to retrieve the cached token: %v" , err )
@@ -223,14 +222,161 @@ func TestAzureTokenSource(t *testing.T) {
223
222
}
224
223
}
225
224
225
+ func TestAzureTokenSourceScenarios (t * testing.T ) {
226
+ configMode := configModeDefault
227
+ expiredToken := newFakeAzureToken ("expired token" , time .Now ().Add (- time .Second ))
228
+ extendedToken := newFakeAzureToken ("extend token" , time .Now ().Add (1000 * time .Second ))
229
+ fakeToken := newFakeAzureToken ("fake token" , time .Now ().Add (1000 * time .Second ))
230
+ wrongToken := newFakeAzureToken ("wrong token" , time .Now ().Add (1000 * time .Second ))
231
+ tests := []struct {
232
+ name string
233
+ sourceToken * azureToken
234
+ refreshToken * azureToken
235
+ cachedToken * azureToken
236
+ configToken * azureToken
237
+ expectToken * azureToken
238
+ tokenErr error
239
+ refreshErr error
240
+ expectErr string
241
+ tokenCalls uint
242
+ refreshCalls uint
243
+ persistCalls uint
244
+ }{
245
+ {
246
+ name : "new config" ,
247
+ sourceToken : fakeToken ,
248
+ expectToken : fakeToken ,
249
+ tokenCalls : 1 ,
250
+ persistCalls : 1 ,
251
+ },
252
+ {
253
+ name : "load token from cache" ,
254
+ sourceToken : wrongToken ,
255
+ cachedToken : fakeToken ,
256
+ configToken : wrongToken ,
257
+ expectToken : fakeToken ,
258
+ },
259
+ {
260
+ name : "load token from config" ,
261
+ sourceToken : wrongToken ,
262
+ configToken : fakeToken ,
263
+ expectToken : fakeToken ,
264
+ },
265
+ {
266
+ name : "cached token timeout, extend success, config token should never load" ,
267
+ cachedToken : expiredToken ,
268
+ refreshToken : extendedToken ,
269
+ configToken : wrongToken ,
270
+ expectToken : extendedToken ,
271
+ refreshCalls : 1 ,
272
+ persistCalls : 1 ,
273
+ },
274
+ {
275
+ name : "config token timeout, extend failure, acquire new token" ,
276
+ configToken : expiredToken ,
277
+ refreshErr : fakeTokenRefreshError {message : "FakeError happened when refreshing" },
278
+ sourceToken : fakeToken ,
279
+ expectToken : fakeToken ,
280
+ refreshCalls : 1 ,
281
+ tokenCalls : 1 ,
282
+ persistCalls : 1 ,
283
+ },
284
+ {
285
+ name : "unexpected error when extend" ,
286
+ configToken : expiredToken ,
287
+ refreshErr : errors .New ("unexpected refresh error" ),
288
+ sourceToken : fakeToken ,
289
+ expectErr : "unexpected refresh error" ,
290
+ refreshCalls : 1 ,
291
+ },
292
+ {
293
+ name : "token error" ,
294
+ tokenErr : errors .New ("tokenerr" ),
295
+ expectErr : "tokenerr" ,
296
+ tokenCalls : 1 ,
297
+ },
298
+ {
299
+ name : "Token() got expired token" ,
300
+ sourceToken : expiredToken ,
301
+ expectErr : "newly acquired token is expired" ,
302
+ tokenCalls : 1 ,
303
+ },
304
+ {
305
+ name : "Token() got nil but no error" ,
306
+ sourceToken : nil ,
307
+ expectErr : "unable to acquire token" ,
308
+ tokenCalls : 1 ,
309
+ },
310
+ }
311
+ for _ , tc := range tests {
312
+ t .Run (tc .name , func (t * testing.T ) {
313
+ persister := newFakePersister ()
314
+
315
+ cfg := map [string ]string {}
316
+ if tc .configToken != nil {
317
+ cfg = token2Cfg (tc .configToken )
318
+ }
319
+
320
+ tokenCache := newAzureTokenCache ()
321
+ if tc .cachedToken != nil {
322
+ tokenCache .setToken (azureTokenKey , tc .cachedToken )
323
+ }
324
+
325
+ fakeSource := fakeTokenSource {
326
+ token : tc .sourceToken ,
327
+ tokenErr : tc .tokenErr ,
328
+ refreshToken : tc .refreshToken ,
329
+ refreshErr : tc .refreshErr ,
330
+ }
331
+
332
+ tokenSource := newAzureTokenSource (& fakeSource , tokenCache , cfg , configMode , & persister )
333
+ token , err := tokenSource .Token ()
334
+
335
+ if fakeSource .tokenCalls != tc .tokenCalls {
336
+ t .Errorf ("expecting tokenCalls: %v, got: %v" , tc .tokenCalls , fakeSource .tokenCalls )
337
+ }
338
+
339
+ if fakeSource .refreshCalls != tc .refreshCalls {
340
+ t .Errorf ("expecting refreshCalls: %v, got: %v" , tc .refreshCalls , fakeSource .refreshCalls )
341
+ }
342
+
343
+ if persister .calls != tc .persistCalls {
344
+ t .Errorf ("expecting persister calls: %v, got: %v" , tc .persistCalls , persister .calls )
345
+ }
346
+
347
+ if tc .expectErr != "" {
348
+ if ! strings .Contains (err .Error (), tc .expectErr ) {
349
+ t .Errorf ("expecting error %v, got %v" , tc .expectErr , err )
350
+ }
351
+ if token != nil {
352
+ t .Errorf ("token should be nil in err situation, got %v" , token )
353
+ }
354
+ } else {
355
+ if err != nil {
356
+ t .Fatalf ("error should be nil, got %v" , err )
357
+ }
358
+ if token .token .AccessToken != tc .expectToken .token .AccessToken {
359
+ t .Errorf ("token should have accessToken %v, got %v" , token .token .AccessToken , tc .expectToken .token .AccessToken )
360
+ }
361
+ }
362
+ })
363
+ }
364
+ }
365
+
226
366
type fakePersister struct {
227
367
lock sync.Mutex
228
368
cache map [string ]string
369
+ calls uint
370
+ }
371
+
372
+ func newFakePersister () fakePersister {
373
+ return fakePersister {cache : make (map [string ]string ), calls : 0 }
229
374
}
230
375
231
376
func (p * fakePersister ) Persist (cache map [string ]string ) error {
232
377
p .lock .Lock ()
233
378
defer p .lock .Unlock ()
379
+ p .calls ++
234
380
p .cache = map [string ]string {}
235
381
for k , v := range cache {
236
382
p .cache [k ] = v
@@ -248,19 +394,24 @@ func (p *fakePersister) Cache() map[string]string {
248
394
return ret
249
395
}
250
396
397
+ // a simple token source simply always returns the token property
251
398
type fakeTokenSource struct {
252
- expiresOn string
253
- accessToken string
399
+ token * azureToken
400
+ tokenCalls uint
401
+ tokenErr error
402
+ refreshToken * azureToken
403
+ refreshCalls uint
404
+ refreshErr error
254
405
}
255
406
256
407
func (ts * fakeTokenSource ) Token () (* azureToken , error ) {
257
- return & azureToken {
258
- token : newFackeAzureToken ( ts .accessToken , ts .expiresOn ),
259
- environment : "testenv" ,
260
- clientID : "fake" ,
261
- tenantID : "fake" ,
262
- apiserverID : "fake" ,
263
- }, nil
408
+ ts . tokenCalls ++
409
+ return ts .token , ts .tokenErr
410
+ }
411
+
412
+ func ( ts * fakeTokenSource ) Refresh ( * azureToken ) ( * azureToken , error ) {
413
+ ts . refreshCalls ++
414
+ return ts . refreshToken , ts . refreshErr
264
415
}
265
416
266
417
func token2Cfg (token * azureToken ) map [string ]string {
@@ -276,7 +427,17 @@ func token2Cfg(token *azureToken) map[string]string {
276
427
return cfg
277
428
}
278
429
279
- func newFackeAzureToken (accessToken string , expiresOn string ) adal.Token {
430
+ func newFakeAzureToken (accessToken string , expiresOnTime time.Time ) * azureToken {
431
+ return & azureToken {
432
+ token : newFakeADALToken (accessToken , strconv .FormatInt (expiresOnTime .Unix (), 10 )),
433
+ environment : "testenv" ,
434
+ clientID : "fake" ,
435
+ tenantID : "fake" ,
436
+ apiserverID : "fake" ,
437
+ }
438
+ }
439
+
440
+ func newFakeADALToken (accessToken string , expiresOn string ) adal.Token {
280
441
return adal.Token {
281
442
AccessToken : accessToken ,
282
443
RefreshToken : "fake" ,
@@ -287,3 +448,19 @@ func newFackeAzureToken(accessToken string, expiresOn string) adal.Token {
287
448
Type : "fake" ,
288
449
}
289
450
}
451
+
452
+ // copied from go-autorest/adal
453
+ type fakeTokenRefreshError struct {
454
+ message string
455
+ resp * http.Response
456
+ }
457
+
458
+ // Error implements the error interface which is part of the TokenRefreshError interface.
459
+ func (tre fakeTokenRefreshError ) Error () string {
460
+ return tre .message
461
+ }
462
+
463
+ // Response implements the TokenRefreshError interface, it returns the raw HTTP response from the refresh operation.
464
+ func (tre fakeTokenRefreshError ) Response () * http.Response {
465
+ return tre .resp
466
+ }
0 commit comments