Skip to content

Commit eeb1065

Browse files
authored
feat: refresh cognito token when expired (#346)
This change inspects the token expiry time each time `Token()` is called on the `CognitoTokenSource` and refreshes the token when it is expired. If the refresh fails, it falls back to full re-authentication.
1 parent fc8369e commit eeb1065

File tree

2 files changed

+377
-13
lines changed

2 files changed

+377
-13
lines changed

offchain/jd/cognito.go

Lines changed: 82 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@ import (
55
"crypto/hmac"
66
"crypto/sha256"
77
"encoding/base64"
8+
"time"
89

910
"github.com/aws/aws-sdk-go-v2/aws"
1011
"github.com/aws/aws-sdk-go-v2/config"
@@ -31,6 +32,9 @@ type CognitoTokenSource struct {
3132
// The cached authentication result from Cognito
3233
authResult *types.AuthenticationResultType
3334

35+
// The time when the cached token expires
36+
tokenExpiry time.Time
37+
3438
// The Cognito client interface for making API calls
3539
client CognitoClient
3640
}
@@ -85,29 +89,44 @@ func (c *CognitoTokenSource) Authenticate(ctx context.Context) error {
8589
return err
8690
}
8791

88-
c.authResult = output.AuthenticationResult
92+
c.setAuthResult(output.AuthenticationResult)
8993

9094
return nil
9195
}
9296

9397
// Token retrieves an OAuth2 access token for authenticating with the Job Distributor service.
9498
//
95-
// This method implements a lazy loading pattern:
96-
// 1. If an authentication result is already cached, it returns the cached access token
97-
// 2. If no cached result exists, it automatically authenticates with AWS Cognito first
98-
// 3. Returns the access token wrapped in an oauth2.Token struct
99+
// This method implements a lazy loading pattern with automatic token refresh:
100+
// 1. If no authentication result is cached, it authenticates with AWS Cognito
101+
// 2. If the cached token has expired, it refreshes using the refresh token (REFRESH_TOKEN_AUTH flow)
102+
// 3. Otherwise, it returns the cached access token
103+
// 4. Returns the access token wrapped in an oauth2.Token struct
99104
//
100105
// The method implements the oauth2.TokenSource interface, making it compatible with
101-
// standard OAuth2 client libraries and HTTP clients that support token sources.
106+
// standard OAuth2 client libraries and HTTP clients that support token sources. Following OAuth2
107+
// best practices, it uses the refresh token when available to refresh the access token.
102108
//
103-
// Note: This method uses context.Background() for authentication if no cached token exists.
104-
// For more control over authentication context and timeout behavior, consider calling
105-
// Authenticate() explicitly before calling Token().
109+
// Note: This method uses context.Background() for authentication/refresh if no cached token exists
110+
// or if the token needs to be refreshed. For more control over authentication context and
111+
// timeout behavior, consider calling Authenticate() or RefreshToken() explicitly before calling
112+
// Token().
106113
//
107114
// Returns an OAuth2 token containing the Cognito access token
108115
func (c *CognitoTokenSource) Token() (*oauth2.Token, error) {
116+
ctx := context.Background()
117+
118+
// Check if we need to authenticate (no token or token expired)
109119
if c.authResult == nil {
110-
if err := c.Authenticate(context.Background()); err != nil {
120+
// No token cached, perform full authentication
121+
if err := c.Authenticate(ctx); err != nil {
122+
return nil, err
123+
}
124+
}
125+
126+
// Check if the token has expired and refresh if necessary
127+
if time.Now().After(c.tokenExpiry) {
128+
// Token expired, try to refresh using refresh token
129+
if err := c.RefreshToken(ctx); err != nil {
111130
return nil, err
112131
}
113132
}
@@ -117,6 +136,49 @@ func (c *CognitoTokenSource) Token() (*oauth2.Token, error) {
117136
}, nil
118137
}
119138

139+
// RefreshToken refreshes the access token using the stored refresh token via the REFRESH_TOKEN_AUTH flow.
140+
//
141+
// This method uses the refresh token from the cached authentication result to obtain new access and ID tokens
142+
// without reusing the user's credentials. This is more efficient than full re-authentication
143+
// and follows OAuth2 best practices.
144+
//
145+
// The method:
146+
// 1. Uses the cached refresh token to call InitiateAuth with REFRESH_TOKEN_AUTH flow
147+
// 2. Updates the cached authentication result with new tokens
148+
// 3. Calculates and stores the new token expiry time
149+
//
150+
// Returns an error if the refresh fails (e.g., refresh token expired or invalid).
151+
func (c *CognitoTokenSource) RefreshToken(ctx context.Context) error {
152+
if c.authResult == nil || c.authResult.RefreshToken == nil {
153+
return c.Authenticate(ctx) // Fall back to full authentication if no refresh token
154+
}
155+
156+
input := &cognitoidentityprovider.InitiateAuthInput{
157+
AuthFlow: types.AuthFlowTypeRefreshTokenAuth,
158+
ClientId: aws.String(c.auth.AppClientID),
159+
AuthParameters: map[string]string{
160+
"REFRESH_TOKEN": aws.ToString(c.authResult.RefreshToken),
161+
"SECRET_HASH": c.secretHash(),
162+
},
163+
}
164+
165+
output, err := c.client.InitiateAuth(ctx, input)
166+
if err != nil {
167+
// If refresh fails, fall back to full authentication
168+
return c.Authenticate(ctx)
169+
}
170+
171+
// Set the new authentication result and token expiry time
172+
c.setAuthResult(output.AuthenticationResult)
173+
174+
return nil
175+
}
176+
177+
// TokenExpiresAt returns the time when the cached token expires.
178+
func (c *CognitoTokenSource) TokenExpiresAt() time.Time {
179+
return c.tokenExpiry
180+
}
181+
120182
// secretHash computes the AWS Cognito secret hash required for authentication with app clients that have a client secret.
121183
//
122184
// The secret hash is calculated using HMAC-SHA256 with the following formula:
@@ -142,3 +204,13 @@ func (c *CognitoTokenSource) secretHash() string {
142204

143205
return base64.StdEncoding.EncodeToString(dataHmac)
144206
}
207+
208+
// setAuthResult sets the authentication result and token expiry time
209+
func (c *CognitoTokenSource) setAuthResult(authResult *types.AuthenticationResultType) {
210+
// Set the new authentication result
211+
c.authResult = authResult
212+
213+
// Calculate and set the token expiry by appending expiresIn (seconds) to the current time
214+
expiresDuration := time.Duration(authResult.ExpiresIn) * time.Second
215+
c.tokenExpiry = time.Now().Add(expiresDuration)
216+
}

0 commit comments

Comments
 (0)