Skip to content

Commit 41af0a7

Browse files
committed
fix(manager): wip, still have tests to resolve
1 parent 676cf1c commit 41af0a7

File tree

9 files changed

+325
-364
lines changed

9 files changed

+325
-364
lines changed

credentials_provider.go

Lines changed: 12 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -67,26 +67,23 @@ func (e *entraidCredentialsProvider) onTokenError(err error) {
6767
//
6868
// Note: If the listener is already subscribed, it will not receive duplicate notifications.
6969
func (e *entraidCredentialsProvider) Subscribe(listener auth.CredentialsListener) (auth.Credentials, auth.UnsubscribeFunc, error) {
70-
var token *token.Token
7170
// check if the manager is working
7271
// If the stopTokenManager is nil, the token manager is not started.
7372
e.tmLock.Lock()
7473
if e.stopTokenManager == nil {
75-
t, stopTM, err := e.tokenManager.Start(tokenListenerFromCP(e))
74+
stopTM, err := e.tokenManager.Start(tokenListenerFromCP(e))
7675
if err != nil {
7776
return nil, nil, fmt.Errorf("couldn't start token manager: %w", err)
7877
}
7978
e.stopTokenManager = stopTM
80-
token = t
81-
} else {
82-
t, err := e.tokenManager.GetToken(false)
83-
if err != nil {
84-
return nil, nil, fmt.Errorf("couldn't get token: %w", err)
85-
}
86-
token = t
8779
}
8880
e.tmLock.Unlock()
8981

82+
token, err := e.tokenManager.GetToken(false)
83+
if err != nil {
84+
return nil, nil, fmt.Errorf("couldn't get token: %w", err)
85+
}
86+
9087
e.rwLock.Lock()
9188
// Check if the listener is already in the list of listeners.
9289
alreadySubscribed := false
@@ -152,5 +149,11 @@ func NewCredentialsProvider(tokenManager manager.TokenManager, options Credentia
152149
options: options,
153150
listeners: make([]auth.CredentialsListener, 0),
154151
}
152+
// Start the token manager.
153+
stop, err := tokenManager.Start(tokenListenerFromCP(cp))
154+
if err != nil {
155+
return nil, fmt.Errorf("couldn't start token manager: %w", err)
156+
}
157+
cp.stopTokenManager = stop
155158
return cp, nil
156159
}

entraid_test.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,7 @@ func (m *fakeTokenManager) GetToken(forceRefresh bool) (*token.Token, error) {
4848
rawTokenString,
4949
time.Now().Add(tokenExpiration),
5050
time.Now(),
51-
int64(100*time.Millisecond),
51+
int64(tokenExpiration.Seconds()),
5252
)
5353
}
5454
return m.token, m.err

