@@ -2,6 +2,7 @@ package entraid
2
2
3
3
import (
4
4
"fmt"
5
+ "sync"
5
6
6
7
"github.com/redis/go-redis/v9/auth"
7
8
)
@@ -16,11 +17,18 @@ type entraidCredentialsProvider struct {
16
17
tokenManager TokenManager
17
18
cancelTokenManager cancelFunc
18
19
20
+ // listeners is a slice of listeners that are notified when the token manager receives a new token.
19
21
listeners []auth.CredentialsListener
22
+
23
+ // rwLock is a mutex that is used to synchronize access to the listeners slice.
24
+ // It is used to ensure that only one goroutine can access the listeners slice at a time.
25
+ rwLock sync.RWMutex
20
26
}
21
27
22
28
// onTokenNext is a method that is called when the token manager receives a new token.
23
29
func (e * entraidCredentialsProvider ) onTokenNext (token * Token ) {
30
+ e .rwLock .RLock ()
31
+ defer e .rwLock .RUnlock ()
24
32
// Notify all listeners with the new token.
25
33
for _ , listener := range e .listeners {
26
34
listener .OnNext (& authCredentials {
@@ -34,6 +42,8 @@ func (e *entraidCredentialsProvider) onTokenNext(token *Token) {
34
42
// onError is a method that is called when the token manager encounters an error.
35
43
// It notifies all listeners with the error.
36
44
func (e * entraidCredentialsProvider ) onTokenError (err error ) {
45
+ e .rwLock .RLock ()
46
+ defer e .rwLock .RUnlock ()
37
47
// Notify all listeners with the error.
38
48
for _ , listener := range e .listeners {
39
49
listener .OnError (err )
@@ -46,6 +56,7 @@ func (e *entraidCredentialsProvider) onTokenError(err error) {
46
56
// The listener will be notified with an error if there is an error obtaining the credentials.
47
57
// The caller can cancel the subscription by calling the cancel function which is the second return value.
48
58
func (e * entraidCredentialsProvider ) Subscribe (listener auth.CredentialsListener ) (auth.Credentials , auth.CancelProviderFunc , error ) {
59
+ e .rwLock .Lock ()
49
60
// Check if the listener is already in the list of listeners.
50
61
alreadySubscribed := false
51
62
for _ , l := range e .listeners {
@@ -59,6 +70,7 @@ func (e *entraidCredentialsProvider) Subscribe(listener auth.CredentialsListener
59
70
// Get the token from the identity provider.
60
71
e .listeners = append (e .listeners , listener )
61
72
}
73
+ e .rwLock .Unlock ()
62
74
63
75
token , err := e .tokenManager .GetToken ()
64
76
if err != nil {
@@ -78,21 +90,32 @@ func (e *entraidCredentialsProvider) Subscribe(listener auth.CredentialsListener
78
90
79
91
cancel := func () error {
80
92
// Remove the listener from the list of listeners.
93
+ e .rwLock .Lock ()
94
+ defer e .rwLock .Unlock ()
81
95
for i , l := range e .listeners {
82
96
if l == listener {
83
97
e .listeners = append (e .listeners [:i ], e .listeners [i + 1 :]... )
84
98
break
85
99
}
86
100
}
87
101
if len (e .listeners ) == 0 {
88
- e .cancelTokenManager ()
102
+ e .close ()
89
103
}
90
104
return nil
91
105
}
92
106
93
107
return credentials , cancel , nil
94
108
}
95
109
110
+ func (e * entraidCredentialsProvider ) close () {
111
+ if e .cancelTokenManager != nil {
112
+ e .cancelTokenManager ()
113
+ }
114
+ e .cancelTokenManager = nil
115
+ e .listeners = nil
116
+ e .rwLock = sync.RWMutex {}
117
+ }
118
+
96
119
type entraidTokenListener struct {
97
120
onTokenNext func (token * Token )
98
121
onTokenError func (err error )
0 commit comments