Skip to content

Commit 9219edc

Browse files
committed
Token Management - Core Interface
1 parent 62ee78b commit 9219edc

File tree

5 files changed

+2099
-0
lines changed

5 files changed

+2099
-0
lines changed

manager/defaults.go

Lines changed: 180 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,180 @@
1+
package manager
2+
3+
import (
4+
"errors"
5+
"fmt"
6+
"net"
7+
"os"
8+
"time"
9+
10+
"github.com/golang-jwt/jwt/v5"
11+
"github.com/redis-developer/go-redis-entraid/shared"
12+
"github.com/redis-developer/go-redis-entraid/token"
13+
)
14+
15+
const (
16+
DefaultExpirationRefreshRatio = 0.7
17+
DefaultRetryOptionsMaxAttempts = 3
18+
DefaultRetryOptionsInitialDelayMs = 1000
19+
DefaultRetryOptionsBackoffMultiplier = 2.0
20+
DefaultRetryOptionsMaxDelayMs = 10000
21+
)
22+
23+
// defaultIsRetryable is a function that checks if the error is retriable.
24+
// It takes an error as an argument and returns a boolean value.
25+
// The function checks if the error is a net.Error and if it is a timeout or temporary error.
26+
// Returns true for nil errors.
27+
var defaultIsRetryable = func(err error) bool {
28+
if err == nil {
29+
return true
30+
}
31+
32+
var netErr net.Error
33+
if errors.As(err, &netErr) {
34+
// Check for timeout first as it's more specific
35+
if netErr.Timeout() {
36+
return true
37+
}
38+
// For temporary errors, we'll use a more modern approach
39+
var tempErr interface{ Temporary() bool }
40+
if errors.As(err, &tempErr) {
41+
return tempErr.Temporary()
42+
}
43+
}
44+
45+
return errors.Is(err, os.ErrDeadlineExceeded)
46+
}
47+
48+
// defaultRetryOptionsOr returns the default retry options if the provided options are not set.
49+
// It sets the maximum number of attempts, initial delay, maximum delay, and backoff multiplier.
50+
// The default values are 3 attempts, 1000 ms initial delay, 10000 ms maximum delay, and 2.0 backoff multiplier.
51+
// The values can be overridden by the user.
52+
func defaultRetryOptionsOr(retryOptions RetryOptions) RetryOptions {
53+
if retryOptions.IsRetryable == nil {
54+
retryOptions.IsRetryable = defaultIsRetryable
55+
}
56+
57+
if retryOptions.MaxAttempts <= 0 {
58+
retryOptions.MaxAttempts = DefaultRetryOptionsMaxAttempts
59+
}
60+
if retryOptions.InitialDelayMs == 0 {
61+
retryOptions.InitialDelayMs = DefaultRetryOptionsInitialDelayMs
62+
}
63+
if retryOptions.BackoffMultiplier == 0 {
64+
retryOptions.BackoffMultiplier = DefaultRetryOptionsBackoffMultiplier
65+
}
66+
if retryOptions.MaxDelayMs == 0 {
67+
retryOptions.MaxDelayMs = DefaultRetryOptionsMaxDelayMs
68+
}
69+
return retryOptions
70+
}
71+
72+
// defaultIdentityProviderResponseParserOr returns the default token parser if the provided token parser is not set.
73+
// It sets the default token parser to the defaultIdentityProviderResponseParser function.
74+
// The default token parser is used to parse the raw token and return a Token object.
75+
func defaultIdentityProviderResponseParserOr(idpResponseParser shared.IdentityProviderResponseParser) shared.IdentityProviderResponseParser {
76+
if idpResponseParser == nil {
77+
return &defaultIdentityProviderResponseParser{}
78+
}
79+
return idpResponseParser
80+
}
81+
82+
func defaultTokenManagerOptionsOr(options TokenManagerOptions) TokenManagerOptions {
83+
options.RetryOptions = defaultRetryOptionsOr(options.RetryOptions)
84+
options.IdentityProviderResponseParser = defaultIdentityProviderResponseParserOr(options.IdentityProviderResponseParser)
85+
if options.ExpirationRefreshRatio == 0 {
86+
options.ExpirationRefreshRatio = DefaultExpirationRefreshRatio
87+
}
88+
return options
89+
}
90+
91+
type defaultIdentityProviderResponseParser struct{}
92+
93+
// ParseResponse parses the response from the identity provider and extracts the token.
94+
// It takes an IdentityProviderResponse as an argument and returns a Token and an error if any.
95+
// The IdentityProviderResponse contains the raw token and the expiration time.
96+
func (*defaultIdentityProviderResponseParser) ParseResponse(response shared.IdentityProviderResponse) (*token.Token, error) {
97+
if response == nil {
98+
return nil, fmt.Errorf("identity provider response cannot be nil")
99+
}
100+
101+
var username, password, rawToken string
102+
var expiresOn time.Time
103+
now := time.Now().UTC()
104+
105+
switch response.Type() {
106+
case shared.ResponseTypeAuthResult:
107+
authResult := response.AuthResult()
108+
if authResult.ExpiresOn.IsZero() {
109+
return nil, fmt.Errorf("auth result expiration time is not set")
110+
}
111+
if authResult.IDToken.Oid == "" {
112+
return nil, fmt.Errorf("auth result OID is empty")
113+
}
114+
rawToken = authResult.IDToken.RawToken
115+
username = authResult.IDToken.Oid
116+
password = rawToken
117+
expiresOn = authResult.ExpiresOn.UTC()
118+
119+
case shared.ResponseTypeRawToken, shared.ResponseTypeAccessToken:
120+
tokenStr := response.RawToken()
121+
122+
if response.Type() == shared.ResponseTypeAccessToken {
123+
accessToken := response.AccessToken()
124+
if accessToken.Token == "" {
125+
return nil, fmt.Errorf("access token value is empty")
126+
}
127+
tokenStr = accessToken.Token
128+
expiresOn = accessToken.ExpiresOn.UTC()
129+
}
130+
131+
if tokenStr == "" {
132+
return nil, fmt.Errorf("raw token is empty")
133+
}
134+
135+
claims := struct {
136+
jwt.RegisteredClaims
137+
Oid string `json:"oid,omitempty"`
138+
}{}
139+
140+
// Parse the token to extract claims, but note that signature verification
141+
// should be handled by the identity provider
142+
_, _, err := jwt.NewParser().ParseUnverified(tokenStr, &claims)
143+
if err != nil {
144+
return nil, fmt.Errorf("failed to parse JWT token: %w", err)
145+
}
146+
147+
if claims.Oid == "" {
148+
return nil, fmt.Errorf("JWT token does not contain OID claim")
149+
}
150+
151+
rawToken = tokenStr
152+
username = claims.Oid
153+
password = rawToken
154+
155+
if expiresOn.IsZero() && claims.ExpiresAt != nil {
156+
expiresOn = claims.ExpiresAt.UTC()
157+
}
158+
159+
default:
160+
return nil, fmt.Errorf("unsupported response type: %s", response.Type())
161+
}
162+
163+
if expiresOn.IsZero() {
164+
return nil, fmt.Errorf("token expiration time is not set")
165+
}
166+
167+
if expiresOn.Before(now) {
168+
return nil, fmt.Errorf("token has expired at %s (current time: %s)", expiresOn, now)
169+
}
170+
171+
// Create the token with consistent time reference
172+
return token.New(
173+
username,
174+
password,
175+
rawToken,
176+
expiresOn,
177+
now,
178+
int64(time.Until(expiresOn).Seconds()),
179+
), nil
180+
}

