Skip to content

Commit 77fe203

Browse files
committed
use idpResponse instead of plain raw token
1 parent 04f4bf0 commit 77fe203

7 files changed

+131
-48
lines changed

azure_default_identity_provider.go

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@ package entraid
22

33
import (
44
"context"
5-
"time"
5+
"fmt"
66

77
"github.com/Azure/azure-sdk-for-go/sdk/azcore/policy"
88
"github.com/Azure/azure-sdk-for-go/sdk/azidentity"
@@ -32,16 +32,16 @@ func NewDefaultAzureIdentityProvider(opts DefaultAzureIdentityProviderOptions) (
3232

3333
// RequestToken requests a token from the Azure Default Identity provider.
3434
// It returns the token, the expiration time, and an error if any.
35-
func (a *DefaultAzureIdentityProvider) RequestToken() (string, time.Time, error) {
35+
func (a *DefaultAzureIdentityProvider) RequestToken() (IdentityProviderResponse, error) {
3636
cred, err := azidentity.NewDefaultAzureCredential(a.options)
3737
if err != nil {
38-
return "", time.Time{}, err
38+
return nil, fmt.Errorf("failed to create default azure credential: %w", err)
3939
}
4040

4141
token, err := cred.GetToken(context.TODO(), policy.TokenRequestOptions{Scopes: a.scopes})
4242
if err != nil {
43-
return "", time.Time{}, err
43+
return nil, fmt.Errorf("failed to get token: %w", err)
4444
}
4545

46-
return token.Token, token.ExpiresOn.UTC(), nil
46+
return newIDPResponse(typeAccessToken, &token)
4747
}

confidential_identity_provider.go

Lines changed: 5 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,6 @@ import (
55
"crypto"
66
"crypto/x509"
77
"fmt"
8-
"time"
98

109
confidential "github.com/AzureAD/microsoft-authentication-library-for-go/apps/confidential"
1110
)
@@ -117,17 +116,16 @@ func NewConfidentialIdentityProvider(opts ConfidentialIdentityProviderOptions) (
117116
}
118117

119118
// RequestToken requests a token from the identity provider.
120-
// It returns the token, the expiration time, and an error if any.
121-
// The token is used to authenticate the identity when requesting a token.
122-
func (c *ConfidentialIdentityProvider) RequestToken() (string, time.Time, error) {
119+
// It returns the identity provider response, including the auth result.
120+
func (c *ConfidentialIdentityProvider) RequestToken() (IdentityProviderResponse, error) {
123121
if c.client == nil {
124-
return "", time.Time{}, fmt.Errorf("client is not initialized")
122+
return nil, fmt.Errorf("client is not initialized")
125123
}
126124

127125
result, err := c.client.AcquireTokenByCredential(context.TODO(), c.scopes)
128126
if err != nil {
129-
return "", time.Time{}, fmt.Errorf("failed to acquire token: %w", err)
127+
return nil, fmt.Errorf("failed to acquire token: %w", err)
130128
}
131129

132-
return result.AccessToken, result.ExpiresOn.UTC(), nil
130+
return newIDPResponse(typeAuthResult, &result)
133131
}

identity_provider.go

Lines changed: 66 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,76 @@
11
package entraid
22

3-
import "time"
3+
import (
4+
"fmt"
5+
6+
"github.com/Azure/azure-sdk-for-go/sdk/azcore"
7+
"github.com/AzureAD/microsoft-authentication-library-for-go/apps/public"
8+
)
9+
10+
const (
11+
// typeAuthResult is the type of the auth result.
12+
typeAuthResult = "AuthResult"
13+
// typeAccessToken is the type of the access token.
14+
typeAccessToken = "AccessToken"
15+
)
16+
17+
// IdentityProviderResponse is an interface that defines the methods for an identity provider authentication result.
18+
// It is used to get the type of the authentication result, the authentication result itself (can be AuthResult or AccessToken),
19+
type IdentityProviderResponse interface {
20+
// Type returns the type of the auth result
21+
Type() string
22+
AuthResult() *public.AuthResult
23+
AccessToken() *azcore.AccessToken
24+
}
425

526
// IdentityProvider is an interface that defines the methods for an identity provider.
627
// It is used to request a token for authentication.
728
// The identity provider is responsible for providing the raw authentication token.
829
type IdentityProvider interface {
930
// RequestToken requests a token from the identity provider.
1031
// It returns the token, the expiration time, and an error if any.
11-
RequestToken() (string, time.Time, error)
32+
RequestToken() (IdentityProviderResponse, error)
33+
}
34+
35+
type authResult struct {
36+
resultType string
37+
authResult *public.AuthResult
38+
accessToken *azcore.AccessToken
39+
}
40+
41+
func (a *authResult) Type() string {
42+
return a.resultType
43+
}
44+
45+
func (a *authResult) AuthResult() *public.AuthResult {
46+
return a.authResult
47+
}
48+
49+
func (a *authResult) AccessToken() *azcore.AccessToken {
50+
return a.accessToken
51+
}
52+
53+
// newAuthResult creates a new auth result based on the type provided.
54+
// It returns an IdentityProviderResponse interface.
55+
func newIDPResponse(t string, result interface{}) (IdentityProviderResponse, error) {
56+
r := &authResult{resultType: t}
57+
58+
switch t {
59+
case typeAuthResult:
60+
if typed, ok := result.(*public.AuthResult); !ok {
61+
return nil, fmt.Errorf("expected AuthResult, got %T", result)
62+
} else {
63+
r.authResult = typed
64+
}
65+
case typeAccessToken:
66+
if typed, ok := result.(*azcore.AccessToken); !ok {
67+
return nil, fmt.Errorf("expected AccessToken, got %T", result)
68+
} else {
69+
r.accessToken = typed
70+
}
71+
default:
72+
return nil, fmt.Errorf("unknown type: %s", t)
73+
}
74+
75+
return r, nil
1276
}

managed_identity_provider.go

Lines changed: 12 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -4,11 +4,12 @@ import (
44
"context"
55
"errors"
66
"fmt"
7-
"time"
87

98
mi "github.com/AzureAD/microsoft-authentication-library-for-go/apps/managedidentity"
109
)
1110

11+
// ManagedIdentityProviderOptions represents the options for the managed identity provider.
12+
// It is used to configure the identity provider when requesting a token.
1213
type ManagedIdentityProviderOptions struct {
1314
// UserAssignedClientID is the client ID of the user assigned identity.
1415
// This is used to identify the identity when requesting a token.
@@ -21,6 +22,7 @@ type ManagedIdentityProviderOptions struct {
2122
Scopes []string
2223
}
2324

25+
// ManagedIdentityProvider represents a managed identity provider.
2426
type ManagedIdentityProvider struct {
2527
// userAssignedClientID is the client ID of the user assigned identity.
2628
// This is used to identify the identity when requesting a token.
@@ -38,6 +40,8 @@ type ManagedIdentityProvider struct {
3840
client *mi.Client
3941
}
4042

43+
// NewManagedIdentityProvider creates a new managed identity provider for Azure with managed identity.
44+
// It is used to configure the identity provider when requesting a token.
4145
func NewManagedIdentityProvider(opts ManagedIdentityProviderOptions) (*ManagedIdentityProvider, error) {
4246
var client mi.Client
4347
var err error
@@ -73,9 +77,11 @@ func NewManagedIdentityProvider(opts ManagedIdentityProviderOptions) (*ManagedId
7377
}, nil
7478
}
7579

76-
func (m *ManagedIdentityProvider) RequestToken() (string, time.Time, error) {
80+
// RequestToken requests a token from the managed identity provider.
81+
// It returns IdentityProviderResponse, which contains the Acc and the expiration time.
82+
func (m *ManagedIdentityProvider) RequestToken() (IdentityProviderResponse, error) {
7783
if m.client == nil {
78-
return "", time.Time{}, errors.New("managed identity client is not initialized")
84+
return nil, errors.New("managed identity client is not initialized")
7985
}
8086

8187
// default resource is RedisResource == "https://redis.azure.com"
@@ -87,11 +93,10 @@ func (m *ManagedIdentityProvider) RequestToken() (string, time.Time, error) {
8793
}
8894
// acquire token using the managed identity client
8995
// the resource is the URL of the resource that the identity has access to
90-
token, err := m.client.AcquireToken(context.TODO(), resource)
96+
authResult, err := m.client.AcquireToken(context.TODO(), resource)
9197
if err != nil {
92-
return "", time.Time{}, fmt.Errorf("coudn't acquire token: %w", err)
98+
return nil, fmt.Errorf("coudn't acquire token: %w", err)
9399
}
94100

95-
// return the access token
96-
return token.AccessToken, token.ExpiresOn.UTC(), nil
101+
return newIDPResponse(typeAuthResult, &authResult)
97102
}

providers.go

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,9 @@ import (
66
"github.com/redis/go-redis/v9/auth"
77
)
88

9+
// CredentialsProviderOptions is a struct that holds the options for the credentials provider.
10+
// It is used to configure the credentials provider when requesting a token.
11+
// It is used to specify the client ID, TokenManagerOptions, and callback functions for re-authentication and retryable errors.
912
type CredentialsProviderOptions struct {
1013
// ClientID is the client ID of the identity.
1114
// This is used to identify the identity when requesting a token.

token.go

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -20,8 +20,8 @@ type Token struct {
2020
RawToken string `json:"raw_token"`
2121
}
2222

23-
// TokenParserFunc is a function that parses the token and returns the username and password.
24-
type TokenParserFunc func(token string, expiresOn time.Time) (*Token, error)
23+
// IdentityProviderResponseParserFunc is a function that parses the token and returns the username and password.
24+
type IdentityProviderResponseParserFunc func(response IdentityProviderResponse) (*Token, error)
2525

2626
// copyToken creates a copy of the token.
2727
func copyToken(token *Token) *Token {

token_manager.go

Lines changed: 38 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -31,12 +31,13 @@ type TokenManagerOptions struct {
3131
// default: 0 ms (no lower bound, refresh based on ExpirationRefreshRatio)
3232
LowerRefreshBoundMs int
3333

34-
// TokenParser is a function that parses the raw token and returns a Token object.
35-
// The function takes the raw token as a string and returns a Token object and an error.
34+
// IdentityProviderResponseParser is a function that parses the IdentityProviderResponse.
35+
// The function takes the response and based on its type returns the populated Token object.
3636
// If this function is not provided, the default implementation will be used.
3737
//
3838
// required: true
39-
TokenParser TokenParserFunc
39+
// default: defaultIdentityProviderResponseParser
40+
IdentityProviderResponseParser IdentityProviderResponseParserFunc
4041

4142
// RetryOptions is a struct that contains the options for retrying the token request.
4243
// It contains the maximum number of attempts, initial delay, maximum delay, and backoff multiplier.
@@ -84,17 +85,28 @@ type TokenManager interface {
8485
Close() error
8586
}
8687

87-
// defaultTokenParser is a function that parses the raw token and returns Token object.
88-
var defaultTokenParser = func(rawToken string, expiresOn time.Time) (*Token, error) {
89-
// Parse the token and return the username and password.
90-
// In this example, we are just returning the raw token as the password.
91-
// In a real implementation, you would parse the token and extract the username and password.
92-
// For example, if the token is a JWT, you would decode the JWT and extract the claims.
93-
// This is just a placeholder implementation.
94-
// You should replace this with your own implementation.
95-
if rawToken == "" {
96-
return nil, fmt.Errorf("raw token is empty")
88+
// defaultIdentityProviderResponseParser is a function that parses the token and returns the username and password.
89+
var defaultIdentityProviderResponseParser = func(response IdentityProviderResponse) (*Token, error) {
90+
var username, password, rawToken string
91+
var expiresOn time.Time
92+
if response == nil {
93+
return nil, fmt.Errorf("response is nil")
9794
}
95+
switch response.Type() {
96+
case typeAuthResult:
97+
authResult := response.AuthResult()
98+
if authResult == nil {
99+
return nil, fmt.Errorf("auth result is nil")
100+
}
101+
case typeAccessToken:
102+
accessToken := response.AccessToken()
103+
if accessToken == nil {
104+
return nil, fmt.Errorf("access token is nil")
105+
}
106+
default:
107+
return nil, fmt.Errorf("unknown response type: %s", response.Type())
108+
}
109+
98110
if expiresOn.IsZero() {
99111
return nil, fmt.Errorf("expires on is zero")
100112
}
@@ -143,8 +155,9 @@ func NewTokenManager(idp IdentityProvider, options TokenManagerOptions) (TokenMa
143155
type entraidTokenManager struct {
144156
idp IdentityProvider
145157
token *Token
158+
146159
// TokenParser is a function that parses the token.
147-
tokenParser TokenParserFunc
160+
identityProviderResponseParser IdentityProviderResponseParserFunc
148161

149162
// retryOptions is a struct that contains the options for retrying the token request.
150163
// It contains the maximum number of attempts, initial delay, maximum delay, and backoff multiplier.
@@ -175,12 +188,12 @@ func (e *entraidTokenManager) GetToken() (*Token, error) {
175188
return copyToken(e.token), nil
176189
}
177190

178-
rawToken, expiresOn, err := e.idp.RequestToken()
191+
idpResult, err := e.idp.RequestToken()
179192
if err != nil {
180-
return nil, fmt.Errorf("failed to request token: %w", err)
193+
return nil, fmt.Errorf("failed to request token from idp: %w", err)
181194
}
182195

183-
token, err := e.tokenParser(rawToken, expiresOn)
196+
token, err := e.identityProviderResponseParser(idpResult)
184197
if err != nil {
185198
return nil, fmt.Errorf("failed to parse token: %w", err)
186199
}
@@ -299,7 +312,7 @@ var defaultRetryableFunc = func(err error) bool {
299312
}
300313

301314
if ok := errors.As(err, netErr); ok {
302-
return netErr.Timeout() || netErr.Temporary()
315+
return netErr.Timeout()
303316
}
304317
return false
305318
}
@@ -328,19 +341,19 @@ func defaultRetryOptionsOr(retryOptions RetryOptions) RetryOptions {
328341
return retryOptions
329342
}
330343

331-
// defaultTokenParserOr returns the default token parser if the provided token parser is not set.
332-
// It sets the default token parser to the defaultTokenParser function.
344+
// defaultIdentityProviderResponseParserOr returns the default token parser if the provided token parser is not set.
345+
// It sets the default token parser to the defaultIdentityProviderResponseParser function.
333346
// The default token parser is used to parse the raw token and return a Token object.
334-
func defaultTokenParserOr(tokenParser TokenParserFunc) TokenParserFunc {
335-
if tokenParser == nil {
336-
return defaultTokenParser
347+
func defaultIdentityProviderResponseParserOr(idpResponseParser IdentityProviderResponseParserFunc) IdentityProviderResponseParserFunc {
348+
if idpResponseParser == nil {
349+
return defaultIdentityProviderResponseParser
337350
}
338-
return tokenParser
351+
return idpResponseParser
339352
}
340353

341354
func defaultTokenManagerOptionsOr(options TokenManagerOptions) TokenManagerOptions {
342355
options.RetryOptions = defaultRetryOptionsOr(options.RetryOptions)
343-
options.TokenParser = defaultTokenParserOr(options.TokenParser)
356+
options.IdentityProviderResponseParser = defaultIdentityProviderResponseParserOr(options.IdentityProviderResponseParser)
344357
if options.ExpirationRefreshRatio <= 0 || options.ExpirationRefreshRatio > 1 {
345358
options.ExpirationRefreshRatio = 0.7
346359
}

0 commit comments

Comments
 (0)