Skip to content

Commit cb7b155

Browse files
committed
rest: refresh oauth tokens
The rest catalog was using a fixed token for the lifetime of the catalog. We need to refresh the token when the oauth server gives us an expiration. This means the credential fetch needs to move into the roundtripper. Also, since we use the same http client for refreshing and making catalog requests, we add a context key to prevent recursion.
1 parent de373a5 commit cb7b155

File tree

6 files changed

+263
-187
lines changed

6 files changed

+263
-187
lines changed

catalog/rest/auth.go

Lines changed: 33 additions & 100 deletions
Original file line numberDiff line numberDiff line change
@@ -18,12 +18,10 @@
1818
package rest
1919

2020
import (
21-
"encoding/json"
21+
"errors"
2222
"fmt"
23-
"io"
24-
"net/http"
25-
"net/url"
26-
"strings"
23+
24+
"golang.org/x/oauth2"
2725
)
2826

2927
// AuthManager is an interface for providing custom authorization headers.
@@ -32,115 +30,50 @@ type AuthManager interface {
3230
AuthHeader() (string, string, error)
3331
}
3432

35-
type oauthTokenResponse struct {
36-
AccessToken string `json:"access_token"`
37-
TokenType string `json:"token_type"`
38-
ExpiresIn int `json:"expires_in"`
39-
Scope string `json:"scope"`
40-
RefreshToken string `json:"refresh_token"`
41-
}
42-
43-
type oauthErrorResponse struct {
44-
Err string `json:"error"`
45-
ErrDesc string `json:"error_description"`
46-
ErrURI string `json:"error_uri"`
47-
}
48-
49-
func (o oauthErrorResponse) Unwrap() error { return ErrOAuthError }
50-
func (o oauthErrorResponse) Error() string {
51-
msg := o.Err
52-
if o.ErrDesc != "" {
53-
msg += ": " + o.ErrDesc
54-
}
55-
56-
if o.ErrURI != "" {
57-
msg += " (" + o.ErrURI + ")"
58-
}
59-
60-
return msg
61-
}
62-
6333
// Oauth2AuthManager is an implementation of the AuthManager interface which
64-
// simply returns the provided token as a bearer token. If a credential
65-
// is provided instead of a static token, it will fetch and refresh the
66-
// token as needed.
34+
// uses an oauth2.TokenSource to provide bearer tokens. The token source
35+
// handles caching, thread-safe refresh, and expiry management.
6736
type Oauth2AuthManager struct {
68-
Token string
69-
Credential string
70-
71-
AuthURI *url.URL
72-
Scope string
73-
Client *http.Client
37+
tokenSource oauth2.TokenSource
7438
}
7539

