Skip to content

Commit 06bb497

Browse files
committed
improve token parsing
1 parent 3555f62 commit 06bb497

File tree

4 files changed

+83
-88
lines changed

4 files changed

+83
-88
lines changed

manager/defaults.go

Lines changed: 48 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -20,18 +20,26 @@ const (
2020
DefaultRetryOptionsMaxDelayMs = 10000
2121
)
2222

23-
// defaultRetryableFunc is a function that checks if the error is retriable.
23+
// defaultIsRetryable is a function that checks if the error is retriable.
2424
// It takes an error as an argument and returns a boolean value.
2525
// The function checks if the error is a net.Error and if it is a timeout or temporary error.
26+
// Returns true for nil errors.
2627
var defaultIsRetryable = func(err error) bool {
27-
var netErr net.Error
2828
if err == nil {
2929
return true
3030
}
3131

32-
// nolint:staticcheck // SA1019 deprecated netErr.Temporary
33-
if ok := errors.As(err, &netErr); ok {
34-
return netErr.Timeout() || netErr.Temporary()
32+
var netErr net.Error
33+
if errors.As(err, &netErr) {
34+
// Check for timeout first as it's more specific
35+
if netErr.Timeout() {
36+
return true
37+
}
38+
// For temporary errors, we'll use a more modern approach
39+
var tempErr interface{ Temporary() bool }
40+
if errors.As(err, &tempErr) {
41+
return tempErr.Temporary()
42+
}
3543
}
3644

3745
return errors.Is(err, os.ErrDeadlineExceeded)
@@ -86,29 +94,40 @@ type defaultIdentityProviderResponseParser struct{}
8694
// It takes an IdentityProviderResponse as an argument and returns a Token and an error if any.
8795
// The IdentityProviderResponse contains the raw token and the expiration time.
8896
func (*defaultIdentityProviderResponseParser) ParseResponse(response shared.IdentityProviderResponse) (*token.Token, error) {
89-
var username, password, rawToken string
90-
var expiresOn time.Time
9197
if response == nil {
92-
return nil, fmt.Errorf("response is nil")
98+
return nil, fmt.Errorf("identity provider response cannot be nil")
9399
}
100+
101+
var username, password, rawToken string
102+
var expiresOn time.Time
103+
now := time.Now().UTC()
104+
94105
switch response.Type() {
95106
case shared.ResponseTypeAuthResult:
96107
authResult := response.AuthResult()
97108
if authResult.ExpiresOn.IsZero() {
98-
return nil, fmt.Errorf("auth result invalid")
109+
return nil, fmt.Errorf("auth result expiration time is not set")
110+
}
111+
if authResult.IDToken.Oid == "" {
112+
return nil, fmt.Errorf("auth result OID is empty")
99113
}
100114
rawToken = authResult.IDToken.RawToken
101115
username = authResult.IDToken.Oid
102116
password = rawToken
103117
expiresOn = authResult.ExpiresOn.UTC()
118+
104119
case shared.ResponseTypeRawToken, shared.ResponseTypeAccessToken:
105-
token := response.RawToken()
120+
tokenStr := response.RawToken()
121+
if tokenStr == "" {
122+
return nil, fmt.Errorf("raw token is empty")
123+
}
124+
106125
if response.Type() == shared.ResponseTypeAccessToken {
107126
accessToken := response.AccessToken()
108127
if accessToken.Token == "" {
109-
return nil, fmt.Errorf("access token is empty")
128+
return nil, fmt.Errorf("access token value is empty")
110129
}
111-
token = accessToken.Token
130+
tokenStr = accessToken.Token
112131
expiresOn = accessToken.ExpiresOn.UTC()
113132
}
114133

@@ -117,40 +136,44 @@ func (*defaultIdentityProviderResponseParser) ParseResponse(response shared.Iden
117136
Oid string `json:"oid,omitempty"`
118137
}{}
119138

120-
// jwt token should be verified from the identity provider
121-
_, _, err := jwt.NewParser().ParseUnverified(token, &claims)
139+
// Parse the token to extract claims, but note that signature verification
140+
// should be handled by the identity provider
141+
_, _, err := jwt.NewParser().ParseUnverified(tokenStr, &claims)
122142
if err != nil {
123-
return nil, fmt.Errorf("failed to parse jwt token: %w", err)
143+
return nil, fmt.Errorf("failed to parse JWT token: %w", err)
124144
}
125-
rawToken = token
145+
146+
if claims.Oid == "" {
147+
return nil, fmt.Errorf("JWT token does not contain OID claim")
148+
}
149+
150+
rawToken = tokenStr
126151
username = claims.Oid
127152
password = rawToken
128153

129154
if expiresOn.IsZero() && claims.ExpiresAt != nil {
130-
expiresOn = claims.ExpiresAt.Time
155+
expiresOn = claims.ExpiresAt.UTC()
131156
}
132157

133158
default:
134-
return nil, fmt.Errorf("unknown response type: %s", response.Type())
159+
return nil, fmt.Errorf("unsupported response type: %s", response.Type())
135160
}
136161

137-
expiresOn = expiresOn.UTC()
138-
139162
if expiresOn.IsZero() {
140-
return nil, fmt.Errorf("expires on is zero")
163+
return nil, fmt.Errorf("token expiration time is not set")
141164
}
142165

143-
if expiresOn.Before(time.Now()) {
144-
return nil, fmt.Errorf("expires on is in the past")
166+
if expiresOn.Before(now) {
167+
return nil, fmt.Errorf("token has expired at %s (current time: %s)", expiresOn, now)
145168
}
146169

147-
// parse token as jwt token and get claims
170+
// Create the token with consistent time reference
148171
return token.New(
149172
username,
150173
password,
151174
rawToken,
152175
expiresOn,
153-
time.Now().UTC(),
176+
now,
154177
int64(time.Until(expiresOn).Seconds()),
155178
), nil
156179
}

manager/token_manager.go

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -334,6 +334,7 @@ func (e *entraidTokenManager) durationToRenewal() time.Duration {
334334
e.tokenRWLock.RUnlock()
335335
return 0
336336
}
337+
337338
timeTillExpiration := time.Until(e.token.ExpirationOn())
338339
e.tokenRWLock.RUnlock()
339340

manager/token_manager_test.go

Lines changed: 33 additions & 63 deletions
Original file line numberDiff line numberDiff line change
@@ -426,9 +426,7 @@ func TestDefaultIdentityProviderResponseParser(t *testing.T) {
426426
parser := &defaultIdentityProviderResponseParser{}
427427
t.Run("Default IdentityProviderResponseParser with type AuthResult", func(t *testing.T) {
428428
t.Parallel()
429-
authResult := &public.AuthResult{
430-
ExpiresOn: time.Now().Add(time.Hour).UTC(),
431-
}
429+
authResult := testAuthResult(time.Now().Add(time.Hour).UTC())
432430
idpResponse, err := shared.NewIDPResponse(shared.ResponseTypeAuthResult,
433431
authResult)
434432
assert.NoError(t, err)
@@ -530,9 +528,7 @@ func TestDefaultIdentityProviderResponseParser(t *testing.T) {
530528
})
531529
t.Run("Default IdentityProviderResponseParser with expired token", func(t *testing.T) {
532530
t.Parallel()
533-
authResult := &public.AuthResult{
534-
ExpiresOn: time.Now().Add(-time.Hour).UTC(),
535-
}
531+
authResult := testAuthResult(time.Now().Add(-time.Hour))
536532
idpResponse, err := shared.NewIDPResponse(shared.ResponseTypeAuthResult,
537533
authResult)
538534
assert.NoError(t, err)
@@ -542,9 +538,7 @@ func TestDefaultIdentityProviderResponseParser(t *testing.T) {
542538
})
543539
t.Run("Default IdentityProviderResponseParser with token that expired", func(t *testing.T) {
544540
t.Parallel()
545-
authResult := &public.AuthResult{
546-
ExpiresOn: time.Now().Add(-time.Hour).UTC(),
547-
}
541+
authResult := testAuthResult(time.Now().Add(-time.Hour))
548542
idpResponse, err := shared.NewIDPResponse(shared.ResponseTypeAuthResult,
549543
authResult)
550544
assert.NoError(t, err)
@@ -626,9 +620,7 @@ func TestEntraidTokenManager_GetToken(t *testing.T) {
626620
)
627621
assert.NoError(t, err)
628622

629-
authResult := &public.AuthResult{
630-
ExpiresOn: time.Now().Add(-time.Hour).UTC(),
631-
}
623+
authResult := testAuthResult(time.Now().Add(-time.Hour))
632624
idpResponse, err := shared.NewIDPResponse(shared.ResponseTypeAuthResult,
633625
authResult)
634626
assert.NoError(t, err)
@@ -730,9 +722,7 @@ func TestEntraidTokenManager_durationToRenewal(t *testing.T) {
730722

731723
// get token that expires before the lower bound
732724
assert.NotPanics(t, func() {
733-
expiresSoon := &public.AuthResult{
734-
ExpiresOn: time.Now().Add(tm.lowerBoundDuration - time.Minute).UTC(),
735-
}
725+
expiresSoon := testAuthResult(time.Now().Add(tm.lowerBoundDuration - time.Minute).UTC())
736726
idpResponse, err := shared.NewIDPResponse(shared.ResponseTypeAuthResult,
737727
expiresSoon)
738728
assert.NoError(t, err)
@@ -750,9 +740,7 @@ func TestEntraidTokenManager_durationToRenewal(t *testing.T) {
750740
// get token that expires after the lower bound and expirationRefreshRatio to 1
751741
assert.NotPanics(t, func() {
752742
tm.expirationRefreshRatio = 1
753-
expiresAfterlb := &public.AuthResult{
754-
ExpiresOn: time.Now().Add(tm.lowerBoundDuration + time.Hour).UTC(),
755-
}
743+
expiresAfterlb := testAuthResult(time.Now().Add(tm.lowerBoundDuration + time.Hour).UTC())
756744
idpResponse, err := shared.NewIDPResponse(shared.ResponseTypeAuthResult,
757745
expiresAfterlb)
758746
assert.NoError(t, err)
@@ -790,9 +778,7 @@ func TestEntraidTokenManager_Streaming(t *testing.T) {
790778

791779
expiresIn := time.Second
792780
expiresOn := time.Now().Add(expiresIn).UTC()
793-
authResult := &public.AuthResult{
794-
ExpiresOn: expiresOn,
795-
}
781+
authResult := testAuthResult(expiresOn)
796782
idpResponse, err := shared.NewIDPResponse(shared.ResponseTypeAuthResult,
797783
authResult)
798784
assert.NoError(t, err)
@@ -851,9 +837,8 @@ func TestEntraidTokenManager_Streaming(t *testing.T) {
851837

852838
expiresIn := time.Second
853839
expiresOn := time.Now().Add(expiresIn).UTC()
854-
res := &public.AuthResult{
855-
ExpiresOn: expiresOn,
856-
}
840+
841+
res := testAuthResult(expiresOn)
857842
idpResponse, err := shared.NewIDPResponse(shared.ResponseTypeAuthResult,
858843
res)
859844
assert.NoError(t, err)
@@ -862,9 +847,7 @@ func TestEntraidTokenManager_Streaming(t *testing.T) {
862847
var start, stop time.Time
863848
idp.On("RequestToken").Run(func(args mock.Arguments) {
864849
expiresOn := time.Now().Add(expiresIn).UTC()
865-
res := &public.AuthResult{
866-
ExpiresOn: expiresOn,
867-
}
850+
res := testAuthResult(expiresOn)
868851
response := idpResponse.(*authResult)
869852
response.AuthResultVal = res
870853
if atomic.LoadInt32(&twice) == 1 {
@@ -920,17 +903,13 @@ func TestEntraidTokenManager_Streaming(t *testing.T) {
920903

921904
expiresIn := time.Second
922905
expiresOn := time.Now().Add(expiresIn).UTC()
923-
res := &public.AuthResult{
924-
ExpiresOn: expiresOn,
925-
}
906+
res := testAuthResult(expiresOn)
926907
idpResponse, err := shared.NewIDPResponse(shared.ResponseTypeAuthResult,
927908
res)
928909
assert.NoError(t, err)
929910
idp.On("RequestToken").Run(func(args mock.Arguments) {
930911
expiresOn := time.Now().Add(expiresIn).UTC()
931-
res := &public.AuthResult{
932-
ExpiresOn: expiresOn,
933-
}
912+
res := testAuthResult(expiresOn)
934913
response := idpResponse.(*authResult)
935914
response.AuthResultVal = res
936915
}).Return(idpResponse, nil)
@@ -976,17 +955,14 @@ func TestEntraidTokenManager_Streaming(t *testing.T) {
976955

977956
expiresIn := time.Second
978957
expiresOn := time.Now().Add(expiresIn).UTC()
979-
res := &public.AuthResult{
980-
ExpiresOn: expiresOn,
981-
}
958+
959+
res := testAuthResult(expiresOn)
982960
idpResponse, err := shared.NewIDPResponse(shared.ResponseTypeAuthResult,
983961
res)
984962
assert.NoError(t, err)
985963
idp.On("RequestToken").Run(func(args mock.Arguments) {
986964
expiresOn := time.Now().Add(expiresIn).UTC()
987-
res := &public.AuthResult{
988-
ExpiresOn: expiresOn,
989-
}
965+
res := testAuthResult(expiresOn)
990966
response := idpResponse.(*authResult)
991967
response.AuthResultVal = res
992968
}).Return(idpResponse, nil)
@@ -1025,18 +1001,14 @@ func TestEntraidTokenManager_Streaming(t *testing.T) {
10251001

10261002
expiresIn := time.Second
10271003
expiresOn := time.Now().Add(expiresIn).UTC()
1028-
res := &public.AuthResult{
1029-
ExpiresOn: expiresOn,
1030-
}
1004+
res := testAuthResult(expiresOn)
10311005
idpResponse, err := shared.NewIDPResponse(shared.ResponseTypeAuthResult,
10321006
res)
10331007
assert.NoError(t, err)
10341008

10351009
noErrCall := idp.On("RequestToken").Run(func(args mock.Arguments) {
10361010
expiresOn := time.Now().Add(expiresIn).UTC()
1037-
res := &public.AuthResult{
1038-
ExpiresOn: expiresOn,
1039-
}
1011+
res := testAuthResult(expiresOn)
10401012
response := idpResponse.(*authResult)
10411013
response.AuthResultVal = res
10421014
}).Return(idpResponse, nil)
@@ -1083,18 +1055,14 @@ func TestEntraidTokenManager_Streaming(t *testing.T) {
10831055

10841056
expiresIn := time.Second
10851057
expiresOn := time.Now().Add(expiresIn).UTC()
1086-
res := &public.AuthResult{
1087-
ExpiresOn: expiresOn,
1088-
}
1058+
res := testAuthResult(expiresOn)
10891059
idpResponse, err := shared.NewIDPResponse(shared.ResponseTypeAuthResult,
10901060
res)
10911061
assert.NoError(t, err)
10921062

10931063
noErrCall := idp.On("RequestToken").Run(func(args mock.Arguments) {
10941064
expiresOn := time.Now().Add(expiresIn).UTC()
1095-
res := &public.AuthResult{
1096-
ExpiresOn: expiresOn,
1097-
}
1065+
res := testAuthResult(expiresOn)
10981066
response := idpResponse.(*authResult)
10991067
response.AuthResultVal = res
11001068
}).Return(idpResponse, nil)
@@ -1153,18 +1121,16 @@ func TestEntraidTokenManager_Streaming(t *testing.T) {
11531121

11541122
expiresIn := time.Second
11551123
expiresOn := time.Now().Add(expiresIn).UTC()
1156-
res := &public.AuthResult{
1157-
ExpiresOn: expiresOn,
1158-
}
1124+
res := testAuthResult(expiresOn)
1125+
res.IDToken.Oid = "test"
11591126
idpResponse, err := shared.NewIDPResponse(shared.ResponseTypeAuthResult,
11601127
res)
11611128
assert.NoError(t, err)
11621129

11631130
noErrCall := idp.On("RequestToken").Run(func(args mock.Arguments) {
11641131
expiresOn := time.Now().Add(expiresIn).UTC()
1165-
res := &public.AuthResult{
1166-
ExpiresOn: expiresOn,
1167-
}
1132+
res := testAuthResult(expiresOn)
1133+
res.IDToken.Oid = "test"
11681134
response := idpResponse.(*authResult)
11691135
response.AuthResultVal = res
11701136
}).Return(idpResponse, nil)
@@ -1239,18 +1205,14 @@ func TestEntraidTokenManager_Streaming(t *testing.T) {
12391205

12401206
expiresIn := time.Second
12411207
expiresOn := time.Now().Add(expiresIn).UTC()
1242-
res := &public.AuthResult{
1243-
ExpiresOn: expiresOn,
1244-
}
1208+
res := testAuthResult(expiresOn)
12451209
idpResponse, err := shared.NewIDPResponse(shared.ResponseTypeAuthResult,
12461210
res)
12471211
assert.NoError(t, err)
12481212

12491213
noErrCall := idp.On("RequestToken").Run(func(args mock.Arguments) {
12501214
expiresOn := time.Now().Add(expiresIn).UTC()
1251-
res := &public.AuthResult{
1252-
ExpiresOn: expiresOn,
1253-
}
1215+
res := testAuthResult(expiresOn)
12541216
response := idpResponse.(*authResult)
12551217
response.AuthResultVal = res
12561218
}).Return(idpResponse, nil)
@@ -1295,3 +1257,11 @@ func TestEntraidTokenManager_Streaming(t *testing.T) {
12951257
mock.AssertExpectationsForObjects(t, idp, listener)
12961258
})
12971259
}
1260+
1261+
func testAuthResult(expiersOn time.Time) *public.AuthResult {
1262+
r := &public.AuthResult{
1263+
ExpiresOn: expiersOn,
1264+
}
1265+
r.IDToken.Oid = "test"
1266+
return r
1267+
}

shared/identity_provider_response.go

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -63,6 +63,7 @@ func NewIDPResponse(responseType string, result interface{}) (IdentityProviderRe
6363
return nil, fmt.Errorf("expected AccessToken, got %T", result)
6464
} else {
6565
r.AccessTokenVal = typed
66+
r.RawTokenVal = typed.Token
6667
}
6768
case ResponseTypeRawToken:
6869
if typed, ok := result.(string); !ok {

0 commit comments

Comments
 (0)