manager/defaults.go

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@ package manager
33
import (
44
"errors"
55
"fmt"
6+
"math"
67
"net"
78
"os"
89
"time"
@@ -105,7 +106,7 @@ func (*defaultIdentityProviderResponseParser) ParseResponse(response shared.Iden
105106

106107
var username, password, rawToken string
107108
var expiresOn time.Time
108-
now := time.Now().UTC().Truncate(time.Second)
109+
now := time.Now().UTC().Truncate(time.Second).Add(time.Second)
109110

110111
switch response.Type() {
111112
case shared.ResponseTypeAuthResult:
@@ -176,6 +177,6 @@ func (*defaultIdentityProviderResponseParser) ParseResponse(response shared.Iden
176177
rawToken,
177178
expiresOn,
178179
now,
179-
int64(time.Until(expiresOn).Seconds()),
180+
int64(math.Ceil(time.Until(expiresOn).Seconds())),
180181
), nil
181182
}

manager/entraid_manager.go

Lines changed: 267 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,267 @@
1+
package manager
2+
3+
import (
4+
"context"
5+
"fmt"
6+
"sync"
7+
"time"
8+
9+
"github.com/redis-developer/go-redis-entraid/internal"
10+
"github.com/redis-developer/go-redis-entraid/shared"
11+
"github.com/redis-developer/go-redis-entraid/token"
12+
)
13+
14+
// entraidTokenManager is a struct that implements the TokenManager interface.
15+
type entraidTokenManager struct {
16+
// idp is the identity provider used to obtain the token.
17+
idp shared.IdentityProvider
18+
19+
// token is the authentication token for the user which should be kept in memory if valid.
20+
token *token.Token
21+
22+
// tokenRWLock is a read-write lock used to protect the token from concurrent access.
23+
tokenRWLock sync.RWMutex
24+
25+
// identityProviderResponseParser is the parser used to parse the response from the identity provider.
26+
// It`s ParseResponse method will be called to parse the response and return the token.
27+
identityProviderResponseParser shared.IdentityProviderResponseParser
28+
29+
// retryOptions is a struct that contains the options for retrying the token request.
30+
// It contains the maximum number of attempts, initial delay, maximum delay, and backoff multiplier.
31+
// The default values are 3 attempts, 1000 ms initial delay, 10000 ms maximum delay, and 2.0 backoff multiplier.
32+
// The values can be overridden by the user.
33+
retryOptions RetryOptions
34+
35+
// listener is the single listener for the token manager.
36+
// It is used to receive updates from the token manager.
37+
// The token manager will call the listener's OnNext method with the updated token.
38+
// If an error occurs, the token manager will call the listener's OnError method with the error.
39+
// if listener is set, Start will fail
40+
listener TokenListener
41+
42+
// lock locks the listener to prevent concurrent access.
43+
lock sync.Mutex
44+
45+
// expirationRefreshRatio is the ratio of the token expiration time to refresh the token.
46+
// It is used to determine when to refresh the token.
47+
// The value should be between 0 and 1.
48+
// For example, if the expiration time is 1 hour and the ratio is 0.75,
49+
// the token will be refreshed after 45 minutes. (the token is refreshed when 75% of its lifetime has passed)
50+
expirationRefreshRatio float64
51+
52+
// lowerBoundDuration is the lower bound for the refresh time in time.Duration.
53+
lowerBoundDuration time.Duration
54+
55+
// closedChan is a channel that is closedChan when the token manager is closedChan.
56+
// It is used to signal the token manager to stop requesting tokens.
57+
closedChan chan struct{}
58+
59+
// context is the context used to request the token from the identity provider.
60+
ctx context.Context
61+
62+
// ctxCancel is the cancel function for the context.
63+
ctxCancel context.CancelFunc
64+
65+
// requestTimeout is the timeout for the request to the identity provider.
66+
requestTimeout time.Duration
67+
}
68+
69+
func (e *entraidTokenManager) GetToken(forceRefresh bool) (*token.Token, error) {
70+
e.tokenRWLock.RLock()
71+
// check if the token is nil and if it is not expired
72+
t := e.token
73+
duration := e.durationToRenewal(t)
74+
if !forceRefresh && t != nil && duration > 0 {
75+
e.tokenRWLock.RUnlock()
76+
return t, nil
77+
}
78+
e.tokenRWLock.RUnlock()
79+
80+
// start the context early,
81+
// since at heavy concurrent load
82+
// locks may take some time to acquire
83+
ctx, ctxCancel := context.WithTimeout(e.ctx, e.requestTimeout)
84+
defer ctxCancel()
85+
86+
// Upgrade to write lock for token update
87+
e.tokenRWLock.Lock()
88+
defer e.tokenRWLock.Unlock()
89+
90+
// Double-check pattern to avoid unnecessary token refresh
91+
t = e.token
92+
duration = e.durationToRenewal(t)
93+
if !forceRefresh && t != nil && duration > 0 {
94+
return t, nil
95+
}
96+
97+
// Request a new token from the identity provider
98+
idpResult, err := e.idp.RequestToken(ctx)
99+
if err != nil {
100+
return nil, fmt.Errorf("failed to request token from idp: %w", err)
101+
}
102+
103+
t, err = e.identityProviderResponseParser.ParseResponse(idpResult)
104+
if err != nil {
105+
return nil, fmt.Errorf("failed to parse token: %w", err)
106+
}
107+
108+
if t == nil {
109+
return nil, fmt.Errorf("failed to get token: token is nil")
110+
}
111+
112+
// Store the token
113+
e.token = t
114+
// Return the token - no need to copy since it's immutable
115+
return t, nil
116+
}
117+
118+
// Start starts the token manager and returns cancelFunc to stop the token manager.
119+
// It takes a TokenListener as an argument, which is used to receive updates.
120+
// The token manager will call the listener's OnNext method with the updated token.
121+
// If an error occurs, the token manager will call the listener's OnError method with the error.
122+
func (e *entraidTokenManager) Start(listener TokenListener) (StopFunc, error) {
123+
e.lock.Lock()
124+
defer e.lock.Unlock()
125+
if e.listener != nil {
126+
return nil, ErrTokenManagerAlreadyStarted
127+
}
128+
129+
if e.closedChan != nil && !internal.IsClosed(e.closedChan) {
130+
// there is a hanging goroutine that is waiting for the closedChan to be closed
131+
// if the closedChan is not nil and not closed, close it
132+
close(e.closedChan)
133+
}
134+
135+
ctx, ctxCancel := context.WithCancel(context.Background())
136+
e.ctx = ctx
137+
e.ctxCancel = ctxCancel
138+
139+
// make sure there is token in memory before starting the loop
140+
_, err := e.GetToken(false)
141+
if err != nil {
142+
return nil, fmt.Errorf("failed to get token: %w", err)
143+
}
144+
145+
e.closedChan = make(chan struct{})
146+
e.listener = listener
147+
148+
go func(listener TokenListener, closed <-chan struct{}) {
149+
maxDelay := e.retryOptions.MaxDelay
150+
initialDelay := e.retryOptions.InitialDelay
151+
152+
for {
153+
e.tokenRWLock.RLock()
154+
timeToRenewal := e.durationToRenewal(e.token)
155+
e.tokenRWLock.RUnlock()
156+
select {
157+
case <-closed:
158+
return
159+
case <-time.After(timeToRenewal):
160+
if timeToRenewal == 0 {
161+
// Token was requested immediately, guard against infinite loop
162+
select {
163+
case <-closed:
164+
return
165+
case <-time.After(initialDelay):
166+
// continue to attempt
167+
}
168+
}
169+
170+
// Token is about to expire, refresh it
171+
delay := initialDelay
172+
for i := 0; i < e.retryOptions.MaxAttempts; i++ {
173+
t, err := e.GetToken(true)
174+
if err == nil {
175+
listener.OnNext(t)
176+
break
177+
}
178+
179+
// check if err is retriable
180+
if e.retryOptions.IsRetryable(err) {
181+
if i == e.retryOptions.MaxAttempts-1 {
182+
// last attempt, call OnError
183+
listener.OnError(fmt.Errorf("max attempts reached: %w", err))
184+
return
185+
}
186+
187+
// Exponential backoff
188+
if delay < maxDelay {
189+
delay = time.Duration(float64(delay) * e.retryOptions.BackoffMultiplier)
190+
}
191+
if delay > maxDelay {
192+
delay = maxDelay
193+
}
194+
195+
select {
196+
case <-closed:
197+
return
198+
case <-time.After(delay):
199+
// continue to next attempt
200+
}
201+
} else {
202+
// not retriable
203+
listener.OnError(err)
204+
return
205+
}
206+
}
207+
}
208+
}
209+
}(listener, e.closedChan)
210+
211+
return e.stop, nil
212+
}
213+
214+
// stop closes the token manager and releases any resources.
215+
func (e *entraidTokenManager) stop() (err error) {
216+
e.lock.Lock()
217+
defer e.lock.Unlock()
218+
defer func() {
219+
// recover from panic and return the error
220+
if r := recover(); r != nil {
221+
err = fmt.Errorf("failed to stop token manager: %s", r)
222+
}
223+
}()
224+
225+
if e.closedChan == nil || e.listener == nil {
226+
return ErrTokenManagerAlreadyStopped
227+
}
228+
229+
e.ctxCancel()
230+
e.listener = nil
231+
close(e.closedChan)
232+
233+
return nil
234+
}
235+
236+
// durationToRenewal calculates the duration to the next token renewal.
237+
// It returns the duration to the next token renewal based on the expiration refresh ratio and the lower bound duration.
238+
// If the token is nil, it returns 0.
239+
// If the time till expiration is less than the lower bound duration, it returns 0 to renew the token now.
240+
func (e *entraidTokenManager) durationToRenewal(t *token.Token) time.Duration {
241+
if t == nil {
242+
return 0
243+
}
244+
expirationRefreshTime := t.ReceivedAt().Add(time.Duration(float64(t.TTL()) * float64(time.Second) * e.expirationRefreshRatio))
245+
timeTillExpiration := time.Until(t.ExpirationOn())
246+
now := time.Now().UTC()
247+
248+
if expirationRefreshTime.Before(now) {
249+
return 0
250+
}
251+
252+
// if the timeTillExpiration is less than the lower bound (or 0), return 0 to renew the token NOW
253+
if timeTillExpiration <= e.lowerBoundDuration || timeTillExpiration <= 0 {
254+
return 0
255+
}
256+
257+
// Calculate the time to renew the token based on the expiration refresh ratio
258+
duration := time.Until(expirationRefreshTime)
259+
260+
// if the duration will take us past the lower bound, return the duration to lower bound
261+
if timeTillExpiration-e.lowerBoundDuration < duration {
262+
return timeTillExpiration - e.lowerBoundDuration
263+
}
264+
265+
// return the calculated duration
266+
return duration
267+
}

manager/manager_test.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -61,7 +61,7 @@ var testTokenValid = token.New(
6161
"test",
6262
time.Now().Add(time.Hour),
6363
time.Now(),
64-
int64(time.Hour),
64+
int64(time.Hour.Seconds()),
6565
)
6666

6767
func newTestJWTToken(expiresOn time.Time) string {

0 commit comments

Comments
 (0)