Skip to content

Commit 5b21e16

Browse files
committed
wip
1 parent 4d8ea71 commit 5b21e16

File tree

2 files changed

+80
-12
lines changed

2 files changed

+80
-12
lines changed

credentials_provider.go

Lines changed: 24 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@ package entraid
22

33
import (
44
"fmt"
5+
"sync"
56

67
"github.com/redis/go-redis/v9/auth"
78
)
@@ -16,11 +17,18 @@ type entraidCredentialsProvider struct {
1617
tokenManager TokenManager
1718
cancelTokenManager cancelFunc
1819

20+
// listeners is a slice of listeners that are notified when the token manager receives a new token.
1921
listeners []auth.CredentialsListener
22+
23+
// rwLock is a mutex that is used to synchronize access to the listeners slice.
24+
// It is used to ensure that only one goroutine can access the listeners slice at a time.
25+
rwLock sync.RWMutex
2026
}
2127

2228
// onTokenNext is a method that is called when the token manager receives a new token.
2329
func (e *entraidCredentialsProvider) onTokenNext(token *Token) {
30+
e.rwLock.RLock()
31+
defer e.rwLock.RUnlock()
2432
// Notify all listeners with the new token.
2533
for _, listener := range e.listeners {
2634
listener.OnNext(&authCredentials{
@@ -34,6 +42,8 @@ func (e *entraidCredentialsProvider) onTokenNext(token *Token) {
3442
// onError is a method that is called when the token manager encounters an error.
3543
// It notifies all listeners with the error.
3644
func (e *entraidCredentialsProvider) onTokenError(err error) {
45+
e.rwLock.RLock()
46+
defer e.rwLock.RUnlock()
3747
// Notify all listeners with the error.
3848
for _, listener := range e.listeners {
3949
listener.OnError(err)
@@ -46,6 +56,7 @@ func (e *entraidCredentialsProvider) onTokenError(err error) {
4656
// The listener will be notified with an error if there is an error obtaining the credentials.
4757
// The caller can cancel the subscription by calling the cancel function which is the second return value.
4858
func (e *entraidCredentialsProvider) Subscribe(listener auth.CredentialsListener) (auth.Credentials, auth.CancelProviderFunc, error) {
59+
e.rwLock.Lock()
4960
// Check if the listener is already in the list of listeners.
5061
alreadySubscribed := false
5162
for _, l := range e.listeners {
@@ -59,6 +70,7 @@ func (e *entraidCredentialsProvider) Subscribe(listener auth.CredentialsListener
5970
// Get the token from the identity provider.
6071
e.listeners = append(e.listeners, listener)
6172
}
73+
e.rwLock.Unlock()
6274

6375
token, err := e.tokenManager.GetToken()
6476
if err != nil {
@@ -78,21 +90,32 @@ func (e *entraidCredentialsProvider) Subscribe(listener auth.CredentialsListener
7890

7991
cancel := func() error {
8092
// Remove the listener from the list of listeners.
93+
e.rwLock.Lock()
94+
defer e.rwLock.Unlock()
8195
for i, l := range e.listeners {
8296
if l == listener {
8397
e.listeners = append(e.listeners[:i], e.listeners[i+1:]...)
8498
break
8599
}
86100
}
87101
if len(e.listeners) == 0 {
88-
e.cancelTokenManager()
102+
e.close()
89103
}
90104
return nil
91105
}
92106

93107
return credentials, cancel, nil
94108
}
95109

110+
func (e *entraidCredentialsProvider) close() {
111+
if e.cancelTokenManager != nil {
112+
e.cancelTokenManager()
113+
}
114+
e.cancelTokenManager = nil
115+
e.listeners = nil
116+
e.rwLock = sync.RWMutex{}
117+
}
118+
96119
type entraidTokenListener struct {
97120
onTokenNext func(token *Token)
98121
onTokenError func(err error)

token_manager.go

Lines changed: 56 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,8 @@ type TokenManager interface {
4848
GetToken() (*Token, error)
4949
// Start starts the token manager and returns a channel that will receive updates.
5050
Start(listener TokenListener) (cancelFunc, error)
51+
// Close closes the token manager and releases any resources.
52+
Close() error
5153
}
5254

5355
// defaultTokenParser is a function that parses the raw token and returns Token object.
@@ -74,6 +76,7 @@ func NewTokenManager(idp IdentityProvider, options TokenManagerOptions) TokenMan
7476
return &entraidTokenManager{
7577
idp: idp,
7678
token: nil,
79+
closed: make(chan struct{}),
7780
tokenParser: tokenParser,
7881
retryOptions: retryOptions,
7982
}
@@ -101,6 +104,8 @@ type entraidTokenManager struct {
101104

102105
// lock locks the listener to prevent concurrent access.
103106
lock sync.Mutex
107+
108+
closed chan struct{}
104109
}
105110

106111
func (e *entraidTokenManager) GetToken() (*Token, error) {
@@ -148,6 +153,7 @@ func (e *entraidTokenManager) Start(listener TokenListener) (cancelFunc, error)
148153
return nil, fmt.Errorf("token manager already started")
149154
}
150155
e.listener = listener
156+
e.closed = make(chan struct{})
151157

152158
token, err := e.GetToken()
153159
if err != nil {
@@ -159,22 +165,61 @@ func (e *entraidTokenManager) Start(listener TokenListener) (cancelFunc, error)
159165
go func(listener TokenListener) {
160166
// Simulate token refresh
161167
for {
162-
time.Sleep(
163-
time.Duration(e.retryOptions.InitialDelayMs) * time.Millisecond)
164-
newToken, err := e.GetToken()
165-
if err != nil {
166-
listener.OnTokenError(err)
168+
select {
169+
case <-time.After(time.Duration(e.token.TTL) * time.Second):
170+
// Token is about to expire, refresh it
171+
for i := 0; i < e.retryOptions.MaxAttempts; i++ {
172+
token, err := e.GetToken()
173+
if err == nil {
174+
listener.OnTokenNext(token)
175+
break
176+
}
177+
// check if err is retryable
178+
if err.Error() == "retryable error" {
179+
// retry
180+
continue
181+
} else {
182+
// not retryable
183+
listener.OnTokenError(err)
184+
return
185+
}
186+
187+
// check if max attempts reached
188+
if i == e.retryOptions.MaxAttempts-1 {
189+
listener.OnTokenError(err)
190+
return
191+
}
192+
193+
// Exponential backoff
194+
delay := time.Duration(e.retryOptions.InitialDelayMs) * time.Millisecond
195+
if delay < time.Duration(e.retryOptions.MaxDelayMs)*time.Millisecond {
196+
delay = time.Duration(float64(delay) * e.retryOptions.BackoffMultiplier)
197+
}
198+
199+
time.Sleep(delay)
200+
201+
if delay > time.Duration(e.retryOptions.MaxDelayMs)*time.Millisecond {
202+
delay = time.Duration(e.retryOptions.MaxDelayMs) * time.Millisecond
203+
}
204+
}
205+
case <-e.closed:
206+
// Token manager is closed, stop the loop
167207
return
168208
}
169-
listener.OnTokenNext(newToken)
170209
}
171210
}(e.listener)
172-
cancel := func() error {
173-
// Stop the token manager.
174-
return nil
175-
}
176211

177-
return cancel, nil
212+
return e.Close, nil
213+
}
214+
215+
func (e *entraidTokenManager) Close() error {
216+
e.lock.Lock()
217+
defer e.lock.Unlock()
218+
if e.listener != nil {
219+
e.listener = nil
220+
}
221+
close(e.closed)
222+
return nil
178223
}
179224

180225
// defaultRetryOptionsOr returns the default retry options if the provided options are not set.

0 commit comments

Comments
 (0)