Skip to content

Commit 6d7a9dd

Browse files
committed
improve tests and implementation
1 parent 40a87e0 commit 6d7a9dd

File tree

2 files changed

+58
-29
lines changed

2 files changed

+58
-29
lines changed

token_manager.go

Lines changed: 20 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -321,6 +321,8 @@ func (e *entraidTokenManager) Start(listener TokenListener) (cancelFunc, error)
321321
go listener.OnTokenNext(token)
322322

323323
go func(listener TokenListener) {
324+
maxDelay := time.Duration(e.retryOptions.MaxDelayMs) * time.Millisecond
325+
initialDelay := time.Duration(e.retryOptions.InitialDelayMs) * time.Millisecond
324326
// Simulate token refresh
325327
for {
326328
timeToRenewal := e.durationToRenewal()
@@ -330,22 +332,20 @@ func (e *entraidTokenManager) Start(listener TokenListener) (cancelFunc, error)
330332
// TODO(ndyakov): Discuss if we should call OnTokenError here
331333
return
332334
case <-time.After(timeToRenewal):
333-
// Token is about to expire, refresh it
334-
delay := time.Duration(e.retryOptions.InitialDelayMs) * time.Millisecond
335-
// Token asked to be refreshed asap, but let's make sure we wait a bit
336335
if timeToRenewal == 0 {
337-
time.Sleep(delay)
338-
}
339-
340-
for i := 0; i < e.retryOptions.MaxAttempts; i++ {
336+
// Token was requested immediately, guard against infinite loop
341337
select {
342338
case <-e.closed:
343339
// Token manager is closed, stop the loop
344340
// TODO(ndyakov): Discuss if we should call OnTokenError here
345341
return
346-
default:
347-
// continue to next attempt
342+
case <-time.After(initialDelay):
343+
// continue to attempt
348344
}
345+
}
346+
// Token is about to expire, refresh it
347+
delay := initialDelay
348+
for i := 0; i < e.retryOptions.MaxAttempts; i++ {
349349
token, err := e.GetToken(true)
350350
if err == nil {
351351
listener.OnTokenNext(token)
@@ -361,16 +361,22 @@ func (e *entraidTokenManager) Start(listener TokenListener) (cancelFunc, error)
361361
return
362362
}
363363

364-
if delay < time.Duration(e.retryOptions.MaxDelayMs)*time.Millisecond {
364+
if delay < maxDelay {
365365
delay = time.Duration(float64(delay) * e.retryOptions.BackoffMultiplier)
366366
}
367367

368-
time.Sleep(delay)
368+
if delay > maxDelay {
369+
delay = maxDelay
370+
}
369371

370-
if delay > time.Duration(e.retryOptions.MaxDelayMs)*time.Millisecond {
371-
delay = time.Duration(e.retryOptions.MaxDelayMs) * time.Millisecond
372+
select {
373+
case <-e.closed:
374+
// Token manager is closed, stop the loop
375+
// TODO(ndyakov): Discuss if we should call OnTokenError here
376+
return
377+
case <-time.After(delay):
378+
// continue to next attempt
372379
}
373-
continue
374380
} else {
375381
// not retriable
376382
listener.OnTokenError(err)

token_manager_test.go

Lines changed: 38 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -714,7 +714,7 @@ func TestEntraidTokenManager_durationToRenewal(t *testing.T) {
714714

715715
// return time to lower bound, if the returned time will be after the lower bound
716716
result = tm.durationToRenewal()
717-
assert.InDelta(t, time.Until(tm.token.expiresOn.Add(-1*tm.lowerBoundDuration)), result, float64(time.Second))
717+
assert.InEpsilon(t, time.Until(tm.token.expiresOn.Add(-1*tm.lowerBoundDuration)), result, float64(time.Second))
718718
})
719719

720720
})
@@ -997,11 +997,21 @@ func TestEntraidTokenManager_Streaming(t *testing.T) {
997997
mock.AssertExpectationsForObjects(t, idp, listener)
998998
})
999999

1000-
t.Run("Start and Listen with retriable error - max retries", func(t *testing.T) {
1000+
t.Run("Start and Listen with retriable error - max retries and max delay", func(t *testing.T) {
10011001
idp := &mockIdentityProvider{}
10021002
listener := &mockTokenListener{}
1003+
maxAttempts := 3
1004+
maxDelayMs := 500
1005+
initialDelayMs := 100
10031006
tokenManager, err := NewTokenManager(idp,
1004-
TokenManagerOptions{},
1007+
TokenManagerOptions{
1008+
RetryOptions: RetryOptions{
1009+
MaxAttempts: maxAttempts,
1010+
MaxDelayMs: maxDelayMs,
1011+
InitialDelayMs: initialDelayMs,
1012+
BackoffMultiplier: 10,
1013+
},
1014+
},
10051015
)
10061016
assert.NoError(t, err)
10071017
assert.NotNil(t, tokenManager)
@@ -1028,11 +1038,19 @@ func TestEntraidTokenManager_Streaming(t *testing.T) {
10281038
response := idpResponse.(*authResult)
10291039
response.authResult = res
10301040
}).Return(idpResponse, nil)
1031-
1032-
listener.On("OnTokenNext", mock.AnythingOfType("*entraid.Token")).Return()
1041+
var start, end time.Time
1042+
var elapsed time.Duration
1043+
1044+
_ = listener.
1045+
On("OnTokenNext", mock.AnythingOfType("*entraid.Token")).
1046+
Run(func(_ mock.Arguments) {
1047+
start = time.Now()
1048+
}).Return()
10331049
maxAttemptsReached := make(chan struct{})
10341050
listener.On("OnTokenError", mock.Anything).Run(func(args mock.Arguments) {
10351051
err := args.Get(0).(error)
1052+
end = time.Now()
1053+
elapsed = end.Sub(start)
10361054
assert.NotNil(t, err)
10371055
assert.ErrorContains(t, err, "max attempts reached")
10381056
close(maxAttemptsReached)
@@ -1042,29 +1060,34 @@ func TestEntraidTokenManager_Streaming(t *testing.T) {
10421060
assert.NotNil(t, cancel)
10431061
assert.NoError(t, err)
10441062
assert.NotNil(t, tm.listener)
1045-
1046-
noErrCall.Unset()
1047-
returnErr := newMockError(true)
1048-
idp.On("RequestToken").Return(nil, returnErr)
1049-
10501063
toRenewal := tm.durationToRenewal()
10511064
assert.NotEqual(t, time.Duration(0), toRenewal)
10521065
assert.NotEqual(t, expiresIn, toRenewal)
10531066
assert.True(t, expiresIn > toRenewal)
10541067

1068+
noErrCall.Unset()
1069+
returnErr := newMockError(true)
1070+
1071+
idp.On("RequestToken").Return(nil, returnErr)
1072+
10551073
select {
1056-
case <-time.After(toRenewal + time.Duration(tm.retryOptions.MaxAttempts*tm.retryOptions.MaxDelayMs)*time.Millisecond):
1057-
assert.Fail(t, "Timeout - max retries not reached ")
1074+
case <-time.After(toRenewal + time.Duration(maxAttempts*maxDelayMs)*time.Millisecond):
1075+
assert.Fail(t, "Timeout - max retries not reached")
10581076
case <-maxAttemptsReached:
10591077
}
10601078

1061-
// maxAttempts + the initial one
1079+
// initialRenewal window, maxAttempts - 1 * max delay + the initial one which was lower than max delay
1080+
allDelaysShouldBe := toRenewal
1081+
allDelaysShouldBe += time.Duration(initialDelayMs) * time.Millisecond
1082+
allDelaysShouldBe += time.Duration(maxAttempts-1) * time.Duration(maxDelayMs) * time.Millisecond
1083+
1084+
assert.InEpsilon(t, elapsed, allDelaysShouldBe, float64(10*time.Millisecond))
1085+
10621086
idp.AssertNumberOfCalls(t, "RequestToken", tm.retryOptions.MaxAttempts+1)
10631087
listener.AssertNumberOfCalls(t, "OnTokenNext", 1)
10641088
listener.AssertNumberOfCalls(t, "OnTokenError", 1)
10651089
mock.AssertExpectationsForObjects(t, idp, listener)
10661090
})
1067-
10681091
t.Run("Start and Listen and close during retries", func(t *testing.T) {
10691092
idp := &mockIdentityProvider{}
10701093
listener := &mockTokenListener{}
@@ -1124,7 +1147,7 @@ func TestEntraidTokenManager_Streaming(t *testing.T) {
11241147
assert.NotEqual(t, expiresIn, toRenewal)
11251148
assert.True(t, expiresIn > toRenewal)
11261149

1127-
<-time.After(toRenewal + 50*time.Millisecond)
1150+
<-time.After(toRenewal + 500*time.Millisecond)
11281151
assert.Nil(t, cancel())
11291152

11301153
select {

0 commit comments

Comments
 (0)