Skip to content

Commit 50f6efd

Browse files
committed
more tests
1 parent b618245 commit 50f6efd

File tree

4 files changed

+471
-1
lines changed

4 files changed

+471
-1
lines changed

identity/managed_identity_provider.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -52,7 +52,7 @@ type ManagedIdentityProvider struct {
5252

5353
// realManagedIdentityClient is a wrapper around the real mi.Client that implements our interface
5454
type realManagedIdentityClient struct {
55-
client mi.Client
55+
client ManagedIdentityClient
5656
}
5757

5858
func (c *realManagedIdentityClient) AcquireToken(ctx context.Context, resource string, opts ...mi.AcquireTokenOption) (public.AuthResult, error) {

identity/managed_identity_provider_test.go

Lines changed: 87 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -216,3 +216,90 @@ func TestRequestToken_ErrorCases(t *testing.T) {
216216
})
217217
}
218218
}
219+
220+
// MockMIClient is a mock implementation of the mi.Client interface
221+
type MockMIClient struct {
222+
mock.Mock
223+
}
224+
225+
func (m *MockMIClient) AcquireToken(ctx context.Context, resource string, opts ...mi.AcquireTokenOption) (public.AuthResult, error) {
226+
args := m.Called(ctx, resource)
227+
return args.Get(0).(public.AuthResult), args.Error(1)
228+
}
229+
230+
func (m *MockMIClient) Close() error {
231+
args := m.Called()
232+
return args.Error(0)
233+
}
234+
235+
func TestRealManagedIdentityClient(t *testing.T) {
236+
// Create a mock managed identity client
237+
mockMIClient := new(MockManagedIdentityClient)
238+
client := &realManagedIdentityClient{client: mockMIClient}
239+
240+
tests := []struct {
241+
name string
242+
resource string
243+
setupMock func(*MockManagedIdentityClient)
244+
expectedError string
245+
}{
246+
{
247+
name: "Success with default resource",
248+
resource: RedisResource,
249+
setupMock: func(m *MockManagedIdentityClient) {
250+
m.On("AcquireToken", mock.Anything, RedisResource, mock.Anything).
251+
Return(public.AuthResult{
252+
AccessToken: "test-token",
253+
ExpiresOn: time.Now().Add(time.Hour),
254+
}, nil)
255+
},
256+
},
257+
{
258+
name: "Success with custom resource",
259+
resource: "custom-resource",
260+
setupMock: func(m *MockManagedIdentityClient) {
261+
m.On("AcquireToken", mock.Anything, "custom-resource", mock.Anything).
262+
Return(public.AuthResult{
263+
AccessToken: "test-token",
264+
ExpiresOn: time.Now().Add(time.Hour),
265+
}, nil)
266+
},
267+
},
268+
{
269+
name: "Error from underlying client",
270+
resource: RedisResource,
271+
setupMock: func(m *MockManagedIdentityClient) {
272+
m.On("AcquireToken", mock.Anything, RedisResource, mock.Anything).
273+
Return(public.AuthResult{}, errors.New("underlying client error"))
274+
},
275+
expectedError: "underlying client error",
276+
},
277+
}
278+
279+
for _, tt := range tests {
280+
t.Run(tt.name, func(t *testing.T) {
281+
// Reset the mock for each test
282+
mockMIClient.ExpectedCalls = nil
283+
mockMIClient.Calls = nil
284+
285+
// Set up the mock
286+
tt.setupMock(mockMIClient)
287+
288+
// Call AcquireToken with empty options slice to match mock setup
289+
result, err := client.AcquireToken(context.Background(), tt.resource, []mi.AcquireTokenOption{}...)
290+
291+
if tt.expectedError != "" {
292+
assert.Error(t, err)
293+
assert.Contains(t, err.Error(), tt.expectedError)
294+
assert.Equal(t, public.AuthResult{}, result)
295+
} else {
296+
assert.NoError(t, err)
297+
assert.NotEqual(t, public.AuthResult{}, result)
298+
assert.Equal(t, "test-token", result.AccessToken)
299+
}
300+
301+
// Verify mock expectations
302+
mockMIClient.AssertExpectations(t)
303+
})
304+
}
305+
}

