55 "crypto/hmac"
66 "crypto/sha256"
77 "encoding/base64"
8+ "time"
89
910 "github.com/aws/aws-sdk-go-v2/aws"
1011 "github.com/aws/aws-sdk-go-v2/config"
@@ -31,6 +32,9 @@ type CognitoTokenSource struct {
3132 // The cached authentication result from Cognito
3233 authResult * types.AuthenticationResultType
3334
35+ // The time when the cached token expires
36+ tokenExpiry time.Time
37+
3438 // The Cognito client interface for making API calls
3539 client CognitoClient
3640}
@@ -85,29 +89,44 @@ func (c *CognitoTokenSource) Authenticate(ctx context.Context) error {
8589 return err
8690 }
8791
88- c .authResult = output .AuthenticationResult
92+ c .setAuthResult ( output .AuthenticationResult )
8993
9094 return nil
9195}
9296
9397// Token retrieves an OAuth2 access token for authenticating with the Job Distributor service.
9498//
95- // This method implements a lazy loading pattern:
96- // 1. If an authentication result is already cached, it returns the cached access token
97- // 2. If no cached result exists, it automatically authenticates with AWS Cognito first
98- // 3. Returns the access token wrapped in an oauth2.Token struct
99+ // This method implements a lazy loading pattern with automatic token refresh:
100+ // 1. If no authentication result is cached, it authenticates with AWS Cognito
101+ // 2. If the cached token has expired, it refreshes using the refresh token (REFRESH_TOKEN_AUTH flow)
102+ // 3. Otherwise, it returns the cached access token
103+ // 4. Returns the access token wrapped in an oauth2.Token struct
99104//
100105// The method implements the oauth2.TokenSource interface, making it compatible with
101- // standard OAuth2 client libraries and HTTP clients that support token sources.
106+ // standard OAuth2 client libraries and HTTP clients that support token sources. Following OAuth2
107+ // best practices, it uses the refresh token when available to refresh the access token.
102108//
103- // Note: This method uses context.Background() for authentication if no cached token exists.
104- // For more control over authentication context and timeout behavior, consider calling
105- // Authenticate() explicitly before calling Token().
109+ // Note: This method uses context.Background() for authentication/refresh if no cached token exists
110+ // or if the token needs to be refreshed. For more control over authentication context and
111+ // timeout behavior, consider calling Authenticate() or RefreshToken() explicitly before calling
112+ // Token().
106113//
107114// Returns an OAuth2 token containing the Cognito access token
108115func (c * CognitoTokenSource ) Token () (* oauth2.Token , error ) {
116+ ctx := context .Background ()
117+
118+ // Check if we need to authenticate (no token or token expired)
109119 if c .authResult == nil {
110- if err := c .Authenticate (context .Background ()); err != nil {
120+ // No token cached, perform full authentication
121+ if err := c .Authenticate (ctx ); err != nil {
122+ return nil , err
123+ }
124+ }
125+
126+ // Check if the token has expired and refresh if necessary
127+ if time .Now ().After (c .tokenExpiry ) {
128+ // Token expired, try to refresh using refresh token
129+ if err := c .RefreshToken (ctx ); err != nil {
111130 return nil , err
112131 }
113132 }
@@ -117,6 +136,49 @@ func (c *CognitoTokenSource) Token() (*oauth2.Token, error) {
117136 }, nil
118137}
119138
139+ // RefreshToken refreshes the access token using the stored refresh token via the REFRESH_TOKEN_AUTH flow.
140+ //
141+ // This method uses the refresh token from the cached authentication result to obtain new access and ID tokens
142+ // without reusing the user's credentials. This is more efficient than full re-authentication
143+ // and follows OAuth2 best practices.
144+ //
145+ // The method:
146+ // 1. Uses the cached refresh token to call InitiateAuth with REFRESH_TOKEN_AUTH flow
147+ // 2. Updates the cached authentication result with new tokens
148+ // 3. Calculates and stores the new token expiry time
149+ //
150+ // Returns an error if the refresh fails (e.g., refresh token expired or invalid).
151+ func (c * CognitoTokenSource ) RefreshToken (ctx context.Context ) error {
152+ if c .authResult == nil || c .authResult .RefreshToken == nil {
153+ return c .Authenticate (ctx ) // Fall back to full authentication if no refresh token
154+ }
155+
156+ input := & cognitoidentityprovider.InitiateAuthInput {
157+ AuthFlow : types .AuthFlowTypeRefreshTokenAuth ,
158+ ClientId : aws .String (c .auth .AppClientID ),
159+ AuthParameters : map [string ]string {
160+ "REFRESH_TOKEN" : aws .ToString (c .authResult .RefreshToken ),
161+ "SECRET_HASH" : c .secretHash (),
162+ },
163+ }
164+
165+ output , err := c .client .InitiateAuth (ctx , input )
166+ if err != nil {
167+ // If refresh fails, fall back to full authentication
168+ return c .Authenticate (ctx )
169+ }
170+
171+ // Set the new authentication result and token expiry time
172+ c .setAuthResult (output .AuthenticationResult )
173+
174+ return nil
175+ }
176+
177+ // TokenExpiresAt returns the time when the cached token expires.
178+ func (c * CognitoTokenSource ) TokenExpiresAt () time.Time {
179+ return c .tokenExpiry
180+ }
181+
120182// secretHash computes the AWS Cognito secret hash required for authentication with app clients that have a client secret.
121183//
122184// The secret hash is calculated using HMAC-SHA256 with the following formula:
@@ -142,3 +204,13 @@ func (c *CognitoTokenSource) secretHash() string {
142204
143205 return base64 .StdEncoding .EncodeToString (dataHmac )
144206}
207+
208+ // setAuthResult sets the authentication result and token expiry time
209+ func (c * CognitoTokenSource ) setAuthResult (authResult * types.AuthenticationResultType ) {
210+ // Set the new authentication result
211+ c .authResult = authResult
212+
213+ // Calculate and set the token expiry by appending expiresIn (seconds) to the current time
214+ expiresDuration := time .Duration (authResult .ExpiresIn ) * time .Second
215+ c .tokenExpiry = time .Now ().Add (expiresDuration )
216+ }
0 commit comments