diff --git a/cmd/thv/app/auth_flags.go b/cmd/thv/app/auth_flags.go index 20a9d677d..a8d45e197 100644 --- a/cmd/thv/app/auth_flags.go +++ b/cmd/thv/app/auth_flags.go @@ -5,6 +5,8 @@ import ( "github.com/spf13/cobra" + "github.com/stacklok/toolhive/pkg/auth/tokenexchange" + "github.com/stacklok/toolhive/pkg/logger" "github.com/stacklok/toolhive/pkg/runner" ) @@ -21,6 +23,58 @@ type RemoteAuthFlags struct { RemoteAuthIssuer string RemoteAuthAuthorizeURL string RemoteAuthTokenURL string + + // Token Exchange Configuration + TokenExchangeURL string + TokenExchangeClientID string + TokenExchangeClientSecret string + TokenExchangeClientSecretFile string + TokenExchangeAudience string + TokenExchangeScopes []string + TokenExchangeSubjectTokenType string + TokenExchangeHeaderName string +} + +// BuildTokenExchangeConfig creates a TokenExchangeConfig from the RemoteAuthFlags +// Returns nil if TokenExchangeURL is empty (token exchange is not configured) +func (f *RemoteAuthFlags) BuildTokenExchangeConfig() *tokenexchange.Config { + // Only create config if token exchange URL is provided + if f.TokenExchangeURL == "" { + return nil + } + + // Resolve token exchange client secret from multiple sources + clientSecret, err := resolveSecretFromSources( + f.TokenExchangeClientSecret, + f.TokenExchangeClientSecretFile, + envTokenExchangeClientSecret, + "token exchange client secret", + ) + if err != nil { + logger.Warnf("Failed to resolve token exchange client secret: %v", err) + clientSecret = "" + } + + // Determine header strategy based on whether custom header name is provided + var headerStrategy string + var externalTokenHeaderName string + if f.TokenExchangeHeaderName != "" { + headerStrategy = tokenexchange.HeaderStrategyCustom + externalTokenHeaderName = f.TokenExchangeHeaderName + } else { + headerStrategy = tokenexchange.HeaderStrategyReplace + } + + return &tokenexchange.Config{ + TokenURL: f.TokenExchangeURL, + ClientID: f.TokenExchangeClientID, + ClientSecret: clientSecret, + Audience: f.TokenExchangeAudience, + Scopes: f.TokenExchangeScopes, + SubjectTokenType: f.TokenExchangeSubjectTokenType, + HeaderStrategy: headerStrategy, + ExternalTokenHeaderName: externalTokenHeaderName, + } } // AddRemoteAuthFlags adds the common remote authentication flags to a command @@ -47,4 +101,22 @@ func AddRemoteAuthFlags(cmd *cobra.Command, config *RemoteAuthFlags) { "OAuth authorization endpoint URL (alternative to --remote-auth-issuer for non-OIDC OAuth)") cmd.Flags().StringVar(&config.RemoteAuthTokenURL, "remote-auth-token-url", "", "OAuth token endpoint URL (alternative to --remote-auth-issuer for non-OIDC OAuth)") + + // Token Exchange flags + cmd.Flags().StringVar(&config.TokenExchangeURL, "token-exchange-url", "", + "OAuth 2.0 token exchange endpoint URL (enables token exchange when provided)") + cmd.Flags().StringVar(&config.TokenExchangeClientID, "token-exchange-client-id", "", + "OAuth client ID for token exchange operations") + cmd.Flags().StringVar(&config.TokenExchangeClientSecret, "token-exchange-client-secret", "", + "OAuth client secret for token exchange operations") + cmd.Flags().StringVar(&config.TokenExchangeClientSecretFile, "token-exchange-client-secret-file", "", + "Path to file containing OAuth client secret for token exchange (alternative to --token-exchange-client-secret)") + cmd.Flags().StringVar(&config.TokenExchangeAudience, "token-exchange-audience", "", + "Target audience for exchanged tokens") + cmd.Flags().StringSliceVar(&config.TokenExchangeScopes, "token-exchange-scopes", []string{}, + "Scopes to request for exchanged tokens") + cmd.Flags().StringVar(&config.TokenExchangeSubjectTokenType, "token-exchange-subject-token-type", "", + "Type of subject token to exchange (default: urn:ietf:params:oauth:token-type:access_token, Google STS requires: urn:ietf:params:oauth:token-type:id_token)") + cmd.Flags().StringVar(&config.TokenExchangeHeaderName, "token-exchange-header-name", "", + "Custom header name for injecting exchanged token (default: replaces Authorization header)") } diff --git a/cmd/thv/app/proxy.go b/cmd/thv/app/proxy.go index 291f1fe83..9f9b3555f 100644 --- a/cmd/thv/app/proxy.go +++ b/cmd/thv/app/proxy.go @@ -18,6 +18,7 @@ import ( "github.com/stacklok/toolhive/pkg/auth" "github.com/stacklok/toolhive/pkg/auth/discovery" "github.com/stacklok/toolhive/pkg/auth/oauth" + "github.com/stacklok/toolhive/pkg/auth/tokenexchange" "github.com/stacklok/toolhive/pkg/logger" "github.com/stacklok/toolhive/pkg/networking" "github.com/stacklok/toolhive/pkg/transport" @@ -121,6 +122,8 @@ var ( const ( // #nosec G101 - this is an environment variable name, not a credential envOAuthClientSecret = "TOOLHIVE_REMOTE_OAUTH_CLIENT_SECRET" + // #nosec G101 - this is an environment variable name, not a credential + envTokenExchangeClientSecret = "TOOLHIVE_TOKEN_EXCHANGE_CLIENT_SECRET" ) func init() { @@ -180,16 +183,24 @@ func proxyCmdFunc(cmd *cobra.Command, args []string) error { // Handle OAuth authentication to the remote server if needed var tokenSource *oauth2.TokenSource var oauthConfig *oauth.Config + var idToken string var introspectionURL string if remoteAuthFlags.EnableRemoteAuth || shouldDetectAuth() { - tokenSource, oauthConfig, err = handleOutgoingAuthentication(ctx) + var result *discovery.OAuthFlowResult + result, err = handleOutgoingAuthentication(ctx) if err != nil { return fmt.Errorf("failed to authenticate to remote server: %w", err) } - if oauthConfig != nil { - introspectionURL = oauthConfig.IntrospectionEndpoint - logger.Infof("Using OAuth config with introspection URL: %s", introspectionURL) + if result != nil { + tokenSource = result.TokenSource + oauthConfig = result.Config + idToken = result.IDToken + + if oauthConfig != nil { + introspectionURL = oauthConfig.IntrospectionEndpoint + logger.Infof("Using OAuth config with introspection URL: %s", introspectionURL) + } } else { logger.Info("No OAuth configuration available, proceeding without outgoing authentication") } @@ -227,10 +238,9 @@ func proxyCmdFunc(cmd *cobra.Command, args []string) error { } middlewares = append(middlewares, authMiddleware) - // Add OAuth token injection middleware for outgoing requests if we have an access token - if tokenSource != nil { - tokenMiddleware := createTokenInjectionMiddleware(tokenSource) - middlewares = append(middlewares, tokenMiddleware) + // Add OAuth token injection or token exchange middleware for outgoing requests + if err := addExternalTokenMiddleware(&middlewares, tokenSource, idToken); err != nil { + return err } // Create the transparent proxy @@ -272,11 +282,11 @@ func shouldDetectAuth() bool { } // handleOutgoingAuthentication handles authentication to the remote MCP server -func handleOutgoingAuthentication(ctx context.Context) (*oauth2.TokenSource, *oauth.Config, error) { +func handleOutgoingAuthentication(ctx context.Context) (*discovery.OAuthFlowResult, error) { // Resolve client secret from multiple sources clientSecret, err := resolveClientSecret() if err != nil { - return nil, nil, fmt.Errorf("failed to resolve client secret: %w", err) + return nil, fmt.Errorf("failed to resolve client secret: %w", err) } if remoteAuthFlags.EnableRemoteAuth { @@ -286,12 +296,12 @@ func handleOutgoingAuthentication(ctx context.Context) (*oauth2.TokenSource, *oa hasManualConfig := remoteAuthFlags.RemoteAuthAuthorizeURL != "" && remoteAuthFlags.RemoteAuthTokenURL != "" if !hasOIDCConfig && !hasManualConfig { - return nil, nil, fmt.Errorf("either --remote-auth-issuer (for OIDC) or both --remote-auth-authorize-url " + + return nil, fmt.Errorf("either --remote-auth-issuer (for OIDC) or both --remote-auth-authorize-url " + "and --remote-auth-token-url (for OAuth) are required") } if hasOIDCConfig && hasManualConfig { - return nil, nil, fmt.Errorf("cannot specify both OIDC issuer and manual OAuth endpoints - choose one approach") + return nil, fmt.Errorf("cannot specify both OIDC issuer and manual OAuth endpoints - choose one approach") } flowConfig := &discovery.OAuthFlowConfig{ @@ -307,17 +317,17 @@ func handleOutgoingAuthentication(ctx context.Context) (*oauth2.TokenSource, *oa result, err := discovery.PerformOAuthFlow(ctx, remoteAuthFlags.RemoteAuthIssuer, flowConfig) if err != nil { - return nil, nil, err + return nil, err } - return result.TokenSource, result.Config, nil + return result, nil } // Try to detect authentication requirements from WWW-Authenticate header authInfo, err := discovery.DetectAuthenticationFromServer(ctx, proxyTargetURI, nil) if err != nil { logger.Debugf("Could not detect authentication from server: %v", err) - return nil, nil, nil // Not an error, just no auth detected + return nil, nil // Not an error, just no auth detected } if authInfo != nil { @@ -337,52 +347,79 @@ func handleOutgoingAuthentication(ctx context.Context) (*oauth2.TokenSource, *oa result, err := discovery.PerformOAuthFlow(ctx, authInfo.Realm, flowConfig) if err != nil { - return nil, nil, err + return nil, err } - return result.TokenSource, result.Config, nil + return result, nil } - return nil, nil, nil // No authentication required + return nil, nil // No authentication required } -// resolveClientSecret resolves the OAuth client secret from multiple sources -// Priority: 1. Flag value, 2. File, 3. Environment variable -func resolveClientSecret() (string, error) { +// resolveSecretFromSources resolves a secret from multiple sources with priority ordering +// Priority: 1. Direct value (flag), 2. File path, 3. Environment variable +// Returns empty string if no source provides a value (not an error) +func resolveSecretFromSources(directValue, filePath, envVarName, secretType string) (string, error) { // 1. Check if provided directly via flag - if remoteAuthFlags.RemoteAuthClientSecret != "" { - logger.Debug("Using client secret from command-line flag") - return remoteAuthFlags.RemoteAuthClientSecret, nil + if directValue != "" { + logger.Debugf("Using %s from command-line flag", secretType) + return directValue, nil } // 2. Check if provided via file - if remoteAuthFlags.RemoteAuthClientSecretFile != "" { + if filePath != "" { // Clean the file path to prevent path traversal - cleanPath := filepath.Clean(remoteAuthFlags.RemoteAuthClientSecretFile) - logger.Debugf("Reading client secret from file: %s", cleanPath) + cleanPath := filepath.Clean(filePath) + logger.Debugf("Reading %s from file: %s", secretType, cleanPath) // #nosec G304 - file path is cleaned above secretBytes, err := os.ReadFile(cleanPath) if err != nil { - return "", fmt.Errorf("failed to read client secret file %s: %w", cleanPath, err) + return "", fmt.Errorf("failed to read %s file %s: %w", secretType, cleanPath, err) } secret := strings.TrimSpace(string(secretBytes)) if secret == "" { - return "", fmt.Errorf("client secret file %s is empty", cleanPath) + return "", fmt.Errorf("%s file %s is empty", secretType, cleanPath) } return secret, nil } // 3. Check environment variable - if secret := os.Getenv(envOAuthClientSecret); secret != "" { - logger.Debugf("Using client secret from %s environment variable", envOAuthClientSecret) - return secret, nil + if envVarName != "" { + if secret := os.Getenv(envVarName); secret != "" { + logger.Debugf("Using %s from %s environment variable", secretType, envVarName) + return secret, nil + } } - // No client secret found - this is acceptable for PKCE flows - logger.Debug("No client secret provided - using PKCE flow") + // No secret found - return empty string (caller decides if this is an error) + logger.Debugf("No %s provided", secretType) return "", nil } +// resolveClientSecret resolves the OAuth client secret from multiple sources +// Priority: 1. Flag value, 2. File, 3. Environment variable +func resolveClientSecret() (string, error) { + secret, err := resolveSecretFromSources( + remoteAuthFlags.RemoteAuthClientSecret, + remoteAuthFlags.RemoteAuthClientSecretFile, + envOAuthClientSecret, + "client secret", + ) + if err != nil { + return "", err + } + if secret == "" { + // No client secret found - this is acceptable for PKCE flows + logger.Debug("No client secret provided - using PKCE flow") + } + return secret, nil +} + +// createTokenExchangeConfig creates a TokenExchangeConfig from remoteAuthFlags +func createTokenExchangeConfig() *tokenexchange.Config { + return remoteAuthFlags.BuildTokenExchangeConfig() +} + // createTokenInjectionMiddleware creates a middleware that injects the OAuth token into requests func createTokenInjectionMiddleware(tokenSource *oauth2.TokenSource) types.MiddlewareFunction { return func(next http.Handler) http.Handler { @@ -399,6 +436,49 @@ func createTokenInjectionMiddleware(tokenSource *oauth2.TokenSource) types.Middl } } +// addExternalTokenMiddleware adds token exchange or token injection middleware to the middleware chain +func addExternalTokenMiddleware(middlewares *[]types.MiddlewareFunction, tokenSource *oauth2.TokenSource, idToken string) error { + if remoteAuthFlags.TokenExchangeURL != "" { + // Use token exchange middleware when token exchange is configured + tokenExchangeConfig := createTokenExchangeConfig() + if tokenExchangeConfig != nil { + // Create subject token provider from tokenSource or idToken based on subject token type + var subjectTokenProvider tokenexchange.SubjectTokenProvider + + // Check if we should use ID token instead of access token + useIDToken := tokenExchangeConfig.SubjectTokenType == "urn:ietf:params:oauth:token-type:id_token" || + tokenExchangeConfig.SubjectTokenType == "urn:ietf:params:oauth:token-type:jwt" + + if useIDToken && idToken != "" { + // Use the ID token from OAuth flow + subjectTokenProvider = func() (string, error) { + return idToken, nil + } + } else if tokenSource != nil { + // Use access token from token source + subjectTokenProvider = func() (string, error) { + token, err := (*tokenSource).Token() + if err != nil { + return "", fmt.Errorf("failed to get token from source: %w", err) + } + return token.AccessToken, nil + } + } + + tokenExchangeMiddleware, err := tokenexchange.CreateTokenExchangeMiddlewareFromClaims(*tokenExchangeConfig, subjectTokenProvider) + if err != nil { + return fmt.Errorf("failed to create token exchange middleware: %v", err) + } + *middlewares = append(*middlewares, tokenExchangeMiddleware) + } + } else if tokenSource != nil { + // Fallback to direct token injection when no token exchange is configured + tokenMiddleware := createTokenInjectionMiddleware(tokenSource) + *middlewares = append(*middlewares, tokenMiddleware) + } + return nil +} + // validateProxyTargetURI validates that the target URI for the proxy is valid and does not contain a path func validateProxyTargetURI(targetURI string) error { // Parse the target URI diff --git a/cmd/thv/app/run_flags.go b/cmd/thv/app/run_flags.go index dd37bf7e1..e5fcb74dd 100644 --- a/cmd/thv/app/run_flags.go +++ b/cmd/thv/app/run_flags.go @@ -8,6 +8,7 @@ import ( "github.com/spf13/cobra" "github.com/stacklok/toolhive/pkg/auth" + "github.com/stacklok/toolhive/pkg/auth/tokenexchange" "github.com/stacklok/toolhive/pkg/authz" "github.com/stacklok/toolhive/pkg/cli" cfg "github.com/stacklok/toolhive/pkg/config" @@ -458,10 +459,12 @@ func buildRunnerConfig( opts = append(opts, runner.WithToolsOverride(toolsOverride)) // Configure middleware from flags // Use computed serverName and transportType for correct telemetry labels + tokenExchangeConfig := getTokenExchangeConfigFromRunFlags(runFlags) opts = append( opts, runner.WithMiddlewareFromFlags( oidcConfig, + tokenExchangeConfig, runFlags.ToolsFilter, toolsOverride, telemetryConfig, @@ -605,6 +608,11 @@ func getRemoteAuthFromRunFlags(runFlags *RunFlags) *runner.RemoteAuthConfig { } } +// getTokenExchangeConfigFromRunFlags creates TokenExchangeConfig from RunFlags +func getTokenExchangeConfigFromRunFlags(runFlags *RunFlags) *tokenexchange.Config { + return runFlags.RemoteAuthFlags.BuildTokenExchangeConfig() +} + // getOidcFromFlags extracts OIDC configuration from command flags func getOidcFromFlags(cmd *cobra.Command) (string, string, string, string, string, string) { oidcIssuer := GetStringFlagOrEmpty(cmd, "oidc-issuer") diff --git a/docs/cli/thv_proxy.md b/docs/cli/thv_proxy.md index f1ccf7b93..82542ab07 100644 --- a/docs/cli/thv_proxy.md +++ b/docs/cli/thv_proxy.md @@ -97,28 +97,35 @@ thv proxy [flags] SERVER_NAME ### Options ``` - -h, --help help for proxy - --host string Host for the HTTP proxy to listen on (IP or hostname) (default "127.0.0.1") - --oidc-audience string Expected audience for the token - --oidc-client-id string OIDC client ID - --oidc-client-secret string OIDC client secret (optional, for introspection) - --oidc-introspection-url string URL for token introspection endpoint - --oidc-issuer string OIDC issuer URL (e.g., https://accounts.google.com) - --oidc-jwks-url string URL to fetch the JWKS from - --port int Port for the HTTP proxy to listen on (host port) - --remote-auth Enable OAuth/OIDC authentication to remote MCP server - --remote-auth-authorize-url string OAuth authorization endpoint URL (alternative to --remote-auth-issuer for non-OIDC OAuth) - --remote-auth-callback-port int Port for OAuth callback server during remote authentication (default 8666) - --remote-auth-client-id string OAuth client ID for remote server authentication - --remote-auth-client-secret string OAuth client secret for remote server authentication (optional for PKCE) - --remote-auth-client-secret-file string Path to file containing OAuth client secret (alternative to --remote-auth-client-secret) - --remote-auth-issuer string OAuth/OIDC issuer URL for remote server authentication (e.g., https://accounts.google.com) - --remote-auth-scopes strings OAuth scopes to request for remote server authentication (defaults: OIDC uses 'openid,profile,email') - --remote-auth-skip-browser Skip opening browser for remote server OAuth flow - --remote-auth-timeout duration Timeout for OAuth authentication flow (e.g., 30s, 1m, 2m30s) (default 30s) - --remote-auth-token-url string OAuth token endpoint URL (alternative to --remote-auth-issuer for non-OIDC OAuth) - --resource-url string Explicit resource URL for OAuth discovery endpoint (RFC 9728) - --target-uri string URI for the target MCP server (e.g., http://localhost:8080) (required) + -h, --help help for proxy + --host string Host for the HTTP proxy to listen on (IP or hostname) (default "127.0.0.1") + --oidc-audience string Expected audience for the token + --oidc-client-id string OIDC client ID + --oidc-client-secret string OIDC client secret (optional, for introspection) + --oidc-introspection-url string URL for token introspection endpoint + --oidc-issuer string OIDC issuer URL (e.g., https://accounts.google.com) + --oidc-jwks-url string URL to fetch the JWKS from + --port int Port for the HTTP proxy to listen on (host port) + --remote-auth Enable OAuth/OIDC authentication to remote MCP server + --remote-auth-authorize-url string OAuth authorization endpoint URL (alternative to --remote-auth-issuer for non-OIDC OAuth) + --remote-auth-callback-port int Port for OAuth callback server during remote authentication (default 8666) + --remote-auth-client-id string OAuth client ID for remote server authentication + --remote-auth-client-secret string OAuth client secret for remote server authentication (optional for PKCE) + --remote-auth-client-secret-file string Path to file containing OAuth client secret (alternative to --remote-auth-client-secret) + --remote-auth-issuer string OAuth/OIDC issuer URL for remote server authentication (e.g., https://accounts.google.com) + --remote-auth-scopes strings OAuth scopes to request for remote server authentication (defaults: OIDC uses 'openid,profile,email') + --remote-auth-skip-browser Skip opening browser for remote server OAuth flow + --remote-auth-timeout duration Timeout for OAuth authentication flow (e.g., 30s, 1m, 2m30s) (default 30s) + --remote-auth-token-url string OAuth token endpoint URL (alternative to --remote-auth-issuer for non-OIDC OAuth) + --resource-url string Explicit resource URL for OAuth discovery endpoint (RFC 9728) + --target-uri string URI for the target MCP server (e.g., http://localhost:8080) (required) + --token-exchange-audience string Target audience for exchanged tokens + --token-exchange-client-id string OAuth client ID for token exchange operations + --token-exchange-client-secret string OAuth client secret for token exchange operations + --token-exchange-client-secret-file string Path to file containing OAuth client secret for token exchange (alternative to --token-exchange-client-secret) + --token-exchange-header-name string Custom header name for injecting exchanged token (default: replaces Authorization header) + --token-exchange-scopes strings Scopes to request for exchanged tokens + --token-exchange-url string OAuth 2.0 token exchange endpoint URL (enables token exchange when provided) ``` ### Options inherited from parent commands diff --git a/docs/cli/thv_run.md b/docs/cli/thv_run.md index 5338ab03f..8306f14cb 100644 --- a/docs/cli/thv_run.md +++ b/docs/cli/thv_run.md @@ -77,65 +77,72 @@ thv run [flags] SERVER_OR_IMAGE_OR_PROTOCOL [-- ARGS...] ### Options ``` - --audit-config string Path to the audit configuration file - --authz-config string Path to the authorization configuration file - --ca-cert string Path to a custom CA certificate file to use for container builds - --enable-audit Enable audit logging with default configuration - -e, --env stringArray Environment variables to pass to the MCP server (format: KEY=VALUE) - --env-file string Load environment variables from a single file - --env-file-dir string Load environment variables from all files in a directory - -f, --foreground Run in foreground mode (block until container exits) - --from-config string Load configuration from exported file - --group string Name of the group this workload belongs to (defaults to 'default' if not specified) (default "default") - -h, --help help for run - --host string Host for the HTTP proxy to listen on (IP or hostname) (default "127.0.0.1") - --ignore-globally Load global ignore patterns from ~/.config/toolhive/thvignore (default true) - --image-verification string Set image verification mode (warn, enabled, disabled) (default "warn") - --isolate-network Isolate the container network from the host (default: false) - --jwks-allow-private-ip Allow JWKS/OIDC endpoints on private IP addresses (use with caution) - --jwks-auth-token-file string Path to file containing bearer token for authenticating JWKS/OIDC requests - -l, --label stringArray Set labels on the container (format: key=value) - --name string Name of the MCP server (auto-generated from image if not provided) - --oidc-audience string Expected audience for the token - --oidc-client-id string OIDC client ID - --oidc-client-secret string OIDC client secret (optional, for introspection) - --oidc-introspection-url string URL for token introspection endpoint - --oidc-issuer string OIDC issuer URL (e.g., https://accounts.google.com) - --oidc-jwks-url string URL to fetch the JWKS from - --otel-enable-prometheus-metrics-path Enable Prometheus-style /metrics endpoint on the main transport port - --otel-endpoint string OpenTelemetry OTLP endpoint URL (e.g., https://api.honeycomb.io) - --otel-env-vars stringArray Environment variable names to include in OpenTelemetry spans (comma-separated: ENV1,ENV2) - --otel-headers stringArray OpenTelemetry OTLP headers in key=value format (e.g., x-honeycomb-team=your-api-key) - --otel-insecure Connect to the OpenTelemetry endpoint using HTTP instead of HTTPS - --otel-metrics-enabled Enable OTLP metrics export (when OTLP endpoint is configured) (default true) - --otel-sampling-rate float OpenTelemetry trace sampling rate (0.0-1.0) (default 0.1) - --otel-service-name string OpenTelemetry service name (defaults to toolhive-mcp-proxy) - --otel-tracing-enabled Enable distributed tracing (when OTLP endpoint is configured) (default true) - --permission-profile string Permission profile to use (none, network, or path to JSON file) - --print-resolved-overlays Debug: show resolved container paths for tmpfs overlays - --proxy-mode string Proxy mode for stdio transport (sse or streamable-http) (default "sse") - --proxy-port int Port for the HTTP proxy to listen on (host port) - --remote-auth Enable OAuth/OIDC authentication to remote MCP server - --remote-auth-authorize-url string OAuth authorization endpoint URL (alternative to --remote-auth-issuer for non-OIDC OAuth) - --remote-auth-callback-port int Port for OAuth callback server during remote authentication (default 8666) - --remote-auth-client-id string OAuth client ID for remote server authentication - --remote-auth-client-secret string OAuth client secret for remote server authentication (optional for PKCE) - --remote-auth-client-secret-file string Path to file containing OAuth client secret (alternative to --remote-auth-client-secret) - --remote-auth-issuer string OAuth/OIDC issuer URL for remote server authentication (e.g., https://accounts.google.com) - --remote-auth-scopes strings OAuth scopes to request for remote server authentication (defaults: OIDC uses 'openid,profile,email') - --remote-auth-skip-browser Skip opening browser for remote server OAuth flow - --remote-auth-timeout duration Timeout for OAuth authentication flow (e.g., 30s, 1m, 2m30s) (default 30s) - --remote-auth-token-url string OAuth token endpoint URL (alternative to --remote-auth-issuer for non-OIDC OAuth) - --resource-url string Explicit resource URL for OAuth discovery endpoint (RFC 9728) - --secret stringArray Specify a secret to be fetched from the secrets manager and set as an environment variable (format: NAME,target=TARGET) - --target-host string Host to forward traffic to (only applicable to SSE or Streamable HTTP transport) (default "127.0.0.1") - --target-port int Port for the container to expose (only applicable to SSE or Streamable HTTP transport) - --thv-ca-bundle string Path to CA certificate bundle for ToolHive HTTP operations (JWKS, OIDC discovery, etc.) - --tools stringArray Filter MCP server tools (comma-separated list of tool names) - --tools-override string Path to a JSON file containing overrides for MCP server tools names and descriptions - --transport string Transport mode (sse, streamable-http or stdio) - --trust-proxy-headers Trust X-Forwarded-* headers from reverse proxies (X-Forwarded-Proto, X-Forwarded-Host, X-Forwarded-Port, X-Forwarded-Prefix) - -v, --volume stringArray Mount a volume into the container (format: host-path:container-path[:ro]) + --audit-config string Path to the audit configuration file + --authz-config string Path to the authorization configuration file + --ca-cert string Path to a custom CA certificate file to use for container builds + --enable-audit Enable audit logging with default configuration + -e, --env stringArray Environment variables to pass to the MCP server (format: KEY=VALUE) + --env-file string Load environment variables from a single file + --env-file-dir string Load environment variables from all files in a directory + -f, --foreground Run in foreground mode (block until container exits) + --from-config string Load configuration from exported file + --group string Name of the group this workload belongs to (defaults to 'default' if not specified) (default "default") + -h, --help help for run + --host string Host for the HTTP proxy to listen on (IP or hostname) (default "127.0.0.1") + --ignore-globally Load global ignore patterns from ~/.config/toolhive/thvignore (default true) + --image-verification string Set image verification mode (warn, enabled, disabled) (default "warn") + --isolate-network Isolate the container network from the host (default: false) + --jwks-allow-private-ip Allow JWKS/OIDC endpoints on private IP addresses (use with caution) + --jwks-auth-token-file string Path to file containing bearer token for authenticating JWKS/OIDC requests + -l, --label stringArray Set labels on the container (format: key=value) + --name string Name of the MCP server (auto-generated from image if not provided) + --oidc-audience string Expected audience for the token + --oidc-client-id string OIDC client ID + --oidc-client-secret string OIDC client secret (optional, for introspection) + --oidc-introspection-url string URL for token introspection endpoint + --oidc-issuer string OIDC issuer URL (e.g., https://accounts.google.com) + --oidc-jwks-url string URL to fetch the JWKS from + --otel-enable-prometheus-metrics-path Enable Prometheus-style /metrics endpoint on the main transport port + --otel-endpoint string OpenTelemetry OTLP endpoint URL (e.g., https://api.honeycomb.io) + --otel-env-vars stringArray Environment variable names to include in OpenTelemetry spans (comma-separated: ENV1,ENV2) + --otel-headers stringArray OpenTelemetry OTLP headers in key=value format (e.g., x-honeycomb-team=your-api-key) + --otel-insecure Connect to the OpenTelemetry endpoint using HTTP instead of HTTPS + --otel-metrics-enabled Enable OTLP metrics export (when OTLP endpoint is configured) (default true) + --otel-sampling-rate float OpenTelemetry trace sampling rate (0.0-1.0) (default 0.1) + --otel-service-name string OpenTelemetry service name (defaults to toolhive-mcp-proxy) + --otel-tracing-enabled Enable distributed tracing (when OTLP endpoint is configured) (default true) + --permission-profile string Permission profile to use (none, network, or path to JSON file) + --print-resolved-overlays Debug: show resolved container paths for tmpfs overlays + --proxy-mode string Proxy mode for stdio transport (sse or streamable-http) (default "sse") + --proxy-port int Port for the HTTP proxy to listen on (host port) + --remote-auth Enable OAuth/OIDC authentication to remote MCP server + --remote-auth-authorize-url string OAuth authorization endpoint URL (alternative to --remote-auth-issuer for non-OIDC OAuth) + --remote-auth-callback-port int Port for OAuth callback server during remote authentication (default 8666) + --remote-auth-client-id string OAuth client ID for remote server authentication + --remote-auth-client-secret string OAuth client secret for remote server authentication (optional for PKCE) + --remote-auth-client-secret-file string Path to file containing OAuth client secret (alternative to --remote-auth-client-secret) + --remote-auth-issuer string OAuth/OIDC issuer URL for remote server authentication (e.g., https://accounts.google.com) + --remote-auth-scopes strings OAuth scopes to request for remote server authentication (defaults: OIDC uses 'openid,profile,email') + --remote-auth-skip-browser Skip opening browser for remote server OAuth flow + --remote-auth-timeout duration Timeout for OAuth authentication flow (e.g., 30s, 1m, 2m30s) (default 30s) + --remote-auth-token-url string OAuth token endpoint URL (alternative to --remote-auth-issuer for non-OIDC OAuth) + --resource-url string Explicit resource URL for OAuth discovery endpoint (RFC 9728) + --secret stringArray Specify a secret to be fetched from the secrets manager and set as an environment variable (format: NAME,target=TARGET) + --target-host string Host to forward traffic to (only applicable to SSE or Streamable HTTP transport) (default "127.0.0.1") + --target-port int Port for the container to expose (only applicable to SSE or Streamable HTTP transport) + --thv-ca-bundle string Path to CA certificate bundle for ToolHive HTTP operations (JWKS, OIDC discovery, etc.) + --token-exchange-audience string Target audience for exchanged tokens + --token-exchange-client-id string OAuth client ID for token exchange operations + --token-exchange-client-secret string OAuth client secret for token exchange operations + --token-exchange-client-secret-file string Path to file containing OAuth client secret for token exchange (alternative to --token-exchange-client-secret) + --token-exchange-header-name string Custom header name for injecting exchanged token (default: replaces Authorization header) + --token-exchange-scopes strings Scopes to request for exchanged tokens + --token-exchange-url string OAuth 2.0 token exchange endpoint URL (enables token exchange when provided) + --tools stringArray Filter MCP server tools (comma-separated list of tool names) + --tools-override string Path to a JSON file containing overrides for MCP server tools names and descriptions + --transport string Transport mode (sse, streamable-http or stdio) + --trust-proxy-headers Trust X-Forwarded-* headers from reverse proxies (X-Forwarded-Proto, X-Forwarded-Host, X-Forwarded-Port, X-Forwarded-Prefix) + -v, --volume stringArray Mount a volume into the container (format: host-path:container-path[:ro]) ``` ### Options inherited from parent commands diff --git a/pkg/api/v1/workload_service.go b/pkg/api/v1/workload_service.go index 53670d58f..9820d6659 100644 --- a/pkg/api/v1/workload_service.go +++ b/pkg/api/v1/workload_service.go @@ -229,6 +229,7 @@ func (s *WorkloadService) BuildFullRunConfig(ctx context.Context, req *createReq options = append(options, runner.WithMiddlewareFromFlags( nil, + nil, // tokenExchangeConfig - not supported via API yet req.ToolsFilter, toolsOverride, nil, diff --git a/pkg/auth/discovery/discovery.go b/pkg/auth/discovery/discovery.go index c3c72510f..91a36fd41 100644 --- a/pkg/auth/discovery/discovery.go +++ b/pkg/auth/discovery/discovery.go @@ -381,6 +381,7 @@ type OAuthFlowConfig struct { type OAuthFlowResult struct { TokenSource *oauth2.TokenSource Config *oauth.Config + IDToken string // OIDC ID token (JWT), if present } func shouldDynamicallyRegisterClient(config *OAuthFlowConfig) bool { @@ -525,6 +526,7 @@ func newOAuthFlow(ctx context.Context, oauthConfig *oauth.Config, config *OAuthF return &OAuthFlowResult{ TokenSource: &source, Config: oauthConfig, + IDToken: tokenResult.IDToken, }, nil } diff --git a/pkg/auth/tokenexchange/exchange.go b/pkg/auth/tokenexchange/exchange.go index 4786a6db8..bcc2abfb6 100644 --- a/pkg/auth/tokenexchange/exchange.go +++ b/pkg/auth/tokenexchange/exchange.go @@ -175,6 +175,14 @@ type ExchangeConfig struct { // Scopes is the list of scopes to request (optional per RFC 8693) Scopes []string + // SubjectTokenType specifies the type of the subject token being exchanged + // Common values: + // - "urn:ietf:params:oauth:token-type:access_token" (default) + // - "urn:ietf:params:oauth:token-type:id_token" (for OIDC ID tokens, required by Google STS) + // - "urn:ietf:params:oauth:token-type:jwt" + // If empty, defaults to access_token + SubjectTokenType string + // SubjectTokenProvider is a function that returns the subject token to exchange // we use a function to allow dynamic retrieval of the token (e.g. from request context) // and also to lazy-load the token only when needed, load from dynamic sources, etc. @@ -195,9 +203,9 @@ func (c *ExchangeConfig) Validate() error { return fmt.Errorf("SubjectTokenProvider is required") } - if c.ClientID == "" { - return fmt.Errorf("ClientID is required") - } + // ClientID is optional - some token exchange endpoints (like Google STS) + // don't require client credentials and rely on the trust relationship + // configured in the identity provider (e.g., Workload Identity Federation) // Validate URL format _, err := url.Parse(c.TokenURL) @@ -230,6 +238,12 @@ func (ts *tokenSource) Token() (*oauth2.Token, error) { return nil, fmt.Errorf("failed to get subject token: %w", err) } + // Determine subject token type (default to access_token if not specified) + subjectTokenType := conf.SubjectTokenType + if subjectTokenType == "" { + subjectTokenType = tokenTypeAccessToken + } + // Build the token exchange request request := &exchangeRequest{ GrantType: grantTypeTokenExchange, @@ -237,7 +251,7 @@ func (ts *tokenSource) Token() (*oauth2.Token, error) { Scope: conf.Scopes, RequestedTokenType: tokenTypeAccessToken, SubjectToken: subjectToken, - SubjectTokenType: tokenTypeAccessToken, + SubjectTokenType: subjectTokenType, } clientAuth := clientAuthentication{ diff --git a/pkg/auth/tokenexchange/middleware.go b/pkg/auth/tokenexchange/middleware.go index 1dc245da4..54f88f0a4 100644 --- a/pkg/auth/tokenexchange/middleware.go +++ b/pkg/auth/tokenexchange/middleware.go @@ -51,6 +51,13 @@ type Config struct { // Scopes is the list of scopes to request for the exchanged token Scopes []string `json:"scopes,omitempty"` + // SubjectTokenType specifies the type of the subject token being exchanged + // Common values: + // - "urn:ietf:params:oauth:token-type:access_token" (default) + // - "urn:ietf:params:oauth:token-type:id_token" (for OIDC ID tokens, required by Google STS) + // - "urn:ietf:params:oauth:token-type:jwt" + SubjectTokenType string `json:"subject_token_type,omitempty"` + // HeaderStrategy determines how to inject the token // Valid values: HeaderStrategyReplace (default), HeaderStrategyCustom HeaderStrategy string `json:"header_strategy,omitempty"` @@ -92,7 +99,7 @@ func CreateMiddleware(config *types.MiddlewareConfig, runner types.MiddlewareRun return fmt.Errorf("invalid token exchange configuration: %w", err) } - middleware, err := CreateTokenExchangeMiddlewareFromClaims(*params.TokenExchangeConfig) + middleware, err := CreateTokenExchangeMiddlewareFromClaims(*params.TokenExchangeConfig, nil) if err != nil { return fmt.Errorf("invalid token exchange middleware config: %w", err) } @@ -151,10 +158,18 @@ func createCustomInjector(headerName string) injectionFunc { } } -// CreateTokenExchangeMiddlewareFromClaims creates a middleware that uses token claims -// from the auth middleware to perform token exchange. +// SubjectTokenProvider is a function that provides the subject token for exchange. +// This is used when the token comes from an external source (e.g., OAuth flow) +// rather than from incoming request headers. +type SubjectTokenProvider func() (string, error) + +// CreateTokenExchangeMiddlewareFromClaims creates a middleware that performs token exchange. +// It supports two modes: +// 1. Header-based (subjectTokenProvider=nil): Extracts token from Authorization header (for OIDC validation) +// 2. Provider-based (subjectTokenProvider!=nil): Uses provided token source (for remote auth/OAuth flow) +// // This is a public function for direct usage in proxy commands. -func CreateTokenExchangeMiddlewareFromClaims(config Config) (types.MiddlewareFunction, error) { +func CreateTokenExchangeMiddlewareFromClaims(config Config, subjectTokenProvider SubjectTokenProvider) (types.MiddlewareFunction, error) { // Determine injection strategy at startup time strategy := config.HeaderStrategy if strategy == "" { @@ -173,11 +188,12 @@ func CreateTokenExchangeMiddlewareFromClaims(config Config) (types.MiddlewareFun // Create base exchange config at startup time with all static fields baseExchangeConfig := ExchangeConfig{ - TokenURL: config.TokenURL, - ClientID: config.ClientID, - ClientSecret: config.ClientSecret, - Audience: config.Audience, - Scopes: config.Scopes, + TokenURL: config.TokenURL, + ClientID: config.ClientID, + ClientSecret: config.ClientSecret, + Audience: config.Audience, + Scopes: config.Scopes, + SubjectTokenType: config.SubjectTokenType, // SubjectTokenProvider will be set per request } @@ -191,19 +207,32 @@ func CreateTokenExchangeMiddlewareFromClaims(config Config) (types.MiddlewareFun return } - // Extract the original token from the Authorization header - authHeader := r.Header.Get("Authorization") - if authHeader == "" || !strings.HasPrefix(authHeader, "Bearer ") { - logger.Debug("No valid Bearer token found, proceeding without token exchange") - next.ServeHTTP(w, r) - return - } - - subjectToken := strings.TrimPrefix(authHeader, "Bearer ") - if subjectToken == "" { - logger.Debug("Empty Bearer token, proceeding without token exchange") - next.ServeHTTP(w, r) - return + var tokenProvider SubjectTokenProvider + + // Determine token source based on whether external provider was given + if subjectTokenProvider != nil { + // Mode 2: Use provided token source (e.g., from OAuth flow during startup) + logger.Debug("Using provided token source for token exchange") + tokenProvider = subjectTokenProvider + } else { + // Mode 1: Extract token from Authorization header (OIDC validation scenario) + authHeader := r.Header.Get("Authorization") + if authHeader == "" || !strings.HasPrefix(authHeader, "Bearer ") { + logger.Debug("No valid Bearer token found, proceeding without token exchange") + next.ServeHTTP(w, r) + return + } + + subjectToken := strings.TrimPrefix(authHeader, "Bearer ") + if subjectToken == "" { + logger.Debug("Empty Bearer token, proceeding without token exchange") + next.ServeHTTP(w, r) + return + } + + tokenProvider = func() (string, error) { + return subjectToken, nil + } } // Log some claim information for debugging @@ -213,9 +242,7 @@ func CreateTokenExchangeMiddlewareFromClaims(config Config) (types.MiddlewareFun // Create a copy of the base config with the request-specific subject token exchangeConfig := baseExchangeConfig - exchangeConfig.SubjectTokenProvider = func() (string, error) { - return subjectToken, nil - } + exchangeConfig.SubjectTokenProvider = tokenProvider // Get token from token source tokenSource := exchangeConfig.TokenSource(r.Context()) diff --git a/pkg/auth/tokenexchange/middleware_test.go b/pkg/auth/tokenexchange/middleware_test.go index f4fd3439c..06359407e 100644 --- a/pkg/auth/tokenexchange/middleware_test.go +++ b/pkg/auth/tokenexchange/middleware_test.go @@ -290,7 +290,7 @@ func TestCreateTokenExchangeMiddlewareFromClaims_Success(t *testing.T) { ExternalTokenHeaderName: tt.customHeaderName, } - middleware, err := CreateTokenExchangeMiddlewareFromClaims(config) + middleware, err := CreateTokenExchangeMiddlewareFromClaims(config, nil) require.NoError(t, err) // Test handler verifies token injection @@ -383,7 +383,7 @@ func TestCreateTokenExchangeMiddlewareFromClaims_PassThrough(t *testing.T) { ClientSecret: "test-client-secret", } - middleware, err := CreateTokenExchangeMiddlewareFromClaims(config) + middleware, err := CreateTokenExchangeMiddlewareFromClaims(config, nil) require.NoError(t, err) handlerCalled := false @@ -471,7 +471,7 @@ func TestCreateTokenExchangeMiddlewareFromClaims_Failures(t *testing.T) { ExternalTokenHeaderName: tt.customHeaderName, } - middleware, err := CreateTokenExchangeMiddlewareFromClaims(config) + middleware, err := CreateTokenExchangeMiddlewareFromClaims(config, nil) require.NoError(t, err) testHandler := http.HandlerFunc(func(_ http.ResponseWriter, _ *http.Request) { diff --git a/pkg/runner/config_builder.go b/pkg/runner/config_builder.go index 1538c96e6..49312425e 100644 --- a/pkg/runner/config_builder.go +++ b/pkg/runner/config_builder.go @@ -9,6 +9,7 @@ import ( "github.com/stacklok/toolhive/pkg/audit" "github.com/stacklok/toolhive/pkg/auth" + "github.com/stacklok/toolhive/pkg/auth/tokenexchange" "github.com/stacklok/toolhive/pkg/authz" rt "github.com/stacklok/toolhive/pkg/container/runtime" "github.com/stacklok/toolhive/pkg/ignore" @@ -430,6 +431,7 @@ func WithIgnoreConfig(ignoreConfig *ignore.Config) RunConfigBuilderOption { // WithMiddlewareFromFlags creates middleware configurations directly from flag values func WithMiddlewareFromFlags( oidcConfig *auth.TokenValidatorConfig, + tokenExchangeConfig *tokenexchange.Config, toolsFilter []string, toolsOverride map[string]ToolOverride, telemetryConfig *telemetry.Config, @@ -456,7 +458,7 @@ func WithMiddlewareFromFlags( middlewareConfigs = addToolFilterMiddlewares(middlewareConfigs, toolsFilter, toolsOverride) // Add core middlewares (always present) - middlewareConfigs = addCoreMiddlewares(middlewareConfigs, oidcConfig) + middlewareConfigs = addCoreMiddlewares(middlewareConfigs, oidcConfig, tokenExchangeConfig) // Add optional middlewares middlewareConfigs = addTelemetryMiddleware(middlewareConfigs, telemetryConfig, serverName, transportType) @@ -520,7 +522,9 @@ func addToolFilterMiddlewares( // addCoreMiddlewares adds core middlewares that are always present func addCoreMiddlewares( - middlewareConfigs []types.MiddlewareConfig, oidcConfig *auth.TokenValidatorConfig, + middlewareConfigs []types.MiddlewareConfig, + oidcConfig *auth.TokenValidatorConfig, + tokenExchangeConfig *tokenexchange.Config, ) []types.MiddlewareConfig { // Authentication middleware (always present) authParams := auth.MiddlewareParams{ @@ -530,6 +534,16 @@ func addCoreMiddlewares( middlewareConfigs = append(middlewareConfigs, *authConfig) } + // Token Exchange middleware (conditionally present) + if tokenExchangeConfig != nil { + tokenExchangeParams := tokenexchange.MiddlewareParams{ + TokenExchangeConfig: tokenExchangeConfig, + } + if tokenExchangeMwConfig, err := types.NewMiddlewareConfig(tokenexchange.MiddlewareType, tokenExchangeParams); err == nil { + middlewareConfigs = append(middlewareConfigs, *tokenExchangeMwConfig) + } + } + // MCP Parser middleware (always present) mcpParserParams := mcp.ParserMiddlewareParams{} if mcpParserConfig, err := types.NewMiddlewareConfig(mcp.ParserMiddlewareType, mcpParserParams); err == nil { diff --git a/pkg/runner/middleware.go b/pkg/runner/middleware.go index 4b1c5f053..d323649ac 100644 --- a/pkg/runner/middleware.go +++ b/pkg/runner/middleware.go @@ -5,6 +5,7 @@ import ( "github.com/stacklok/toolhive/pkg/audit" "github.com/stacklok/toolhive/pkg/auth" + "github.com/stacklok/toolhive/pkg/auth/tokenexchange" "github.com/stacklok/toolhive/pkg/authz" "github.com/stacklok/toolhive/pkg/mcp" "github.com/stacklok/toolhive/pkg/telemetry" @@ -15,6 +16,7 @@ import ( func GetSupportedMiddlewareFactories() map[string]types.MiddlewareFactory { return map[string]types.MiddlewareFactory{ auth.MiddlewareType: auth.CreateMiddleware, + tokenexchange.MiddlewareType: tokenexchange.CreateMiddleware, mcp.ParserMiddlewareType: mcp.CreateParserMiddleware, mcp.ToolFilterMiddlewareType: mcp.CreateToolFilterMiddleware, mcp.ToolCallFilterMiddlewareType: mcp.CreateToolCallFilterMiddleware,