-
Notifications
You must be signed in to change notification settings - Fork 126
Integrate token exchange middleware #2143
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from 2 commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change | ||||||||
---|---|---|---|---|---|---|---|---|---|---|
|
@@ -18,6 +18,7 @@ import ( | |||||||||
"github.com/stacklok/toolhive/pkg/auth" | ||||||||||
"github.com/stacklok/toolhive/pkg/auth/discovery" | ||||||||||
"github.com/stacklok/toolhive/pkg/auth/oauth" | ||||||||||
"github.com/stacklok/toolhive/pkg/auth/tokenexchange" | ||||||||||
"github.com/stacklok/toolhive/pkg/logger" | ||||||||||
"github.com/stacklok/toolhive/pkg/networking" | ||||||||||
"github.com/stacklok/toolhive/pkg/transport" | ||||||||||
|
@@ -121,6 +122,8 @@ var ( | |||||||||
const ( | ||||||||||
// #nosec G101 - this is an environment variable name, not a credential | ||||||||||
envOAuthClientSecret = "TOOLHIVE_REMOTE_OAUTH_CLIENT_SECRET" | ||||||||||
// #nosec G101 - this is an environment variable name, not a credential | ||||||||||
envTokenExchangeClientSecret = "TOOLHIVE_TOKEN_EXCHANGE_CLIENT_SECRET" | ||||||||||
) | ||||||||||
|
||||||||||
func init() { | ||||||||||
|
@@ -227,10 +230,9 @@ func proxyCmdFunc(cmd *cobra.Command, args []string) error { | |||||||||
} | ||||||||||
middlewares = append(middlewares, authMiddleware) | ||||||||||
|
||||||||||
// Add OAuth token injection middleware for outgoing requests if we have an access token | ||||||||||
if tokenSource != nil { | ||||||||||
tokenMiddleware := createTokenInjectionMiddleware(tokenSource) | ||||||||||
middlewares = append(middlewares, tokenMiddleware) | ||||||||||
// Add OAuth token injection or token exchange middleware for outgoing requests | ||||||||||
if err := addExternalTokenMiddleware(&middlewares, tokenSource); err != nil { | ||||||||||
return err | ||||||||||
} | ||||||||||
|
||||||||||
// Create the transparent proxy | ||||||||||
|
@@ -346,43 +348,70 @@ func handleOutgoingAuthentication(ctx context.Context) (*oauth2.TokenSource, *oa | |||||||||
return nil, nil, nil // No authentication required | ||||||||||
} | ||||||||||
|
||||||||||
// resolveClientSecret resolves the OAuth client secret from multiple sources | ||||||||||
// Priority: 1. Flag value, 2. File, 3. Environment variable | ||||||||||
func resolveClientSecret() (string, error) { | ||||||||||
// resolveSecretFromSources resolves a secret from multiple sources with priority ordering | ||||||||||
// Priority: 1. Direct value (flag), 2. File path, 3. Environment variable | ||||||||||
// Returns empty string if no source provides a value (not an error) | ||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The function comment states it returns empty string when no source provides a value, but this behavior may not be clear to callers who expect secrets. Consider clarifying that callers should validate if an empty result is acceptable for their use case.
Suggested change
Copilot uses AI. Check for mistakes. Positive FeedbackNegative Feedback |
||||||||||
func resolveSecretFromSources(directValue, filePath, envVarName, secretType string) (string, error) { | ||||||||||
// 1. Check if provided directly via flag | ||||||||||
if remoteAuthFlags.RemoteAuthClientSecret != "" { | ||||||||||
logger.Debug("Using client secret from command-line flag") | ||||||||||
return remoteAuthFlags.RemoteAuthClientSecret, nil | ||||||||||
if directValue != "" { | ||||||||||
logger.Debugf("Using %s from command-line flag", secretType) | ||||||||||
return directValue, nil | ||||||||||
} | ||||||||||
|
||||||||||
// 2. Check if provided via file | ||||||||||
if remoteAuthFlags.RemoteAuthClientSecretFile != "" { | ||||||||||
if filePath != "" { | ||||||||||
// Clean the file path to prevent path traversal | ||||||||||
cleanPath := filepath.Clean(remoteAuthFlags.RemoteAuthClientSecretFile) | ||||||||||
logger.Debugf("Reading client secret from file: %s", cleanPath) | ||||||||||
cleanPath := filepath.Clean(filePath) | ||||||||||
logger.Debugf("Reading %s from file: %s", secretType, cleanPath) | ||||||||||
// #nosec G304 - file path is cleaned above | ||||||||||
secretBytes, err := os.ReadFile(cleanPath) | ||||||||||
if err != nil { | ||||||||||
return "", fmt.Errorf("failed to read client secret file %s: %w", cleanPath, err) | ||||||||||
return "", fmt.Errorf("failed to read %s file %s: %w", secretType, cleanPath, err) | ||||||||||
} | ||||||||||
secret := strings.TrimSpace(string(secretBytes)) | ||||||||||
if secret == "" { | ||||||||||
return "", fmt.Errorf("client secret file %s is empty", cleanPath) | ||||||||||
return "", fmt.Errorf("%s file %s is empty", secretType, cleanPath) | ||||||||||
} | ||||||||||
return secret, nil | ||||||||||
} | ||||||||||
|
||||||||||
// 3. Check environment variable | ||||||||||
if secret := os.Getenv(envOAuthClientSecret); secret != "" { | ||||||||||
logger.Debugf("Using client secret from %s environment variable", envOAuthClientSecret) | ||||||||||
return secret, nil | ||||||||||
if envVarName != "" { | ||||||||||
if secret := os.Getenv(envVarName); secret != "" { | ||||||||||
logger.Debugf("Using %s from %s environment variable", secretType, envVarName) | ||||||||||
return secret, nil | ||||||||||
} | ||||||||||
} | ||||||||||
|
||||||||||
// No client secret found - this is acceptable for PKCE flows | ||||||||||
logger.Debug("No client secret provided - using PKCE flow") | ||||||||||
// No secret found - return empty string (caller decides if this is an error) | ||||||||||
logger.Debugf("No %s provided", secretType) | ||||||||||
return "", nil | ||||||||||
} | ||||||||||
|
||||||||||
// resolveClientSecret resolves the OAuth client secret from multiple sources | ||||||||||
// Priority: 1. Flag value, 2. File, 3. Environment variable | ||||||||||
func resolveClientSecret() (string, error) { | ||||||||||
secret, err := resolveSecretFromSources( | ||||||||||
remoteAuthFlags.RemoteAuthClientSecret, | ||||||||||
remoteAuthFlags.RemoteAuthClientSecretFile, | ||||||||||
envOAuthClientSecret, | ||||||||||
"client secret", | ||||||||||
) | ||||||||||
if err != nil { | ||||||||||
return "", err | ||||||||||
} | ||||||||||
if secret == "" { | ||||||||||
// No client secret found - this is acceptable for PKCE flows | ||||||||||
logger.Debug("No client secret provided - using PKCE flow") | ||||||||||
} | ||||||||||
return secret, nil | ||||||||||
} | ||||||||||
|
||||||||||
// createTokenExchangeConfig creates a TokenExchangeConfig from remoteAuthFlags | ||||||||||
func createTokenExchangeConfig() *tokenexchange.Config { | ||||||||||
return remoteAuthFlags.BuildTokenExchangeConfig() | ||||||||||
} | ||||||||||
|
||||||||||
// createTokenInjectionMiddleware creates a middleware that injects the OAuth token into requests | ||||||||||
func createTokenInjectionMiddleware(tokenSource *oauth2.TokenSource) types.MiddlewareFunction { | ||||||||||
return func(next http.Handler) http.Handler { | ||||||||||
|
@@ -399,6 +428,26 @@ func createTokenInjectionMiddleware(tokenSource *oauth2.TokenSource) types.Middl | |||||||||
} | ||||||||||
} | ||||||||||
|
||||||||||
// addExternalTokenMiddleware adds token exchange or token injection middleware to the middleware chain | ||||||||||
func addExternalTokenMiddleware(middlewares *[]types.MiddlewareFunction, tokenSource *oauth2.TokenSource) error { | ||||||||||
if remoteAuthFlags.TokenExchangeURL != "" { | ||||||||||
// Use token exchange middleware when token exchange is configured | ||||||||||
tokenExchangeConfig := createTokenExchangeConfig() | ||||||||||
if tokenExchangeConfig != nil { | ||||||||||
tokenExchangeMiddleware, err := tokenexchange.CreateTokenExchangeMiddlewareFromClaims(*tokenExchangeConfig) | ||||||||||
if err != nil { | ||||||||||
return fmt.Errorf("failed to create token exchange middleware: %v", err) | ||||||||||
} | ||||||||||
*middlewares = append(*middlewares, tokenExchangeMiddleware) | ||||||||||
} | ||||||||||
} else if tokenSource != nil { | ||||||||||
// Fallback to direct token injection when no token exchange is configured | ||||||||||
tokenMiddleware := createTokenInjectionMiddleware(tokenSource) | ||||||||||
*middlewares = append(*middlewares, tokenMiddleware) | ||||||||||
} | ||||||||||
return nil | ||||||||||
} | ||||||||||
|
||||||||||
// validateProxyTargetURI validates that the target URI for the proxy is valid and does not contain a path | ||||||||||
func validateProxyTargetURI(targetURI string) error { | ||||||||||
// Parse the target URI | ||||||||||
|
Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The error is logged as a warning but then silently ignored by setting clientSecret to empty string. This could lead to configuration issues being overlooked. Consider either propagating the error or making the warning more explicit about the fallback behavior.
Copilot uses AI. Check for mistakes.