Skip to content

Commit 3555f62

Browse files
committed
safer concurrent implementation of Start
1 parent 9d9d71a commit 3555f62

File tree

1 file changed

+22
-16
lines changed

1 file changed

+22
-16
lines changed

manager/token_manager.go

Lines changed: 22 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -179,13 +179,21 @@ func (e *entraidTokenManager) GetToken(forceRefresh bool) (*token.Token, error)
179179
e.tokenRWLock.RLock()
180180
// check if the token is nil and if it is not expired
181181
if !forceRefresh && e.token != nil && time.Now().Add(e.lowerBoundDuration).Before(e.token.ExpirationOn()) {
182-
t := e.token.Copy()
182+
t := e.token
183183
e.tokenRWLock.RUnlock()
184-
// copy the token so the caller can't modify it
185184
return t, nil
186185
}
187186
e.tokenRWLock.RUnlock()
188187

188+
// Upgrade to write lock for token update
189+
e.tokenRWLock.Lock()
190+
defer e.tokenRWLock.Unlock()
191+
192+
// Double-check pattern to avoid unnecessary token refresh
193+
if !forceRefresh && e.token != nil && time.Now().Add(e.lowerBoundDuration).Before(e.token.ExpirationOn()) {
194+
return e.token, nil
195+
}
196+
189197
idpResult, err := e.idp.RequestToken()
190198
if err != nil {
191199
return nil, fmt.Errorf("failed to request token from idp: %w", err)
@@ -199,18 +207,20 @@ func (e *entraidTokenManager) GetToken(forceRefresh bool) (*token.Token, error)
199207
if t == nil {
200208
return nil, fmt.Errorf("failed to get token: token is nil")
201209
}
202-
e.tokenRWLock.Lock()
203-
// copy the token so the caller can't modify it
204-
e.token = t.Copy()
205-
e.tokenRWLock.Unlock()
206210

211+
// Store the token
212+
e.token = t
213+
// Return the token - no need to copy since it's immutable
207214
return t, nil
208215
}
209216

210217
// Start starts the token manager and returns cancelFunc to stop the token manager.
211218
// It takes a TokenListener as an argument, which is used to receive updates.
212219
// The token manager will call the listener's OnTokenNext method with the updated token.
213220
// If an error occurs, the token manager will call the listener's OnError method with the error.
221+
//
222+
// Note: The initial token is delivered synchronously.
223+
// The TokenListener will receive the token immediately, before the token manager goroutine starts.
214224
func (e *entraidTokenManager) Start(listener TokenListener) (CancelFunc, error) {
215225
e.lock.Lock()
216226
defer e.lock.Unlock()
@@ -230,33 +240,32 @@ func (e *entraidTokenManager) Start(listener TokenListener) (CancelFunc, error)
230240
return nil, fmt.Errorf("failed to start token manager: %w", err)
231241
}
232242

233-
go listener.OnTokenNext(t)
243+
// Deliver initial token synchronously
244+
listener.OnTokenNext(t)
234245

235246
e.closedChan = make(chan struct{})
236247
e.listener = listener
237248

238249
go func(listener TokenListener, closed <-chan struct{}) {
239250
maxDelay := time.Duration(e.retryOptions.MaxDelayMs) * time.Millisecond
240251
initialDelay := time.Duration(e.retryOptions.InitialDelayMs) * time.Millisecond
252+
241253
for {
242254
timeToRenewal := e.durationToRenewal()
243255
select {
244256
case <-closed:
245-
// Token manager is closed, stop the loop
246-
// TODO(ndyakov): Discuss if we should call OnTokenError here
247257
return
248258
case <-time.After(timeToRenewal):
249259
if timeToRenewal == 0 {
250260
// Token was requested immediately, guard against infinite loop
251261
select {
252262
case <-closed:
253-
// Token manager is closed, stop the loop
254-
// TODO(ndyakov): Discuss if we should call OnTokenError here
255263
return
256264
case <-time.After(initialDelay):
257265
// continue to attempt
258266
}
259267
}
268+
260269
// Token is about to expire, refresh it
261270
delay := initialDelay
262271
for i := 0; i < e.retryOptions.MaxAttempts; i++ {
@@ -265,28 +274,25 @@ func (e *entraidTokenManager) Start(listener TokenListener) (CancelFunc, error)
265274
listener.OnTokenNext(t)
266275
break
267276
}
277+
268278
// check if err is retriable
269279
if e.retryOptions.IsRetryable(err) {
270-
// retriable error, continue to next attempt
271-
// Exponential backoff
272280
if i == e.retryOptions.MaxAttempts-1 {
273281
// last attempt, call OnTokenError
274282
listener.OnTokenError(fmt.Errorf("max attempts reached: %w", err))
275283
return
276284
}
277285

286+
// Exponential backoff
278287
if delay < maxDelay {
279288
delay = time.Duration(float64(delay) * e.retryOptions.BackoffMultiplier)
280289
}
281-
282290
if delay > maxDelay {
283291
delay = maxDelay
284292
}
285293

286294
select {
287295
case <-closed:
288-
// Token manager is closed, stop the loop
289-
// TODO(ndyakov): Discuss if we should call OnTokenError here
290296
return
291297
case <-time.After(delay):
292298
// continue to next attempt

0 commit comments

Comments
 (0)