Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
253 changes: 174 additions & 79 deletions core/clients/key_flow_continuous_refresh_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ import (
"fmt"
"io"
"net/http"
"net/url"
"testing"
"time"

Expand All @@ -21,27 +22,28 @@ func TestContinuousRefreshToken(t *testing.T) {
jwt.TimePrecision = time.Millisecond

// Refresher settings
timeStartBeforeTokenExpiration := 100 * time.Millisecond
timeBetweenContextCheck := 5 * time.Millisecond
timeBetweenTries := 40 * time.Millisecond
timeStartBeforeTokenExpiration := 500 * time.Millisecond
timeBetweenContextCheck := 10 * time.Millisecond
timeBetweenTries := 100 * time.Millisecond

// All generated acess tokens will have this time to live
accessTokensTimeToLive := 200 * time.Millisecond
accessTokensTimeToLive := 1 * time.Second

tests := []struct {
desc string
contextClosesIn time.Duration
doError error
expectedNumberDoCalls int
expectedCallRange []int // Optional: for tests that can have variable call counts
}{
{
desc: "update access token once",
contextClosesIn: 150 * time.Millisecond,
contextClosesIn: 700 * time.Millisecond, // Should allow one refresh
expectedNumberDoCalls: 1,
},
{
desc: "update access token twice",
contextClosesIn: 250 * time.Millisecond,
contextClosesIn: 1300 * time.Millisecond, // Should allow two refreshes
expectedNumberDoCalls: 2,
},
{
Expand All @@ -61,25 +63,26 @@ func TestContinuousRefreshToken(t *testing.T) {
},
{
desc: "refresh token fails - non-API error",
contextClosesIn: 250 * time.Millisecond,
contextClosesIn: 700 * time.Millisecond,
doError: fmt.Errorf("something went wrong"),
expectedNumberDoCalls: 1,
},
{
desc: "refresh token fails - API non-5xx error",
contextClosesIn: 250 * time.Millisecond,
contextClosesIn: 700 * time.Millisecond,
doError: &oapierror.GenericOpenAPIError{
StatusCode: http.StatusBadRequest,
},
expectedNumberDoCalls: 1,
},
{
desc: "refresh token fails - API 5xx error",
contextClosesIn: 200 * time.Millisecond,
contextClosesIn: 800 * time.Millisecond,
doError: &oapierror.GenericOpenAPIError{
StatusCode: http.StatusInternalServerError,
},
expectedNumberDoCalls: 3,
expectedCallRange: []int{3, 4}, // Allow 3 or 4 calls due to timing race condition
},
}

Expand All @@ -101,19 +104,16 @@ func TestContinuousRefreshToken(t *testing.T) {

numberDoCalls := 0
mockDo := func(_ *http.Request) (resp *http.Response, err error) {
numberDoCalls++

numberDoCalls++ // count refresh attempts
if tt.doError != nil {
return nil, tt.doError
}

newAccessToken, err := jwt.NewWithClaims(jwt.SigningMethodHS256, jwt.RegisteredClaims{
ExpiresAt: jwt.NewNumericDate(time.Now().Add(accessTokensTimeToLive)),
}).SignedString([]byte("test"))
if err != nil {
t.Fatalf("Do call: failed to create access token: %v", err)
}

responseBodyStruct := TokenResponseBody{
AccessToken: newAccessToken,
RefreshToken: refreshToken,
Expand All @@ -133,19 +133,34 @@ func TestContinuousRefreshToken(t *testing.T) {
ctx, cancel := context.WithTimeout(ctx, tt.contextClosesIn)
defer cancel()

keyFlow := &KeyFlow{
config: &KeyFlowConfig{
BackgroundTokenRefreshContext: ctx,
},
authClient: &http.Client{
keyFlow := &KeyFlow{}
privateKeyBytes, err := generatePrivateKey()
if err != nil {
t.Fatalf("Error generating private key: %s", err)
}
keyFlowConfig := &KeyFlowConfig{
ServiceAccountKey: fixtureServiceAccountKey(),
PrivateKey: string(privateKeyBytes),
AuthHTTPClient: &http.Client{
Transport: mockTransportFn{mockDo},
},
token: &TokenResponseBody{
AccessToken: accessToken,
RefreshToken: refreshToken,
},
HTTPTransport: mockTransportFn{mockDo},
BackgroundTokenRefreshContext: nil,
}
err = keyFlow.Init(keyFlowConfig)
if err != nil {
t.Fatalf("failed to initialize key flow: %v", err)
}

// Set the token after initialization
err = keyFlow.SetToken(accessToken, refreshToken)
if err != nil {
t.Fatalf("failed to set token: %v", err)
}

// Set the context for continuous refresh
keyFlow.config.BackgroundTokenRefreshContext = ctx

refresher := &continuousTokenRefresher{
keyFlow: keyFlow,
timeStartBeforeTokenExpiration: timeStartBeforeTokenExpiration,
Expand All @@ -157,7 +172,13 @@ func TestContinuousRefreshToken(t *testing.T) {
if err == nil {
t.Fatalf("routine finished with non-nil error")
}
if numberDoCalls != tt.expectedNumberDoCalls {

// Check if we have a range of expected calls (for timing-sensitive tests)
if tt.expectedCallRange != nil {
if !contains(tt.expectedCallRange, numberDoCalls) {
t.Fatalf("expected %v calls to API to refresh token, got %d", tt.expectedCallRange, numberDoCalls)
}
} else if numberDoCalls != tt.expectedNumberDoCalls {
t.Fatalf("expected %d calls to API to refresh token, got %d", tt.expectedNumberDoCalls, numberDoCalls)
}
})
Expand Down Expand Up @@ -194,7 +215,7 @@ func TestContinuousRefreshTokenConcurrency(t *testing.T) {

// The access token at the start
accessTokenFirst, err := jwt.NewWithClaims(jwt.SigningMethodHS256, jwt.RegisteredClaims{
ExpiresAt: jwt.NewNumericDate(time.Now().Add(100 * time.Millisecond)),
ExpiresAt: jwt.NewNumericDate(time.Now().Add(10 * time.Second)),
}).SignedString([]byte("token-first"))
if err != nil {
t.Fatalf("failed to create first access token: %v", err)
Expand Down Expand Up @@ -225,60 +246,98 @@ func TestContinuousRefreshTokenConcurrency(t *testing.T) {
ctx, cancel := context.WithCancel(ctx)
defer cancel() // This cancels the refresher goroutine

// Extract host from tokenAPI constant for consistency
tokenURL, _ := url.Parse(tokenAPI)
tokenHost := tokenURL.Host

// The Do() routine, that both the keyFlow and continuousRefreshToken() use to make their requests
// The bools are used to make sure only one request goes through on each test phase
doTestPhase1RequestDone := false
doTestPhase2RequestDone := false
doTestPhase4RequestDone := false
mockDo := func(req *http.Request) (resp *http.Response, err error) {
switch currentTestPhase {
default:
t.Fatalf("Do call: unexpected request during test phase %d", currentTestPhase)
return nil, nil
case 1: // Call by continuousRefreshToken()
if doTestPhase1RequestDone {
t.Fatalf("Do call: multiple requests during test phase 1")
}
doTestPhase1RequestDone = true
// Handle auth requests (token refresh)
if req.URL.Host == tokenHost {
switch currentTestPhase {
default:
// After phase 1, allow additional auth requests but don't fail the test
// This handles the continuous nature of the refresh routine
if currentTestPhase > 1 {
// Return a valid response for any additional auth requests
newAccessToken, err := jwt.NewWithClaims(jwt.SigningMethodHS256, jwt.RegisteredClaims{
ExpiresAt: jwt.NewNumericDate(time.Now().Add(time.Hour)),
}).SignedString([]byte("additional-token"))
if err != nil {
t.Fatalf("Do call: failed to create additional access token: %v", err)
}
responseBodyStruct := TokenResponseBody{
AccessToken: newAccessToken,
RefreshToken: refreshToken,
}
responseBody, err := json.Marshal(responseBodyStruct)
if err != nil {
t.Fatalf("Do call: failed to marshal additional response: %v", err)
}
response := &http.Response{
StatusCode: http.StatusOK,
Body: io.NopCloser(bytes.NewReader(responseBody)),
}
return response, nil
}
t.Fatalf("Do call: unexpected request during test phase %d", currentTestPhase)
return nil, nil
case 1: // Call by continuousRefreshToken()
if doTestPhase1RequestDone {
t.Fatalf("Do call: multiple requests during test phase 1")
}
doTestPhase1RequestDone = true

currentTestPhase = 2
chanBlockContinuousRefreshToken <- true
currentTestPhase = 2
chanBlockContinuousRefreshToken <- true

// Wait until continuousRefreshToken() is to be unblocked
<-chanUnblockContinuousRefreshToken
// Wait until continuousRefreshToken() is to be unblocked
<-chanUnblockContinuousRefreshToken

if currentTestPhase != 3 {
t.Fatalf("Do call: after unlocking refreshToken(), expected test phase to be 3, got %d", currentTestPhase)
}
if currentTestPhase != 3 {
t.Fatalf("Do call: after unlocking refreshToken(), expected test phase to be 3, got %d", currentTestPhase)
}

// Check required fields are passed
err = req.ParseForm()
if err != nil {
t.Fatalf("Do call: failed to parse body form: %v", err)
}
reqGrantType := req.Form.Get("grant_type")
if reqGrantType != "refresh_token" {
t.Fatalf("Do call: failed request to refresh token: call to refresh access expected to have grant type as %q, found %q instead", "refresh_token", reqGrantType)
}
reqRefreshToken := req.Form.Get("refresh_token")
if reqRefreshToken != refreshToken {
t.Fatalf("Do call: failed request to refresh token: call to refresh access token did not have the expected refresh token set")
}
// Check required fields are passed
err = req.ParseForm()
if err != nil {
t.Fatalf("Do call: failed to parse body form: %v", err)
}
reqGrantType := req.Form.Get("grant_type")
if reqGrantType != "refresh_token" {
t.Fatalf("Do call: failed request to refresh token: call to refresh access expected to have grant type as %q, found %q instead", "refresh_token", reqGrantType)
}
reqRefreshToken := req.Form.Get("refresh_token")
if reqRefreshToken != refreshToken {
t.Fatalf("Do call: failed request to refresh token: call to refresh access token did not have the expected refresh token set")
}

// Return response with accessTokenSecond
responseBodyStruct := TokenResponseBody{
AccessToken: accessTokenSecond,
RefreshToken: refreshToken,
}
responseBody, err := json.Marshal(responseBodyStruct)
if err != nil {
t.Fatalf("Do call: failed request to refresh token: marshal access token response: %v", err)
}
response := &http.Response{
StatusCode: http.StatusOK,
Body: io.NopCloser(bytes.NewReader(responseBody)),
// Return response with accessTokenSecond
responseBodyStruct := TokenResponseBody{
AccessToken: accessTokenSecond,
RefreshToken: refreshToken,
}
responseBody, err := json.Marshal(responseBodyStruct)
if err != nil {
t.Fatalf("Do call: failed request to refresh token: marshal access token response: %v", err)
}
response := &http.Response{
StatusCode: http.StatusOK,
Body: io.NopCloser(bytes.NewReader(responseBody)),
}
return response, nil
}
return response, nil
}

// Handle regular HTTP requests
switch currentTestPhase {
default:
t.Fatalf("Do call: unexpected request during test phase %d", currentTestPhase)
return nil, nil
case 2: // Call by tokenFlow, first request
if doTestPhase2RequestDone {
t.Fatalf("Do call: multiple requests during test phase 2")
Expand All @@ -292,8 +351,9 @@ func TestContinuousRefreshTokenConcurrency(t *testing.T) {
t.Fatalf("Do call: first request expected to have host %q, found %q", expectedHost, host)
}
authHeader := req.Header.Get("Authorization")
if authHeader != fmt.Sprintf("Bearer %s", accessTokenFirst) {
t.Fatalf("Do call: first request didn't carry first access token")
expectedAuthHeader := fmt.Sprintf("Bearer %s", accessTokenFirst)
if authHeader != expectedAuthHeader {
t.Fatalf("Do call: first request didn't carry first access token. Expected: %s, Got: %s", expectedAuthHeader, authHeader)
}

// Return empty response
Expand Down Expand Up @@ -328,23 +388,49 @@ func TestContinuousRefreshTokenConcurrency(t *testing.T) {
}
}

keyFlow := &KeyFlow{
config: &KeyFlowConfig{
BackgroundTokenRefreshContext: ctx,
},
authClient: &http.Client{
keyFlow := &KeyFlow{}
privateKeyBytes, err := generatePrivateKey()
if err != nil {
t.Fatalf("Error generating private key: %s", err)
}
keyFlowConfig := &KeyFlowConfig{
ServiceAccountKey: fixtureServiceAccountKey(),
PrivateKey: string(privateKeyBytes),
AuthHTTPClient: &http.Client{
Transport: mockTransportFn{mockDo},
},
rt: mockTransportFn{mockDo},
token: &TokenResponseBody{
AccessToken: accessTokenFirst,
RefreshToken: refreshToken,
},
HTTPTransport: mockTransportFn{mockDo}, // Use same mock for regular requests
// Don't start continuous refresh automatically
BackgroundTokenRefreshContext: nil,
}
err = keyFlow.Init(keyFlowConfig)
if err != nil {
t.Fatalf("failed to initialize key flow: %v", err)
}

// Set the token after initialization
err = keyFlow.SetToken(accessTokenFirst, refreshToken)
if err != nil {
t.Fatalf("failed to set token: %v", err)
}

// Set the context for continuous refresh
keyFlow.config.BackgroundTokenRefreshContext = ctx

// Create a custom refresher with shorter timing for the test
refresher := &continuousTokenRefresher{
keyFlow: keyFlow,
timeStartBeforeTokenExpiration: 9 * time.Second, // Start 9 seconds before expiration
timeBetweenContextCheck: 5 * time.Millisecond,
timeBetweenTries: 40 * time.Millisecond,
}

// TEST START
currentTestPhase = 1
go continuousRefreshToken(keyFlow)
// Ignore returned error as expected in test
go func() {
_ = refresher.continuousRefreshToken()
}()

// Wait until continuousRefreshToken() is blocked
<-chanBlockContinuousRefreshToken
Expand Down Expand Up @@ -389,3 +475,12 @@ func TestContinuousRefreshTokenConcurrency(t *testing.T) {
t.Fatalf("Second request body failed to close: %v", err)
}
}

func contains(arr []int, val int) bool {
for _, v := range arr {
if v == val {
return true
}
}
return false
}
Loading
Loading