Skip to content

Commit ef716b0

Browse files
committed
feat: Remove refresh_token grant type
Signed-off-by: Jorge Turrado <[email protected]>
1 parent 43846a3 commit ef716b0

File tree

4 files changed

+27
-156
lines changed

4 files changed

+27
-156
lines changed

core/clients/key_flow.go

Lines changed: 14 additions & 60 deletions
Original file line numberDiff line numberDiff line change
@@ -68,11 +68,10 @@ type KeyFlowConfig struct {
6868
// TokenResponseBody is the API response
6969
// when requesting a new token
7070
type TokenResponseBody struct {
71-
AccessToken string `json:"access_token"`
72-
ExpiresIn int `json:"expires_in"`
73-
RefreshToken string `json:"refresh_token"`
74-
Scope string `json:"scope"`
75-
TokenType string `json:"token_type"`
71+
AccessToken string `json:"access_token"`
72+
ExpiresIn int `json:"expires_in"`
73+
Scope string `json:"scope"`
74+
TokenType string `json:"token_type"`
7675
}
7776

7877
// ServiceAccountKeyResponse is the API response
@@ -158,9 +157,9 @@ func (c *KeyFlow) Init(cfg *KeyFlowConfig) error {
158157
return nil
159158
}
160159

161-
// SetToken can be used to set an access and refresh token manually in the client.
160+
// SetToken can be used to set an access token manually in the client.
162161
// The other fields in the token field are determined by inspecting the token or setting default values.
163-
func (c *KeyFlow) SetToken(accessToken, refreshToken string) error {
162+
func (c *KeyFlow) SetToken(accessToken string) error {
164163
// We can safely use ParseUnverified because we are not authenticating the user,
165164
// We are parsing the token just to get the expiration time claim
166165
parsedAccessToken, _, err := jwt.NewParser().ParseUnverified(accessToken, &jwt.RegisteredClaims{})
@@ -174,11 +173,10 @@ func (c *KeyFlow) SetToken(accessToken, refreshToken string) error {
174173

175174
c.tokenMutex.Lock()
176175
c.token = &TokenResponseBody{
177-
AccessToken: accessToken,
178-
ExpiresIn: int(exp.Time.Unix()),
179-
Scope: defaultScope,
180-
RefreshToken: refreshToken,
181-
TokenType: defaultTokenType,
176+
AccessToken: accessToken,
177+
ExpiresIn: int(exp.Time.Unix()),
178+
Scope: defaultScope,
179+
TokenType: defaultTokenType,
182180
}
183181
c.tokenMutex.Unlock()
184182
return nil
@@ -198,7 +196,7 @@ func (c *KeyFlow) RoundTrip(req *http.Request) (*http.Response, error) {
198196
return c.rt.RoundTrip(req)
199197
}
200198

201-
// GetAccessToken returns a short-lived access token and saves the access and refresh tokens in the token field
199+
// GetAccessToken returns a short-lived access token and saves the access token in the token field
202200
func (c *KeyFlow) GetAccessToken() (string, error) {
203201
if c.rt == nil {
204202
return "", fmt.Errorf("nil http round tripper, please run Init()")
@@ -219,7 +217,7 @@ func (c *KeyFlow) GetAccessToken() (string, error) {
219217
if !accessTokenExpired {
220218
return accessToken, nil
221219
}
222-
if err = c.recreateAccessToken(); err != nil {
220+
if err = c.createAccessToken(); err != nil {
223221
var oapiErr *oapierror.GenericOpenAPIError
224222
if ok := errors.As(err, &oapiErr); ok {
225223
reg := regexp.MustCompile("Key with kid .*? was not found")
@@ -269,27 +267,6 @@ func (c *KeyFlow) validate() error {
269267

270268
// Flow auth functions
271269

272-
// recreateAccessToken is used to create a new access token
273-
// when the existing one isn't valid anymore
274-
func (c *KeyFlow) recreateAccessToken() error {
275-
var refreshToken string
276-
277-
c.tokenMutex.RLock()
278-
if c.token != nil {
279-
refreshToken = c.token.RefreshToken
280-
}
281-
c.tokenMutex.RUnlock()
282-
283-
refreshTokenExpired, err := tokenExpired(refreshToken, c.tokenExpirationLeeway)
284-
if err != nil {
285-
return err
286-
}
287-
if !refreshTokenExpired {
288-
return c.createAccessTokenWithRefreshToken()
289-
}
290-
return c.createAccessToken()
291-
}
292-
293270
// createAccessToken creates an access token using self signed JWT
294271
func (c *KeyFlow) createAccessToken() (err error) {
295272
grant := "urn:ietf:params:oauth:grant-type:jwt-bearer"
@@ -310,26 +287,6 @@ func (c *KeyFlow) createAccessToken() (err error) {
310287
return c.parseTokenResponse(res)
311288
}
312289

313-
// createAccessTokenWithRefreshToken creates an access token using
314-
// an existing pre-validated refresh token
315-
func (c *KeyFlow) createAccessTokenWithRefreshToken() (err error) {
316-
c.tokenMutex.RLock()
317-
refreshToken := c.token.RefreshToken
318-
c.tokenMutex.RUnlock()
319-
320-
res, err := c.requestToken("refresh_token", refreshToken)
321-
if err != nil {
322-
return err
323-
}
324-
defer func() {
325-
tempErr := res.Body.Close()
326-
if tempErr != nil && err == nil {
327-
err = fmt.Errorf("close request access token with refresh token response: %w", tempErr)
328-
}
329-
}()
330-
return c.parseTokenResponse(res)
331-
}
332-
333290
// generateSelfSignedJWT generates JWT token
334291
func (c *KeyFlow) generateSelfSignedJWT() (string, error) {
335292
claims := jwt.MapClaims{
@@ -353,11 +310,8 @@ func (c *KeyFlow) generateSelfSignedJWT() (string, error) {
353310
func (c *KeyFlow) requestToken(grant, assertion string) (*http.Response, error) {
354311
body := url.Values{}
355312
body.Set("grant_type", grant)
356-
if grant == "refresh_token" {
357-
body.Set("refresh_token", assertion)
358-
} else {
359-
body.Set("assertion", assertion)
360-
}
313+
body.Set("assertion", assertion)
314+
361315
payload := strings.NewReader(body.Encode())
362316
req, err := http.NewRequest(http.MethodPost, c.config.TokenUrl, payload)
363317
if err != nil {

core/clients/key_flow_continuous_refresh.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -125,7 +125,7 @@ func (refresher *continuousTokenRefresher) waitUntilTimestamp(timestamp time.Tim
125125
// - (false, nil) if not successful but should be retried.
126126
// - (_, err) if not successful and shouldn't be retried.
127127
func (refresher *continuousTokenRefresher) refreshToken() (bool, error) {
128-
err := refresher.keyFlow.recreateAccessToken()
128+
err := refresher.keyFlow.createAccessToken()
129129
if err == nil {
130130
return true, nil
131131
}

core/clients/key_flow_continuous_refresh_test.go

Lines changed: 9 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -95,15 +95,8 @@ func TestContinuousRefreshToken(t *testing.T) {
9595
t.Fatalf("failed to create access token: %v", err)
9696
}
9797

98-
refreshToken, err := jwt.NewWithClaims(jwt.SigningMethodHS256, jwt.RegisteredClaims{
99-
ExpiresAt: jwt.NewNumericDate(time.Now().Add(time.Hour)),
100-
}).SignedString([]byte("test"))
101-
if err != nil {
102-
t.Fatalf("failed to create refresh token: %v", err)
103-
}
104-
10598
numberDoCalls := 0
106-
mockDo := func(_ *http.Request) (resp *http.Response, err error) {
99+
mockDo := func(r *http.Request) (resp *http.Response, err error) {
107100
numberDoCalls++ // count refresh attempts
108101
if tt.doError != nil {
109102
return nil, tt.doError
@@ -115,8 +108,7 @@ func TestContinuousRefreshToken(t *testing.T) {
115108
t.Fatalf("Do call: failed to create access token: %v", err)
116109
}
117110
responseBodyStruct := TokenResponseBody{
118-
AccessToken: newAccessToken,
119-
RefreshToken: refreshToken,
111+
AccessToken: newAccessToken,
120112
}
121113
responseBody, err := json.Marshal(responseBodyStruct)
122114
if err != nil {
@@ -153,7 +145,7 @@ func TestContinuousRefreshToken(t *testing.T) {
153145
}
154146

155147
// Set the token after initialization
156-
err = keyFlow.SetToken(accessToken, refreshToken)
148+
err = keyFlow.SetToken(accessToken)
157149
if err != nil {
158150
t.Fatalf("failed to set token: %v", err)
159151
}
@@ -186,7 +178,7 @@ func TestContinuousRefreshToken(t *testing.T) {
186178
}
187179

188180
// Tests if
189-
// - continuousRefreshToken() updates access token using the refresh token
181+
// - continuousRefreshToken() updates access token
190182
// - The access token can be accessed while continuousRefreshToken() is trying to update it
191183
func TestContinuousRefreshTokenConcurrency(t *testing.T) {
192184
// The times here are in the order of miliseconds (so they run faster)
@@ -234,14 +226,6 @@ func TestContinuousRefreshTokenConcurrency(t *testing.T) {
234226
t.Fatalf("created tokens are equal")
235227
}
236228

237-
// The refresh token used to update the access token
238-
refreshToken, err := jwt.NewWithClaims(jwt.SigningMethodHS256, jwt.RegisteredClaims{
239-
ExpiresAt: jwt.NewNumericDate(time.Now().Add(time.Hour)),
240-
}).SignedString([]byte("test"))
241-
if err != nil {
242-
t.Fatalf("failed to create refresh token: %v", err)
243-
}
244-
245229
ctx := context.Background()
246230
ctx, cancel := context.WithCancel(ctx)
247231
defer cancel() // This cancels the refresher goroutine
@@ -271,8 +255,7 @@ func TestContinuousRefreshTokenConcurrency(t *testing.T) {
271255
t.Fatalf("Do call: failed to create additional access token: %v", err)
272256
}
273257
responseBodyStruct := TokenResponseBody{
274-
AccessToken: newAccessToken,
275-
RefreshToken: refreshToken,
258+
AccessToken: newAccessToken,
276259
}
277260
responseBody, err := json.Marshal(responseBodyStruct)
278261
if err != nil {
@@ -308,18 +291,12 @@ func TestContinuousRefreshTokenConcurrency(t *testing.T) {
308291
t.Fatalf("Do call: failed to parse body form: %v", err)
309292
}
310293
reqGrantType := req.Form.Get("grant_type")
311-
if reqGrantType != "refresh_token" {
312-
t.Fatalf("Do call: failed request to refresh token: call to refresh access expected to have grant type as %q, found %q instead", "refresh_token", reqGrantType)
294+
if reqGrantType != "urn:ietf:params:oauth:grant-type:jwt-bearer" {
295+
t.Fatalf("Do call: failed request to refresh token: call to refresh access expected to have grant type as %q, found %q instead", "urn:ietf:params:oauth:grant-type:jwt-bearer", reqGrantType)
313296
}
314-
reqRefreshToken := req.Form.Get("refresh_token")
315-
if reqRefreshToken != refreshToken {
316-
t.Fatalf("Do call: failed request to refresh token: call to refresh access token did not have the expected refresh token set")
317-
}
318-
319297
// Return response with accessTokenSecond
320298
responseBodyStruct := TokenResponseBody{
321-
AccessToken: accessTokenSecond,
322-
RefreshToken: refreshToken,
299+
AccessToken: accessTokenSecond,
323300
}
324301
responseBody, err := json.Marshal(responseBodyStruct)
325302
if err != nil {
@@ -409,7 +386,7 @@ func TestContinuousRefreshTokenConcurrency(t *testing.T) {
409386
}
410387

411388
// Set the token after initialization
412-
err = keyFlow.SetToken(accessTokenFirst, refreshToken)
389+
err = keyFlow.SetToken(accessTokenFirst)
413390
if err != nil {
414391
t.Fatalf("failed to set token: %v", err)
415392
}

core/clients/key_flow_test.go

Lines changed: 3 additions & 63 deletions
Original file line numberDiff line numberDiff line change
@@ -130,65 +130,6 @@ func TestKeyFlowInit(t *testing.T) {
130130
}
131131
}
132132

133-
func TestSetToken(t *testing.T) {
134-
tests := []struct {
135-
name string
136-
tokenInvalid bool
137-
refreshToken string
138-
wantErr bool
139-
}{
140-
{
141-
name: "ok",
142-
tokenInvalid: false,
143-
refreshToken: "refresh_token",
144-
wantErr: false,
145-
},
146-
{
147-
name: "invalid_token",
148-
tokenInvalid: true,
149-
refreshToken: "",
150-
wantErr: true,
151-
},
152-
}
153-
for _, tt := range tests {
154-
t.Run(tt.name, func(t *testing.T) {
155-
var accessToken string
156-
var err error
157-
158-
timestamp := time.Now().Add(24 * time.Hour)
159-
if tt.tokenInvalid {
160-
accessToken = "foo"
161-
} else {
162-
accessTokenJWT := jwt.NewWithClaims(jwt.SigningMethodHS256, jwt.RegisteredClaims{
163-
ExpiresAt: jwt.NewNumericDate(timestamp)})
164-
accessToken, err = accessTokenJWT.SignedString(testSigningKey)
165-
if err != nil {
166-
t.Fatalf("get test access token as string: %s", err)
167-
}
168-
}
169-
170-
keyFlow := &KeyFlow{}
171-
err = keyFlow.SetToken(accessToken, tt.refreshToken)
172-
173-
if (err != nil) != tt.wantErr {
174-
t.Errorf("KeyFlow.SetToken() error = %v, wantErr %v", err, tt.wantErr)
175-
}
176-
if err == nil {
177-
expectedKeyFlowToken := &TokenResponseBody{
178-
AccessToken: accessToken,
179-
ExpiresIn: int(timestamp.Unix()),
180-
RefreshToken: tt.refreshToken,
181-
Scope: defaultScope,
182-
TokenType: defaultTokenType,
183-
}
184-
if !cmp.Equal(expectedKeyFlowToken, keyFlow.token) {
185-
t.Errorf("The returned result is wrong. Expected %+v, got %+v", expectedKeyFlowToken, keyFlow.token)
186-
}
187-
}
188-
})
189-
}
190-
}
191-
192133
func TestTokenExpired(t *testing.T) {
193134
tokenExpirationLeeway := 5 * time.Second
194135
tests := []struct {
@@ -442,10 +383,9 @@ func TestKeyFlow_Do(t *testing.T) {
442383
res.Header().Set("Content-Type", "application/json")
443384

444385
token := &TokenResponseBody{
445-
AccessToken: testBearerToken,
446-
ExpiresIn: 2147483647,
447-
RefreshToken: testBearerToken,
448-
TokenType: "Bearer",
386+
AccessToken: testBearerToken,
387+
ExpiresIn: 2147483647,
388+
TokenType: "Bearer",
449389
}
450390

451391
if err := json.NewEncoder(res.Body).Encode(token); err != nil {

0 commit comments

Comments
 (0)