Skip to content

Commit 676cf1c

Browse files
committed
fix(manager): starting and stopping the manager
Starting and stopping the manager can be executed multiple times
1 parent 7755fb9 commit 676cf1c

File tree

3 files changed

+104
-40
lines changed

3 files changed

+104
-40
lines changed

credentials_provider.go

Lines changed: 22 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,8 @@ type entraidCredentialsProvider struct {
2727

2828
// rwLock is a mutex that is used to synchronize access to the listeners slice.
2929
rwLock sync.RWMutex // Mutex for synchronizing access to the listeners slice.
30+
31+
tmLock sync.Mutex
3032
}
3133

3234
// onTokenNext is a method that is called when the token manager receives a new token.
@@ -65,11 +67,25 @@ func (e *entraidCredentialsProvider) onTokenError(err error) {
6567
//
6668
// Note: If the listener is already subscribed, it will not receive duplicate notifications.
6769
func (e *entraidCredentialsProvider) Subscribe(listener auth.CredentialsListener) (auth.Credentials, auth.UnsubscribeFunc, error) {
68-
// First try to get a token, only then subscribe the listener.
69-
token, err := e.tokenManager.GetToken(false)
70-
if err != nil {
71-
return nil, nil, err
70+
var token *token.Token
71+
// check if the manager is working
72+
// If the stopTokenManager is nil, the token manager is not started.
73+
e.tmLock.Lock()
74+
if e.stopTokenManager == nil {
75+
t, stopTM, err := e.tokenManager.Start(tokenListenerFromCP(e))
76+
if err != nil {
77+
return nil, nil, fmt.Errorf("couldn't start token manager: %w", err)
78+
}
79+
e.stopTokenManager = stopTM
80+
token = t
81+
} else {
82+
t, err := e.tokenManager.GetToken(false)
83+
if err != nil {
84+
return nil, nil, fmt.Errorf("couldn't get token: %w", err)
85+
}
86+
token = t
7287
}
88+
e.tmLock.Unlock()
7389

7490
e.rwLock.Lock()
7591
// Check if the listener is already in the list of listeners.
@@ -102,6 +118,7 @@ func (e *entraidCredentialsProvider) Subscribe(listener auth.CredentialsListener
102118
// Clear the listeners slice if it's empty
103119
if len(e.listeners) == 0 {
104120
e.listeners = make([]auth.CredentialsListener, 0)
121+
e.tmLock.Lock()
105122
if e.stopTokenManager != nil {
106123
err := e.stopTokenManager()
107124
if err != nil {
@@ -111,6 +128,7 @@ func (e *entraidCredentialsProvider) Subscribe(listener auth.CredentialsListener
111128
// This prevents multiple calls to stopTokenManager.
112129
e.stopTokenManager = nil
113130
}
131+
e.tmLock.Unlock()
114132
}
115133
return nil
116134
}
@@ -134,10 +152,5 @@ func NewCredentialsProvider(tokenManager manager.TokenManager, options Credentia
134152
options: options,
135153
listeners: make([]auth.CredentialsListener, 0),
136154
}
137-
stopTM, err := cp.tokenManager.Start(tokenListenerFromCP(cp))
138-
if err != nil {
139-
return nil, fmt.Errorf("couldn't start token manager: %w", err)
140-
}
141-
cp.stopTokenManager = stopTM
142155
return cp, nil
143156
}

manager/token_manager.go

Lines changed: 20 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -82,7 +82,7 @@ type TokenManager interface {
8282
// It takes a boolean value forceRefresh as an argument.
8383
GetToken(forceRefresh bool) (*token.Token, error)
8484
// Start starts the token manager and returns a channel that will receive updates.
85-
Start(listener TokenListener) (StopFunc, error)
85+
Start(listener TokenListener) (*token.Token, StopFunc, error)
8686
}
8787

8888
// StopFunc is a function that stops the token manager.
@@ -239,11 +239,11 @@ func (e *entraidTokenManager) GetToken(forceRefresh bool) (*token.Token, error)
239239
//
240240
// Note: The initial token is delivered synchronously.
241241
// The TokenListener will receive the token immediately, before the token manager goroutine starts.
242-
func (e *entraidTokenManager) Start(listener TokenListener) (StopFunc, error) {
242+
func (e *entraidTokenManager) Start(listener TokenListener) (*token.Token, StopFunc, error) {
243243
e.lock.Lock()
244244
defer e.lock.Unlock()
245245
if e.listener != nil {
246-
return nil, ErrTokenManagerAlreadyStarted
246+
return nil, nil, ErrTokenManagerAlreadyStarted
247247
}
248248

249249
if e.closedChan != nil && !internal.IsClosed(e.closedChan) {
@@ -252,15 +252,25 @@ func (e *entraidTokenManager) Start(listener TokenListener) (StopFunc, error) {
252252
close(e.closedChan)
253253
}
254254

255-
t, err := e.GetToken(true)
255+
ctx, ctxCancel := context.WithCancel(context.Background())
256+
e.ctx = ctx
257+
e.ctxCancel = ctxCancel
258+
259+
t, err := e.GetToken(false)
260+
// If a token was found in the cache, check if:
261+
// - it is expired (based on the lower bound)
262+
// - it is about to expire (based on the expiration refresh ratio)
263+
// if so, get a new token
264+
expirationRefreshTime := t.ReceivedAt().Add(time.Duration(float64(t.TTL()) * float64(time.Second) * e.expirationRefreshRatio))
265+
expirationWithoutLowerBound := t.ExpirationOn().Add(-1 * e.lowerBoundDuration)
266+
now := time.Now()
267+
if t != nil && (expirationWithoutLowerBound.Before(now) || expirationRefreshTime.Before(now)) {
268+
t, err = e.GetToken(true)
269+
}
256270
if err != nil {
257-
go listener.OnError(err)
258-
return nil, fmt.Errorf("failed to start token manager: %w", err)
271+
return nil, nil, fmt.Errorf("failed to start token manager: %w", err)
259272
}
260273

261-
// Deliver initial token synchronously
262-
listener.OnNext(t)
263-
264274
e.closedChan = make(chan struct{})
265275
e.listener = listener
266276

@@ -325,7 +335,7 @@ func (e *entraidTokenManager) Start(listener TokenListener) (StopFunc, error) {
325335
}
326336
}(listener, e.closedChan)
327337

328-
return e.stop, nil
338+
return t, e.stop, nil
329339
}
330340

331341
// stop closes the token manager and releases any resources.

manager/token_manager_test.go

Lines changed: 62 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -182,7 +182,7 @@ func TestTokenManager_Close(t *testing.T) {
182182

183183
var stopper StopFunc
184184
assert.NotPanics(t, func() {
185-
stopper, err = tokenManager.Start(listener)
185+
_, stopper, err = tokenManager.Start(listener)
186186
assert.NotNil(t, stopper)
187187
assert.NoError(t, err)
188188
})
@@ -222,7 +222,7 @@ func TestTokenManager_Close(t *testing.T) {
222222
listener.On("OnNext", testTokenValid).Return()
223223

224224
assert.NotPanics(t, func() {
225-
cancel, err := tokenManager.Start(listener)
225+
_, cancel, err := tokenManager.Start(listener)
226226
assert.NotNil(t, cancel)
227227
assert.NoError(t, err)
228228
assert.NotNil(t, tm.listener)
@@ -258,7 +258,7 @@ func TestTokenManager_Close(t *testing.T) {
258258
listener.On("OnNext", testTokenValid).Return()
259259

260260
assert.NotPanics(t, func() {
261-
stopper, err := tokenManager.Start(listener)
261+
_, stopper, err := tokenManager.Start(listener)
262262
assert.NotNil(t, stopper)
263263
assert.NoError(t, err)
264264
assert.NotNil(t, tm.listener)
@@ -329,7 +329,7 @@ func TestTokenManager_Start(t *testing.T) {
329329
go func() {
330330
defer wg.Done()
331331
time.Sleep(time.Duration(int64(rand.Intn(100)) * int64(time.Millisecond)))
332-
_, err := tokenManager.Start(listener)
332+
_, _, err := tokenManager.Start(listener)
333333
if err == nil {
334334
hasStarted += 1
335335
return
@@ -344,7 +344,7 @@ func TestTokenManager_Start(t *testing.T) {
344344
assert.NotNil(t, tm.listener)
345345
assert.Equal(t, 1, hasStarted)
346346
assert.Equal(t, int32(numExecutions-1), atomic.LoadInt32(&alreadyStarted))
347-
cancel, err := tokenManager.Start(listener)
347+
_, cancel, err := tokenManager.Start(listener)
348348
assert.Nil(t, cancel)
349349
assert.Error(t, err)
350350
assert.NotNil(t, tm.listener)
@@ -389,7 +389,7 @@ func TestTokenManager_Start(t *testing.T) {
389389
} else {
390390
l := &mockTokenListener{Id: num}
391391
l.On("OnNext", testTokenValid).Return()
392-
_, err = tokenManager.Start(l)
392+
_, _, err = tokenManager.Start(l)
393393
}
394394
if err != nil {
395395
if err != ErrTokenManagerAlreadyStopped && err != ErrTokenManagerAlreadyStarted {
@@ -412,7 +412,7 @@ func TestTokenManager_Start(t *testing.T) {
412412
log.Printf("FAILING WITH lastExecution[STOPPED]: %d", lastExecution)
413413
}
414414
assert.NotNil(t, tm.listener)
415-
stopper, err := tokenManager.Start(listener)
415+
_, stopper, err := tokenManager.Start(listener)
416416
assert.Nil(t, stopper)
417417
assert.Error(t, err)
418418
// Stop the token manager with internal stop, since stopper should be nil
@@ -588,7 +588,8 @@ func TestEntraidTokenManager_GetToken(t *testing.T) {
588588
mParser.On("ParseResponse", rawResponse).Return(testTokenValid, nil)
589589
listener.On("OnNext", testTokenValid).Return()
590590

591-
cancel, err := tokenManager.Start(listener)
591+
initialToken, cancel, err := tokenManager.Start(listener)
592+
assert.NotNil(t, initialToken)
592593
assert.NotNil(t, cancel)
593594
assert.NoError(t, err)
594595
assert.NotNil(t, tm.listener)
@@ -598,6 +599,46 @@ func TestEntraidTokenManager_GetToken(t *testing.T) {
598599
assert.NotNil(t, token1)
599600
})
600601

602+
t.Run("GetToken with cached token", func(t *testing.T) {
603+
t.Parallel()
604+
idp := &mockIdentityProvider{}
605+
listener := &mockTokenListener{}
606+
mParser := &mockIdentityProviderResponseParser{}
607+
tokenManager, err := NewTokenManager(idp,
608+
TokenManagerOptions{
609+
IdentityProviderResponseParser: mParser,
610+
},
611+
)
612+
assert.NoError(t, err)
613+
assert.NotNil(t, tokenManager)
614+
tm, ok := tokenManager.(*entraidTokenManager)
615+
assert.True(t, ok)
616+
assert.Nil(t, tm.listener)
617+
618+
rawResponse := &authResult{
619+
ResultType: shared.ResponseTypeRawToken,
620+
RawTokenVal: "test",
621+
}
622+
623+
idp.On("RequestToken", mock.Anything).Return(rawResponse, nil)
624+
mParser.On("ParseResponse", rawResponse).Return(testTokenValid, nil)
625+
listener.On("OnNext", testTokenValid).Return()
626+
627+
initialToken, cancel, err := tokenManager.Start(listener)
628+
assert.NotNil(t, initialToken)
629+
assert.NotNil(t, cancel)
630+
assert.NoError(t, err)
631+
assert.NotNil(t, tm.listener)
632+
633+
token1, err := tokenManager.GetToken(false)
634+
assert.NoError(t, err)
635+
assert.NotNil(t, token1)
636+
637+
token2, err := tokenManager.GetToken(false)
638+
assert.NoError(t, err)
639+
assert.Equal(t, token1, token2)
640+
})
641+
601642
t.Run("GetToken with parse error", func(t *testing.T) {
602643
t.Parallel()
603644
idp := &mockIdentityProvider{}
@@ -623,7 +664,7 @@ func TestEntraidTokenManager_GetToken(t *testing.T) {
623664
mParser.On("ParseResponse", rawResponse).Return(nil, fmt.Errorf("parse error"))
624665
listener.On("OnError", mock.Anything).Return()
625666

626-
cancel, err := tokenManager.Start(listener)
667+
_, cancel, err := tokenManager.Start(listener)
627668
assert.Error(t, err)
628669
assert.Nil(t, cancel)
629670
assert.Nil(t, tm.listener)
@@ -814,7 +855,7 @@ func TestEntraidTokenManager_Streaming(t *testing.T) {
814855
mParser.On("ParseResponse", idpResponse).Return(token1, nil).Once()
815856
listener.On("OnNext", token1).Return().Once()
816857

817-
stopper, err := tokenManager.Start(listener)
858+
_, stopper, err := tokenManager.Start(listener)
818859
assert.NotNil(t, stopper)
819860
assert.NoError(t, err)
820861
assert.NotNil(t, tm.listener)
@@ -881,7 +922,7 @@ func TestEntraidTokenManager_Streaming(t *testing.T) {
881922

882923
listener.On("OnNext", mock.AnythingOfType("*token.Token")).Return()
883924

884-
cancel, err := tokenManager.Start(listener)
925+
_, cancel, err := tokenManager.Start(listener)
885926
assert.NotNil(t, cancel)
886927
assert.NoError(t, err)
887928
assert.NotNil(t, tm.listener)
@@ -938,7 +979,7 @@ func TestEntraidTokenManager_Streaming(t *testing.T) {
938979

939980
listener.On("OnNext", mock.AnythingOfType("*token.Token")).Return()
940981

941-
cancel, err := tokenManager.Start(listener)
982+
_, cancel, err := tokenManager.Start(listener)
942983
assert.NotNil(t, cancel)
943984
assert.NoError(t, err)
944985
assert.NotNil(t, tm.listener)
@@ -991,7 +1032,7 @@ func TestEntraidTokenManager_Streaming(t *testing.T) {
9911032

9921033
listener.On("OnNext", mock.AnythingOfType("*token.Token")).Return()
9931034

994-
cancel, err := tokenManager.Start(listener)
1035+
_, cancel, err := tokenManager.Start(listener)
9951036
assert.NotNil(t, cancel)
9961037
assert.NoError(t, err)
9971038
assert.NotNil(t, tm.listener)
@@ -1041,7 +1082,7 @@ func TestEntraidTokenManager_Streaming(t *testing.T) {
10411082
assert.NotNil(t, err)
10421083
}).Return().Maybe()
10431084

1044-
cancel, err := tokenManager.Start(listener)
1085+
_, cancel, err := tokenManager.Start(listener)
10451086
assert.NotNil(t, cancel)
10461087
assert.NoError(t, err)
10471088
assert.NotNil(t, tm.listener)
@@ -1095,7 +1136,7 @@ func TestEntraidTokenManager_Streaming(t *testing.T) {
10951136
assert.NotNil(t, err)
10961137
}).Return()
10971138

1098-
cancel, err := tokenManager.Start(listener)
1139+
_, cancel, err := tokenManager.Start(listener)
10991140
assert.NotNil(t, cancel)
11001141
assert.NoError(t, err)
11011142
assert.NotNil(t, tm.listener)
@@ -1174,7 +1215,7 @@ func TestEntraidTokenManager_Streaming(t *testing.T) {
11741215
close(maxAttemptsReached)
11751216
}).Return()
11761217

1177-
cancel, err := tokenManager.Start(listener)
1218+
_, cancel, err := tokenManager.Start(listener)
11781219
assert.NotNil(t, cancel)
11791220
assert.NoError(t, err)
11801221
assert.NotNil(t, tm.listener)
@@ -1248,7 +1289,7 @@ func TestEntraidTokenManager_Streaming(t *testing.T) {
12481289
close(maxAttemptsReached)
12491290
}).Return().Maybe()
12501291

1251-
cancel, err := tokenManager.Start(listener)
1292+
_, cancel, err := tokenManager.Start(listener)
12521293
assert.NotNil(t, cancel)
12531294
assert.NoError(t, err)
12541295
assert.NotNil(t, tm.listener)
@@ -1338,7 +1379,7 @@ func BenchmarkTokenManager_Start(b *testing.B) {
13381379

13391380
b.ResetTimer()
13401381
for i := 0; i < b.N; i++ {
1341-
_, _ = tokenManager.Start(listener)
1382+
_, _, _ = tokenManager.Start(listener)
13421383
}
13431384
}
13441385

@@ -1364,7 +1405,7 @@ func BenchmarkTokenManager_Close(b *testing.B) {
13641405
mParser.On("ParseResponse", rawResponse).Return(testTokenValid, nil)
13651406
listener.On("OnNext", testTokenValid).Return()
13661407

1367-
stopper, err := tokenManager.Start(listener)
1408+
_, stopper, err := tokenManager.Start(listener)
13681409
if err != nil {
13691410
b.Fatal(err)
13701411
}
@@ -1477,7 +1518,7 @@ func TestConcurrentTokenManagerOperations(t *testing.T) {
14771518
case 0:
14781519
// Start the token manager with a new listener
14791520
// t.Logf("Goroutine %d, Operation %d: Attempting to start token manager", routineID, j)
1480-
closeFunc, err := tm.Start(listener)
1521+
_, closeFunc, err := tm.Start(listener)
14811522

14821523
if err != nil {
14831524
if err != ErrTokenManagerAlreadyStarted {
@@ -1655,7 +1696,7 @@ func TestConcurrentTokenManagerOperations(t *testing.T) {
16551696
},
16561697
}
16571698

1658-
closeFunc, err := tm.Start(finalListener)
1699+
_, closeFunc, err := tm.Start(finalListener)
16591700
if err != nil && err != ErrTokenManagerAlreadyStarted {
16601701
t.Fatalf("Failed to start token manager after concurrent operations: %v", err)
16611702
}

0 commit comments

Comments
 (0)