Skip to content

Commit 26c97fa

Browse files
committed
Azure auth fallback to real auth if refresh token fails, refactor and add more tests.
Signed-off-by: Ping He <[email protected]>
1 parent df908c3 commit 26c97fa

File tree

2 files changed

+251
-34
lines changed

2 files changed

+251
-34
lines changed

staging/src/k8s.io/client-go/plugin/pkg/client/auth/azure/azure.go

Lines changed: 59 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -180,6 +180,7 @@ type azureToken struct {
180180

181181
type tokenSource interface {
182182
Token() (*azureToken, error)
183+
Refresh(*azureToken) (*azureToken, error)
183184
}
184185

185186
type azureTokenSource struct {
@@ -210,33 +211,66 @@ func (ts *azureTokenSource) Token() (*azureToken, error) {
210211

211212
var err error
212213
token := ts.cache.getToken(azureTokenKey)
214+
215+
if token != nil && !token.token.IsExpired() {
216+
return token, nil
217+
}
218+
219+
// retrieve from config if no cache
213220
if token == nil {
214-
token, err = ts.retrieveTokenFromCfg()
215-
if err != nil {
216-
token, err = ts.source.Token()
217-
if err != nil {
218-
return nil, fmt.Errorf("acquiring a new fresh token: %v", err)
219-
}
221+
tokenFromCfg, err := ts.retrieveTokenFromCfg()
222+
223+
if err == nil {
224+
token = tokenFromCfg
220225
}
226+
}
227+
228+
if token != nil {
229+
// cache and return if the token is as good
230+
// avoids frequent persistor calls
221231
if !token.token.IsExpired() {
222232
ts.cache.setToken(azureTokenKey, token)
223-
err = ts.storeTokenInCfg(token)
224-
if err != nil {
225-
return nil, fmt.Errorf("storing the token in configuration: %v", err)
226-
}
233+
return token, nil
227234
}
228-
}
229-
if token.token.IsExpired() {
230-
token, err = ts.refreshToken(token)
231-
if err != nil {
232-
return nil, fmt.Errorf("refreshing the expired token: %v", err)
235+
236+
klog.V(4).Info("Refreshing token.")
237+
tokenFromRefresh, err := ts.Refresh(token)
238+
switch {
239+
case err == nil:
240+
token = tokenFromRefresh
241+
case autorest.IsTokenRefreshError(err):
242+
klog.V(4).Infof("Failed to refresh expired token, proceed to auth: %v", err)
243+
// reset token to nil so that the token source will be used to acquire new
244+
token = nil
245+
default:
246+
return nil, fmt.Errorf("unexpected error when refreshing token: %v", err)
233247
}
234-
ts.cache.setToken(azureTokenKey, token)
235-
err = ts.storeTokenInCfg(token)
248+
}
249+
250+
if token == nil {
251+
tokenFromSource, err := ts.source.Token()
236252
if err != nil {
237-
return nil, fmt.Errorf("storing the refreshed token in configuration: %v", err)
253+
return nil, fmt.Errorf("failed acquiring new token: %v", err)
238254
}
255+
token = tokenFromSource
256+
}
257+
258+
// sanity check
259+
if token == nil {
260+
return nil, fmt.Errorf("unable to acquire token")
261+
}
262+
263+
// corner condition, newly got token is valid but expired
264+
if token.token.IsExpired() {
265+
return nil, fmt.Errorf("newly acquired token is expired")
239266
}
267+
268+
err = ts.storeTokenInCfg(token)
269+
if err != nil {
270+
return nil, fmt.Errorf("storing the refreshed token in configuration: %v", err)
271+
}
272+
ts.cache.setToken(azureTokenKey, token)
273+
240274
return token, nil
241275
}
242276

@@ -314,7 +348,13 @@ func (ts *azureTokenSource) storeTokenInCfg(token *azureToken) error {
314348
return nil
315349
}
316350

317-
func (ts *azureTokenSource) refreshToken(token *azureToken) (*azureToken, error) {
351+
func (ts *azureTokenSource) Refresh(token *azureToken) (*azureToken, error) {
352+
return ts.source.Refresh(token)
353+
}
354+
355+
// refresh outdated token with adal.
356+
// adal.RefreshTokenError will be returned if error occur during refreshing.
357+
func (ts *azureTokenSourceDeviceCode) Refresh(token *azureToken) (*azureToken, error) {
318358
env, err := azure.EnvironmentFromName(token.environment)
319359
if err != nil {
320360
return nil, err

staging/src/k8s.io/client-go/plugin/pkg/client/auth/azure/azure_test.go

Lines changed: 192 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,8 @@ package azure
1818

1919
import (
2020
"encoding/json"
21+
"errors"
22+
"net/http"
2123
"strconv"
2224
"strings"
2325
"sync"
@@ -172,10 +174,7 @@ func TestAzureTokenSource(t *testing.T) {
172174
for i, configMode := range configModes {
173175
t.Run("validate token against cache", func(t *testing.T) {
174176
fakeAccessToken := "fake token 1"
175-
fakeSource := fakeTokenSource{
176-
accessToken: fakeAccessToken,
177-
expiresOn: strconv.FormatInt(time.Now().Add(3600*time.Second).Unix(), 10),
178-
}
177+
fakeSource := fakeTokenSource{token: newFakeAzureToken(fakeAccessToken, time.Now().Add(3600*time.Second))}
179178
cfg := make(map[string]string)
180179
persiter := &fakePersister{cache: make(map[string]string)}
181180
tokenCache := newAzureTokenCache()
@@ -210,7 +209,7 @@ func TestAzureTokenSource(t *testing.T) {
210209
}
211210
}
212211

213-
fakeSource.accessToken = "fake token 2"
212+
fakeSource.token = newFakeAzureToken("fake token 2", time.Now().Add(3600*time.Second))
214213
token, err = tokenSource.Token()
215214
if err != nil {
216215
t.Errorf("failed to retrieve the cached token: %v", err)
@@ -223,14 +222,161 @@ func TestAzureTokenSource(t *testing.T) {
223222
}
224223
}
225224

225+
func TestAzureTokenSourceScenarios(t *testing.T) {
226+
configMode := configModeDefault
227+
expiredToken := newFakeAzureToken("expired token", time.Now().Add(-time.Second))
228+
extendedToken := newFakeAzureToken("extend token", time.Now().Add(1000*time.Second))
229+
fakeToken := newFakeAzureToken("fake token", time.Now().Add(1000*time.Second))
230+
wrongToken := newFakeAzureToken("wrong token", time.Now().Add(1000*time.Second))
231+
tests := []struct {
232+
name string
233+
sourceToken *azureToken
234+
refreshToken *azureToken
235+
cachedToken *azureToken
236+
configToken *azureToken
237+
expectToken *azureToken
238+
tokenErr error
239+
refreshErr error
240+
expectErr string
241+
tokenCalls uint
242+
refreshCalls uint
243+
persistCalls uint
244+
}{
245+
{
246+
name: "new config",
247+
sourceToken: fakeToken,
248+
expectToken: fakeToken,
249+
tokenCalls: 1,
250+
persistCalls: 1,
251+
},
252+
{
253+
name: "load token from cache",
254+
sourceToken: wrongToken,
255+
cachedToken: fakeToken,
256+
configToken: wrongToken,
257+
expectToken: fakeToken,
258+
},
259+
{
260+
name: "load token from config",
261+
sourceToken: wrongToken,
262+
configToken: fakeToken,
263+
expectToken: fakeToken,
264+
},
265+
{
266+
name: "cached token timeout, extend success, config token should never load",
267+
cachedToken: expiredToken,
268+
refreshToken: extendedToken,
269+
configToken: wrongToken,
270+
expectToken: extendedToken,
271+
refreshCalls: 1,
272+
persistCalls: 1,
273+
},
274+
{
275+
name: "config token timeout, extend failure, acquire new token",
276+
configToken: expiredToken,
277+
refreshErr: fakeTokenRefreshError{message: "FakeError happened when refreshing"},
278+
sourceToken: fakeToken,
279+
expectToken: fakeToken,
280+
refreshCalls: 1,
281+
tokenCalls: 1,
282+
persistCalls: 1,
283+
},
284+
{
285+
name: "unexpected error when extend",
286+
configToken: expiredToken,
287+
refreshErr: errors.New("unexpected refresh error"),
288+
sourceToken: fakeToken,
289+
expectErr: "unexpected refresh error",
290+
refreshCalls: 1,
291+
},
292+
{
293+
name: "token error",
294+
tokenErr: errors.New("tokenerr"),
295+
expectErr: "tokenerr",
296+
tokenCalls: 1,
297+
},
298+
{
299+
name: "Token() got expired token",
300+
sourceToken: expiredToken,
301+
expectErr: "newly acquired token is expired",
302+
tokenCalls: 1,
303+
},
304+
{
305+
name: "Token() got nil but no error",
306+
sourceToken: nil,
307+
expectErr: "unable to acquire token",
308+
tokenCalls: 1,
309+
},
310+
}
311+
for _, tc := range tests {
312+
t.Run(tc.name, func(t *testing.T) {
313+
persister := newFakePersister()
314+
315+
cfg := map[string]string{}
316+
if tc.configToken != nil {
317+
cfg = token2Cfg(tc.configToken)
318+
}
319+
320+
tokenCache := newAzureTokenCache()
321+
if tc.cachedToken != nil {
322+
tokenCache.setToken(azureTokenKey, tc.cachedToken)
323+
}
324+
325+
fakeSource := fakeTokenSource{
326+
token: tc.sourceToken,
327+
tokenErr: tc.tokenErr,
328+
refreshToken: tc.refreshToken,
329+
refreshErr: tc.refreshErr,
330+
}
331+
332+
tokenSource := newAzureTokenSource(&fakeSource, tokenCache, cfg, configMode, &persister)
333+
token, err := tokenSource.Token()
334+
335+
if fakeSource.tokenCalls != tc.tokenCalls {
336+
t.Errorf("expecting tokenCalls: %v, got: %v", tc.tokenCalls, fakeSource.tokenCalls)
337+
}
338+
339+
if fakeSource.refreshCalls != tc.refreshCalls {
340+
t.Errorf("expecting refreshCalls: %v, got: %v", tc.refreshCalls, fakeSource.refreshCalls)
341+
}
342+
343+
if persister.calls != tc.persistCalls {
344+
t.Errorf("expecting persister calls: %v, got: %v", tc.persistCalls, persister.calls)
345+
}
346+
347+
if tc.expectErr != "" {
348+
if !strings.Contains(err.Error(), tc.expectErr) {
349+
t.Errorf("expecting error %v, got %v", tc.expectErr, err)
350+
}
351+
if token != nil {
352+
t.Errorf("token should be nil in err situation, got %v", token)
353+
}
354+
} else {
355+
if err != nil {
356+
t.Fatalf("error should be nil, got %v", err)
357+
}
358+
if token.token.AccessToken != tc.expectToken.token.AccessToken {
359+
t.Errorf("token should have accessToken %v, got %v", token.token.AccessToken, tc.expectToken.token.AccessToken)
360+
}
361+
}
362+
})
363+
}
364+
}
365+
226366
type fakePersister struct {
227367
lock sync.Mutex
228368
cache map[string]string
369+
calls uint
370+
}
371+
372+
func newFakePersister() fakePersister {
373+
return fakePersister{cache: make(map[string]string), calls: 0}
229374
}
230375

231376
func (p *fakePersister) Persist(cache map[string]string) error {
232377
p.lock.Lock()
233378
defer p.lock.Unlock()
379+
p.calls++
234380
p.cache = map[string]string{}
235381
for k, v := range cache {
236382
p.cache[k] = v
@@ -248,19 +394,24 @@ func (p *fakePersister) Cache() map[string]string {
248394
return ret
249395
}
250396

397+
// a simple token source simply always returns the token property
251398
type fakeTokenSource struct {
252-
expiresOn string
253-
accessToken string
399+
token *azureToken
400+
tokenCalls uint
401+
tokenErr error
402+
refreshToken *azureToken
403+
refreshCalls uint
404+
refreshErr error
254405
}
255406

256407
func (ts *fakeTokenSource) Token() (*azureToken, error) {
257-
return &azureToken{
258-
token: newFackeAzureToken(ts.accessToken, ts.expiresOn),
259-
environment: "testenv",
260-
clientID: "fake",
261-
tenantID: "fake",
262-
apiserverID: "fake",
263-
}, nil
408+
ts.tokenCalls++
409+
return ts.token, ts.tokenErr
410+
}
411+
412+
func (ts *fakeTokenSource) Refresh(*azureToken) (*azureToken, error) {
413+
ts.refreshCalls++
414+
return ts.refreshToken, ts.refreshErr
264415
}
265416

266417
func token2Cfg(token *azureToken) map[string]string {
@@ -276,7 +427,17 @@ func token2Cfg(token *azureToken) map[string]string {
276427
return cfg
277428
}
278429

279-
func newFackeAzureToken(accessToken string, expiresOn string) adal.Token {
430+
func newFakeAzureToken(accessToken string, expiresOnTime time.Time) *azureToken {
431+
return &azureToken{
432+
token: newFakeADALToken(accessToken, strconv.FormatInt(expiresOnTime.Unix(), 10)),
433+
environment: "testenv",
434+
clientID: "fake",
435+
tenantID: "fake",
436+
apiserverID: "fake",
437+
}
438+
}
439+
440+
func newFakeADALToken(accessToken string, expiresOn string) adal.Token {
280441
return adal.Token{
281442
AccessToken: accessToken,
282443
RefreshToken: "fake",
@@ -287,3 +448,19 @@ func newFackeAzureToken(accessToken string, expiresOn string) adal.Token {
287448
Type: "fake",
288449
}
289450
}
451+
452+
// copied from go-autorest/adal
453+
type fakeTokenRefreshError struct {
454+
message string
455+
resp *http.Response
456+
}
457+
458+
// Error implements the error interface which is part of the TokenRefreshError interface.
459+
func (tre fakeTokenRefreshError) Error() string {
460+
return tre.message
461+
}
462+
463+
// Response implements the TokenRefreshError interface, it returns the raw HTTP response from the refresh operation.
464+
func (tre fakeTokenRefreshError) Response() *http.Response {
465+
return tre.resp
466+
}

0 commit comments

Comments
 (0)