7640
// AuthHeader returns the authorization header with the bearer token.
7741
func (o *Oauth2AuthManager) AuthHeader() (string, string, error) {
78-
if o.Token == "" && o.Credential != "" {
79-
if o.Client == nil {
80-
return "", "", fmt.Errorf("%w: cannot fetch token without http client", ErrRESTError)
42+
tok, err := o.tokenSource.Token()
43+
if err != nil {
44+
var re *oauth2.RetrieveError
45+
if errors.As(err, &re) {
46+
return "", "", oauthError{
47+
code: re.ErrorCode,
48+
desc: re.ErrorDescription,
49+
uri: re.ErrorURI,
50+
}
8151
}
8252

83-
tok, err := o.fetchAccessToken()
84-
if err != nil {
85-
return "", "", err
86-
}
87-
o.Token = tok
53+
return "", "", fmt.Errorf("%w: %s", ErrOAuthError, err)
8854
}
8955

90-
return "Authorization", "Bearer " + o.Token, nil
56+
return "Authorization", tok.Type() + " " + tok.AccessToken, nil
9157
}
9258

93-
func (o *Oauth2AuthManager) fetchAccessToken() (string, error) {
94-
clientID, clientSecret, hasID := strings.Cut(o.Credential, ":")
95-
if !hasID {
96-
clientID, clientSecret = "", o.Credential
97-
}
98-
99-
scope := "catalog"
100-
if o.Scope != "" {
101-
scope = o.Scope
102-
}
103-
data := url.Values{
104-
"grant_type": {"client_credentials"},
105-
"client_id": {clientID},
106-
"client_secret": {clientSecret},
107-
"scope": {scope},
108-
}
59+
// oauthError wraps OAuth2 error details and implements the error chain
60+
// so that errors.Is(err, ErrOAuthError) returns true.
61+
type oauthError struct {
62+
code string
63+
desc string
64+
uri string
65+
}
10966

110-
if o.AuthURI == nil {
111-
return "", fmt.Errorf("%w: missing auth uri for fetching token", ErrRESTError)
67+
func (e oauthError) Error() string {
68+
msg := e.code
69+
if e.desc != "" {
70+
msg += ": " + e.desc
11271
}
113-
114-
rsp, err := o.Client.PostForm(o.AuthURI.String(), data)
115-
if err != nil {
116-
return "", err
72+
if e.uri != "" {
73+
msg += " (" + e.uri + ")"
11774
}
11875

119-
if rsp.StatusCode == http.StatusOK {
120-
defer rsp.Body.Close()
121-
dec := json.NewDecoder(rsp.Body)
122-
var tok oauthTokenResponse
123-
if err := dec.Decode(&tok); err != nil {
124-
return "", fmt.Errorf("failed to decode oauth token response: %w", err)
125-
}
126-
127-
return tok.AccessToken, nil
128-
}
129-
130-
switch rsp.StatusCode {
131-
case http.StatusUnauthorized, http.StatusBadRequest:
132-
defer func() {
133-
_, _ = io.Copy(io.Discard, rsp.Body)
134-
_ = rsp.Body.Close()
135-
}()
136-
dec := json.NewDecoder(rsp.Body)
137-
var oauthErr oauthErrorResponse
138-
if err := dec.Decode(&oauthErr); err != nil {
139-
return "", fmt.Errorf("failed to decode oauth error: %w", err)
140-
}
141-
142-
return "", oauthErr
143-
default:
144-
return "", handleNon200(rsp, nil)
145-
}
76+
return msg
14677
}
78+
79+
func (e oauthError) Unwrap() error { return ErrOAuthError }

catalog/rest/auth_test.go

Lines changed: 38 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -18,19 +18,25 @@
1818
package rest
1919

2020
import (
21+
"context"
2122
"encoding/json"
23+
"errors"
2224
"net/http"
2325
"net/http/httptest"
24-
"net/url"
2526
"testing"
2627

2728
"github.com/stretchr/testify/assert"
2829
"github.com/stretchr/testify/require"
30+
"golang.org/x/oauth2"
31+
"golang.org/x/oauth2/clientcredentials"
2932
)
3033

3134
func TestOauth2AuthManager_AuthHeader_StaticToken(t *testing.T) {
3235
manager := &Oauth2AuthManager{
33-
Token: "static_token",
36+
tokenSource: oauth2.StaticTokenSource(&oauth2.Token{
37+
AccessToken: "static_token",
38+
TokenType: "Bearer",
39+
}),
3440
}
3541

3642
key, value, err := manager.AuthHeader()
@@ -39,16 +45,6 @@ func TestOauth2AuthManager_AuthHeader_StaticToken(t *testing.T) {
3945
assert.Equal(t, "Bearer static_token", value)
4046
}
4147

42-
func TestOauth2AuthManager_AuthHeader_MissingClient(t *testing.T) {
43-
manager := &Oauth2AuthManager{
44-
Credential: "client:secret",
45-
}
46-
47-
_, _, err := manager.AuthHeader()
48-
require.Error(t, err)
49-
assert.Contains(t, err.Error(), "cannot fetch token without http client")
50-
}
51-
5248
func TestOauth2AuthManager_AuthHeader_FetchToken_Success(t *testing.T) {
5349
mux := http.NewServeMux()
5450
server := httptest.NewServer(mux)
@@ -61,28 +57,32 @@ func TestOauth2AuthManager_AuthHeader_FetchToken_Success(t *testing.T) {
6157
assert.Equal(t, "secret", r.FormValue("client_secret"))
6258
assert.Equal(t, "catalog", r.FormValue("scope"))
6359

60+
w.Header().Set("Content-Type", "application/json")
6461
w.WriteHeader(http.StatusOK)
65-
json.NewEncoder(w).Encode(oauthTokenResponse{
66-
AccessToken: "fetched_token",
67-
TokenType: "Bearer",
68-
ExpiresIn: 3600,
62+
json.NewEncoder(w).Encode(map[string]any{
63+
"access_token": "fetched_token",
64+
"token_type": "Bearer",
65+
"expires_in": 3600,
6966
})
7067
})
7168

72-
authURL, err := url.Parse(server.URL + "/oauth/token")
73-
require.NoError(t, err)
69+
cfg := &clientcredentials.Config{
70+
ClientID: "client",
71+
ClientSecret: "secret",
72+
TokenURL: server.URL + "/oauth/token",
73+
Scopes: []string{"catalog"},
74+
AuthStyle: oauth2.AuthStyleInParams,
75+
}
7476

77+
ctx := context.WithValue(context.Background(), oauth2.HTTPClient, server.Client())
7578
manager := &Oauth2AuthManager{
76-
Credential: "client:secret",
77-
AuthURI: authURL,
78-
Client: server.Client(),
79+
tokenSource: cfg.TokenSource(ctx),
7980
}
8081

8182
key, value, err := manager.AuthHeader()
8283
require.NoError(t, err)
8384
assert.Equal(t, "Authorization", key)
8485
assert.Equal(t, "Bearer fetched_token", value)
85-
assert.Equal(t, "fetched_token", manager.Token)
8686
}
8787

8888
func TestOauth2AuthManager_AuthHeader_FetchToken_ErrorResponse(t *testing.T) {
@@ -91,23 +91,29 @@ func TestOauth2AuthManager_AuthHeader_FetchToken_ErrorResponse(t *testing.T) {
9191
defer server.Close()
9292

9393
mux.HandleFunc("/oauth/token", func(w http.ResponseWriter, r *http.Request) {
94+
w.Header().Set("Content-Type", "application/json")
9495
w.WriteHeader(http.StatusBadRequest)
95-
json.NewEncoder(w).Encode(oauthErrorResponse{
96-
Err: "invalid_client",
97-
ErrDesc: "Invalid client credentials",
96+
json.NewEncoder(w).Encode(map[string]any{
97+
"error": "invalid_client",
98+
"error_description": "Invalid client credentials",
9899
})
99100
})
100101

101-
authURL, err := url.Parse(server.URL + "/oauth/token")
102-
require.NoError(t, err)
102+
cfg := &clientcredentials.Config{
103+
ClientID: "client",
104+
ClientSecret: "secret",
105+
TokenURL: server.URL + "/oauth/token",
106+
AuthStyle: oauth2.AuthStyleInParams,
107+
}
103108

109+
ctx := context.WithValue(context.Background(), oauth2.HTTPClient, server.Client())
104110
manager := &Oauth2AuthManager{
105-
Credential: "client:secret",
106-
AuthURI: authURL,
107-
Client: server.Client(),
111+
tokenSource: cfg.TokenSource(ctx),
108112
}
109113

110-
_, _, err = manager.AuthHeader()
114+
_, _, err := manager.AuthHeader()
111115
require.Error(t, err)
112-
assert.Contains(t, err.Error(), "invalid_client: Invalid client credentials")
116+
assert.True(t, errors.Is(err, ErrOAuthError), "error should wrap ErrOAuthError")
117+
assert.Contains(t, err.Error(), "invalid_client")
118+
assert.Contains(t, err.Error(), "Invalid client credentials")
113119
}

0 commit comments

Comments
 (0)