Skip to content
Open
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
68 changes: 68 additions & 0 deletions cmd/thv/app/auth_flags.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@ import (

"github.com/spf13/cobra"

"github.com/stacklok/toolhive/pkg/auth/tokenexchange"
"github.com/stacklok/toolhive/pkg/logger"
"github.com/stacklok/toolhive/pkg/runner"
)

Expand All @@ -21,6 +23,56 @@ type RemoteAuthFlags struct {
RemoteAuthIssuer string
RemoteAuthAuthorizeURL string
RemoteAuthTokenURL string

// Token Exchange Configuration
TokenExchangeURL string
TokenExchangeClientID string
TokenExchangeClientSecret string
TokenExchangeClientSecretFile string
TokenExchangeAudience string
TokenExchangeScopes []string
TokenExchangeHeaderName string
}

// BuildTokenExchangeConfig creates a TokenExchangeConfig from the RemoteAuthFlags
// Returns nil if TokenExchangeURL is empty (token exchange is not configured)
func (f *RemoteAuthFlags) BuildTokenExchangeConfig() *tokenexchange.Config {
// Only create config if token exchange URL is provided
if f.TokenExchangeURL == "" {
return nil
}

// Resolve token exchange client secret from multiple sources
clientSecret, err := resolveSecretFromSources(
f.TokenExchangeClientSecret,
f.TokenExchangeClientSecretFile,
envTokenExchangeClientSecret,
"token exchange client secret",
)
if err != nil {
logger.Warnf("Failed to resolve token exchange client secret: %v", err)
clientSecret = ""
Comment on lines +54 to +55
Copy link

Copilot AI Oct 10, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The error is logged as a warning but then silently ignored by setting clientSecret to empty string. This could lead to configuration issues being overlooked. Consider either propagating the error or making the warning more explicit about the fallback behavior.

Suggested change
logger.Warnf("Failed to resolve token exchange client secret: %v", err)
clientSecret = ""
logger.Errorf("Failed to resolve token exchange client secret: %v. Token exchange will be disabled.", err)
return nil

Copilot uses AI. Check for mistakes.

}

// Determine header strategy based on whether custom header name is provided
var headerStrategy string
var externalTokenHeaderName string
if f.TokenExchangeHeaderName != "" {
headerStrategy = tokenexchange.HeaderStrategyCustom
externalTokenHeaderName = f.TokenExchangeHeaderName
} else {
headerStrategy = tokenexchange.HeaderStrategyReplace
}

return &tokenexchange.Config{
TokenURL: f.TokenExchangeURL,
ClientID: f.TokenExchangeClientID,
ClientSecret: clientSecret,
Audience: f.TokenExchangeAudience,
Scopes: f.TokenExchangeScopes,
HeaderStrategy: headerStrategy,
ExternalTokenHeaderName: externalTokenHeaderName,
}
}

// AddRemoteAuthFlags adds the common remote authentication flags to a command
Expand All @@ -47,4 +99,20 @@ func AddRemoteAuthFlags(cmd *cobra.Command, config *RemoteAuthFlags) {
"OAuth authorization endpoint URL (alternative to --remote-auth-issuer for non-OIDC OAuth)")
cmd.Flags().StringVar(&config.RemoteAuthTokenURL, "remote-auth-token-url", "",
"OAuth token endpoint URL (alternative to --remote-auth-issuer for non-OIDC OAuth)")

