Skip to content

Commit 04f4bf0

Browse files
committed
improvements around error handling
1 parent 2ccdfe2 commit 04f4bf0

File tree

3 files changed

+114
-26
lines changed

3 files changed

+114
-26
lines changed

azure_default_identity_provider.go

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,9 +10,10 @@ import (
1010

1111
// DefaultAzureIdentityProviderOptions represents the options for the DefaultAzureIdentityProvider.
1212
type DefaultAzureIdentityProviderOptions struct {
13-
// Scopes is the list of scopes used to request a token from the identity provider.
13+
// AzureOptions is the options used to configure the Azure identity provider.
1414
AzureOptions *azidentity.DefaultAzureCredentialOptions
15-
Scopes []string
15+
// Scopes is the list of scopes used to request a token from the identity provider.
16+
Scopes []string
1617
}
1718

1819
type DefaultAzureIdentityProvider struct {

credentials_provider.go

Lines changed: 5 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -99,23 +99,18 @@ func (e *entraidCredentialsProvider) Subscribe(listener auth.CredentialsListener
9999
}
100100
}
101101
if len(e.listeners) == 0 {
102-
e.close()
102+
if e.cancelTokenManager != nil {
103+
e.cancelTokenManager()
104+
}
105+
e.cancelTokenManager = nil
106+
e.listeners = nil
103107
}
104108
return nil
105109
}
106110

107111
return credentials, cancel, nil
108112
}
109113

