Skip to content

Commit 7c86751

Browse files
committed
improve calculation of time to renewal
1 parent 329d9e7 commit 7c86751

File tree

4 files changed

+202
-12
lines changed

4 files changed

+202
-12
lines changed

entraid_test.go

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -82,6 +82,9 @@ type mockIdentityProvider struct {
8282

8383
func (m *mockIdentityProvider) RequestToken() (IdentityProviderResponse, error) {
8484
args := m.Called()
85+
if args.Get(0) == nil {
86+
return nil, args.Error(1)
87+
}
8588
return args.Get(0).(IdentityProviderResponse), args.Error(1)
8689
}
8790

token.go

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -60,6 +60,9 @@ func NewToken(username, password, rawToken string, expiresOn, receivedAt time.Ti
6060

6161
// copyToken creates a copy of the token.
6262
func copyToken(token *Token) *Token {
63+
if token == nil {
64+
return nil
65+
}
6366
return NewToken(token.username, token.password, token.rawToken, token.expiresOn, token.receivedAt, token.ttl)
6467
}
6568

token_manager.go

Lines changed: 30 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@ type TokenManagerOptions struct {
2323
// default: 0.7
2424
ExpirationRefreshRatio float64
2525
// LowerRefreshBoundMs is the lower bound for the refresh time in milliseconds.
26-
// Represents the minimum time in milliseconds before token expiration to trigger a refresh, in milliseconds.
26+
// Represents the minimum time in milliseconds before token expiration to trigger a refresh.
2727
// This value sets a fixed lower bound for when a token refresh should occur, regardless
2828
// of the token's total lifetime.
2929
//
@@ -239,7 +239,8 @@ type entraidTokenManager struct {
239239
}
240240

241241
func (e *entraidTokenManager) GetToken() (*Token, error) {
242-
if e.token != nil && e.token.expiresOn.Before(time.Now().Add(e.lowerBoundDuration)) {
242+
// check if the token is nil and if it is not expired
243+
if e.token != nil && e.token.expiresOn.After(time.Now()) && e.durationToRenewal() > e.lowerBoundDuration {
243244
// copy the token so the caller can't modify it
244245
return copyToken(e.token), nil
245246
}
@@ -256,6 +257,10 @@ func (e *entraidTokenManager) GetToken() (*Token, error) {
256257

257258
// copy the token so the caller can't modify it
258259
e.token = copyToken(token)
260+
261+
if e.token == nil {
262+
return nil, fmt.Errorf("failed to get token: token is nil")
263+
}
259264
return token, nil
260265
}
261266

@@ -276,12 +281,25 @@ func (e *entraidTokenManager) durationToRenewal() time.Duration {
276281
if e.token == nil {
277282
return 0
278283
}
284+
timeTillExpiration := time.Until(e.token.expiresOn)
285+
286+
// if lower bound has passed, do it NOW
287+
if timeTillExpiration <= e.lowerBoundDuration {
288+
return 0
289+
}
290+
279291
// Calculate the time to renew the token based on the expiration refresh ratio
280-
duration := time.Duration(float64(time.Until(e.token.expiresOn)) * e.expirationRefreshRatio)
281-
if duration < e.lowerBoundDuration {
282-
return e.lowerBoundDuration
292+
duration := time.Duration(float64(timeTillExpiration) * e.expirationRefreshRatio)
293+
if duration <= 0 {
294+
return 0
295+
}
296+
297+
// if the duration will take us past the lower bound, return the duration to lower bound
298+
if timeTillExpiration-e.lowerBoundDuration < duration {
299+
return timeTillExpiration - e.lowerBoundDuration
283300
}
284301

302+
// return the calculated duration
285303
return duration
286304
}
287305

@@ -309,13 +327,19 @@ func (e *entraidTokenManager) Start(listener TokenListener) (cancelFunc, error)
309327
go func(listener TokenListener) {
310328
// Simulate token refresh
311329
for {
330+
timeToRenewal := e.durationToRenewal()
312331
select {
313332
case <-e.closed:
314333
// Token manager is closed, stop the loop
315334
return
316-
case <-time.After(e.durationToRenewal()):
335+
case <-time.After(timeToRenewal):
317336
// Token is about to expire, refresh it
318337
delay := time.Duration(e.retryOptions.InitialDelayMs) * time.Millisecond
338+
// Token asked to be refreshed asap, but let's make sure we wait a bit
339+
if timeToRenewal == 0 {
340+
time.Sleep(delay)
341+
}
342+
319343
for i := 0; i < e.retryOptions.MaxAttempts; i++ {
320344
select {
321345
case <-e.closed:

token_manager_test.go

Lines changed: 166 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -308,7 +308,7 @@ func TestTokenManager_Start(t *testing.T) {
308308
go func() {
309309
defer wg.Done()
310310
time.Sleep(time.Duration(int64(rand.Intn(100)) * int64(time.Millisecond)))
311-
_, err = tokenManager.Start(listener)
311+
_, err := tokenManager.Start(listener)
312312
if err == nil {
313313
hasStarted += 1
314314
return
@@ -554,6 +554,7 @@ func TestEntraidTokenManager_GetToken(t *testing.T) {
554554
assert.NotNil(t, token)
555555

556556
})
557+
557558
t.Run("GetToken with parse error", func(t *testing.T) {
558559
idp := &mockIdentityProvider{}
559560
listener := &mockTokenListener{}
@@ -580,7 +581,93 @@ func TestEntraidTokenManager_GetToken(t *testing.T) {
580581
assert.Error(t, err)
581582
assert.Nil(t, cancel)
582583
assert.NotNil(t, tm.listener)
584+
})
585+
t.Run("GetToken with expired token", func(t *testing.T) {
586+
idp := &mockIdentityProvider{}
587+
tokenManager, err := NewTokenManager(idp,
588+
TokenManagerOptions{},
589+
)
590+
assert.NoError(t, err)
591+
592+
authResult := &public.AuthResult{
593+
ExpiresOn: time.Now().Add(-time.Hour).UTC(),
594+
}
595+
idpResponse, err := NewIDPResponse(ResponseTypeAuthResult,
596+
authResult)
597+
assert.NoError(t, err)
598+
assert.NotNil(t, tokenManager)
599+
tm, ok := tokenManager.(*entraidTokenManager)
600+
assert.True(t, ok)
601+
assert.Nil(t, tm.listener)
602+
603+
idp.On("RequestToken").Return(idpResponse, nil)
604+
605+
token, err := tokenManager.GetToken()
606+
assert.Error(t, err)
607+
assert.Nil(t, token)
608+
})
609+
610+
t.Run("GetToken with nil token", func(t *testing.T) {
611+
idp := &mockIdentityProvider{}
612+
tokenManager, err := NewTokenManager(idp,
613+
TokenManagerOptions{},
614+
)
615+
assert.NoError(t, err)
616+
assert.NotNil(t, tokenManager)
617+
_, ok := tokenManager.(*entraidTokenManager)
618+
assert.True(t, ok)
619+
620+
rawResponse, err := NewIDPResponse(ResponseTypeRawToken, "test")
621+
assert.NoError(t, err)
622+
623+
idp.On("RequestToken").Return(rawResponse, nil)
624+
625+
token, err := tokenManager.GetToken()
626+
assert.Error(t, err)
627+
assert.Nil(t, token)
628+
})
629+
630+
t.Run("GetToken with nil from parser", func(t *testing.T) {
631+
idp := &mockIdentityProvider{}
632+
mParser := &mockIdentityProviderResponseParser{}
633+
tokenManager, err := NewTokenManager(idp,
634+
TokenManagerOptions{
635+
IdentityProviderResponseParser: mParser,
636+
},
637+
)
638+
assert.NoError(t, err)
639+
assert.NotNil(t, tokenManager)
640+
_, ok := tokenManager.(*entraidTokenManager)
641+
assert.True(t, ok)
642+
643+
idpResponse, err := NewIDPResponse(ResponseTypeRawToken, "test")
644+
assert.NoError(t, err)
645+
idp.On("RequestToken").Return(idpResponse, nil)
646+
mParser.On("ParseResponse", idpResponse).Return(nil, nil)
647+
648+
token, err := tokenManager.GetToken()
649+
assert.Error(t, err)
650+
assert.Nil(t, token)
651+
})
583652

653+
t.Run("GetToken with idp error", func(t *testing.T) {
654+
idp := &mockIdentityProvider{}
655+
mParser := &mockIdentityProviderResponseParser{}
656+
tokenManager, err := NewTokenManager(idp,
657+
TokenManagerOptions{
658+
IdentityProviderResponseParser: mParser,
659+
},
660+
)
661+
assert.NoError(t, err)
662+
assert.NotNil(t, tokenManager)
663+
_, ok := tokenManager.(*entraidTokenManager)
664+
assert.True(t, ok)
665+
666+
idp.On("RequestToken").Return(nil, fmt.Errorf("idp error"))
667+
668+
token, err := tokenManager.GetToken()
669+
assert.Error(t, err)
670+
assert.Nil(t, token)
584671
})
585672
}
586673

@@ -591,7 +678,6 @@ func TestEntraidTokenManager_durationToRenewal(t *testing.T) {
591678
idp := &mockIdentityProvider{}
592679
tokenManager, err := NewTokenManager(idp, TokenManagerOptions{
593680
LowerRefreshBoundMs: 1000 * 60 * 60, // 1 hour
594-
595681
})
596682
assert.NoError(t, err)
597683
assert.NotNil(t, tokenManager)
@@ -610,15 +696,89 @@ func TestEntraidTokenManager_durationToRenewal(t *testing.T) {
610696
idpResponse, err := NewIDPResponse(ResponseTypeAuthResult,
611697
expiresSoon)
612698
assert.NoError(t, err)
613-
idp.On("RequestToken").Return(idpResponse, nil)
699+
idp.On("RequestToken").Return(idpResponse, nil).Once()
700+
tm.token = nil
701+
_, err = tm.GetToken()
702+
assert.NoError(t, err)
703+
assert.NotNil(t, tm.token)
704+
705+
// return zero, should happen now since it expires before the lower bound
706+
result = tm.durationToRenewal()
707+
assert.Equal(t, time.Duration(0), result)
708+
})
614709

710+
// get token that expires after the lower bound and expirationRefreshRatio to 1
711+
assert.NotPanics(t, func() {
712+
tm.expirationRefreshRatio = 1
713+
expiresAfterlb := &public.AuthResult{
714+
ExpiresOn: time.Now().Add(time.Duration(tm.lowerBoundDuration) + time.Hour).UTC(),
715+
}
716+
idpResponse, err := NewIDPResponse(ResponseTypeAuthResult,
717+
expiresAfterlb)
718+
assert.NoError(t, err)
719+
idp.On("RequestToken").Return(idpResponse, nil).Once()
720+
tm.token = nil
615721
_, err = tm.GetToken()
616722
assert.NoError(t, err)
617723
assert.NotNil(t, tm.token)
724+
725+
// return time to lower bound, if the returned time will be after the lower bound
726+
result = tm.durationToRenewal()
727+
assert.InDelta(t, time.Until(tm.token.expiresOn.Add(-1*tm.lowerBoundDuration)), result, float64(time.Second))
618728
})
619729

620-
// return the lower bound
621-
result = tm.durationToRenewal()
622-
assert.Equal(t, tm.lowerBoundDuration, result)
730+
})
731+
}
732+
733+
func TestEntraidTokenManager_Streaming(t *testing.T) {
734+
// write a test that will cover the goroutine in the Start
735+
t.Parallel()
736+
t.Run("Streaming", func(t *testing.T) {
737+
idp := &mockIdentityProvider{}
738+
listener := &mockTokenListener{}
739+
mParser := &mockIdentityProviderResponseParser{}
740+
tokenManager, err := NewTokenManager(idp,
741+
TokenManagerOptions{
742+
IdentityProviderResponseParser: mParser,
743+
},
744+
)
745+
assert.NoError(t, err)
746+
assert.NotNil(t, tokenManager)
747+
tm, ok := tokenManager.(*entraidTokenManager)
748+
assert.True(t, ok)
749+
assert.Nil(t, tm.listener)
750+
751+
expiresIn := 10 * time.Millisecond
752+
expiresOn := time.Now().Add(expiresIn).UTC()
753+
authResult := &public.AuthResult{
754+
ExpiresOn: expiresOn,
755+
}
756+
idpResponse, err := NewIDPResponse(ResponseTypeAuthResult,
757+
authResult)
758+
assert.NoError(t, err)
759+
760+
idp.On("RequestToken").Return(idpResponse, nil)
761+
token := NewToken(
762+
"test",
763+
"test",
764+
"test",
765+
expiresOn,
766+
time.Now(),
767+
int64(time.Until(expiresOn)),
768+
)
769+
770+
mParser.On("ParseResponse", idpResponse).Return(token, nil)
771+
listener.On("OnTokenNext", token).Return()
772+
773+
cancel, err := tokenManager.Start(listener)
774+
assert.NotNil(t, cancel)
775+
assert.NoError(t, err)
776+
assert.NotNil(t, tm.listener)
777+
778+
toRenewal := tm.durationToRenewal()
779+
assert.NotEqual(t, 0, toRenewal)
780+
assert.NotEqual(t, expiresIn, toRenewal)
781+
assert.True(t, expiresIn > toRenewal)
782+
// should fail on mocks
623783
})
624784
}

0 commit comments

Comments
 (0)