Skip to content

Commit 842fe6c

Browse files
committed
allow for pointer and non-pointer types in the idp response
1 parent 06bb497 commit 842fe6c

File tree

2 files changed

+106
-48
lines changed

2 files changed

+106
-48
lines changed

shared/identity_provider_response.go

Lines changed: 31 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -49,30 +49,47 @@ type IdentityProvider interface {
4949
// Type can be either AuthResult, AccessToken, or RawToken.
5050
// Second argument is the result of the type provided in the first argument.
5151
func NewIDPResponse(responseType string, result interface{}) (IdentityProviderResponse, error) {
52+
if result == nil {
53+
return nil, fmt.Errorf("result cannot be nil")
54+
}
55+
5256
r := &internal.IDPResp{ResultType: responseType}
5357

5458
switch responseType {
5559
case ResponseTypeAuthResult:
56-
if typed, ok := result.(*public.AuthResult); !ok {
57-
return nil, fmt.Errorf("expected AuthResult, got %T", result)
58-
} else {
59-
r.AuthResultVal = typed
60+
switch v := result.(type) {
61+
case *public.AuthResult:
62+
r.AuthResultVal = v
63+
case public.AuthResult:
64+
r.AuthResultVal = &v
65+
default:
66+
return nil, fmt.Errorf("invalid auth result type: expected public.AuthResult or *public.AuthResult, got %T with value %v", result, result)
6067
}
6168
case ResponseTypeAccessToken:
62-
if typed, ok := result.(*azcore.AccessToken); !ok {
63-
return nil, fmt.Errorf("expected AccessToken, got %T", result)
64-
} else {
65-
r.AccessTokenVal = typed
66-
r.RawTokenVal = typed.Token
69+
switch v := result.(type) {
70+
case *azcore.AccessToken:
71+
r.AccessTokenVal = v
72+
r.RawTokenVal = v.Token
73+
case azcore.AccessToken:
74+
r.AccessTokenVal = &v
75+
r.RawTokenVal = v.Token
76+
default:
77+
return nil, fmt.Errorf("invalid access token type: expected azcore.AccessToken or *azcore.AccessToken, got %T with value %v", result, result)
6778
}
6879
case ResponseTypeRawToken:
69-
if typed, ok := result.(string); !ok {
70-
return nil, fmt.Errorf("expected string, got %T", result)
71-
} else {
72-
r.RawTokenVal = typed
80+
switch v := result.(type) {
81+
case string:
82+
r.RawTokenVal = v
83+
case *string:
84+
if v == nil {
85+
return nil, fmt.Errorf("raw token cannot be nil")
86+
}
87+
r.RawTokenVal = *v
88+
default:
89+
return nil, fmt.Errorf("invalid raw token type: expected string or *string, got %T with value %v", result, result)
7390
}
7491
default:
75-
return nil, fmt.Errorf("unknown idp response type: %s", responseType)
92+
return nil, fmt.Errorf("unsupported identity provider response type: %s", responseType)
7693
}
7794
return r, nil
7895
}

shared/identity_provider_response_test.go

Lines changed: 75 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -66,70 +66,111 @@ func (m *mockIDP) RequestToken() (IdentityProviderResponse, error) {
6666

6767
func TestNewIDPResponse(t *testing.T) {
6868
tests := []struct {
69-
name string
70-
responseType string
71-
result interface{}
72-
wantErr bool
69+
name string
70+
responseType string
71+
result interface{}
72+
expectedError string
7373
}{
7474
{
75-
name: "Valid AuthResult",
75+
name: "Valid AuthResult pointer",
7676
responseType: ResponseTypeAuthResult,
7777
result: &public.AuthResult{},
78-
wantErr: false,
7978
},
8079
{
81-
name: "Valid AccessToken",
80+
name: "Valid AuthResult value",
81+
responseType: ResponseTypeAuthResult,
82+
result: public.AuthResult{},
83+
},
84+
{
85+
name: "Valid AccessToken pointer",
8286
responseType: ResponseTypeAccessToken,
83-
result: &azcore.AccessToken{},
84-
wantErr: false,
87+
result: &azcore.AccessToken{Token: "test-token"},
8588
},
8689
{
87-
name: "Valid RawToken",
90+
name: "Valid AccessToken value",
91+
responseType: ResponseTypeAccessToken,
92+
result: azcore.AccessToken{Token: "test-token"},
93+
},
94+
{
95+
name: "Valid RawToken string",
8896
responseType: ResponseTypeRawToken,
8997
result: "test-token",
90-
wantErr: false,
9198
},
9299
{
93-
name: "Invalid AuthResult type",
94-
responseType: ResponseTypeAuthResult,
95-
result: "not-an-auth-result",
96-
wantErr: true,
100+
name: "Valid RawToken string pointer",
101+
responseType: ResponseTypeRawToken,
102+
result: stringPtr("test-token"),
97103
},
98104
{
99-
name: "Invalid AccessToken type",
100-
responseType: ResponseTypeAccessToken,
101-
result: "not-an-access-token",
102-
wantErr: true,
105+
name: "Nil result",
106+
responseType: ResponseTypeAuthResult,
107+
result: nil,
108+
expectedError: "result cannot be nil",
103109
},
104110
{
105-
name: "Invalid RawToken type",
106-
responseType: ResponseTypeRawToken,
107-
result: 123,
108-
wantErr: true,
111+
name: "Nil string pointer",
112+
responseType: ResponseTypeRawToken,
113+
result: (*string)(nil),
114+
expectedError: "raw token cannot be nil",
115+
},
116+
{
117+
name: "Invalid AuthResult type",
118+
responseType: ResponseTypeAuthResult,
119+
result: "not-an-auth-result",
120+
expectedError: "invalid auth result type: expected public.AuthResult or *public.AuthResult",
121+
},
122+
{
123+
name: "Invalid AccessToken type",
124+
responseType: ResponseTypeAccessToken,
125+
result: "not-an-access-token",
126+
expectedError: "invalid access token type: expected azcore.AccessToken or *azcore.AccessToken",
109127
},
110128
{
111-
name: "Unknown response type",
112-
responseType: "UnknownType",
113-
result: nil,
114-
wantErr: true,
129+
name: "Invalid RawToken type",
130+
responseType: ResponseTypeRawToken,
131+
result: 123,
132+
expectedError: "invalid raw token type: expected string or *string",
133+
},
134+
{
135+
name: "Invalid response type",
136+
responseType: "InvalidType",
137+
result: "test",
138+
expectedError: "unsupported identity provider response type: InvalidType",
115139
},
116140
}
117141

118142
for _, tt := range tests {
119143
t.Run(tt.name, func(t *testing.T) {
120-
response, err := NewIDPResponse(tt.responseType, tt.result)
121-
if tt.wantErr {
144+
resp, err := NewIDPResponse(tt.responseType, tt.result)
145+
146+
if tt.expectedError != "" {
122147
assert.Error(t, err)
123-
assert.Nil(t, response)
124-
} else {
125-
assert.NoError(t, err)
126-
assert.NotNil(t, response)
127-
assert.Equal(t, tt.responseType, response.Type())
148+
assert.Contains(t, err.Error(), tt.expectedError)
149+
assert.Nil(t, resp)
150+
return
151+
}
152+
153+
assert.NoError(t, err)
154+
assert.NotNil(t, resp)
155+
assert.Equal(t, tt.responseType, resp.Type())
156+
157+
switch tt.responseType {
158+
case ResponseTypeAuthResult:
159+
assert.NotNil(t, resp.AuthResult())
160+
case ResponseTypeAccessToken:
161+
assert.NotNil(t, resp.AccessToken())
162+
assert.NotEmpty(t, resp.AccessToken().Token)
163+
case ResponseTypeRawToken:
164+
assert.NotEmpty(t, resp.RawToken())
128165
}
129166
})
130167
}
131168
}
132169

170+
func stringPtr(s string) *string {
171+
return &s
172+
}
173+
133174
func TestIdentityProviderResponse(t *testing.T) {
134175
now := time.Now()
135176
expires := now.Add(time.Hour)

0 commit comments

Comments
 (0)