// Token Exchange flags
cmd.Flags().StringVar(&config.TokenExchangeURL, "token-exchange-url", "",
"OAuth 2.0 token exchange endpoint URL (enables token exchange when provided)")
cmd.Flags().StringVar(&config.TokenExchangeClientID, "token-exchange-client-id", "",
"OAuth client ID for token exchange operations")
cmd.Flags().StringVar(&config.TokenExchangeClientSecret, "token-exchange-client-secret", "",
"OAuth client secret for token exchange operations")
cmd.Flags().StringVar(&config.TokenExchangeClientSecretFile, "token-exchange-client-secret-file", "",
"Path to file containing OAuth client secret for token exchange (alternative to --token-exchange-client-secret)")
cmd.Flags().StringVar(&config.TokenExchangeAudience, "token-exchange-audience", "",
"Target audience for exchanged tokens")
cmd.Flags().StringSliceVar(&config.TokenExchangeScopes, "token-exchange-scopes", []string{},
"Scopes to request for exchanged tokens")
cmd.Flags().StringVar(&config.TokenExchangeHeaderName, "token-exchange-header-name", "",
"Custom header name for injecting exchanged token (default: replaces Authorization header)")
}
89 changes: 69 additions & 20 deletions cmd/thv/app/proxy.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ import (
"github.com/stacklok/toolhive/pkg/auth"
"github.com/stacklok/toolhive/pkg/auth/discovery"
"github.com/stacklok/toolhive/pkg/auth/oauth"
"github.com/stacklok/toolhive/pkg/auth/tokenexchange"
"github.com/stacklok/toolhive/pkg/logger"
"github.com/stacklok/toolhive/pkg/networking"
"github.com/stacklok/toolhive/pkg/transport"
Expand Down Expand Up @@ -121,6 +122,8 @@ var (
const (
// #nosec G101 - this is an environment variable name, not a credential
envOAuthClientSecret = "TOOLHIVE_REMOTE_OAUTH_CLIENT_SECRET"
// #nosec G101 - this is an environment variable name, not a credential
envTokenExchangeClientSecret = "TOOLHIVE_TOKEN_EXCHANGE_CLIENT_SECRET"
)

func init() {
Expand Down Expand Up @@ -227,10 +230,9 @@ func proxyCmdFunc(cmd *cobra.Command, args []string) error {
}
middlewares = append(middlewares, authMiddleware)

// Add OAuth token injection middleware for outgoing requests if we have an access token
if tokenSource != nil {
tokenMiddleware := createTokenInjectionMiddleware(tokenSource)
middlewares = append(middlewares, tokenMiddleware)
// Add OAuth token injection or token exchange middleware for outgoing requests
if err := addExternalTokenMiddleware(&middlewares, tokenSource); err != nil {
return err
}

// Create the transparent proxy
Expand Down Expand Up @@ -346,43 +348,70 @@ func handleOutgoingAuthentication(ctx context.Context) (*oauth2.TokenSource, *oa
return nil, nil, nil // No authentication required
}

// resolveClientSecret resolves the OAuth client secret from multiple sources
// Priority: 1. Flag value, 2. File, 3. Environment variable
func resolveClientSecret() (string, error) {
// resolveSecretFromSources resolves a secret from multiple sources with priority ordering
// Priority: 1. Direct value (flag), 2. File path, 3. Environment variable
// Returns empty string if no source provides a value (not an error)
Copy link

Copilot AI Oct 10, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The function comment states it returns empty string when no source provides a value, but this behavior may not be clear to callers who expect secrets. Consider clarifying that callers should validate if an empty result is acceptable for their use case.

Suggested change
// Returns empty string if no source provides a value (not an error)
// Returns empty string if no source provides a value (not an error).
// Callers MUST validate whether an empty result is acceptable for their use case,
// especially when expecting secrets. An empty string means no secret was found.

Copilot uses AI. Check for mistakes.

func resolveSecretFromSources(directValue, filePath, envVarName, secretType string) (string, error) {
// 1. Check if provided directly via flag
if remoteAuthFlags.RemoteAuthClientSecret != "" {
logger.Debug("Using client secret from command-line flag")
return remoteAuthFlags.RemoteAuthClientSecret, nil
if directValue != "" {
logger.Debugf("Using %s from command-line flag", secretType)
return directValue, nil
}

// 2. Check if provided via file
if remoteAuthFlags.RemoteAuthClientSecretFile != "" {
if filePath != "" {
// Clean the file path to prevent path traversal
cleanPath := filepath.Clean(remoteAuthFlags.RemoteAuthClientSecretFile)
logger.Debugf("Reading client secret from file: %s", cleanPath)
cleanPath := filepath.Clean(filePath)
logger.Debugf("Reading %s from file: %s", secretType, cleanPath)
// #nosec G304 - file path is cleaned above
secretBytes, err := os.ReadFile(cleanPath)
if err != nil {
return "", fmt.Errorf("failed to read client secret file %s: %w", cleanPath, err)
return "", fmt.Errorf("failed to read %s file %s: %w", secretType, cleanPath, err)
}
secret := strings.TrimSpace(string(secretBytes))
if secret == "" {
return "", fmt.Errorf("client secret file %s is empty", cleanPath)
return "", fmt.Errorf("%s file %s is empty", secretType, cleanPath)
}
return secret, nil
}

// 3. Check environment variable
if secret := os.Getenv(envOAuthClientSecret); secret != "" {
logger.Debugf("Using client secret from %s environment variable", envOAuthClientSecret)
return secret, nil
if envVarName != "" {
if secret := os.Getenv(envVarName); secret != "" {
logger.Debugf("Using %s from %s environment variable", secretType, envVarName)
return secret, nil
}
}

// No client secret found - this is acceptable for PKCE flows
logger.Debug("No client secret provided - using PKCE flow")
// No secret found - return empty string (caller decides if this is an error)
logger.Debugf("No %s provided", secretType)
return "", nil
}

// resolveClientSecret resolves the OAuth client secret from multiple sources
// Priority: 1. Flag value, 2. File, 3. Environment variable
func resolveClientSecret() (string, error) {
secret, err := resolveSecretFromSources(
remoteAuthFlags.RemoteAuthClientSecret,
remoteAuthFlags.RemoteAuthClientSecretFile,
envOAuthClientSecret,
"client secret",
)
if err != nil {
return "", err
}
if secret == "" {
// No client secret found - this is acceptable for PKCE flows
logger.Debug("No client secret provided - using PKCE flow")
}
return secret, nil
}

// createTokenExchangeConfig creates a TokenExchangeConfig from remoteAuthFlags
func createTokenExchangeConfig() *tokenexchange.Config {
return remoteAuthFlags.BuildTokenExchangeConfig()
}

// createTokenInjectionMiddleware creates a middleware that injects the OAuth token into requests
func createTokenInjectionMiddleware(tokenSource *oauth2.TokenSource) types.MiddlewareFunction {
return func(next http.Handler) http.Handler {
Expand All @@ -399,6 +428,26 @@ func createTokenInjectionMiddleware(tokenSource *oauth2.TokenSource) types.Middl
}
}

// addExternalTokenMiddleware adds token exchange or token injection middleware to the middleware chain
func addExternalTokenMiddleware(middlewares *[]types.MiddlewareFunction, tokenSource *oauth2.TokenSource) error {
if remoteAuthFlags.TokenExchangeURL != "" {
// Use token exchange middleware when token exchange is configured
tokenExchangeConfig := createTokenExchangeConfig()
if tokenExchangeConfig != nil {
tokenExchangeMiddleware, err := tokenexchange.CreateTokenExchangeMiddlewareFromClaims(*tokenExchangeConfig)
if err != nil {
return fmt.Errorf("failed to create token exchange middleware: %v", err)
}
*middlewares = append(*middlewares, tokenExchangeMiddleware)
}
} else if tokenSource != nil {
// Fallback to direct token injection when no token exchange is configured
tokenMiddleware := createTokenInjectionMiddleware(tokenSource)
*middlewares = append(*middlewares, tokenMiddleware)
}
return nil
}

// validateProxyTargetURI validates that the target URI for the proxy is valid and does not contain a path
func validateProxyTargetURI(targetURI string) error {
// Parse the target URI
Expand Down
8 changes: 8 additions & 0 deletions cmd/thv/app/run_flags.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ import (
"github.com/spf13/cobra"

"github.com/stacklok/toolhive/pkg/auth"
"github.com/stacklok/toolhive/pkg/auth/tokenexchange"
"github.com/stacklok/toolhive/pkg/authz"
"github.com/stacklok/toolhive/pkg/cli"
cfg "github.com/stacklok/toolhive/pkg/config"
Expand Down Expand Up @@ -458,10 +459,12 @@ func buildRunnerConfig(
opts = append(opts, runner.WithToolsOverride(toolsOverride))
// Configure middleware from flags
// Use computed serverName and transportType for correct telemetry labels
tokenExchangeConfig := getTokenExchangeConfigFromRunFlags(runFlags)
opts = append(
opts,
runner.WithMiddlewareFromFlags(
oidcConfig,
tokenExchangeConfig,
runFlags.ToolsFilter,
toolsOverride,
telemetryConfig,
Expand Down Expand Up @@ -605,6 +608,11 @@ func getRemoteAuthFromRunFlags(runFlags *RunFlags) *runner.RemoteAuthConfig {
}
}

// getTokenExchangeConfigFromRunFlags creates TokenExchangeConfig from RunFlags
func getTokenExchangeConfigFromRunFlags(runFlags *RunFlags) *tokenexchange.Config {
return runFlags.RemoteAuthFlags.BuildTokenExchangeConfig()
}

// getOidcFromFlags extracts OIDC configuration from command flags
func getOidcFromFlags(cmd *cobra.Command) (string, string, string, string, string, string) {
oidcIssuer := GetStringFlagOrEmpty(cmd, "oidc-issuer")
Expand Down
51 changes: 29 additions & 22 deletions docs/cli/thv_proxy.md

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

Loading
Loading