diff --git a/CHANGELOG.md b/CHANGELOG.md index 55e781a17..741755226 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,3 +1,7 @@ +## Release (2025-YY-XX) +- `core`: [v0.17.2](core/CHANGELOG.md#v0172-2025-05-22) + - **Bugfix:** Access tokens generated via key flow authentication are refreshed 5 seconds before expiration to prevent timing issues with upstream systems which could lead to unexpected 401 error responses + ## Release (2025-05-15) - `alb`: - [v0.4.0](services/alb/CHANGELOG.md#v040-2025-05-15) diff --git a/core/CHANGELOG.md b/core/CHANGELOG.md index c38237322..5b070c901 100644 --- a/core/CHANGELOG.md +++ b/core/CHANGELOG.md @@ -1,3 +1,6 @@ +## v0.17.2 (2025-05-22) +- **Bugfix:** Access tokens generated via key flow authentication are refreshed 5 seconds before expiration to prevent timing issues with upstream systems which could lead to unexpected 401 error responses + ## v0.17.1 (2025-04-09) - **Improvement:** Improve error message for key flow authentication diff --git a/core/clients/key_flow.go b/core/clients/key_flow.go index c529a6f80..589774314 100644 --- a/core/clients/key_flow.go +++ b/core/clients/key_flow.go @@ -32,6 +32,8 @@ const ( tokenAPI = "https://service-account.api.stackit.cloud/token" //nolint:gosec // linter false positive defaultTokenType = "Bearer" defaultScope = "" + + defaultTokenExpirationLeeway = time.Second * 5 ) // KeyFlow handles auth with SA key @@ -45,6 +47,10 @@ type KeyFlow struct { tokenMutex sync.RWMutex token *TokenResponseBody + + // If the current access token would expire in less than TokenExpirationLeeway, + // the client will refresh it early to prevent clock skew or other timing issues. + tokenExpirationLeeway time.Duration } // KeyFlowConfig is the flow config @@ -129,6 +135,8 @@ func (c *KeyFlow) Init(cfg *KeyFlowConfig) error { c.config.TokenUrl = tokenAPI } + c.tokenExpirationLeeway = defaultTokenExpirationLeeway + if c.rt = cfg.HTTPTransport; c.rt == nil { c.rt = http.DefaultTransport } @@ -204,7 +212,7 @@ func (c *KeyFlow) GetAccessToken() (string, error) { } c.tokenMutex.RUnlock() - accessTokenExpired, err := tokenExpired(accessToken) + accessTokenExpired, err := tokenExpired(accessToken, c.tokenExpirationLeeway) if err != nil { return "", fmt.Errorf("check access token is expired: %w", err) } @@ -252,6 +260,10 @@ func (c *KeyFlow) validate() error { } c.privateKeyPEM = pem.EncodeToMemory(privKeyPEM) + if c.tokenExpirationLeeway < 0 { + return fmt.Errorf("token expiration leeway cannot be negative") + } + return nil } @@ -268,7 +280,7 @@ func (c *KeyFlow) recreateAccessToken() error { } c.tokenMutex.RUnlock() - refreshTokenExpired, err := tokenExpired(refreshToken) + refreshTokenExpired, err := tokenExpired(refreshToken, c.tokenExpirationLeeway) if err != nil { return err } @@ -389,7 +401,7 @@ func (c *KeyFlow) parseTokenResponse(res *http.Response) error { return nil } -func tokenExpired(token string) (bool, error) { +func tokenExpired(token string, tokenExpirationLeeway time.Duration) (bool, error) { if token == "" { return true, nil } @@ -400,11 +412,15 @@ func tokenExpired(token string) (bool, error) { if err != nil { return false, fmt.Errorf("parse token: %w", err) } + expirationTimestampNumeric, err := tokenParsed.Claims.GetExpirationTime() if err != nil { return false, fmt.Errorf("get expiration timestamp: %w", err) } - expirationTimestamp := expirationTimestampNumeric.Time - now := time.Now() - return now.After(expirationTimestamp), nil + + // Pretend to be `tokenExpirationLeeway` into the future to avoid token expiring + // between retrieving the token and upstream systems validating it. + now := time.Now().Add(tokenExpirationLeeway) + + return now.After(expirationTimestampNumeric.Time), nil } diff --git a/core/clients/key_flow_test.go b/core/clients/key_flow_test.go index b37b9593f..78dc43a3b 100644 --- a/core/clients/key_flow_test.go +++ b/core/clients/key_flow_test.go @@ -190,6 +190,7 @@ func TestSetToken(t *testing.T) { } func TestTokenExpired(t *testing.T) { + tokenExpirationLeeway := 5 * time.Second tests := []struct { desc string tokenInvalid bool @@ -209,6 +210,12 @@ func TestTokenExpired(t *testing.T) { expectedErr: false, expectedIsExpired: true, }, + { + desc: "token almost expired", + tokenExpiresAt: time.Now().Add(tokenExpirationLeeway), + expectedErr: false, + expectedIsExpired: true, + }, { desc: "token invalid", tokenInvalid: true, @@ -229,7 +236,7 @@ func TestTokenExpired(t *testing.T) { } } - isExpired, err := tokenExpired(token) + isExpired, err := tokenExpired(token, tokenExpirationLeeway) if err != nil && !tt.expectedErr { t.Fatalf("failed on valid input: %v", err) }