Skip to content

Remote server support #1423

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 11 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
240 changes: 37 additions & 203 deletions cmd/thv/app/proxy.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ import (
"golang.org/x/oauth2"

"github.com/stacklok/toolhive/pkg/auth"
"github.com/stacklok/toolhive/pkg/auth/discovery"
"github.com/stacklok/toolhive/pkg/auth/oauth"
"github.com/stacklok/toolhive/pkg/logger"
"github.com/stacklok/toolhive/pkg/networking"
Expand Down Expand Up @@ -113,15 +114,6 @@ var (
remoteAuthTokenURL string
)

// Default timeout constants
const (
defaultOAuthTimeout = 5 * time.Minute
defaultHTTPTimeout = 30 * time.Second
defaultAuthDetectTimeout = 10 * time.Second
maxRetryAttempts = 3
retryBaseDelay = 2 * time.Second
)

// Environment variable names
const (
// #nosec G101 - this is an environment variable name, not a credential
Expand All @@ -145,7 +137,7 @@ func init() {
"Explicit resource URL for OAuth discovery endpoint (RFC 9728)")

// Add remote server authentication flags
proxyCmd.Flags().BoolVar(&enableRemoteAuth, "remote-auth", false, "Enable OAuth authentication to remote MCP server")
proxyCmd.Flags().BoolVar(&enableRemoteAuth, "remote-auth", false, "Enable OAuth/OIDC authentication to remote MCP server")
proxyCmd.Flags().StringVar(&remoteAuthIssuer, "remote-auth-issuer", "",
"OAuth/OIDC issuer URL for remote server authentication (e.g., https://accounts.google.com)")
proxyCmd.Flags().StringVar(&remoteAuthClientID, "remote-auth-client-id", "",
Expand Down Expand Up @@ -287,196 +279,6 @@ func proxyCmdFunc(cmd *cobra.Command, args []string) error {
return proxy.Stop(shutdownCtx)
}

// AuthInfo contains authentication information extracted from WWW-Authenticate header
type AuthInfo struct {
Realm string
Type string
}

// detectAuthenticationFromServer attempts to detect authentication requirements from the target server
func detectAuthenticationFromServer(ctx context.Context, targetURI string) (*AuthInfo, error) {
// Create a context with timeout for auth detection
detectCtx, cancel := context.WithTimeout(ctx, defaultAuthDetectTimeout)
defer cancel()

// Make a test request to the target server to see if it returns WWW-Authenticate
client := &http.Client{
Timeout: defaultAuthDetectTimeout,
Transport: &http.Transport{
TLSHandshakeTimeout: defaultHTTPTimeout / 3,
ResponseHeaderTimeout: defaultHTTPTimeout / 3,
},
}

req, err := http.NewRequestWithContext(detectCtx, http.MethodGet, targetURI, nil)
if err != nil {
return nil, fmt.Errorf("failed to create request: %w", err)
}

resp, err := client.Do(req)
if err != nil {
return nil, fmt.Errorf("failed to make request: %w", err)
}
defer resp.Body.Close()

// Check if we got a 401 Unauthorized with WWW-Authenticate header
if resp.StatusCode == http.StatusUnauthorized {
wwwAuth := resp.Header.Get("WWW-Authenticate")
if wwwAuth != "" {
return parseWWWAuthenticate(wwwAuth)
}
}

return nil, nil
}

// parseWWWAuthenticate parses the WWW-Authenticate header to extract realm and type
// Supports multiple authentication schemes and complex header formats
func parseWWWAuthenticate(header string) (*AuthInfo, error) {
// Trim whitespace and handle empty headers
header = strings.TrimSpace(header)
if header == "" {
return nil, fmt.Errorf("empty WWW-Authenticate header")
}

// Split by comma to handle multiple authentication schemes
schemes := strings.Split(header, ",")

for _, scheme := range schemes {
scheme = strings.TrimSpace(scheme)

// Check for Bearer authentication
if strings.HasPrefix(scheme, "Bearer") {
authInfo := &AuthInfo{Type: "Bearer"}

// Extract parameters after "Bearer"
params := strings.TrimSpace(strings.TrimPrefix(scheme, "Bearer"))
if params == "" {
// Simple "Bearer" without parameters
return authInfo, nil
}

// Parse parameters (realm, scope, etc.)
realm := extractParameter(params, "realm")
if realm != "" {
authInfo.Realm = realm
}

return authInfo, nil
}

// Check for other authentication types (Basic, Digest, etc.)
if strings.HasPrefix(scheme, "Basic") {
return &AuthInfo{Type: "Basic"}, nil
}

if strings.HasPrefix(scheme, "Digest") {
authInfo := &AuthInfo{Type: "Digest"}
realm := extractParameter(scheme, "realm")
if realm != "" {
authInfo.Realm = realm
}
return authInfo, nil
}
}

return nil, fmt.Errorf("no supported authentication type found in header: %s", header)
}

// extractParameter extracts a parameter value from an authentication header
func extractParameter(params, paramName string) string {
// Look for paramName=value or paramName="value"
parts := strings.Split(params, ",")
for _, part := range parts {
part = strings.TrimSpace(part)
if strings.HasPrefix(part, paramName+"=") {
value := strings.TrimPrefix(part, paramName+"=")
// Remove quotes if present
value = strings.Trim(value, `"`)
return value
}
}
return ""
}

// performOAuthFlow performs the OAuth authentication flow
func performOAuthFlow(ctx context.Context, issuer, clientID, clientSecret string,
scopes []string) (*oauth2.TokenSource, *oauth.Config, error) {
logger.Info("Starting OAuth authentication flow...")

var oauthConfig *oauth.Config
var err error

// Check if we have manual OAuth endpoints configured
if remoteAuthAuthorizeURL != "" && remoteAuthTokenURL != "" {
logger.Info("Using manual OAuth configuration")
oauthConfig, err = oauth.CreateOAuthConfigManual(
clientID,
clientSecret,
remoteAuthAuthorizeURL,
remoteAuthTokenURL,
scopes,
true, // Enable PKCE by default for security
remoteAuthCallbackPort,
)
} else {
// Fall back to OIDC discovery
logger.Info("Using OIDC discovery")
oauthConfig, err = oauth.CreateOAuthConfigFromOIDC(
ctx,
issuer,
clientID,
clientSecret,
scopes,
true, // Enable PKCE by default for security
remoteAuthCallbackPort,
)
}
if err != nil {
return nil, nil, fmt.Errorf("failed to create OAuth config: %w", err)
}

// Create OAuth flow
flow, err := oauth.NewFlow(oauthConfig)
if err != nil {
return nil, nil, fmt.Errorf("failed to create OAuth flow: %w", err)
}

// Create a context with timeout for the OAuth flow
// Use the configured timeout, defaulting to the constant if not set
oauthTimeout := remoteAuthTimeout
if oauthTimeout <= 0 {
oauthTimeout = defaultOAuthTimeout
}

oauthCtx, cancel := context.WithTimeout(ctx, oauthTimeout)
defer cancel()

// Start OAuth flow
tokenResult, err := flow.Start(oauthCtx, remoteAuthSkipBrowser)
if err != nil {
if oauthCtx.Err() == context.DeadlineExceeded {
return nil, nil, fmt.Errorf("OAuth flow timed out after %v - user did not complete authentication", oauthTimeout)
}
return nil, nil, fmt.Errorf("OAuth flow failed: %w", err)
}

logger.Info("OAuth authentication successful")

// Log token info (without exposing the actual token)
if tokenResult.Claims != nil {
if sub, ok := tokenResult.Claims["sub"].(string); ok {
logger.Infof("Authenticated as subject: %s", sub)
}
if email, ok := tokenResult.Claims["email"].(string); ok {
logger.Infof("Authenticated email: %s", email)
}
}

source := flow.TokenSource()
return &source, oauthConfig, nil
}

// shouldDetectAuth determines if we should try to detect authentication requirements
func shouldDetectAuth() bool {
// Only try to detect auth if OAuth client ID is provided
Expand Down Expand Up @@ -511,11 +313,27 @@ func handleOutgoingAuthentication(ctx context.Context) (*oauth2.TokenSource, *oa
return nil, nil, fmt.Errorf("cannot specify both OIDC issuer and manual OAuth endpoints - choose one approach")
}

return performOAuthFlow(ctx, remoteAuthIssuer, remoteAuthClientID, clientSecret, remoteAuthScopes)
flowConfig := &discovery.OAuthFlowConfig{
ClientID: remoteAuthClientID,
ClientSecret: clientSecret,
AuthorizeURL: remoteAuthAuthorizeURL,
TokenURL: remoteAuthTokenURL,
Scopes: remoteAuthScopes,
CallbackPort: remoteAuthCallbackPort,
Timeout: remoteAuthTimeout,
SkipBrowser: remoteAuthSkipBrowser,
}

result, err := discovery.PerformOAuthFlow(ctx, remoteAuthIssuer, flowConfig)
if err != nil {
return nil, nil, err
}

return result.TokenSource, result.Config, nil
}

// Try to detect authentication requirements from WWW-Authenticate header
authInfo, err := detectAuthenticationFromServer(ctx, proxyTargetURI)
authInfo, err := discovery.DetectAuthenticationFromServer(ctx, proxyTargetURI, nil)
if err != nil {
logger.Debugf("Could not detect authentication from server: %v", err)
return nil, nil, nil // Not an error, just no auth detected
Expand All @@ -529,7 +347,23 @@ func handleOutgoingAuthentication(ctx context.Context) (*oauth2.TokenSource, *oa
}

// Perform OAuth flow with discovered configuration
return performOAuthFlow(ctx, authInfo.Realm, remoteAuthClientID, clientSecret, remoteAuthScopes)
flowConfig := &discovery.OAuthFlowConfig{
ClientID: remoteAuthClientID,
ClientSecret: clientSecret,
AuthorizeURL: remoteAuthAuthorizeURL,
TokenURL: remoteAuthTokenURL,
Scopes: remoteAuthScopes,
CallbackPort: remoteAuthCallbackPort,
Timeout: remoteAuthTimeout,
SkipBrowser: remoteAuthSkipBrowser,
}

result, err := discovery.PerformOAuthFlow(ctx, authInfo.Realm, flowConfig)
if err != nil {
return nil, nil, err
}

return result.TokenSource, result.Config, nil
}

return nil, nil, nil // No authentication required
Expand Down
14 changes: 9 additions & 5 deletions cmd/thv/app/proxy_tunnel.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@ import (
"fmt"
"net/url"
"os/signal"
"strings"
"syscall"

"github.com/spf13/cobra"
Expand Down Expand Up @@ -126,11 +125,16 @@ func resolveTarget(ctx context.Context, target string) (string, error) {
}

func looksLikeURL(s string) bool {
// Parse the URL once
u, err := url.Parse(s)
if err != nil {
return false
}

// Fast-path for common schemes
if strings.HasPrefix(s, "http://") || strings.HasPrefix(s, "https://") {
if u.Scheme == "http" || u.Scheme == "https" {
return true
}
// Fallback parse check
u, err := url.Parse(s)
return err == nil && u.Scheme != "" && u.Host != ""
// Fallback check for other schemes
return u.Scheme != "" && u.Host != ""
}
10 changes: 8 additions & 2 deletions cmd/thv/app/registry.go
Original file line number Diff line number Diff line change
Expand Up @@ -159,12 +159,18 @@ func printTextServers(servers []registry.ServerMetadata) {
}
}

// ServerType constants
const (
ServerTypeRemote = "remote"
ServerTypeContainer = "container"
)

// getServerType returns the type of server (container or remote)
func getServerType(server registry.ServerMetadata) string {
if server.IsRemote() {
return "remote"
return ServerTypeRemote
}
return "container"
return ServerTypeContainer
}

// printTextServerInfo prints detailed information about a server in text format
Expand Down
Loading
Loading