Skip to content

Commit a0cf97a

Browse files
committed
Merge remote-tracking branch 'origin/main' into ndyakov/examples
2 parents e44f848 + 1e25b29 commit a0cf97a

15 files changed

+472
-390
lines changed

credentials_provider.go

Lines changed: 21 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,8 @@ type entraidCredentialsProvider struct {
2727

2828
// rwLock is a mutex that is used to synchronize access to the listeners slice.
2929
rwLock sync.RWMutex // Mutex for synchronizing access to the listeners slice.
30+
31+
tmLock sync.Mutex
3032
}
3133

3234
// onTokenNext is a method that is called when the token manager receives a new token.
@@ -60,15 +62,26 @@ func (e *entraidCredentialsProvider) onTokenError(err error) {
6062
//
6163
// Returns:
6264
// - auth.Credentials: The current credentials for the listener.
63-
// - auth.CancelProviderFunc: A function that can be called to unsubscribe the listener.
65+
// - auth.UnsubscribeFunc: A function that can be called to unsubscribe the listener.
6466
// - error: An error if the subscription fails, such as if the token cannot be retrieved.
6567
//
6668
// Note: If the listener is already subscribed, it will not receive duplicate notifications.
6769
func (e *entraidCredentialsProvider) Subscribe(listener auth.CredentialsListener) (auth.Credentials, auth.UnsubscribeFunc, error) {
68-
// First try to get a token, only then subscribe the listener.
70+
// check if the manager is working
71+
// If the stopTokenManager is nil, the token manager is not started.
72+
e.tmLock.Lock()
73+
if e.stopTokenManager == nil {
74+
stopTM, err := e.tokenManager.Start(tokenListenerFromCP(e))
75+
if err != nil {
76+
return nil, nil, fmt.Errorf("couldn't start token manager: %w", err)
77+
}
78+
e.stopTokenManager = stopTM
79+
}
80+
e.tmLock.Unlock()
81+
6982
token, err := e.tokenManager.GetToken(false)
7083
if err != nil {
71-
return nil, nil, err
84+
return nil, nil, fmt.Errorf("couldn't get token: %w", err)
7285
}
7386

7487
e.rwLock.Lock()
@@ -102,6 +115,7 @@ func (e *entraidCredentialsProvider) Subscribe(listener auth.CredentialsListener
102115
// Clear the listeners slice if it's empty
103116
if len(e.listeners) == 0 {
104117
e.listeners = make([]auth.CredentialsListener, 0)
118+
e.tmLock.Lock()
105119
if e.stopTokenManager != nil {
106120
err := e.stopTokenManager()
107121
if err != nil {
@@ -111,6 +125,7 @@ func (e *entraidCredentialsProvider) Subscribe(listener auth.CredentialsListener
111125
// This prevents multiple calls to stopTokenManager.
112126
e.stopTokenManager = nil
113127
}
128+
e.tmLock.Unlock()
114129
}
115130
return nil
116131
}
@@ -134,10 +149,11 @@ func NewCredentialsProvider(tokenManager manager.TokenManager, options Credentia
134149
options: options,
135150
listeners: make([]auth.CredentialsListener, 0),
136151
}
137-
stopTM, err := cp.tokenManager.Start(tokenListenerFromCP(cp))
152+
// Start the token manager.
153+
stop, err := tokenManager.Start(tokenListenerFromCP(cp))
138154
if err != nil {
139155
return nil, fmt.Errorf("couldn't start token manager: %w", err)
140156
}
141-
cp.stopTokenManager = stopTM
157+
cp.stopTokenManager = stop
142158
return cp, nil
143159
}

credentials_provider_test.go

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -343,7 +343,7 @@ func TestCredentialsProviderSubscribe(t *testing.T) {
343343
mtm.On("GetToken", false).Return(testToken, nil)
344344
mtm.On("Start", mock.Anything).
345345
Run(mockTokenManagerLoop(mtm, tokenExpiration, testToken, nil)).
346-
Return(manager.StopFunc(mtm.Stop), nil)
346+
Return(manager.StopFunc(mtm.stop), nil)
347347
provider, err := NewConfidentialCredentialsProvider(options)
348348
require.NoError(t, err)
349349
require.NotNil(t, provider)
@@ -396,7 +396,7 @@ func TestCredentialsProviderSubscribe(t *testing.T) {
396396

397397
mtm.On("Start", mock.Anything).
398398
Run(mockTokenManagerLoop(mtm, tokenExpiration, nil, errTokenError)).
399-
Return(manager.StopFunc(mtm.Stop), nil)
399+
Return(manager.StopFunc(mtm.stop), nil)
400400
provider, err := NewConfidentialCredentialsProvider(options)
401401
require.NoError(t, err)
402402
require.NotNil(t, provider)
@@ -467,7 +467,7 @@ func TestCredentialsProviderSubscribe(t *testing.T) {
467467

468468
mtm.On("Start", mock.Anything).
469469
Run(mockTokenManagerLoop(mtm, tokenExpiration, nil, errTokenError)).
470-
Return(manager.StopFunc(mtm.Stop), nil)
470+
Return(manager.StopFunc(mtm.stop), nil)
471471
provider, err := NewConfidentialCredentialsProvider(options)
472472
require.NoError(t, err)
473473
require.NotNil(t, provider)

entraid_test.go

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,7 @@ func (m *fakeTokenManager) GetToken(forceRefresh bool) (*token.Token, error) {
4848
rawTokenString,
4949
time.Now().Add(tokenExpiration),
5050
time.Now(),
51-
int64(100*time.Millisecond),
51+
int64(tokenExpiration.Seconds()),
5252
)
5353
}
5454
return m.token, m.err
@@ -84,7 +84,7 @@ func (m *fakeTokenManager) Start(listener manager.TokenListener) (manager.StopFu
8484
}, nil
8585
}
8686

87-
func (m *fakeTokenManager) Stop() error {
87+
func (m *fakeTokenManager) stop() error {
8888
return nil
8989
}
9090

@@ -163,7 +163,7 @@ func (m *mockTokenManager) Start(listener manager.TokenListener) (manager.StopFu
163163
m.lock.Unlock()
164164
return args.Get(0).(manager.StopFunc), args.Error(1)
165165
}
166-
func (m *mockTokenManager) Stop() error {
166+
func (m *mockTokenManager) stop() error {
167167
m.lock.Lock()
168168
defer m.lock.Unlock()
169169
if m.listener == nil {

internal/errors.go

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,6 @@ package internal
33
import "fmt"
44

55
var ErrInvalidIDPResponse = fmt.Errorf("invalid identity provider response")
6-
var ErrInvalidIDPResponseType = fmt.Errorf("invalid identity provider response type")
76
var ErrAuthResultNotFound = fmt.Errorf("auth result not found")
87
var ErrAccessTokenNotFound = fmt.Errorf("access token not found")
98
var ErrRawTokenNotFound = fmt.Errorf("raw token not found")

internal/idp_response.go

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -36,17 +36,17 @@ func NewIDPResp(resultType string, result interface{}) (*IDPResp, error) {
3636
default:
3737
return nil, fmt.Errorf("invalid auth result type: expected public.AuthResult or *public.AuthResult, got %T", result)
3838
}
39+
r.rawTokenVal = r.authResultVal.AccessToken
3940
case "AccessToken":
4041
switch v := result.(type) {
4142
case *azcore.AccessToken:
4243
r.accessTokenVal = v
43-
r.rawTokenVal = v.Token
4444
case azcore.AccessToken:
4545
r.accessTokenVal = &v
46-
r.rawTokenVal = v.Token
4746
default:
4847
return nil, fmt.Errorf("invalid access token type: expected azcore.AccessToken or *azcore.AccessToken, got %T", result)
4948
}
49+
r.rawTokenVal = r.accessTokenVal.Token
5050
case "RawToken":
5151
switch v := result.(type) {
5252
case string:

internal/utils.go

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,10 @@
11
package internal
22

33
// IsClosed checks if a channel is closed.
4+
//
5+
// NOTE: It returns true if the channel is closed as well
6+
// as if the channel is not empty. Used internally
7+
// to check if the channel is closed.
48
func IsClosed(ch <-chan struct{}) bool {
59
select {
610
case <-ch:

manager/defaults.go

Lines changed: 46 additions & 56 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@ package manager
33
import (
44
"errors"
55
"fmt"
6+
"math"
67
"net"
78
"os"
89
"time"
@@ -13,6 +14,7 @@ import (
1314
)
1415

1516
const (
17+
DefaultRequestTimeout = 30 * time.Second
1618
DefaultExpirationRefreshRatio = 0.7
1719
DefaultRetryOptionsMaxAttempts = 3
1820
DefaultRetryOptionsBackoffMultiplier = 2.0
@@ -85,91 +87,79 @@ func defaultTokenManagerOptionsOr(options TokenManagerOptions) TokenManagerOptio
8587
if options.ExpirationRefreshRatio == 0 {
8688
options.ExpirationRefreshRatio = DefaultExpirationRefreshRatio
8789
}
90+
if options.RequestTimeout == 0 {
91+
options.RequestTimeout = DefaultRequestTimeout
92+
}
8893
return options
8994
}
9095

9196
type defaultIdentityProviderResponseParser struct{}
9297

9398
// ParseResponse parses the response from the identity provider and extracts the token.
9499
// It takes an IdentityProviderResponse as an argument and returns a Token and an error if any.
95-
// The IdentityProviderResponse contains the raw token and the expiration time.
100+
// The raw token is extracted based on the IdentityProviderResponse Type and then
101+
// is parsed as a JWT token to extract the claims.
96102
func (*defaultIdentityProviderResponseParser) ParseResponse(response shared.IdentityProviderResponse) (*token.Token, error) {
97103
if response == nil {
98104
return nil, fmt.Errorf("identity provider response cannot be nil")
99105
}
100106

101107
var username, password, rawToken string
102108
var expiresOn time.Time
103-
now := time.Now().UTC()
109+
now := time.Now().UTC().Truncate(time.Second).Add(time.Second)
104110

105111
switch response.Type() {
106112
case shared.ResponseTypeAuthResult:
107113
authResult, err := response.(shared.AuthResultIDPResponse).AuthResult()
108114
if err != nil {
109115
return nil, fmt.Errorf("failed to get auth result: %w", err)
110116
}
111-
if authResult.ExpiresOn.IsZero() {
112-
return nil, fmt.Errorf("auth result expiration time is not set")
113-
}
114-
if authResult.IDToken.Oid == "" {
115-
return nil, fmt.Errorf("auth result OID is empty")
116-
}
117-
rawToken = authResult.IDToken.RawToken
118-
username = authResult.IDToken.Oid
119-
password = rawToken
120-
expiresOn = authResult.ExpiresOn.UTC()
121117

122-
case shared.ResponseTypeRawToken, shared.ResponseTypeAccessToken:
123-
var tokenStr string
124-
var err error
125-
if response.Type() == shared.ResponseTypeRawToken {
126-
tokenStr, err = response.(shared.RawTokenIDPResponse).RawToken()
127-
if err != nil {
128-
return nil, fmt.Errorf("failed to get raw token: %w", err)
129-
}
130-
}
131-
if response.Type() == shared.ResponseTypeAccessToken {
132-
accessToken, err := response.(shared.AccessTokenIDPResponse).AccessToken()
133-
if err != nil {
134-
return nil, fmt.Errorf("failed to get access token: %w", err)
135-
}
136-
if accessToken.Token == "" {
137-
return nil, fmt.Errorf("access token value is empty")
138-
}
139-
tokenStr = accessToken.Token
140-
expiresOn = accessToken.ExpiresOn.UTC()
141-
}
142-
143-
if tokenStr == "" {
144-
return nil, fmt.Errorf("raw token is empty")
118+
expiresOn = authResult.ExpiresOn.UTC()
119+
rawToken = authResult.AccessToken
120+
case shared.ResponseTypeAccessToken:
121+
accessToken, err := response.(shared.AccessTokenIDPResponse).AccessToken()
122+
if err != nil {
123+
return nil, fmt.Errorf("failed to get access token: %w", err)
145124
}
146125

147-
claims := struct {
148-
jwt.RegisteredClaims
149-
Oid string `json:"oid,omitempty"`
150-
}{}
151-
152-
// Parse the token to extract claims, but note that signature verification
153-
// should be handled by the identity provider
154-
_, _, err = jwt.NewParser().ParseUnverified(tokenStr, &claims)
126+
rawToken = accessToken.Token
127+
expiresOn = accessToken.ExpiresOn.UTC()
128+
case shared.ResponseTypeRawToken:
129+
tokenStr, err := response.(shared.RawTokenIDPResponse).RawToken()
155130
if err != nil {
156-
return nil, fmt.Errorf("failed to parse JWT token: %w", err)
131+
return nil, fmt.Errorf("failed to get raw token: %w", err)
157132
}
133+
rawToken = tokenStr
134+
default:
135+
return nil, fmt.Errorf("unsupported response type: %s", response.Type())
136+
}
158137

159-
if claims.Oid == "" {
160-
return nil, fmt.Errorf("JWT token does not contain OID claim")
161-
}
138+
if rawToken == "" {
139+
return nil, fmt.Errorf("raw token is empty")
140+
}
162141

163-
rawToken = tokenStr
164-
username = claims.Oid
165-
password = rawToken
142+
// Parse JWT
143+
claims := struct {
144+
jwt.RegisteredClaims
145+
Oid string `json:"oid,omitempty"`
146+
}{}
147+
148+
// Parse the token to extract claims, but note that signature verification
149+
// should be handled by the identity provider
150+
_, _, err := jwt.NewParser().ParseUnverified(rawToken, &claims)
151+
if err != nil {
152+
return nil, fmt.Errorf("failed to parse JWT token: %w", err)
153+
}
166154

167-
if expiresOn.IsZero() && claims.ExpiresAt != nil {
168-
expiresOn = claims.ExpiresAt.UTC()
169-
}
155+
if claims.Oid == "" {
156+
return nil, fmt.Errorf("JWT token does not contain OID claim")
157+
}
170158

171-
default:
172-
return nil, fmt.Errorf("unsupported response type: %s", response.Type())
159+
username = claims.Oid
160+
password = rawToken
161+
if expiresOn.IsZero() && claims.ExpiresAt != nil {
162+
expiresOn = claims.ExpiresAt.UTC()
173163
}
174164

175165
if expiresOn.IsZero() {
@@ -187,6 +177,6 @@ func (*defaultIdentityProviderResponseParser) ParseResponse(response shared.Iden
187177
rawToken,
188178
expiresOn,
189179
now,
190-
int64(time.Until(expiresOn).Seconds()),
180+
int64(math.Ceil(time.Until(expiresOn).Seconds())),
191181
), nil
192182
}

0 commit comments

Comments
 (0)