Skip to content

Commit a9b3437

Browse files
committed
IDPResp struct fields should be unexported
1 parent 842fe6c commit a9b3437

File tree

6 files changed

+365
-157
lines changed

6 files changed

+365
-157
lines changed

internal/idp_response.go

Lines changed: 85 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,35 +1,110 @@
11
package internal
22

33
import (
4+
"fmt"
5+
46
"github.com/Azure/azure-sdk-for-go/sdk/azcore"
57
"github.com/AzureAD/microsoft-authentication-library-for-go/apps/public"
68
)
79

10+
// IDPResp represents a response from an Identity Provider (IDP)
11+
// It can contain either an AuthResult, AccessToken, or a raw token string
812
type IDPResp struct {
9-
ResultType string
10-
AuthResultVal *public.AuthResult
11-
AccessTokenVal *azcore.AccessToken
12-
RawTokenVal string
13+
// resultType indicates which type of response this is
14+
resultType string
15+
authResultVal *public.AuthResult
16+
accessTokenVal *azcore.AccessToken
17+
rawTokenVal string
18+
}
19+
20+
// NewIDPResp creates a new IDPResp with the given values
21+
// It validates the input and ensures the response type matches the provided value
22+
func NewIDPResp(resultType string, result interface{}) (*IDPResp, error) {
23+
if result == nil {
24+
return nil, fmt.Errorf("result cannot be nil")
25+
}
26+
27+
r := &IDPResp{resultType: resultType}
28+
29+
switch resultType {
30+
case "AuthResult":
31+
switch v := result.(type) {
32+
case *public.AuthResult:
33+
r.authResultVal = v
34+
case public.AuthResult:
35+
r.authResultVal = &v
36+
default:
37+
return nil, fmt.Errorf("invalid auth result type: expected public.AuthResult or *public.AuthResult, got %T", result)
38+
}
39+
case "AccessToken":
40+
switch v := result.(type) {
41+
case *azcore.AccessToken:
42+
r.accessTokenVal = v
43+
r.rawTokenVal = v.Token
44+
case azcore.AccessToken:
45+
r.accessTokenVal = &v
46+
r.rawTokenVal = v.Token
47+
default:
48+
return nil, fmt.Errorf("invalid access token type: expected azcore.AccessToken or *azcore.AccessToken, got %T", result)
49+
}
50+
case "RawToken":
51+
switch v := result.(type) {
52+
case string:
53+
r.rawTokenVal = v
54+
case *string:
55+
if v == nil {
56+
return nil, fmt.Errorf("raw token cannot be nil")
57+
}
58+
r.rawTokenVal = *v
59+
default:
60+
return nil, fmt.Errorf("invalid raw token type: expected string or *string, got %T", result)
61+
}
62+
default:
63+
return nil, fmt.Errorf("unsupported identity provider response type: %s", resultType)
64+
}
65+
66+
return r, nil
1367
}
1468

69+
// Type returns the type of response this IDPResp represents
1570
func (a *IDPResp) Type() string {
16-
return a.ResultType
71+
return a.resultType
1772
}
1873

