Skip to content

Commit 6770e84

Browse files
authored
Add middleware to swap the downstream ticket for the upstream ticket (#2113)
* Add middleware to swap the downstream ticket for the upstream ticket Implements HTTP middleware that automatically exchanges downstream authentication tokens for backend-specific tokens using RFC 8693 OAuth 2.0 Token Exchange. The middleware extracts subject tokens from authenticated requests and replaces them with exchanged tokens, supporting two injection strategies: replacing the Authorization header or adding a custom header while preserving the original. Fixes: #2065 * review feedback: Change scopes in Config to []strings * review feedback: Make the strategy selection a closure called by the middleware handler * review feedback: move exhcnageConfig outside the handler to CreateTokenExchangeMiddlewareFromClaims * throw an error instead of nil in case the middleware is misconfigured * Review feedback: Make CreateTokenExchangeMiddlewareFromClaims return middleware, err
1 parent dd77f43 commit 6770e84

File tree

4 files changed

+876
-16
lines changed

4 files changed

+876
-16
lines changed

pkg/auth/tokenexchange/exchange.go

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -158,8 +158,8 @@ func (c clientAuthentication) String() string {
158158
c.ClientID, clientSecret)
159159
}
160160

161-
// Config holds the configuration for token exchange.
162-
type Config struct {
161+
// ExchangeConfig holds the configuration for token exchange.
162+
type ExchangeConfig struct {
163163
// TokenURL is the OAuth 2.0 token endpoint URL
164164
TokenURL string
165165

@@ -185,8 +185,8 @@ type Config struct {
185185
HTTPClient *http.Client
186186
}
187187

188-
// Validate checks if the Config contains all required fields.
189-
func (c *Config) Validate() error {
188+
// Validate checks if the ExchangeConfig contains all required fields.
189+
func (c *ExchangeConfig) Validate() error {
190190
if c.TokenURL == "" {
191191
return fmt.Errorf("TokenURL is required")
192192
}
@@ -211,7 +211,7 @@ func (c *Config) Validate() error {
211211
// tokenSource implements oauth2.TokenSource for token exchange.
212212
type tokenSource struct {
213213
ctx context.Context
214-
conf *Config
214+
conf *ExchangeConfig
215215
}
216216

217217
// Token implements oauth2.TokenSource interface.
@@ -281,7 +281,7 @@ func (ts *tokenSource) Token() (*oauth2.Token, error) {
281281
}
282282

283283
// TokenSource returns an oauth2.TokenSource that performs token exchange.
284-
func (c *Config) TokenSource(ctx context.Context) oauth2.TokenSource {
284+
func (c *ExchangeConfig) TokenSource(ctx context.Context) oauth2.TokenSource {
285285
return &tokenSource{
286286
ctx: ctx,
287287
conf: c,

pkg/auth/tokenexchange/exchange_test.go

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -125,7 +125,7 @@ func TestTokenSource_Token_Success(t *testing.T) {
125125
defer server.Close()
126126

127127
// Create config with test server
128-
config := &Config{
128+
config := &ExchangeConfig{
129129
TokenURL: server.URL,
130130
ClientID: "test-client-id",
131131
ClientSecret: "test-client-secret",
@@ -166,7 +166,7 @@ func TestTokenSource_Token_WithRefreshToken(t *testing.T) {
166166
}))
167167
defer server.Close()
168168

169-
config := &Config{
169+
config := &ExchangeConfig{
170170
TokenURL: server.URL,
171171
ClientID: "test-client-id",
172172
ClientSecret: "test-client-secret",
@@ -198,7 +198,7 @@ func TestTokenSource_Token_NoExpiry(t *testing.T) {
198198
}))
199199
defer server.Close()
200200

201-
config := &Config{
201+
config := &ExchangeConfig{
202202
TokenURL: server.URL,
203203
ClientID: "test-client-id",
204204
ClientSecret: "test-client-secret",
@@ -221,7 +221,7 @@ func TestTokenSource_Token_SubjectTokenProviderError(t *testing.T) {
221221
t.Parallel()
222222

223223
providerErr := errors.New("failed to get token from provider")
224-
config := &Config{
224+
config := &ExchangeConfig{
225225
TokenURL: "https://example.com/token",
226226
ClientID: "test-client-id",
227227
ClientSecret: "test-client-secret",
@@ -251,7 +251,7 @@ func TestTokenSource_Token_ContextCancellation(t *testing.T) {
251251
}))
252252
defer server.Close()
253253

254-
config := &Config{
254+
config := &ExchangeConfig{
255255
TokenURL: server.URL,
256256
ClientID: "test-client-id",
257257
ClientSecret: "test-client-secret",
@@ -800,7 +800,7 @@ func TestSubjectTokenProvider_Variants(t *testing.T) {
800800
}))
801801
defer server.Close()
802802

803-
config := &Config{
803+
config := &ExchangeConfig{
804804
TokenURL: server.URL,
805805
ClientID: "test-client-id",
806806
ClientSecret: "test-client-secret",
@@ -1036,10 +1036,10 @@ func TestExchangeToken_ScopeArray(t *testing.T) {
10361036
}
10371037

10381038
// TestConfig_TokenSource tests that TokenSource creates a valid tokenSource.
1039-
func TestConfig_TokenSource(t *testing.T) {
1039+
func TestExchangeConfig_TokenSource(t *testing.T) {
10401040
t.Parallel()
10411041

1042-
config := &Config{
1042+
config := &ExchangeConfig{
10431043
TokenURL: "https://example.com/token",
10441044
ClientID: "test-client-id",
10451045
ClientSecret: "test-client-secret",
@@ -1175,14 +1175,14 @@ func TestClientAuthentication_Fields(t *testing.T) {
11751175
}
11761176

11771177
// TestConfig_Fields tests Config struct fields.
1178-
func TestConfig_Fields(t *testing.T) {
1178+
func TestExchangeConfig_Fields(t *testing.T) {
11791179
t.Parallel()
11801180

11811181
provider := func() (string, error) {
11821182
return "token", nil
11831183
}
11841184

1185-
config := &Config{
1185+
config := &ExchangeConfig{
11861186
TokenURL: "https://example.com/token",
11871187
ClientID: "test-client-id",
11881188
ClientSecret: "test-client-secret",
Lines changed: 239 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,239 @@
1+
package tokenexchange
2+
3+
import (
4+
"encoding/json"
5+
"errors"
6+
"fmt"
7+
"net/http"
8+
"strings"
9+
10+
"github.com/golang-jwt/jwt/v5"
11+
12+
"github.com/stacklok/toolhive/pkg/auth"
13+
"github.com/stacklok/toolhive/pkg/logger"
14+
"github.com/stacklok/toolhive/pkg/transport/types"
15+
)
16+
17+
// Middleware type constant
18+
const (
19+
MiddlewareType = "tokenexchange"
20+
)
21+
22+
// Header injection strategy constants
23+
const (
24+
// HeaderStrategyReplace replaces the Authorization header with the exchanged token
25+
HeaderStrategyReplace = "replace"
26+
// HeaderStrategyCustom adds the exchanged token to a custom header
27+
HeaderStrategyCustom = "custom"
28+
)
29+
30+
var errUnknownStrategy = errors.New("unknown token injection strategy")
31+
32+
// MiddlewareParams represents the parameters for token exchange middleware
33+
type MiddlewareParams struct {
34+
TokenExchangeConfig *Config `json:"token_exchange_config,omitempty"`
35+
}
36+
37+
// Config holds configuration for token exchange middleware
38+
type Config struct {
39+
// TokenURL is the OAuth 2.0 token endpoint URL
40+
TokenURL string `json:"token_url"`
41+
42+
// ClientID is the OAuth 2.0 client identifier
43+
ClientID string `json:"client_id"`
44+
45+
// ClientSecret is the OAuth 2.0 client secret
46+
ClientSecret string `json:"client_secret"`
47+
48+
// Audience is the target audience for the exchanged token
49+
Audience string `json:"audience"`
50+
51+
// Scopes is the list of scopes to request for the exchanged token
52+
Scopes []string `json:"scopes,omitempty"`
53+
54+
// HeaderStrategy determines how to inject the token
55+
// Valid values: HeaderStrategyReplace (default), HeaderStrategyCustom
56+
HeaderStrategy string `json:"header_strategy,omitempty"`
57+
58+
// ExternalTokenHeaderName is the name of the custom header to use when HeaderStrategy is "custom"
59+
ExternalTokenHeaderName string `json:"external_token_header_name,omitempty"`
60+
}
61+
62+
// Middleware wraps token exchange middleware functionality
63+
type Middleware struct {
64+
middleware types.MiddlewareFunction
65+
}
66+
67+
// Handler returns the middleware function used by the proxy.
68+
func (m *Middleware) Handler() types.MiddlewareFunction {
69+
return m.middleware
70+
}
71+
72+
// Close cleans up any resources used by the middleware.
73+
func (*Middleware) Close() error {
74+
// Token exchange middleware doesn't need cleanup
75+
return nil
76+
}
77+
78+
// CreateMiddleware factory function for token exchange middleware
79+
func CreateMiddleware(config *types.MiddlewareConfig, runner types.MiddlewareRunner) error {
80+
var params MiddlewareParams
81+
if err := json.Unmarshal(config.Parameters, &params); err != nil {
82+
return fmt.Errorf("failed to unmarshal token exchange middleware parameters: %w", err)
83+
}
84+
85+
// Token exchange config is required when this middleware type is specified
86+
if params.TokenExchangeConfig == nil {
87+
return fmt.Errorf("token exchange configuration is required but not provided")
88+
}
89+
90+
// Validate configuration
91+
if err := validateTokenExchangeConfig(params.TokenExchangeConfig); err != nil {
92+
return fmt.Errorf("invalid token exchange configuration: %w", err)
93+
}
94+
95+
middleware, err := CreateTokenExchangeMiddlewareFromClaims(*params.TokenExchangeConfig)
96+
if err != nil {
97+
return fmt.Errorf("invalid token exchange middleware config: %w", err)
98+
}
99+
100+
tokenExchangeMw := &Middleware{
101+
middleware: middleware,
102+
}
103+
104+
// Add middleware to runner
105+
runner.AddMiddleware(tokenExchangeMw)
106+
107+
return nil
108+
}
109+
110+
// validateTokenExchangeConfig validates the token exchange configuration
111+
func validateTokenExchangeConfig(config *Config) error {
112+
if config.HeaderStrategy == HeaderStrategyCustom && config.ExternalTokenHeaderName == "" {
113+
return fmt.Errorf("external_token_header_name must be specified when header_strategy is '%s'", HeaderStrategyCustom)
114+
}
115+
116+
if config.HeaderStrategy != "" &&
117+
config.HeaderStrategy != HeaderStrategyReplace &&
118+
config.HeaderStrategy != HeaderStrategyCustom {
119+
return fmt.Errorf("invalid header_strategy: %s (valid values: '%s', '%s')",
120+
config.HeaderStrategy, HeaderStrategyReplace, HeaderStrategyCustom)
121+
}
122+
123+
return nil
124+
}
125+
126+
// injectionFunc is a function that injects a token into an HTTP request
127+
type injectionFunc func(*http.Request, string) error
128+
129+
// createReplaceInjector creates an injection function that replaces the Authorization header
130+
func createReplaceInjector() injectionFunc {
131+
return func(r *http.Request, token string) error {
132+
logger.Debugf("Token exchange successful, replacing Authorization header")
133+
r.Header.Set("Authorization", fmt.Sprintf("Bearer %s", token))
134+
return nil
135+
}
136+
}
137+
138+
// createCustomInjector creates an injection function that adds the token to a custom header
139+
func createCustomInjector(headerName string) injectionFunc {
140+
// Validate header name at creation time
141+
if headerName == "" {
142+
return func(_ *http.Request, _ string) error {
143+
return fmt.Errorf("external_token_header_name must be specified when header_strategy is '%s'", HeaderStrategyCustom)
144+
}
145+
}
146+
147+
return func(r *http.Request, token string) error {
148+
logger.Debugf("Token exchange successful, adding token to custom header: %s", headerName)
149+
r.Header.Set(headerName, fmt.Sprintf("Bearer %s", token))
150+
return nil
151+
}
152+
}
153+
154+
// CreateTokenExchangeMiddlewareFromClaims creates a middleware that uses token claims
155+
// from the auth middleware to perform token exchange.
156+
// This is a public function for direct usage in proxy commands.
157+
func CreateTokenExchangeMiddlewareFromClaims(config Config) (types.MiddlewareFunction, error) {
158+
// Determine injection strategy at startup time
159+
strategy := config.HeaderStrategy
160+
if strategy == "" {
161+
strategy = HeaderStrategyReplace // Default to replace for backwards compatibility
162+
}
163+
164+
var injectToken injectionFunc
165+
switch strategy {
166+
case HeaderStrategyReplace:
167+
injectToken = createReplaceInjector()
168+
case HeaderStrategyCustom:
169+
injectToken = createCustomInjector(config.ExternalTokenHeaderName)
170+
default:
171+
return nil, fmt.Errorf("%w: invalid header injection strategy %s", errUnknownStrategy, strategy)
172+
}
173+
174+
// Create base exchange config at startup time with all static fields
175+
baseExchangeConfig := ExchangeConfig{
176+
TokenURL: config.TokenURL,
177+
ClientID: config.ClientID,
178+
ClientSecret: config.ClientSecret,
179+
Audience: config.Audience,
180+
Scopes: config.Scopes,
181+
// SubjectTokenProvider will be set per request
182+
}
183+
184+
return func(next http.Handler) http.Handler {
185+
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
186+
// Get claims from the auth middleware
187+
claims, ok := r.Context().Value(auth.ClaimsContextKey{}).(jwt.MapClaims)
188+
if !ok {
189+
logger.Debug("No claims found in context, proceeding without token exchange")
190+
next.ServeHTTP(w, r)
191+
return
192+
}
193+
194+
// Extract the original token from the Authorization header
195+
authHeader := r.Header.Get("Authorization")
196+
if authHeader == "" || !strings.HasPrefix(authHeader, "Bearer ") {
197+
logger.Debug("No valid Bearer token found, proceeding without token exchange")
198+
next.ServeHTTP(w, r)
199+
return
200+
}
201+
202+
subjectToken := strings.TrimPrefix(authHeader, "Bearer ")
203+
if subjectToken == "" {
204+
logger.Debug("Empty Bearer token, proceeding without token exchange")
205+
next.ServeHTTP(w, r)
206+
return
207+
}
208+
209+
// Log some claim information for debugging
210+
if sub, exists := claims["sub"]; exists {
211+
logger.Debugf("Performing token exchange for subject: %v", sub)
212+
}
213+
214+
// Create a copy of the base config with the request-specific subject token
215+
exchangeConfig := baseExchangeConfig
216+
exchangeConfig.SubjectTokenProvider = func() (string, error) {
217+
return subjectToken, nil
218+
}
219+
220+
// Get token from token source
221+
tokenSource := exchangeConfig.TokenSource(r.Context())
222+
exchangedToken, err := tokenSource.Token()
223+
if err != nil {
224+
logger.Warnf("Token exchange failed: %v", err)
225+
http.Error(w, "Token exchange failed", http.StatusUnauthorized)
226+
return
227+
}
228+
229+
// Inject the exchanged token into the request using the pre-selected strategy
230+
if err := injectToken(r, exchangedToken.AccessToken); err != nil {
231+
logger.Warnf("Failed to inject token: %v", err)
232+
http.Error(w, "Token injection failed", http.StatusInternalServerError)
233+
return
234+
}
235+
236+
next.ServeHTTP(w, r)
237+
})
238+
}, nil
239+
}

0 commit comments

Comments
 (0)