Skip to content

Commit c7b7533

Browse files
committed
introduce forceRefresh for GetToken
1 parent 31c438a commit c7b7533

File tree

3 files changed

+83
-17
lines changed

3 files changed

+83
-17
lines changed

credentials_provider.go

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -68,14 +68,14 @@ func (e *entraidCredentialsProvider) Subscribe(listener auth.CredentialsListener
6868
}
6969
e.rwLock.Unlock()
7070

71-
token, err := e.tokenManager.GetToken()
71+
token, err := e.tokenManager.GetToken(false)
7272
if err != nil {
73-
listener.OnError(err)
73+
go listener.OnError(err)
7474
return nil, nil, err
7575
}
7676

7777
// Notify the listener with the credentials.
78-
listener.OnNext(token)
78+
go listener.OnNext(token)
7979

8080
cancel := func() error {
8181
// Remove the listener from the list of listeners.

token_manager.go

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -75,7 +75,8 @@ type RetryOptions struct {
7575
// It is typically used in conjunction with an IdentityProvider to obtain the token.
7676
type TokenManager interface {
7777
// GetToken returns the token for authentication.
78-
GetToken() (*Token, error)
78+
// It takes a boolean value forceRefresh as an argument.
79+
GetToken(forceRefresh bool) (*Token, error)
7980
// Start starts the token manager and returns a channel that will receive updates.
8081
Start(listener TokenListener) (cancelFunc, error)
8182
// Close closes the token manager and releases any resources.
@@ -234,9 +235,9 @@ type entraidTokenManager struct {
234235
closed chan struct{}
235236
}
236237

237-
func (e *entraidTokenManager) GetToken() (*Token, error) {
238+
func (e *entraidTokenManager) GetToken(forceRefresh bool) (*Token, error) {
238239
// check if the token is nil and if it is not expired
239-
if e.token != nil && e.token.expiresOn.After(time.Now()) && e.durationToRenewal() > e.lowerBoundDuration {
240+
if !forceRefresh && e.token != nil && time.Now().Add(e.lowerBoundDuration).Before(e.token.expiresOn) {
240241
// copy the token so the caller can't modify it
241242
return copyToken(e.token), nil
242243
}
@@ -311,7 +312,7 @@ func (e *entraidTokenManager) Start(listener TokenListener) (cancelFunc, error)
311312
e.listener = listener
312313
e.closed = make(chan struct{})
313314

314-
token, err := e.GetToken()
315+
token, err := e.GetToken(true)
315316
if err != nil {
316317
go listener.OnTokenError(err)
317318
return nil, fmt.Errorf("failed to start token manager: %w", err)
@@ -343,7 +344,7 @@ func (e *entraidTokenManager) Start(listener TokenListener) (cancelFunc, error)
343344
default:
344345
// continue to next attempt
345346
}
346-
token, err := e.GetToken()
347+
token, err := e.GetToken(true)
347348
if err == nil {
348349
listener.OnTokenNext(token)
349350
break

token_manager_test.go

Lines changed: 74 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@ package entraid
22

33
import (
44
"fmt"
5+
"log"
56
"math/rand"
67
"os"
78
"reflect"
@@ -538,7 +539,7 @@ func TestEntraidTokenManager_GetToken(t *testing.T) {
538539
assert.NoError(t, err)
539540
assert.NotNil(t, tm.listener)
540541

541-
token, err := tokenManager.GetToken()
542+
token, err := tokenManager.GetToken(false)
542543
assert.NoError(t, err)
543544
assert.NotNil(t, token)
544545

@@ -591,7 +592,7 @@ func TestEntraidTokenManager_GetToken(t *testing.T) {
591592

592593
idp.On("RequestToken").Return(idpResponse, nil)
593594

594-
token, err := tokenManager.GetToken()
595+
token, err := tokenManager.GetToken(false)
595596
assert.Error(t, err)
596597
assert.Nil(t, token)
597598
})
@@ -611,7 +612,7 @@ func TestEntraidTokenManager_GetToken(t *testing.T) {
611612

612613
idp.On("RequestToken").Return(rawResponse, nil)
613614

614-
token, err := tokenManager.GetToken()
615+
token, err := tokenManager.GetToken(false)
615616
assert.Error(t, err)
616617
assert.Nil(t, token)
617618
})
@@ -634,7 +635,7 @@ func TestEntraidTokenManager_GetToken(t *testing.T) {
634635
idp.On("RequestToken").Return(idpResponse, nil)
635636
mParser.On("ParseResponse", idpResponse).Return(nil, nil)
636637

637-
token, err := tokenManager.GetToken()
638+
token, err := tokenManager.GetToken(false)
638639
assert.Error(t, err)
639640
assert.Nil(t, token)
640641
})
@@ -654,7 +655,7 @@ func TestEntraidTokenManager_GetToken(t *testing.T) {
654655

655656
idp.On("RequestToken").Return(nil, fmt.Errorf("idp error"))
656657

657-
token, err := tokenManager.GetToken()
658+
token, err := tokenManager.GetToken(false)
658659
assert.Error(t, err)
659660
assert.Nil(t, token)
660661
})
@@ -687,7 +688,7 @@ func TestEntraidTokenManager_durationToRenewal(t *testing.T) {
687688
assert.NoError(t, err)
688689
idp.On("RequestToken").Return(idpResponse, nil).Once()
689690
tm.token = nil
690-
_, err = tm.GetToken()
691+
_, err = tm.GetToken(false)
691692
assert.NoError(t, err)
692693
assert.NotNil(t, tm.token)
693694

@@ -707,7 +708,7 @@ func TestEntraidTokenManager_durationToRenewal(t *testing.T) {
707708
assert.NoError(t, err)
708709
idp.On("RequestToken").Return(idpResponse, nil).Once()
709710
tm.token = nil
710-
_, err = tm.GetToken()
711+
_, err = tm.GetToken(false)
711712
assert.NoError(t, err)
712713
assert.NotNil(t, tm.token)
713714

@@ -771,14 +772,17 @@ func TestEntraidTokenManager_Streaming(t *testing.T) {
771772
assert.NotNil(t, tm.listener)
772773
assert.NoError(t, tokenManager.Close())
773774
assert.Nil(t, tm.listener)
775+
assert.Panics(t, func() {
776+
close(tm.closed)
777+
})
778+
774779
<-time.After(toRenewal)
775780
assert.Error(t, tokenManager.Close())
776781
mock.AssertExpectationsForObjects(t, idp, mParser, listener)
777782
})
778783
t.Run("Start and Listen", func(t *testing.T) {
779784
idp := &mockIdentityProvider{}
780785
listener := &mockTokenListener{}
781-
mParser := &mockIdentityProviderResponseParser{}
782786
tokenManager, err := NewTokenManager(idp,
783787
TokenManagerOptions{},
784788
)
@@ -821,6 +825,67 @@ func TestEntraidTokenManager_Streaming(t *testing.T) {
821825

822826
<-time.After(toRenewal + time.Second)
823827

824-
mock.AssertExpectationsForObjects(t, idp, mParser, listener)
828+
mock.AssertExpectationsForObjects(t, idp, listener)
829+
})
830+
831+
t.Run("Start and Listen with retryable error", func(t *testing.T) {
832+
idp := &mockIdentityProvider{}
833+
listener := &mockTokenListener{}
834+
tokenManager, err := NewTokenManager(idp,
835+
TokenManagerOptions{},
836+
)
837+
assert.NoError(t, err)
838+
assert.NotNil(t, tokenManager)
839+
tm, ok := tokenManager.(*entraidTokenManager)
840+
assert.True(t, ok)
841+
assert.Nil(t, tm.listener)
842+
843+
assert.NoError(t, err)
844+
845+
expiresIn := time.Second
846+
expiresOn := time.Now().Add(expiresIn).UTC()
847+
res := &public.AuthResult{
848+
ExpiresOn: expiresOn,
849+
}
850+
idpResponse, err := NewIDPResponse(ResponseTypeAuthResult,
851+
res)
852+
assert.NoError(t, err)
853+
var returnErr error
854+
var secondResponse bool
855+
856+
idp.On("RequestToken").Run(func(args mock.Arguments) {
857+
if secondResponse {
858+
returnErr = mockError{isTimeout: true}
859+
return
860+
}
861+
expiresOn := time.Now().Add(expiresIn).UTC()
862+
res := &public.AuthResult{
863+
ExpiresOn: expiresOn,
864+
}
865+
response := idpResponse.(*authResult)
866+
response.authResult = res
867+
secondResponse = true
868+
}).Return(idpResponse, returnErr)
869+
870+
listener.On("OnTokenNext", mock.AnythingOfType("*entraid.Token")).Return()
871+
listener.On("OnTokenError", mock.Anything).Run(func(args mock.Arguments) {
872+
err := args.Get(0)
873+
assert.NotNil(t, err)
874+
log.Printf("Found TOKEN Error: %v", err)
875+
}).Return()
876+
877+
cancel, err := tokenManager.Start(listener)
878+
assert.NotNil(t, cancel)
879+
assert.NoError(t, err)
880+
assert.NotNil(t, tm.listener)
881+
882+
toRenewal := tm.durationToRenewal()
883+
assert.NotEqual(t, 0, toRenewal)
884+
assert.NotEqual(t, expiresIn, toRenewal)
885+
assert.True(t, expiresIn > toRenewal)
886+
887+
<-time.After(toRenewal + time.Second)
888+
889+
mock.AssertExpectationsForObjects(t, idp, listener)
825890
})
826891
}

0 commit comments

Comments
 (0)