Skip to content

Commit 2a286c6

Browse files
committed
Address PR comments, refactor
- Change type names to make more sense (e.g. Start / Stop ) - Add context.Context to the IDP RequestToken - Add RequestTimeout to TokenManagerOptions which will be utilized by the context - Change the LowerRefreshBoundMs from int64 to time.Duration and use better name (dropping the Ms suffix) TODO: - Address changes in the documentation
1 parent 87f7d84 commit 2a286c6

20 files changed

+211
-177
lines changed

.github/workflows/build.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@ jobs:
1919
- name: Install dependencies
2020
run: go mod tidy
2121
- name: Run tests with coverage
22-
run: go test ./... -coverprofile=./cover.out -covermode=atomic -race -count 2 -timeout 1m
22+
run: go test ./... -coverprofile=./cover.out -covermode=atomic -race -count 2 -timeout 5m
2323
- name: Upload coverage
2424
uses: actions/upload-artifact@v4
2525
with:

README.md

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -231,7 +231,7 @@ type TokenManagerOptions struct {
231231

232232
// Optional: Minimum time before expiration to refresh (ms)
233233
// Default: 10000 (10 seconds)
234-
LowerRefreshBoundMs int64
234+
LowerRefreshBounds int64
235235

236236
// Optional: Configuration for retry behavior
237237
RetryOptions RetryOptions
@@ -350,7 +350,7 @@ options := entraid.CredentialsProviderOptions{
350350
ClientID: os.Getenv("AZURE_CLIENT_ID"),
351351
TokenManagerOptions: manager.TokenManagerOptions{
352352
ExpirationRefreshRatio: 0.7,
353-
LowerRefreshBoundMs: 10000,
353+
LowerRefreshBounds: 10000,
354354
},
355355
}
356356
```
@@ -361,7 +361,7 @@ options := entraid.CredentialsProviderOptions{
361361
ClientID: os.Getenv("AZURE_CLIENT_ID"),
362362
TokenManagerOptions: manager.TokenManagerOptions{
363363
ExpirationRefreshRatio: 0.7,
364-
LowerRefreshBoundMs: 10000,
364+
LowerRefreshBounds: 10000,
365365
RetryOptions: manager.RetryOptions{
366366
MaxAttempts: 3,
367367
InitialDelayMs: 1000,
@@ -516,7 +516,7 @@ func main() {
516516
tokenManager, err := manager.NewTokenManager(customProvider, manager.TokenManagerOptions{
517517
// Configure token refresh behavior
518518
ExpirationRefreshRatio: 0.7,
519-
LowerRefreshBoundMs: 10000,
519+
LowerRefreshBounds: 10000,
520520
})
521521
if err != nil {
522522
log.Fatalf("Failed to create token manager: %v", err)

credentials_provider.go

Lines changed: 12 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -19,8 +19,8 @@ var _ auth.StreamingCredentialsProvider = (*entraidCredentialsProvider)(nil)
1919
type entraidCredentialsProvider struct {
2020
options CredentialsProviderOptions // Configuration options for the provider.
2121

22-
tokenManager manager.TokenManager // Manages token retrieval.
23-
closeTokenManager manager.CloseFunc // Function to cancel the token manager.
22+
tokenManager manager.TokenManager // Manages token retrieval.
23+
stopTokenManager manager.StopFunc // Function to stop the token manager.
2424

2525
// listeners is a slice of listeners that are notified when the token manager receives a new token.
2626
listeners []auth.CredentialsListener // Slice of listeners notified on token updates.
@@ -64,7 +64,7 @@ func (e *entraidCredentialsProvider) onTokenError(err error) {
6464
// - error: An error if the subscription fails, such as if the token cannot be retrieved.
6565
//
6666
// Note: If the listener is already subscribed, it will not receive duplicate notifications.
67-
func (e *entraidCredentialsProvider) Subscribe(listener auth.CredentialsListener) (auth.Credentials, auth.CancelProviderFunc, error) {
67+
func (e *entraidCredentialsProvider) Subscribe(listener auth.CredentialsListener) (auth.Credentials, auth.UnsubscribeFunc, error) {
6868
// First try to get a token, only then subscribe the listener.
6969
token, err := e.tokenManager.GetToken(false)
7070
if err != nil {
@@ -87,7 +87,7 @@ func (e *entraidCredentialsProvider) Subscribe(listener auth.CredentialsListener
8787
}
8888
e.rwLock.Unlock()
8989

90-
cancel := func() error {
90+
unsub := func() error {
9191
// Remove the listener from the list of listeners.
9292
e.rwLock.Lock()
9393
defer e.rwLock.Unlock()
@@ -102,20 +102,20 @@ func (e *entraidCredentialsProvider) Subscribe(listener auth.CredentialsListener
102102
// Clear the listeners slice if it's empty
103103
if len(e.listeners) == 0 {
104104
e.listeners = make([]auth.CredentialsListener, 0)
105-
if e.closeTokenManager != nil {
106-
err := e.closeTokenManager()
105+
if e.stopTokenManager != nil {
106+
err := e.stopTokenManager()
107107
if err != nil {
108108
return fmt.Errorf("couldn't cancel token manager: %w", err)
109109
}
110-
// Set the cancelTokenManager to nil to indicate that it has been canceled.
111-
// This prevents multiple calls to cancelTokenManager.
112-
e.closeTokenManager = nil
110+
// Set the stopTokenManager to nil to indicate that it has been stopped.
111+
// This prevents multiple calls to stopTokenManager.
112+
e.stopTokenManager = nil
113113
}
114114
}
115115
return nil
116116
}
117117

118-
return token, cancel, nil
118+
return token, unsub, nil
119119
}
120120

121121
// NewCredentialsProvider creates a new credentials provider with the specified token manager and options.
@@ -134,10 +134,10 @@ func NewCredentialsProvider(tokenManager manager.TokenManager, options Credentia
134134
options: options,
135135
listeners: make([]auth.CredentialsListener, 0),
136136
}
137-
cancelTokenManager, err := cp.tokenManager.Start(tokenListenerFromCP(cp))
137+
stopTM, err := cp.tokenManager.Start(tokenListenerFromCP(cp))
138138
if err != nil {
139139
return nil, fmt.Errorf("couldn't start token manager: %w", err)
140140
}
141-
cp.closeTokenManager = cancelTokenManager
141+
cp.stopTokenManager = stopTM
142142
return cp, nil
143143
}

credentials_provider_test.go

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -343,7 +343,7 @@ func TestCredentialsProviderSubscribe(t *testing.T) {
343343
mtm.On("GetToken", false).Return(testToken, nil)
344344
mtm.On("Start", mock.Anything).
345345
Run(mockTokenManagerLoop(mtm, tokenExpiration, testToken, nil)).
346-
Return(manager.CloseFunc(mtm.Close), nil)
346+
Return(manager.StopFunc(mtm.Stop), nil)
347347
provider, err := NewConfidentialCredentialsProvider(options)
348348
require.NoError(t, err)
349349
require.NotNil(t, provider)
@@ -396,13 +396,13 @@ func TestCredentialsProviderSubscribe(t *testing.T) {
396396

397397
mtm.On("Start", mock.Anything).
398398
Run(mockTokenManagerLoop(mtm, tokenExpiration, nil, errTokenError)).
399-
Return(manager.CloseFunc(mtm.Close), nil)
399+
Return(manager.StopFunc(mtm.Stop), nil)
400400
provider, err := NewConfidentialCredentialsProvider(options)
401401
require.NoError(t, err)
402402
require.NotNil(t, provider)
403403
var wg sync.WaitGroup
404404
listeners := make([]*mockCredentialsListener, numListeners)
405-
cancels := make([]auth.CancelProviderFunc, numListeners)
405+
cancels := make([]auth.UnsubscribeFunc, numListeners)
406406

407407
// Subscribe multiple listeners concurrently
408408
for i := 0; i < numListeners; i++ {
@@ -467,7 +467,7 @@ func TestCredentialsProviderSubscribe(t *testing.T) {
467467

468468
mtm.On("Start", mock.Anything).
469469
Run(mockTokenManagerLoop(mtm, tokenExpiration, nil, errTokenError)).
470-
Return(manager.CloseFunc(mtm.Close), nil)
470+
Return(manager.StopFunc(mtm.Stop), nil)
471471
provider, err := NewConfidentialCredentialsProvider(options)
472472
require.NoError(t, err)
473473
require.NotNil(t, provider)
@@ -525,7 +525,7 @@ func TestCredentialsProviderSubscribe(t *testing.T) {
525525
require.NotNil(t, provider)
526526
var wg sync.WaitGroup
527527
listeners := make([]*mockCredentialsListener, numListeners)
528-
cancels := make([]auth.CancelProviderFunc, numListeners)
528+
cancels := make([]auth.UnsubscribeFunc, numListeners)
529529

530530
// Subscribe multiple listeners concurrently
531531
for i := 0; i < numListeners; i++ {

entraid_test.go

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -54,7 +54,7 @@ func (m *fakeTokenManager) GetToken(forceRefresh bool) (*token.Token, error) {
5454
return m.token, m.err
5555
}
5656

57-
func (m *fakeTokenManager) Start(listener manager.TokenListener) (manager.CloseFunc, error) {
57+
func (m *fakeTokenManager) Start(listener manager.TokenListener) (manager.StopFunc, error) {
5858
if m.err != nil {
5959
return nil, m.err
6060
}
@@ -65,10 +65,10 @@ func (m *fakeTokenManager) Start(listener manager.TokenListener) (manager.CloseF
6565
case <-time.After(tokenExpiration):
6666
m.lock.Lock()
6767
if m.err != nil {
68-
listener.OnTokenError(m.err)
68+
listener.OnError(m.err)
6969
return
7070
}
71-
listener.OnTokenNext(m.token)
71+
listener.OnNext(m.token)
7272
m.lock.Unlock()
7373
case <-done:
7474
// Exit the loop if done channel is closed
@@ -84,7 +84,7 @@ func (m *fakeTokenManager) Start(listener manager.TokenListener) (manager.CloseF
8484
}, nil
8585
}
8686

87-
func (m *fakeTokenManager) Close() error {
87+
func (m *fakeTokenManager) Stop() error {
8888
return nil
8989
}
9090

@@ -147,7 +147,7 @@ func (m *mockTokenManager) GetToken(forceRefresh bool) (*token.Token, error) {
147147
return args.Get(0).(*token.Token), args.Error(1)
148148
}
149149

150-
func (m *mockTokenManager) Start(listener manager.TokenListener) (manager.CloseFunc, error) {
150+
func (m *mockTokenManager) Start(listener manager.TokenListener) (manager.StopFunc, error) {
151151
args := m.Called(listener)
152152
m.lock.Lock()
153153
if m.done == nil {
@@ -161,13 +161,13 @@ func (m *mockTokenManager) Start(listener manager.TokenListener) (manager.CloseF
161161
m.listener = listener
162162
}
163163
m.lock.Unlock()
164-
return args.Get(0).(manager.CloseFunc), args.Error(1)
164+
return args.Get(0).(manager.StopFunc), args.Error(1)
165165
}
166-
func (m *mockTokenManager) Close() error {
166+
func (m *mockTokenManager) Stop() error {
167167
m.lock.Lock()
168168
defer m.lock.Unlock()
169169
if m.listener == nil {
170-
return manager.ErrTokenManagerAlreadyClosed
170+
return manager.ErrTokenManagerAlreadyStopped
171171
}
172172
if m.listener != nil {
173173
m.listener = nil
@@ -200,9 +200,9 @@ func mockTokenManagerLoop(mtm *mockTokenManager, tokenExpiration time.Duration,
200200
case <-time.After(tokenExpiration):
201201
mtm.lock.Lock()
202202
if err != nil {
203-
mtm.listener.OnTokenError(err)
203+
mtm.listener.OnError(err)
204204
} else {
205-
mtm.listener.OnTokenNext(testToken)
205+
mtm.listener.OnNext(testToken)
206206
}
207207
mtm.lock.Unlock()
208208
}

go.mod

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,12 +4,14 @@ go 1.18
44

55
require (
66
github.com/Azure/azure-sdk-for-go/sdk/azidentity v1.8.0-beta.1
7-
github.com/redis/go-redis/v9 v9.5.3-0.20250331212737-c248425ade4a
7+
github.com/redis/go-redis/v9 v9.5.3-0.20250416091253-d0a8c76d8420
88
github.com/stretchr/testify v1.10.0
99
)
1010

1111
require (
12+
github.com/cespare/xxhash/v2 v2.3.0 // indirect
1213
github.com/davecgh/go-spew v1.1.1 // indirect
14+
github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f // indirect
1315
github.com/pmezard/go-difflib v1.0.0 // indirect
1416
github.com/stretchr/objx v0.5.2 // indirect
1517
gopkg.in/yaml.v3 v3.0.1 // indirect

go.sum

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,8 +8,12 @@ github.com/Azure/azure-sdk-for-go/sdk/internal v1.10.0/go.mod h1:iZDifYGJTIgIIkY
88
github.com/AzureAD/microsoft-authentication-extensions-for-go/cache v0.1.1 h1:WJTmL004Abzc5wDB5VtZG2PJk5ndYDgVacGqfirKxjM=
99
github.com/AzureAD/microsoft-authentication-library-for-go v1.4.1 h1:8BKxhZZLX/WosEeoCvWysmKUscfa9v8LIPEEU0JjE2o=
1010
github.com/AzureAD/microsoft-authentication-library-for-go v1.4.1/go.mod h1:wP83P5OoQ5p6ip3ScPr0BAq0BvuPAvacpEuSzyouqAI=
11+
github.com/cespare/xxhash/v2 v2.3.0 h1:UL815xU9SqsFlibzuggzjXhog7bL6oX9BbNZnL2UFvs=
12+
github.com/cespare/xxhash/v2 v2.3.0/go.mod h1:VGX0DQ3Q6kWi7AoAeZDth3/j3BFtOZR5XLFGgcrjCOs=
1113
github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c=
1214
github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
15+
github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f h1:lO4WD4F/rVNCu3HqELle0jiPLLBs70cWOduZpkS1E78=
16+
github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f/go.mod h1:cuUVRXasLTGF7a8hSLbxyZXjz+1KgoB3wDUb6vlszIc=
1317
github.com/golang-jwt/jwt/v5 v5.2.1 h1:OuVbFODueb089Lh128TAcimifWaLhJwVflnrgM17wHk=
1418
github.com/golang-jwt/jwt/v5 v5.2.1/go.mod h1:pqrtFR0X4osieyHYxtmOUWsAWrfe1Q5UVIyoH402zdk=
1519
github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0=
@@ -25,6 +29,8 @@ github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZb
2529
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
2630
github.com/redis/go-redis/v9 v9.5.3-0.20250331212737-c248425ade4a h1:R5xgk8m+CF7lVE0EGr+tLkT1eM3Zfd39BJfnANQqpKA=
2731
github.com/redis/go-redis/v9 v9.5.3-0.20250331212737-c248425ade4a/go.mod h1:huWgSWd8mW6+m0VPhJjSSQ+d6Nh1VICQ6Q5lHuCH/Iw=
32+
github.com/redis/go-redis/v9 v9.5.3-0.20250416091253-d0a8c76d8420 h1:/dxO9rhmlhKP5pyI7omDH3QQzC0AppWxHT1w5TBsdTU=
33+
github.com/redis/go-redis/v9 v9.5.3-0.20250416091253-d0a8c76d8420/go.mod h1:huWgSWd8mW6+m0VPhJjSSQ+d6Nh1VICQ6Q5lHuCH/Iw=
2834
github.com/rogpeppe/go-internal v1.12.0 h1:exVL4IDcn6na9z1rAb56Vxr+CgyK3nn3O+epU5NdKM8=
2935
github.com/stretchr/objx v0.5.2 h1:xuMeJ0Sdp5ZMRXx/aWO6RZxdr3beISkG5/G/aIRr3pY=
3036
github.com/stretchr/objx v0.5.2/go.mod h1:FRsXN1f5AsAjCGJKqEizvkpNtU+EGNCLh3NxZ/8L+MA=

identity/azure_default_identity_provider.go

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -58,7 +58,7 @@ func NewDefaultAzureIdentityProvider(opts DefaultAzureIdentityProviderOptions) (
5858

5959
// RequestToken requests a token from the Azure Default Identity provider.
6060
// It returns the token, the expiration time, and an error if any.
61-
func (a *DefaultAzureIdentityProvider) RequestToken() (shared.IdentityProviderResponse, error) {
61+
func (a *DefaultAzureIdentityProvider) RequestToken(ctx context.Context) (shared.IdentityProviderResponse, error) {
6262
credFactory := a.credFactory
6363
if credFactory == nil {
6464
credFactory = &defaultCredFactory{}
@@ -68,7 +68,7 @@ func (a *DefaultAzureIdentityProvider) RequestToken() (shared.IdentityProviderRe
6868
return nil, fmt.Errorf("failed to create default azure credential: %w", err)
6969
}
7070

71-
token, err := cred.GetToken(context.TODO(), policy.TokenRequestOptions{Scopes: a.scopes})
71+
token, err := cred.GetToken(ctx, policy.TokenRequestOptions{Scopes: a.scopes})
7272
if err != nil {
7373
return nil, fmt.Errorf("failed to get token: %w", err)
7474
}

identity/azure_default_identity_provider_test.go

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
package identity
22

33
import (
4+
"context"
45
"testing"
56

67
"github.com/Azure/azure-sdk-for-go/sdk/azcore"
@@ -38,7 +39,7 @@ func TestAzureDefaultIdentityProvider_RequestToken(t *testing.T) {
3839

3940
// Request a token from the provider in incorrect environment
4041
// should fail.
41-
token, err := provider.RequestToken()
42+
token, err := provider.RequestToken(context.Background())
4243
assert.Nil(t, token, "token should be nil")
4344
assert.Error(t, err, "failed to request token")
4445

@@ -51,7 +52,7 @@ func TestAzureDefaultIdentityProvider_RequestToken(t *testing.T) {
5152
mCredFactory := &mockCredFactory{}
5253
mCredFactory.On("NewDefaultAzureCredential", mock.Anything).Return(mCreds, nil)
5354
provider.credFactory = mCredFactory
54-
token, err = provider.RequestToken()
55+
token, err = provider.RequestToken(context.Background())
5556
assert.NotNil(t, token, "token should not be nil")
5657
assert.NoError(t, err, "failed to request token")
5758
assert.Equal(t, shared.ResponseTypeAccessToken, token.Type(), "token type should be access token")
@@ -70,7 +71,7 @@ func TestAzureDefaultIdentityProvider_RequestTokenWithScopes(t *testing.T) {
7071

7172
t.Run("RequestToken with custom scopes", func(t *testing.T) {
7273
// Request a token from the provider
73-
token, err := provider.RequestToken()
74+
token, err := provider.RequestToken(context.Background())
7475
assert.Nil(t, token, "token should be nil")
7576
assert.Error(t, err, "failed to request token")
7677

@@ -83,7 +84,7 @@ func TestAzureDefaultIdentityProvider_RequestTokenWithScopes(t *testing.T) {
8384
mCredFactory := &mockCredFactory{}
8485
mCredFactory.On("NewDefaultAzureCredential", mock.Anything).Return(mCreds, nil)
8586
provider.credFactory = mCredFactory
86-
token, err = provider.RequestToken()
87+
token, err = provider.RequestToken(context.Background())
8788
assert.NotNil(t, token, "token should not be nil")
8889
assert.NoError(t, err, "failed to request token")
8990
assert.Equal(t, shared.ResponseTypeAccessToken, token.Type(), "token type should be access token")
@@ -94,7 +95,7 @@ func TestAzureDefaultIdentityProvider_RequestTokenWithScopes(t *testing.T) {
9495
mCredFactory := &mockCredFactory{}
9596
mCredFactory.On("NewDefaultAzureCredential", mock.Anything).Return(nil, assert.AnError)
9697
provider.credFactory = mCredFactory
97-
token, err := provider.RequestToken()
98+
token, err := provider.RequestToken(context.Background())
9899
assert.Nil(t, token, "token should be nil")
99100
assert.Error(t, err, "failed to request token")
100101
})

identity/confidential_identity_provider.go

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -155,12 +155,12 @@ func NewConfidentialIdentityProvider(opts ConfidentialIdentityProviderOptions) (
155155

156156
// RequestToken requests a token from the identity provider.
157157
// It returns the identity provider response, including the auth result.
158-
func (c *ConfidentialIdentityProvider) RequestToken() (shared.IdentityProviderResponse, error) {
158+
func (c *ConfidentialIdentityProvider) RequestToken(ctx context.Context) (shared.IdentityProviderResponse, error) {
159159
if c.client == nil {
160160
return nil, fmt.Errorf("client is not initialized")
161161
}
162162

163-
result, err := c.client.AcquireTokenByCredential(context.TODO(), c.scopes)
163+
result, err := c.client.AcquireTokenByCredential(ctx, c.scopes)
164164
if err != nil {
165165
return nil, fmt.Errorf("failed to acquire token: %w", err)
166166
}

0 commit comments

Comments
 (0)