@@ -179,13 +179,21 @@ func (e *entraidTokenManager) GetToken(forceRefresh bool) (*token.Token, error)
179
179
e .tokenRWLock .RLock ()
180
180
// check if the token is nil and if it is not expired
181
181
if ! forceRefresh && e .token != nil && time .Now ().Add (e .lowerBoundDuration ).Before (e .token .ExpirationOn ()) {
182
- t := e .token . Copy ()
182
+ t := e .token
183
183
e .tokenRWLock .RUnlock ()
184
- // copy the token so the caller can't modify it
185
184
return t , nil
186
185
}
187
186
e .tokenRWLock .RUnlock ()
188
187
188
+ // Upgrade to write lock for token update
189
+ e .tokenRWLock .Lock ()
190
+ defer e .tokenRWLock .Unlock ()
191
+
192
+ // Double-check pattern to avoid unnecessary token refresh
193
+ if ! forceRefresh && e .token != nil && time .Now ().Add (e .lowerBoundDuration ).Before (e .token .ExpirationOn ()) {
194
+ return e .token , nil
195
+ }
196
+
189
197
idpResult , err := e .idp .RequestToken ()
190
198
if err != nil {
191
199
return nil , fmt .Errorf ("failed to request token from idp: %w" , err )
@@ -199,18 +207,20 @@ func (e *entraidTokenManager) GetToken(forceRefresh bool) (*token.Token, error)
199
207
if t == nil {
200
208
return nil , fmt .Errorf ("failed to get token: token is nil" )
201
209
}
202
- e .tokenRWLock .Lock ()
203
- // copy the token so the caller can't modify it
204
- e .token = t .Copy ()
205
- e .tokenRWLock .Unlock ()
206
210
211
+ // Store the token
212
+ e .token = t
213
+ // Return the token - no need to copy since it's immutable
207
214
return t , nil
208
215
}
209
216
210
217
// Start starts the token manager and returns cancelFunc to stop the token manager.
211
218
// It takes a TokenListener as an argument, which is used to receive updates.
212
219
// The token manager will call the listener's OnTokenNext method with the updated token.
213
220
// If an error occurs, the token manager will call the listener's OnError method with the error.
221
+ //
222
+ // Note: The initial token is delivered synchronously.
223
+ // The TokenListener will receive the token immediately, before the token manager goroutine starts.
214
224
func (e * entraidTokenManager ) Start (listener TokenListener ) (CancelFunc , error ) {
215
225
e .lock .Lock ()
216
226
defer e .lock .Unlock ()
@@ -230,33 +240,32 @@ func (e *entraidTokenManager) Start(listener TokenListener) (CancelFunc, error)
230
240
return nil , fmt .Errorf ("failed to start token manager: %w" , err )
231
241
}
232
242
233
- go listener .OnTokenNext (t )
243
+ // Deliver initial token synchronously
244
+ listener .OnTokenNext (t )
234
245
235
246
e .closedChan = make (chan struct {})
236
247
e .listener = listener
237
248
238
249
go func (listener TokenListener , closed <- chan struct {}) {
239
250
maxDelay := time .Duration (e .retryOptions .MaxDelayMs ) * time .Millisecond
240
251
initialDelay := time .Duration (e .retryOptions .InitialDelayMs ) * time .Millisecond
252
+
241
253
for {
242
254
timeToRenewal := e .durationToRenewal ()
243
255
select {
244
256
case <- closed :
245
- // Token manager is closed, stop the loop
246
- // TODO(ndyakov): Discuss if we should call OnTokenError here
247
257
return
248
258
case <- time .After (timeToRenewal ):
249
259
if timeToRenewal == 0 {
250
260
// Token was requested immediately, guard against infinite loop
251
261
select {
252
262
case <- closed :
253
- // Token manager is closed, stop the loop
254
- // TODO(ndyakov): Discuss if we should call OnTokenError here
255
263
return
256
264
case <- time .After (initialDelay ):
257
265
// continue to attempt
258
266
}
259
267
}
268
+
260
269
// Token is about to expire, refresh it
261
270
delay := initialDelay
262
271
for i := 0 ; i < e .retryOptions .MaxAttempts ; i ++ {
@@ -265,28 +274,25 @@ func (e *entraidTokenManager) Start(listener TokenListener) (CancelFunc, error)
265
274
listener .OnTokenNext (t )
266
275
break
267
276
}
277
+
268
278
// check if err is retriable
269
279
if e .retryOptions .IsRetryable (err ) {
270
- // retriable error, continue to next attempt
271
- // Exponential backoff
272
280
if i == e .retryOptions .MaxAttempts - 1 {
273
281
// last attempt, call OnTokenError
274
282
listener .OnTokenError (fmt .Errorf ("max attempts reached: %w" , err ))
275
283
return
276
284
}
277
285
286
+ // Exponential backoff
278
287
if delay < maxDelay {
279
288
delay = time .Duration (float64 (delay ) * e .retryOptions .BackoffMultiplier )
280
289
}
281
-
282
290
if delay > maxDelay {
283
291
delay = maxDelay
284
292
}
285
293
286
294
select {
287
295
case <- closed :
288
- // Token manager is closed, stop the loop
289
- // TODO(ndyakov): Discuss if we should call OnTokenError here
290
296
return
291
297
case <- time .After (delay ):
292
298
// continue to next attempt
0 commit comments