74+
// AuthResult returns the AuthResult if present, or an empty AuthResult if not set
75+
// Use HasAuthResult() to check if the value is actually set
1976
func (a *IDPResp) AuthResult() public.AuthResult {
20-
if a.AuthResultVal == nil {
77+
if a.authResultVal == nil {
2178
return public.AuthResult{}
2279
}
23-
return *a.AuthResultVal
80+
return *a.authResultVal
81+
}
82+
83+
// HasAuthResult returns true if an AuthResult is set
84+
func (a *IDPResp) HasAuthResult() bool {
85+
return a.authResultVal != nil
2486
}
2587

88+
// AccessToken returns the AccessToken if present, or an empty AccessToken if not set
89+
// Use HasAccessToken() to check if the value is actually set
2690
func (a *IDPResp) AccessToken() azcore.AccessToken {
27-
if a.AccessTokenVal == nil {
91+
if a.accessTokenVal == nil {
2892
return azcore.AccessToken{}
2993
}
30-
return *a.AccessTokenVal
94+
return *a.accessTokenVal
3195
}
3296

97+
// HasAccessToken returns true if an AccessToken is set
98+
func (a *IDPResp) HasAccessToken() bool {
99+
return a.accessTokenVal != nil
100+
}
101+
102+
// RawToken returns the raw token string
33103
func (a *IDPResp) RawToken() string {
34-
return a.RawTokenVal
104+
return a.rawTokenVal
105+
}
106+
107+
// HasRawToken returns true if a raw token is set
108+
func (a *IDPResp) HasRawToken() bool {
109+
return a.rawTokenVal != ""
35110
}

internal/idp_response_test.go

Lines changed: 151 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@ import (
66

77
"github.com/Azure/azure-sdk-for-go/sdk/azcore"
88
"github.com/AzureAD/microsoft-authentication-library-for-go/apps/public"
9+
"github.com/stretchr/testify/assert"
910
)
1011

1112
func TestIDPResp_Type(t *testing.T) {
@@ -29,7 +30,7 @@ func TestIDPResp_Type(t *testing.T) {
2930
for _, tt := range tests {
3031
t.Run(tt.name, func(t *testing.T) {
3132
resp := &IDPResp{
32-
ResultType: tt.resultType,
33+
resultType: tt.resultType,
3334
}
3435
if got := resp.Type(); got != tt.want {
3536
t.Errorf("IDPResp.Type() = %v, want %v", got, tt.want)
@@ -68,7 +69,7 @@ func TestIDPResp_AuthResult(t *testing.T) {
6869
for _, tt := range tests {
6970
t.Run(tt.name, func(t *testing.T) {
7071
resp := &IDPResp{
71-
AuthResultVal: tt.authResult,
72+
authResultVal: tt.authResult,
7273
}
7374
got := resp.AuthResult()
7475
if got.AccessToken != tt.wantToken {
@@ -111,7 +112,7 @@ func TestIDPResp_AccessToken(t *testing.T) {
111112
for _, tt := range tests {
112113
t.Run(tt.name, func(t *testing.T) {
113114
resp := &IDPResp{
114-
AccessTokenVal: tt.accessToken,
115+
accessTokenVal: tt.accessToken,
115116
}
116117
got := resp.AccessToken()
117118
if got.Token != tt.wantToken {
@@ -145,11 +146,157 @@ func TestIDPResp_RawToken(t *testing.T) {
145146
for _, tt := range tests {
146147
t.Run(tt.name, func(t *testing.T) {
147148
resp := &IDPResp{
148-
RawTokenVal: tt.rawToken,
149+
rawTokenVal: tt.rawToken,
149150
}
150151
if got := resp.RawToken(); got != tt.want {
151152
t.Errorf("IDPResp.RawToken() = %v, want %v", got, tt.want)
152153
}
153154
})
154155
}
155156
}
157+
158+
func TestNewIDPResp(t *testing.T) {
159+
tests := []struct {
160+
name string
161+
resultType string
162+
result interface{}
163+
wantErr bool
164+
checkResult func(t *testing.T, resp *IDPResp)
165+
}{
166+
{
167+
name: "valid AuthResult pointer",
168+
resultType: "AuthResult",
169+
result: &public.AuthResult{
170+
AccessToken: "test-token",
171+
},
172+
wantErr: false,
173+
checkResult: func(t *testing.T, resp *IDPResp) {
174+
assert.True(t, resp.HasAuthResult())
175+
assert.Equal(t, "test-token", resp.AuthResult().AccessToken)
176+
assert.False(t, resp.HasAccessToken())
177+
assert.False(t, resp.HasRawToken())
178+
},
179+
},
180+
{
181+
name: "valid AuthResult value",
182+
resultType: "AuthResult",
183+
result: public.AuthResult{
184+
AccessToken: "test-token",
185+
},
186+
wantErr: false,
187+
checkResult: func(t *testing.T, resp *IDPResp) {
188+
assert.True(t, resp.HasAuthResult())
189+
assert.Equal(t, "test-token", resp.AuthResult().AccessToken)
190+
},
191+
},
192+
{
193+
name: "valid AccessToken pointer",
194+
resultType: "AccessToken",
195+
result: &azcore.AccessToken{
196+
Token: "test-token",
197+
ExpiresOn: time.Now(),
198+
},
199+
wantErr: false,
200+
checkResult: func(t *testing.T, resp *IDPResp) {
201+
assert.True(t, resp.HasAccessToken())
202+
assert.Equal(t, "test-token", resp.AccessToken().Token)
203+
assert.Equal(t, "test-token", resp.RawToken())
204+
},
205+
},
206+
{
207+
name: "valid AccessToken value",
208+
resultType: "AccessToken",
209+
result: azcore.AccessToken{
210+
Token: "test-token",
211+
ExpiresOn: time.Now(),
212+
},
213+
wantErr: false,
214+
checkResult: func(t *testing.T, resp *IDPResp) {
215+
assert.True(t, resp.HasAccessToken())
216+
assert.Equal(t, "test-token", resp.AccessToken().Token)
217+
assert.Equal(t, "test-token", resp.RawToken())
218+
},
219+
},
220+
{
221+
name: "valid RawToken string",
222+
resultType: "RawToken",
223+
result: "test-token",
224+
wantErr: false,
225+
checkResult: func(t *testing.T, resp *IDPResp) {
226+
assert.True(t, resp.HasRawToken())
227+
assert.Equal(t, "test-token", resp.RawToken())
228+
assert.False(t, resp.HasAuthResult())
229+
assert.False(t, resp.HasAccessToken())
230+
},
231+
},
232+
{
233+
name: "valid RawToken string pointer",
234+
resultType: "RawToken",
235+
result: stringPtr("test-token"),
236+
wantErr: false,
237+
checkResult: func(t *testing.T, resp *IDPResp) {
238+
assert.True(t, resp.HasRawToken())
239+
assert.Equal(t, "test-token", resp.RawToken())
240+
},
241+
},
242+
{
243+
name: "nil result",
244+
resultType: "AuthResult",
245+
result: nil,
246+
wantErr: true,
247+
},
248+
{
249+
name: "nil RawToken pointer",
250+
resultType: "RawToken",
251+
result: (*string)(nil),
252+
wantErr: true,
253+
},
254+
{
255+
name: "invalid AuthResult type",
256+
resultType: "AuthResult",
257+
result: "not-an-auth-result",
258+
wantErr: true,
259+
},
260+
{
261+
name: "invalid AccessToken type",
262+
resultType: "AccessToken",
263+
result: "not-an-access-token",
264+
wantErr: true,
265+
},
266+
{
267+
name: "invalid RawToken type",
268+
resultType: "RawToken",
269+
result: 123,
270+
wantErr: true,
271+
},
272+
{
273+
name: "unsupported result type",
274+
resultType: "InvalidType",
275+
result: "test",
276+
wantErr: true,
277+
},
278+
}
279+
280+
for _, tt := range tests {
281+
t.Run(tt.name, func(t *testing.T) {
282+
got, err := NewIDPResp(tt.resultType, tt.result)
283+
if tt.wantErr {
284+
assert.Error(t, err)
285+
assert.Nil(t, got)
286+
return
287+
}
288+
289+
assert.NoError(t, err)
290+
assert.NotNil(t, got)
291+
assert.Equal(t, tt.resultType, got.Type())
292+
293+
if tt.checkResult != nil {
294+
tt.checkResult(t, got)
295+
}
296+
})
297+
}
298+
}
299+
300+
func stringPtr(s string) *string {
301+
return &s
302+
}

manager/defaults.go

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -118,9 +118,6 @@ func (*defaultIdentityProviderResponseParser) ParseResponse(response shared.Iden
118118

119119
case shared.ResponseTypeRawToken, shared.ResponseTypeAccessToken:
120120
tokenStr := response.RawToken()
121-
if tokenStr == "" {
122-
return nil, fmt.Errorf("raw token is empty")
123-
}
124121

125122
if response.Type() == shared.ResponseTypeAccessToken {
126123
accessToken := response.AccessToken()
@@ -131,6 +128,10 @@ func (*defaultIdentityProviderResponseParser) ParseResponse(response shared.Iden
131128
expiresOn = accessToken.ExpiresOn.UTC()
132129
}
133130

131+
if tokenStr == "" {
132+
return nil, fmt.Errorf("raw token is empty")
133+
}
134+
134135
claims := struct {
135136
jwt.RegisteredClaims
136137
Oid string `json:"oid,omitempty"`

manager/manager_test.go

Lines changed: 34 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,8 @@ import (
55
"os"
66
"time"
77

8-
"github.com/redis-developer/go-redis-entraid/internal"
8+
"github.com/Azure/azure-sdk-for-go/sdk/azcore"
9+
"github.com/AzureAD/microsoft-authentication-library-for-go/apps/public"
910
"github.com/redis-developer/go-redis-entraid/shared"
1011
"github.com/redis-developer/go-redis-entraid/token"
1112
"github.com/stretchr/testify/mock"
@@ -145,4 +146,35 @@ func (m *mockTokenListener) OnTokenError(err error) {
145146
_ = m.Called(err)
146147
}
147148

148-
type authResult = internal.IDPResp
149+
type authResult struct {
150+
// ResultType is the type of the auth result
151+
ResultType string
152+
// AuthResultVal is the auth result value
153+
AuthResultVal *public.AuthResult
154+
// AccessTokenVal is the access token value
155+
AccessTokenVal *azcore.AccessToken
156+
// RawTokenVal is the raw token value
157+
RawTokenVal string
158+
}
159+
160+
func (a *authResult) Type() string {
161+
return a.ResultType
162+
}
163+
164+
func (a *authResult) AuthResult() public.AuthResult {
165+
if a.AuthResultVal == nil {
166+
return public.AuthResult{}
167+
}
168+
return *a.AuthResultVal
169+
}
170+
171+
func (a *authResult) AccessToken() azcore.AccessToken {
172+
if a.AccessTokenVal == nil {
173+
return azcore.AccessToken{}
174+
}
175+
return *a.AccessTokenVal
176+
}
177+
178+
func (a *authResult) RawToken() string {
179+
return a.RawTokenVal
180+
}

0 commit comments

Comments
 (0)