110-
func (e *entraidCredentialsProvider) close() {
111-
if e.cancelTokenManager != nil {
112-
e.cancelTokenManager()
113-
}
114-
e.cancelTokenManager = nil
115-
e.listeners = nil
116-
e.rwLock = sync.RWMutex{}
117-
}
118-
119114
type entraidTokenListener struct {
120115
onTokenNext func(token *Token)
121116
onTokenError func(err error)

token_manager.go

Lines changed: 106 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,10 @@
11
package entraid
22

33
import (
4+
"errors"
45
"fmt"
6+
"log"
7+
"net"
58
"sync"
69
"time"
710
)
@@ -15,27 +18,56 @@ type TokenManagerOptions struct {
1518
// The value should be between 0 and 1.
1619
// For example, if the expiration time is 1 hour and the ratio is 0.5,
1720
// the token will be refreshed after 30 minutes.
21+
//
22+
// default: 0.7
1823
ExpirationRefreshRatio float64
1924

25+
// LowerRefreshBoundMs is the lower bound for the refresh time in milliseconds.
26+
// It is used to determine when to refresh the token.
27+
// The value should be greater than 0.
28+
// For example, if the expiration time is 1 hour and the lower bound is 30 minutes,
29+
// the token will be refreshed after 30 minutes.
30+
//
31+
// default: 0 ms (no lower bound, refresh based on ExpirationRefreshRatio)
32+
LowerRefreshBoundMs int
33+
2034
// TokenParser is a function that parses the raw token and returns a Token object.
2135
// The function takes the raw token as a string and returns a Token object and an error.
2236
// If this function is not provided, the default implementation will be used.
37+
//
38+
// required: true
2339
TokenParser TokenParserFunc
2440

2541
// RetryOptions is a struct that contains the options for retrying the token request.
2642
// It contains the maximum number of attempts, initial delay, maximum delay, and backoff multiplier.
43+
//
44+
// The default values are 3 attempts, 1000 ms initial delay, 10000 ms maximum delay, and 2.0 backoff multiplier.
2745
RetryOptions RetryOptions
2846
}
2947

3048
// RetryOptions is a struct that contains the options for retrying the token request.
3149
type RetryOptions struct {
50+
// IsRetryable is a function that checks if the error is retryable.
51+
// It takes an error as an argument and returns a boolean value.
52+
//
53+
// default: defaultRetryableFunc
54+
IsRetryable func(err error) bool
3255
// MaxAttempts is the maximum number of attempts to retry the token request.
56+
//
57+
// default: 3
3358
MaxAttempts int
3459
// InitialDelayMs is the initial delay in milliseconds before retrying the token request.
60+
//
61+
// default: 1000 ms
3562
InitialDelayMs int
63+
3664
// MaxDelayMs is the maximum delay in milliseconds between retry attempts.
65+
//
66+
// default: 10000 ms
3767
MaxDelayMs int
68+
3869
// BackoffMultiplier is the multiplier for the backoff delay.
70+
// default: 2.0
3971
BackoffMultiplier float64
4072
}
4173

@@ -55,12 +87,30 @@ type TokenManager interface {
5587
// defaultTokenParser is a function that parses the raw token and returns Token object.
5688
var defaultTokenParser = func(rawToken string, expiresOn time.Time) (*Token, error) {
5789
// Parse the token and return the username and password.
58-
// This is a placeholder implementation.
90+
// In this example, we are just returning the raw token as the password.
91+
// In a real implementation, you would parse the token and extract the username and password.
92+
// For example, if the token is a JWT, you would decode the JWT and extract the claims.
93+
// This is just a placeholder implementation.
94+
// You should replace this with your own implementation.
95+
if rawToken == "" {
96+
return nil, fmt.Errorf("raw token is empty")
97+
}
98+
if expiresOn.IsZero() {
99+
return nil, fmt.Errorf("expires on is zero")
100+
}
101+
if expiresOn.Before(time.Now()) {
102+
return nil, fmt.Errorf("expires on is in the past")
103+
}
104+
if expiresOn.Sub(time.Now()) < MinTokenTTL {
105+
return nil, fmt.Errorf("expires on is less than minimum token TTL")
106+
}
107+
// parse token as jwt token and get claims
108+
59109
return &Token{
60110
Username: "username",
61-
Password: "password",
62-
ExpiresOn: expiresOn,
63-
TTL: expiresOn.Unix() - time.Now().Unix(),
111+
Password: rawToken,
112+
ExpiresOn: expiresOn.UTC(),
113+
TTL: expiresOn.UTC().Unix() - time.Now().UTC().Unix(),
64114
RawToken: rawToken,
65115
}, nil
66116
}
@@ -70,18 +120,22 @@ var defaultTokenParser = func(rawToken string, expiresOn time.Time) (*Token, err
70120
// The IdentityProvider is used to obtain the token, and the TokenManagerOptions contains options for the TokenManager.
71121
// The TokenManager is responsible for managing the token and refreshing it when necessary.
72122
func NewTokenManager(idp IdentityProvider, options TokenManagerOptions) (TokenManager, error) {
73-
tokenParser := defaultTokenParserOr(options.TokenParser)
74-
retryOptions := defaultRetryOptionsOr(options.RetryOptions)
123+
options = defaultTokenManagerOptionsOr(options)
124+
if options.ExpirationRefreshRatio <= 0 || options.ExpirationRefreshRatio > 1 {
125+
return nil, fmt.Errorf("expiration refresh ratio must be between 0 and 1")
126+
}
127+
75128
if idp == nil {
76129
return nil, fmt.Errorf("identity provider is required")
77130
}
78131

79132
return &entraidTokenManager{
80-
idp: idp,
81-
token: nil,
82-
closed: make(chan struct{}),
83-
tokenParser: tokenParser,
84-
retryOptions: retryOptions,
133+
idp: idp,
134+
token: nil,
135+
closed: make(chan struct{}),
136+
expirationRefreshRatio: options.ExpirationRefreshRatio,
137+
tokenParser: options.TokenParser,
138+
retryOptions: options.RetryOptions,
85139
}, nil
86140
}
87141

@@ -108,6 +162,10 @@ type entraidTokenManager struct {
108162
// lock locks the listener to prevent concurrent access.
109163
lock sync.Mutex
110164

165+
// expirationRefreshRatio is the ratio of the token expiration time to refresh the token.
166+
// It is used to determine when to refresh the token.
167+
expirationRefreshRatio float64
168+
111169
closed chan struct{}
112170
}
113171

@@ -169,7 +227,7 @@ func (e *entraidTokenManager) Start(listener TokenListener) (cancelFunc, error)
169227
// Simulate token refresh
170228
for {
171229
select {
172-
case <-time.After(time.Duration(e.token.TTL) * time.Second):
230+
case <-time.After(time.Until(token.ExpiresOn) * time.Duration(e.expirationRefreshRatio)):
173231
// Token is about to expire, refresh it
174232
for i := 0; i < e.retryOptions.MaxAttempts; i++ {
175233
token, err := e.GetToken()
@@ -178,8 +236,8 @@ func (e *entraidTokenManager) Start(listener TokenListener) (cancelFunc, error)
178236
break
179237
}
180238
// check if err is retryable
181-
if err.Error() == "retryable error" {
182-
// retry
239+
if e.retryOptions.IsRetryable(err) {
240+
// retryable error, continue to next attempt
183241
continue
184242
} else {
185243
// not retryable
@@ -216,6 +274,12 @@ func (e *entraidTokenManager) Start(listener TokenListener) (cancelFunc, error)
216274
}
217275

218276
func (e *entraidTokenManager) Close() error {
277+
defer func() {
278+
if r := recover(); r != nil {
279+
// handle panic
280+
log.Printf("Recovered from panic: %v", r)
281+
}
282+
}()
219283
e.lock.Lock()
220284
defer e.lock.Unlock()
221285
if e.listener != nil {
@@ -225,11 +289,30 @@ func (e *entraidTokenManager) Close() error {
225289
return nil
226290
}
227291

292+
// defaultRetryableFunc is a function that checks if the error is retryable.
293+
// It takes an error as an argument and returns a boolean value.
294+
// The function checks if the error is a net.Error and if it is a timeout or temporary error.
295+
var defaultRetryableFunc = func(err error) bool {
296+
var netErr net.Error
297+
if err == nil {
298+
return true
299+
}
300+
301+
if ok := errors.As(err, netErr); ok {
302+
return netErr.Timeout() || netErr.Temporary()
303+
}
304+
return false
305+
}
306+
228307
// defaultRetryOptionsOr returns the default retry options if the provided options are not set.
229308
// It sets the maximum number of attempts, initial delay, maximum delay, and backoff multiplier.
230309
// The default values are 3 attempts, 1000 ms initial delay, 10000 ms maximum delay, and 2.0 backoff multiplier.
231310
// The values can be overridden by the user.
232311
func defaultRetryOptionsOr(retryOptions RetryOptions) RetryOptions {
312+
if retryOptions.IsRetryable == nil {
313+
retryOptions.IsRetryable = defaultRetryableFunc
314+
}
315+
233316
if retryOptions.MaxAttempts <= 0 {
234317
retryOptions.MaxAttempts = 3
235318
}
@@ -254,3 +337,12 @@ func defaultTokenParserOr(tokenParser TokenParserFunc) TokenParserFunc {
254337
}
255338
return tokenParser
256339
}
340+
341+
func defaultTokenManagerOptionsOr(options TokenManagerOptions) TokenManagerOptions {
342+
options.RetryOptions = defaultRetryOptionsOr(options.RetryOptions)
343+
options.TokenParser = defaultTokenParserOr(options.TokenParser)
344+
if options.ExpirationRefreshRatio <= 0 || options.ExpirationRefreshRatio > 1 {
345+
options.ExpirationRefreshRatio = 0.7
346+
}
347+
return options
348+
}

0 commit comments

Comments
 (0)