manager/errors.go

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,9 @@
1+
package manager
2+
3+
import "fmt"
4+
5+
// ErrTokenManagerAlreadyCanceled is returned when the token manager is already canceled.
6+
var ErrTokenManagerAlreadyCanceled = fmt.Errorf("token manager already canceled")
7+
8+
// ErrTokenManagerAlreadyStarted is returned when the token manager is already started.
9+
var ErrTokenManagerAlreadyStarted = fmt.Errorf("token manager already started")

manager/manager_test.go

Lines changed: 180 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,180 @@
1+
package manager
2+
3+
import (
4+
"net"
5+
"os"
6+
"time"
7+
8+
"github.com/Azure/azure-sdk-for-go/sdk/azcore"
9+
"github.com/AzureAD/microsoft-authentication-library-for-go/apps/public"
10+
"github.com/redis-developer/go-redis-entraid/shared"
11+
"github.com/redis-developer/go-redis-entraid/token"
12+
"github.com/stretchr/testify/mock"
13+
)
14+
15+
// testJWTToken is a JWT token for testing
16+
//
17+
// {
18+
// "iss": "test jwt",
19+
// "iat": 1743515011,
20+
// "exp": 1775051011,
21+
// "aud": "www.example.com",
22+
// "sub": "[email protected]",
23+
// "oid": "test"
24+
// }
25+
//
26+
// key: qwertyuiopasdfghjklzxcvbnm123456
27+
const testJWTToken = "eyJ0eXAiOiJKV1QiLCJhbGciOiJIUzI1NiJ9.eyJpc3MiOiJ0ZXN0IGp3dCIsImlhdCI6MTc0MzUxNTAxMSwiZXhwIjoxNzc1MDUxMDExLCJhdWQiOiJ3d3cuZXhhbXBsZS5jb20iLCJzdWIiOiJ0ZXN0QHRlc3QuY29tIiwib2lkIjoidGVzdCJ9.6RG721V2eFlSLsCRmo53kSRRrTZIe1UPdLZCUEvIarU"
28+
29+
// testJWTExpiredToken is an expired JWT token for testing
30+
//
31+
// {
32+
// "iss": "test jwt",
33+
// "iat": 1617795148,
34+
// "exp": 1617795148,
35+
// "aud": "www.example.com",
36+
// "sub": "[email protected]",
37+
// "oid": "test"
38+
// }
39+
//
40+
// key: qwertyuiopasdfghjklzxcvbnm123456
41+
const testJWTExpiredToken = "eyJ0eXAiOiJKV1QiLCJhbGciOiJIUzI1NiJ9.eyJpc3MiOiJ0ZXN0IGp3dCIsImlhdCI6MTYxNzc5NTE0OCwiZXhwIjoxNjE3Nzk1MTQ4LCJhdWQiOiJ3d3cuZXhhbXBsZS5jb20iLCJzdWIiOiJ0ZXN0QHRlc3QuY29tIiwib2lkIjoidGVzdCJ9.IbGPhHRiPYcpUDrhAPf4h3gH1XXBOu560NYT59rUMzc"
42+
43+
// testJWTWithZeroExpiryToken is a JWT token with zero expiry for testing
44+
//
45+
// {
46+
// "iss": "test jwt",
47+
// "iat": 1744025944,
48+
// "exp": null,
49+
// "aud": "www.example.com",
50+
// "sub": "[email protected]",
51+
// "oid": "test"
52+
// }
53+
// key: qwertyuiopasdfghjklzxcvbnm123456
54+
const testJWTWithZeroExpiryToken = "eyJ0eXAiOiJKV1QiLCJhbGciOiJIUzI1NiJ9.eyJpc3MiOiJ0ZXN0IGp3dCIsImlhdCI6MTc0NDAyNTk0NCwiZXhwIjpudWxsLCJhdWQiOiJ3d3cuZXhhbXBsZS5jb20iLCJzdWIiOiJ0ZXN0QHRlc3QuY29tIiwib2lkIjoidGVzdCJ9.bLSANIzawE5Y6rgspvvUaRhkBq6Y4E0ggjXlmHRn8ew"
55+
56+
var testTokenValid = token.New(
57+
"test",
58+
"password",
59+
"test",
60+
time.Now().Add(time.Hour),
61+
time.Now(),
62+
int64(time.Hour),
63+
)
64+
65+
type mockIdentityProviderResponseParser struct {
66+
// Mock implementation of the IdentityProviderResponseParser interface
67+
mock.Mock
68+
}
69+
70+
func (m *mockIdentityProviderResponseParser) ParseResponse(response shared.IdentityProviderResponse) (*token.Token, error) {
71+
args := m.Called(response)
72+
if args.Get(0) == nil {
73+
return nil, args.Error(1)
74+
}
75+
return args.Get(0).(*token.Token), args.Error(1)
76+
}
77+
78+
type mockIdentityProvider struct {
79+
// Mock implementation of the mockIdentityProvider interface
80+
// Add any necessary fields or methods for the mock identity provider here
81+
mock.Mock
82+
}
83+
84+
func (m *mockIdentityProvider) RequestToken() (shared.IdentityProviderResponse, error) {
85+
args := m.Called()
86+
if args.Get(0) == nil {
87+
return nil, args.Error(1)
88+
}
89+
return args.Get(0).(shared.IdentityProviderResponse), args.Error(1)
90+
}
91+
92+
func newMockError(retriable bool) error {
93+
if retriable {
94+
return &mockError{
95+
isTimeout: true,
96+
isTemporary: true,
97+
error: os.ErrDeadlineExceeded,
98+
}
99+
} else {
100+
return &mockError{
101+
isTimeout: false,
102+
isTemporary: false,
103+
error: os.ErrInvalid,
104+
}
105+
}
106+
}
107+
108+
type mockError struct {
109+
// Mock implementation of the network error
110+
error
111+
isTimeout bool
112+
isTemporary bool
113+
}
114+
115+
func (m *mockError) Error() string {
116+
return "this is mock error"
117+
}
118+
119+
func (m *mockError) Timeout() bool {
120+
return m.isTimeout
121+
}
122+
func (m *mockError) Temporary() bool {
123+
return m.isTemporary
124+
}
125+
func (m *mockError) Unwrap() error {
126+
return m.error
127+
}
128+
129+
func (m *mockError) Is(err error) bool {
130+
return m.error == err
131+
}
132+
133+
var _ net.Error = (*mockError)(nil)
134+
135+
type mockTokenListener struct {
136+
// Mock implementation of the TokenManagerListener interface
137+
mock.Mock
138+
Id int32
139+
}
140+
141+
func (m *mockTokenListener) OnTokenNext(token *token.Token) {
142+
_ = m.Called(token)
143+
}
144+
145+
func (m *mockTokenListener) OnTokenError(err error) {
146+
_ = m.Called(err)
147+
}
148+
149+
type authResult struct {
150+
// ResultType is the type of the response (AuthResult, AccessToken, or RawToken)
151+
ResultType string
152+
// AuthResultVal is the auth result value
153+
AuthResultVal *public.AuthResult
154+
// AccessTokenVal is the access token value
155+
AccessTokenVal *azcore.AccessToken
156+
// RawTokenVal is the raw token value
157+
RawTokenVal string
158+
}
159+
160+
func (a *authResult) Type() string {
161+
return a.ResultType
162+
}
163+
164+
func (a *authResult) AuthResult() public.AuthResult {
165+
if a.AuthResultVal == nil {
166+
return public.AuthResult{}
167+
}
168+
return *a.AuthResultVal
169+
}
170+
171+
func (a *authResult) AccessToken() azcore.AccessToken {
172+
if a.AccessTokenVal == nil {
173+
return azcore.AccessToken{}
174+
}
175+
return *a.AccessTokenVal
176+
}
177+
178+
func (a *authResult) RawToken() string {
179+
return a.RawTokenVal
180+
}

0 commit comments

Comments
 (0)