diff --git a/cmd/thv/app/proxy.go b/cmd/thv/app/proxy.go index 291f1fe83..e20278014 100644 --- a/cmd/thv/app/proxy.go +++ b/cmd/thv/app/proxy.go @@ -170,6 +170,14 @@ func proxyCmdFunc(cmd *cobra.Command, args []string) error { return fmt.Errorf("invalid target URI: %w", err) } + // Validate OAuth callback port availability + if err := networking.ValidateCallbackPort( + remoteAuthFlags.RemoteAuthCallbackPort, + remoteAuthFlags.RemoteAuthClientID, + ); err != nil { + return err + } + // Select a port for the HTTP proxy (host port) port, err := networking.FindOrUsePort(proxyPort) if err != nil { diff --git a/cmd/thv/app/run_flags.go b/cmd/thv/app/run_flags.go index dd37bf7e1..b096032a7 100644 --- a/cmd/thv/app/run_flags.go +++ b/cmd/thv/app/run_flags.go @@ -475,10 +475,22 @@ func buildRunnerConfig( if remoteServerMetadata, ok := serverMetadata.(*registry.RemoteServerMetadata); ok { remoteAuthConfig := getRemoteAuthFromRemoteServerMetadata(remoteServerMetadata) + + // Validate OAuth callback port availability upfront for better user experience + if err := networking.ValidateCallbackPort(remoteAuthConfig.CallbackPort, remoteAuthConfig.ClientID); err != nil { + return nil, err + } + opts = append(opts, runner.WithRemoteAuth(remoteAuthConfig), runner.WithRemoteURL(remoteServerMetadata.URL)) } if runFlags.RemoteURL != "" { remoteAuthConfig := getRemoteAuthFromRunFlags(runFlags) + + // Validate OAuth callback port availability upfront for better user experience + if err := networking.ValidateCallbackPort(remoteAuthConfig.CallbackPort, remoteAuthConfig.ClientID); err != nil { + return nil, err + } + opts = append(opts, runner.WithRemoteAuth(remoteAuthConfig)) } diff --git a/pkg/auth/discovery/discovery.go b/pkg/auth/discovery/discovery.go index c3c72510f..de3a39b3d 100644 --- a/pkg/auth/discovery/discovery.go +++ b/pkg/auth/discovery/discovery.go @@ -395,6 +395,32 @@ func PerformOAuthFlow(ctx context.Context, issuer string, config *OAuthFlowConfi return nil, fmt.Errorf("OAuth flow config cannot be nil") } + // Resolve port availability BEFORE dynamic registration + // This ensures we register the OAuth client with the same port we'll actually use + + if shouldDynamicallyRegisterClient(config) { + // For dynamic registration, we can allow fallback to alternative ports + // since we can register the client with the actual port we'll use + port, err := networking.FindOrUsePort(config.CallbackPort) + if err != nil { + return nil, fmt.Errorf("failed to find available port: %w", err) + } + + if port != config.CallbackPort { + logger.Warnf("Specified auth callback port %d is unavailable, using port %d instead", config.CallbackPort, port) + } + config.CallbackPort = port + } else { + // For pre-registered clients, use strict port checking + // The user likely configured this port in their IdP/app + if !networking.IsAvailable(config.CallbackPort) { + return nil, fmt.Errorf( + "specified auth callback port %d is not available - please choose a different port or ensure it's not in use", + config.CallbackPort, + ) + } + } + // Handle dynamic client registration if needed if shouldDynamicallyRegisterClient(config) { if err := handleDynamicRegistration(ctx, issuer, config); err != nil { diff --git a/pkg/auth/discovery/discovery_test.go b/pkg/auth/discovery/discovery_test.go index b4651eb34..b96476138 100644 --- a/pkg/auth/discovery/discovery_test.go +++ b/pkg/auth/discovery/discovery_test.go @@ -2,13 +2,18 @@ package discovery import ( "context" + "net" "net/http" "net/http/httptest" "strings" "testing" "time" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "github.com/stacklok/toolhive/pkg/logger" + "github.com/stacklok/toolhive/pkg/networking" ) func init() { @@ -422,3 +427,325 @@ func TestDeriveIssuerFromURL(t *testing.T) { }) } } + +func TestPerformOAuthFlow_PortBehavior(t *testing.T) { + t.Parallel() + + // Test dynamic registration with available port + t.Run("dynamic registration with available port", func(t *testing.T) { + t.Parallel() + + config := &OAuthFlowConfig{ + ClientID: "", // No client ID triggers dynamic registration + ClientSecret: "", + CallbackPort: 0, // Use 0 to find an available port + Scopes: []string{"openid"}, + } + + // Create a mock OIDC discovery server + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if strings.HasSuffix(r.URL.Path, "/.well-known/openid_configuration") { + // Return OIDC discovery document + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(http.StatusOK) + w.Write([]byte(`{ + "issuer": "https://example.com", + "authorization_endpoint": "https://example.com/auth", + "token_endpoint": "https://example.com/token", + "registration_endpoint": "https://example.com/register" + }`)) + return + } + if strings.HasSuffix(r.URL.Path, "/register") { + // Return dynamic registration response + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(http.StatusCreated) + w.Write([]byte(`{ + "client_id": "dynamic-client-id", + "client_secret": "dynamic-client-secret" + }`)) + return + } + w.WriteHeader(http.StatusNotFound) + })) + defer server.Close() + + ctx := context.Background() + _, err := PerformOAuthFlow(ctx, server.URL, config) + + // For successful cases, we expect the OAuth flow to fail later + // (since we're not actually completing the full flow), but the + // port resolution should work correctly + if err != nil { + // Check if it's a port-related error (which we don't want) + if strings.Contains(err.Error(), "not available") { + t.Errorf("Unexpected port availability error: %v", err) + } + } + }) + + // Test dynamic registration with unavailable port - should fallback + t.Run("dynamic registration with unavailable port - should fallback", func(t *testing.T) { + t.Parallel() + + // Create a listener to make a port unavailable + listener, err := net.Listen("tcp", "127.0.0.1:0") + require.NoError(t, err) + defer listener.Close() + unavailablePort := listener.Addr().(*net.TCPAddr).Port + + config := &OAuthFlowConfig{ + ClientID: "", // No client ID triggers dynamic registration + ClientSecret: "", + CallbackPort: unavailablePort, // Use the unavailable port + Scopes: []string{"openid"}, + } + + // Create a mock OIDC discovery server + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if strings.HasSuffix(r.URL.Path, "/.well-known/openid_configuration") { + // Return OIDC discovery document + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(http.StatusOK) + w.Write([]byte(`{ + "issuer": "https://example.com", + "authorization_endpoint": "https://example.com/auth", + "token_endpoint": "https://example.com/token", + "registration_endpoint": "https://example.com/register" + }`)) + return + } + if strings.HasSuffix(r.URL.Path, "/register") { + // Return dynamic registration response + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(http.StatusCreated) + w.Write([]byte(`{ + "client_id": "dynamic-client-id", + "client_secret": "dynamic-client-secret" + }`)) + return + } + w.WriteHeader(http.StatusNotFound) + })) + defer server.Close() + + ctx := context.Background() + _, err = PerformOAuthFlow(ctx, server.URL, config) + + // Should not fail due to port unavailability (should fallback) + if err != nil { + // Check if it's a port-related error (which we don't want for dynamic registration) + if strings.Contains(err.Error(), "not available") { + t.Errorf("Dynamic registration should allow port fallback, but got port error: %v", err) + } + } + }) + + // Test pre-registered client with available port + t.Run("pre-registered client with available port", func(t *testing.T) { + t.Parallel() + + config := &OAuthFlowConfig{ + ClientID: "test-client", + ClientSecret: "test-secret", + CallbackPort: 0, // Use 0 to find an available port + Scopes: []string{"openid"}, + } + + // Create a mock OIDC discovery server + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if strings.HasSuffix(r.URL.Path, "/.well-known/openid_configuration") { + // Return OIDC discovery document + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(http.StatusOK) + w.Write([]byte(`{ + "issuer": "https://example.com", + "authorization_endpoint": "https://example.com/auth", + "token_endpoint": "https://example.com/token", + "registration_endpoint": "https://example.com/register" + }`)) + return + } + w.WriteHeader(http.StatusNotFound) + })) + defer server.Close() + + ctx := context.Background() + _, err := PerformOAuthFlow(ctx, server.URL, config) + + // For successful cases, we expect the OAuth flow to fail later + // (since we're not actually completing the full flow), but the + // port resolution should work correctly + if err != nil { + // Check if it's a port-related error (which we don't want) + if strings.Contains(err.Error(), "not available") { + t.Errorf("Unexpected port availability error: %v", err) + } + } + }) + + // Test pre-registered client with unavailable port - should fail + t.Run("pre-registered client with unavailable port - should fail", func(t *testing.T) { + t.Parallel() + + // Create a listener to make a port unavailable + listener, err := net.Listen("tcp", "127.0.0.1:0") + require.NoError(t, err) + defer listener.Close() + unavailablePort := listener.Addr().(*net.TCPAddr).Port + + config := &OAuthFlowConfig{ + ClientID: "test-client", + ClientSecret: "test-secret", + CallbackPort: unavailablePort, // Use the unavailable port + Scopes: []string{"openid"}, + } + + // Create a mock OIDC discovery server + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if strings.HasSuffix(r.URL.Path, "/.well-known/openid_configuration") { + // Return OIDC discovery document + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(http.StatusOK) + w.Write([]byte(`{ + "issuer": "https://example.com", + "authorization_endpoint": "https://example.com/auth", + "token_endpoint": "https://example.com/token", + "registration_endpoint": "https://example.com/register" + }`)) + return + } + w.WriteHeader(http.StatusNotFound) + })) + defer server.Close() + + // Verify the port is actually unavailable + if networking.IsAvailable(config.CallbackPort) { + t.Fatalf("Test setup error: Expected port %d to be unavailable, but it's available", config.CallbackPort) + } + + ctx := context.Background() + _, err = PerformOAuthFlow(ctx, server.URL, config) + + // Should fail due to port unavailability + require.Error(t, err) + assert.Contains(t, err.Error(), "not available") + }) +} + +func TestPerformOAuthFlow_PortFallbackBehavior(t *testing.T) { + t.Parallel() + + // Test that dynamic registration allows port fallback + t.Run("dynamic registration port fallback", func(t *testing.T) { + t.Parallel() + + // Create a listener to make a port unavailable + listener, err := net.Listen("tcp", "127.0.0.1:0") + require.NoError(t, err) + defer listener.Close() + unavailablePort := listener.Addr().(*net.TCPAddr).Port + + config := &OAuthFlowConfig{ + ClientID: "", // No client ID triggers dynamic registration + ClientSecret: "", + CallbackPort: unavailablePort, + Scopes: []string{"openid"}, + } + + // Create a mock OIDC discovery server + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if strings.HasSuffix(r.URL.Path, "/.well-known/openid_configuration") { + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(http.StatusOK) + w.Write([]byte(`{ + "issuer": "https://example.com", + "authorization_endpoint": "https://example.com/auth", + "token_endpoint": "https://example.com/token", + "registration_endpoint": "https://example.com/register" + }`)) + return + } + if strings.HasSuffix(r.URL.Path, "/register") { + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(http.StatusCreated) + w.Write([]byte(`{ + "client_id": "dynamic-client-id", + "client_secret": "dynamic-client-secret" + }`)) + return + } + w.WriteHeader(http.StatusNotFound) + })) + defer server.Close() + + ctx := context.Background() + _, err = PerformOAuthFlow(ctx, server.URL, config) + + // Should not fail due to port unavailability + // (it may fail later in the OAuth flow, but not due to port issues) + if err != nil && strings.Contains(err.Error(), "not available") { + t.Errorf("Dynamic registration should allow port fallback, but got port error: %v", err) + } + }) + + // Test that pre-registered clients fail on unavailable ports + t.Run("pre-registered client strict port checking", func(t *testing.T) { + t.Parallel() + + // Create a listener to make a port unavailable + listener, err := net.Listen("tcp", "127.0.0.1:0") + require.NoError(t, err) + defer listener.Close() + unavailablePort := listener.Addr().(*net.TCPAddr).Port + + config := &OAuthFlowConfig{ + ClientID: "test-client", + ClientSecret: "test-secret", + CallbackPort: unavailablePort, + Scopes: []string{"openid"}, + } + + ctx := context.Background() + _, err = PerformOAuthFlow(ctx, "https://example.com", config) + + // Should fail due to port unavailability + require.Error(t, err) + assert.Contains(t, err.Error(), "not available") + }) +} + +// TestPerformOAuthFlow_PortCheckingOnly tests just the port checking logic +// without going through the full OAuth flow +func TestPerformOAuthFlow_PortCheckingOnly(t *testing.T) { + t.Parallel() + + // Test that pre-registered clients fail on unavailable ports + t.Run("pre-registered client strict port checking", func(t *testing.T) { + t.Parallel() + + // Create a listener to make a port unavailable + listener, err := net.Listen("tcp", "127.0.0.1:0") + require.NoError(t, err) + defer listener.Close() + + unavailablePort := listener.Addr().(*net.TCPAddr).Port + + config := &OAuthFlowConfig{ + ClientID: "test-client", + ClientSecret: "test-secret", + CallbackPort: unavailablePort, + Scopes: []string{"openid"}, + } + + // Test the port checking logic directly + if shouldDynamicallyRegisterClient(config) { + t.Error("Expected shouldDynamicallyRegisterClient to return false for pre-registered client") + } + + // This should fail because the port is unavailable + if networking.IsAvailable(config.CallbackPort) { + t.Errorf("Expected port %d to be unavailable, but IsAvailable returned true", config.CallbackPort) + } + }) +} diff --git a/pkg/networking/port.go b/pkg/networking/port.go index 376a6fa39..c9a9cf1a3 100644 --- a/pkg/networking/port.go +++ b/pkg/networking/port.go @@ -141,3 +141,32 @@ func FindOrUsePort(port int) (int, error) { } return alt, nil } + +// ValidateCallbackPort validates that the specified callback port is available +// for pre-registered clients (with clientID), it returns an error if +// it's not available. +func ValidateCallbackPort(callbackPort int, clientID string) error { + // If port is 0, we'll find an available port later, so no need to validate + if callbackPort == 0 { + return nil + } + + // Check if this is a pre-registered client (has client credentials) + // For pre-registered clients, we need strict port checking + isPreRegisteredClient := IsPreRegisteredClient(clientID) + + if isPreRegisteredClient { + // For pre-registered clients, the port must be available + // The user likely configured this port in their IdP/app + if !IsAvailable(callbackPort) { + return fmt.Errorf("OAuth callback port %d is not available - please choose a different port", callbackPort) + } + } + + return nil +} + +// IsPreRegisteredClient determines if the OAuth client is pre-registered (has client ID) +func IsPreRegisteredClient(clientID string) bool { + return clientID != "" +}