providers_test.go

Lines changed: 94 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@ package entraid
22

33
import (
44
"errors"
5+
"sync"
56
"testing"
67
"time"
78

@@ -504,3 +505,96 @@ func TestCredentialsProviderInterface(t *testing.T) {
504505
})
505506
}
506507
}
508+
509+
func TestCredentialsProviderSubscribe(t *testing.T) {
510+
// Create a test token
511+
testToken := token.New(
512+
"test",
513+
"test",
514+
"mock-token",
515+
time.Now().Add(time.Hour),
516+
time.Now(),
517+
int64(time.Hour),
518+
)
519+
520+
// Create a test provider
521+
options := ConfidentialCredentialsProviderOptions{
522+
CredentialsProviderOptions: CredentialsProviderOptions{
523+
ClientID: "test-client-id",
524+
TokenManagerOptions: manager.TokenManagerOptions{
525+
ExpirationRefreshRatio: 0.7,
526+
},
527+
},
528+
ConfidentialIdentityProviderOptions: identity.ConfidentialIdentityProviderOptions{
529+
ClientID: "test-client-id",
530+
CredentialsType: identity.ClientSecretCredentialType,
531+
ClientSecret: "test-secret",
532+
Scopes: []string{identity.RedisScopeDefault},
533+
Authority: identity.AuthorityConfiguration{},
534+
},
535+
}
536+
537+
// Set the token manager factory in the options
538+
options.tokenManagerFactory = testTokenManagerFactory(testToken, nil)
539+
540+
provider, err := NewConfidentialCredentialsProvider(options)
541+
require.NoError(t, err)
542+
require.NotNil(t, provider)
543+
544+
t.Run("concurrent subscribe and cancel", func(t *testing.T) {
545+
const numListeners = 10
546+
var wg sync.WaitGroup
547+
listeners := make([]*mockCredentialsListener, numListeners)
548+
cancels := make([]auth.CancelProviderFunc, numListeners)
549+
550+
// Subscribe multiple listeners concurrently
551+
for i := 0; i < numListeners; i++ {
552+
wg.Add(1)
553+
go func(idx int) {
554+
defer wg.Done()
555+
listener := &mockCredentialsListener{
556+
LastTokenCh: make(chan string, 1),
557+
LastErrCh: make(chan error, 1),
558+
}
559+
listeners[idx] = listener
560+
_, cancel, err := provider.Subscribe(listener)
561+
require.NoError(t, err)
562+
cancels[idx] = cancel
563+
}(i)
564+
}
565+
wg.Wait()
566+
567+
// Verify all listeners received the token
568+
for i, listener := range listeners {
569+
select {
570+
case token := <-listener.LastTokenCh:
571+
assert.Equal(t, "mock-token", token, "listener %d received wrong token", i)
572+
case err := <-listener.LastErrCh:
573+
t.Fatalf("listener %d received error: %v", i, err)
574+
}
575+
}
576+
577+
// Cancel all subscriptions concurrently
578+
for i := 0; i < numListeners; i++ {
579+
wg.Add(1)
580+
go func(idx int) {
581+
defer wg.Done()
582+
err := cancels[idx]()
583+
require.NoError(t, err)
584+
}(i)
585+
}
586+
wg.Wait()
587+
588+
// Verify no more tokens are sent after cancellation
589+
for i, listener := range listeners {
590+
select {
591+
case token := <-listener.LastTokenCh:
592+
t.Fatalf("listener %d received unexpected token after cancellation: %s", i, token)
593+
case err := <-listener.LastErrCh:
594+
t.Fatalf("listener %d received unexpected error after cancellation: %v", i, err)
595+
default:
596+
// No message received, which is expected
597+
}
598+
}
599+
})
600+
}

0 commit comments

Comments
 (0)