Skip to content

Commit 4152087

Browse files
committed
more tests, should cover above 80%
1 parent aff8eac commit 4152087

File tree

6 files changed

+979
-14
lines changed

6 files changed

+979
-14
lines changed

identity/managed_identity_provider.go

Lines changed: 31 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -6,9 +6,18 @@ import (
66
"fmt"
77

88
mi "github.com/AzureAD/microsoft-authentication-library-for-go/apps/managedidentity"
9+
"github.com/AzureAD/microsoft-authentication-library-for-go/apps/public"
910
"github.com/redis-developer/go-redis-entraid/shared"
1011
)
1112

13+
// ManagedIdentityClient is an interface that defines the methods for a managed identity client.
14+
// It is used to acquire a token using the managed identity.
15+
type ManagedIdentityClient interface {
16+
// AcquireToken acquires a token using the managed identity.
17+
// It returns the token and an error if any.
18+
AcquireToken(ctx context.Context, resource string, opts ...mi.AcquireTokenOption) (public.AuthResult, error)
19+
}
20+
1221
// ManagedIdentityProviderOptions represents the options for the managed identity provider.
1322
// It is used to configure the identity provider when requesting a manager.
1423
type ManagedIdentityProviderOptions struct {
@@ -38,14 +47,22 @@ type ManagedIdentityProvider struct {
3847
scopes []string
3948

4049
// client is the managed identity client used to request a manager.
41-
client *mi.Client
50+
client ManagedIdentityClient
51+
}
52+
53+
// realManagedIdentityClient is a wrapper around the real mi.Client that implements our interface
54+
type realManagedIdentityClient struct {
55+
client mi.Client
56+
}
57+
58+
func (c *realManagedIdentityClient) AcquireToken(ctx context.Context, resource string, opts ...mi.AcquireTokenOption) (public.AuthResult, error) {
59+
return c.client.AcquireToken(ctx, resource, opts...)
4260
}
4361

4462
// NewManagedIdentityProvider creates a new managed identity provider for Azure with managed identity.
4563
// It is used to configure the identity provider when requesting a manager.
4664
func NewManagedIdentityProvider(opts ManagedIdentityProviderOptions) (*ManagedIdentityProvider, error) {
47-
var client mi.Client
48-
var err error
65+
var client ManagedIdentityClient
4966

5067
if opts.ManagedIdentityType != SystemAssignedIdentity && opts.ManagedIdentityType != UserAssignedIdentity {
5168
return nil, errors.New("invalid managed identity type")
@@ -56,25 +73,29 @@ func NewManagedIdentityProvider(opts ManagedIdentityProviderOptions) (*ManagedId
5673
// SystemAssignedIdentity is the type of identity that is automatically managed by Azure.
5774
// This type of identity is automatically created and managed by Azure.
5875
// It is used to authenticate the identity when requesting a manager.
59-
client, err = mi.New(mi.SystemAssigned())
76+
miClient, err := mi.New(mi.SystemAssigned())
77+
if err != nil {
78+
return nil, fmt.Errorf("couldn't create managed identity client: %w", err)
79+
}
80+
client = &realManagedIdentityClient{client: miClient}
6081
case UserAssignedIdentity:
6182
// UserAssignedIdentity is required to be specified when using a user assigned identity.
6283
if opts.UserAssignedClientID == "" {
6384
return nil, errors.New("user assigned client ID is required when using user assigned identity")
6485
}
6586
// UserAssignedIdentity is the type of identity that is managed by the user.
66-
client, err = mi.New(mi.UserAssignedClientID(opts.UserAssignedClientID))
67-
}
68-
69-
if err != nil {
70-
return nil, fmt.Errorf("couldn't create managed identity client: %w", err)
87+
miClient, err := mi.New(mi.UserAssignedClientID(opts.UserAssignedClientID))
88+
if err != nil {
89+
return nil, fmt.Errorf("couldn't create managed identity client: %w", err)
90+
}
91+
client = &realManagedIdentityClient{client: miClient}
7192
}
7293

7394
return &ManagedIdentityProvider{
7495
userAssignedClientID: opts.UserAssignedClientID,
7596
managedIdentityType: opts.ManagedIdentityType,
7697
scopes: opts.Scopes,
77-
client: &client,
98+
client: client,
7899
}, nil
79100
}
80101

Lines changed: 218 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,218 @@
1+
package identity
2+
3+
import (
4+
"context"
5+
"errors"
6+
"testing"
7+
"time"
8+
9+
mi "github.com/AzureAD/microsoft-authentication-library-for-go/apps/managedidentity"
10+
"github.com/AzureAD/microsoft-authentication-library-for-go/apps/public"
11+
"github.com/stretchr/testify/assert"
12+
"github.com/stretchr/testify/mock"
13+
)
14+
15+
// MockManagedIdentityClient is a mock implementation of the managed identity client
16+
type MockManagedIdentityClient struct {
17+
mock.Mock
18+
}
19+
20+
func (m *MockManagedIdentityClient) AcquireToken(ctx context.Context, resource string, opts ...mi.AcquireTokenOption) (public.AuthResult, error) {
21+
args := m.Called(ctx, resource)
22+
return args.Get(0).(public.AuthResult), args.Error(1)
23+
}
24+
25+
func TestNewManagedIdentityProvider(t *testing.T) {
26+
tests := []struct {
27+
name string
28+
opts ManagedIdentityProviderOptions
29+
expectedError string
30+
}{
31+
{
32+
name: "System assigned identity",
33+
opts: ManagedIdentityProviderOptions{
34+
ManagedIdentityType: SystemAssignedIdentity,
35+
Scopes: []string{"https://redis.azure.com"},
36+
},
37+
expectedError: "",
38+
},
39+
{
40+
name: "User assigned identity with client ID",
41+
opts: ManagedIdentityProviderOptions{
42+
ManagedIdentityType: UserAssignedIdentity,
43+
UserAssignedClientID: "test-client-id",
44+
Scopes: []string{"https://redis.azure.com"},
45+
},
46+
expectedError: "",
47+
},
48+
{
49+
name: "User assigned identity without client ID",
50+
opts: ManagedIdentityProviderOptions{
51+
ManagedIdentityType: UserAssignedIdentity,
52+
Scopes: []string{"https://redis.azure.com"},
53+
},
54+
expectedError: "user assigned client ID is required when using user assigned identity",
55+
},
56+
{
57+
name: "Invalid identity type",
58+
opts: ManagedIdentityProviderOptions{
59+
ManagedIdentityType: "invalid-type",
60+
Scopes: []string{"https://redis.azure.com"},
61+
},
62+
expectedError: "invalid managed identity type",
63+
},
64+
}
65+
66+
for _, tt := range tests {
67+
t.Run(tt.name, func(t *testing.T) {
68+
provider, err := NewManagedIdentityProvider(tt.opts)
69+
70+
if tt.expectedError != "" {
71+
assert.Error(t, err)
72+
assert.Contains(t, err.Error(), tt.expectedError)
73+
assert.Nil(t, provider)
74+
} else {
75+
assert.NoError(t, err)
76+
assert.NotNil(t, provider)
77+
assert.Equal(t, tt.opts.ManagedIdentityType, provider.managedIdentityType)
78+
assert.Equal(t, tt.opts.UserAssignedClientID, provider.userAssignedClientID)
79+
assert.Equal(t, tt.opts.Scopes, provider.scopes)
80+
assert.NotNil(t, provider.client)
81+
}
82+
})
83+
}
84+
}
85+
86+
func TestRequestToken(t *testing.T) {
87+
tests := []struct {
88+
name string
89+
provider *ManagedIdentityProvider
90+
expectedError string
91+
}{
92+
{
93+
name: "Success with default resource",
94+
provider: &ManagedIdentityProvider{
95+
scopes: []string{},
96+
client: new(MockManagedIdentityClient),
97+
},
98+
expectedError: "",
99+
},
100+
{
101+
name: "Success with custom resource",
102+
provider: &ManagedIdentityProvider{
103+
scopes: []string{"custom-resource"},
104+
client: new(MockManagedIdentityClient),
105+
},
106+
expectedError: "",
107+
},
108+
{
109+
name: "Error when client is nil",
110+
provider: &ManagedIdentityProvider{
111+
scopes: []string{},
112+
client: nil,
113+
},
114+
expectedError: "managed identity client is not initialized",
115+
},
116+
}
117+
118+
for _, tt := range tests {
119+
t.Run(tt.name, func(t *testing.T) {
120+
// Set up the mock expectations if we have a mock client
121+
if tt.provider.client != nil {
122+
mockClient := tt.provider.client.(*MockManagedIdentityClient)
123+
expectedResource := RedisResource
124+
if len(tt.provider.scopes) > 0 {
125+
expectedResource = tt.provider.scopes[0]
126+
}
127+
128+
if tt.expectedError == "" {
129+
mockClient.On("AcquireToken", mock.Anything, expectedResource).
130+
Return(public.AuthResult{
131+
AccessToken: "test-token",
132+
ExpiresOn: time.Now().Add(time.Hour),
133+
}, nil)
134+
} else {
135+
mockClient.On("AcquireToken", mock.Anything, expectedResource).
136+
Return(public.AuthResult{}, errors.New(tt.expectedError))
137+
}
138+
}
139+
140+
response, err := tt.provider.RequestToken()
141+
142+
if tt.expectedError != "" {
143+
assert.Error(t, err)
144+
assert.Contains(t, err.Error(), tt.expectedError)
145+
assert.Nil(t, response)
146+
} else {
147+
assert.NoError(t, err)
148+
assert.NotNil(t, response)
149+
}
150+
151+
// Verify mock expectations
152+
if tt.provider.client != nil {
153+
mockClient := tt.provider.client.(*MockManagedIdentityClient)
154+
mockClient.AssertExpectations(t)
155+
}
156+
})
157+
}
158+
}
159+
160+
func TestRequestToken_ErrorCases(t *testing.T) {
161+
tests := []struct {
162+
name string
163+
provider *ManagedIdentityProvider
164+
setupMock func(*MockManagedIdentityClient)
165+
expectedError string
166+
}{
167+
{
168+
name: "AcquireToken fails",
169+
provider: &ManagedIdentityProvider{
170+
scopes: []string{},
171+
client: new(MockManagedIdentityClient),
172+
},
173+
setupMock: func(m *MockManagedIdentityClient) {
174+
m.On("AcquireToken", mock.Anything, RedisResource).
175+
Return(public.AuthResult{}, errors.New("failed to acquire token"))
176+
},
177+
expectedError: "coudn't acquire manager: failed to acquire token",
178+
},
179+
{
180+
name: "AcquireToken fails with custom resource",
181+
provider: &ManagedIdentityProvider{
182+
scopes: []string{"custom-resource"},
183+
client: new(MockManagedIdentityClient),
184+
},
185+
setupMock: func(m *MockManagedIdentityClient) {
186+
m.On("AcquireToken", mock.Anything, "custom-resource").
187+
Return(public.AuthResult{}, errors.New("failed to acquire token"))
188+
},
189+
expectedError: "coudn't acquire manager: failed to acquire token",
190+
},
191+
{
192+
name: "AcquireToken fails with invalid resource",
193+
provider: &ManagedIdentityProvider{
194+
scopes: []string{"invalid-resource"},
195+
client: new(MockManagedIdentityClient),
196+
},
197+
setupMock: func(m *MockManagedIdentityClient) {
198+
m.On("AcquireToken", mock.Anything, "invalid-resource").
199+
Return(public.AuthResult{}, errors.New("invalid resource"))
200+
},
201+
expectedError: "coudn't acquire manager: invalid resource",
202+
},
203+
}
204+
205+
for _, tt := range tests {
206+
t.Run(tt.name, func(t *testing.T) {
207+
mockClient := tt.provider.client.(*MockManagedIdentityClient)
208+
tt.setupMock(mockClient)
209+
210+
response, err := tt.provider.RequestToken()
211+
212+
assert.Error(t, err)
213+
assert.Contains(t, err.Error(), tt.expectedError)
214+
assert.Nil(t, response)
215+
mockClient.AssertExpectations(t)
216+
})
217+
}
218+
}

0 commit comments

Comments
 (0)