From 372214344fe0de3ac7fb3f456be8d3e0168e00c4 Mon Sep 17 00:00:00 2001 From: Sam El-Borai Date: Sun, 30 Nov 2025 13:39:32 +0000 Subject: [PATCH 1/7] Implement multi-IDP support with provider abstraction Replaces Google-specific OAuth with a provider abstraction pattern supporting Google, Azure AD, GitHub, and generic OIDC providers. Changes: - Create internal/idp package with Provider interface and implementations - Update config from googleClientId/googleClientSecret to structured idp block - Add provider tracking to browser session cookies for multi-IDP readiness - Delete internal/googleauth package - Add comprehensive tests for idp package (including GitHub UserInfo tests) - Update all example configs and documentation - Refactor parseIDPConfig to use helper function for cleaner code Config format change: "idp": { "provider": "google|azure|github|oidc", "clientId": "...", "clientSecret": {"$env": "..."}, "redirectUri": "...", // Provider-specific: tenantId (azure), allowedOrgs (github), discoveryUrl (oidc) } Co-authored-by: Claude --- cmd/mcp-front/main.go | 25 +- config-admin-example.json | 9 +- config-inline-example.json | 9 +- config-inline-test.json | 9 +- config-oauth-firestore.example.json | 9 +- config-oauth.example.json | 9 +- config-oauth.json | 9 +- config-user-tokens-example.json | 9 +- docs-site/src/content/docs/configuration.md | 20 +- .../src/content/docs/examples/oauth-google.md | 19 +- docs-site/src/content/docs/index.mdx | 8 +- .../config/config.oauth-integration-test.json | 9 +- .../config/config.oauth-rfc8707-test.json | 9 +- ...config.oauth-service-integration-test.json | 9 +- .../config/config.oauth-service-test.json | 9 +- integration/config/config.oauth-test.json | 11 +- .../config/config.oauth-token-test.json | 9 +- .../config.oauth-usertoken-tools-test.json | 11 +- internal/browserauth/session.go | 5 +- internal/browserauth/session_test.go | 16 +- internal/config/load.go | 62 +++- internal/config/load_test.go | 125 ++++---- internal/config/types.go | 32 +- internal/config/unmarshal.go | 120 ++++++-- internal/config/unmarshal_test.go | 24 +- internal/config/validation.go | 83 ++++- internal/config/validation_test.go | 61 ++-- internal/googleauth/google.go | 121 -------- internal/googleauth/google_test.go | 240 --------------- internal/idp/azure.go | 25 ++ internal/idp/azure_test.go | 15 + internal/idp/factory.go | 51 ++++ internal/idp/factory_test.go | 105 +++++++ internal/idp/github.go | 217 +++++++++++++ internal/idp/github_test.go | 284 ++++++++++++++++++ internal/idp/google.go | 103 +++++++ internal/idp/google_test.go | 149 +++++++++ internal/idp/oidc.go | 183 +++++++++++ internal/idp/oidc_test.go | 218 ++++++++++++++ internal/idp/provider.go | 80 +++++ internal/idp/provider_test.go | 113 +++++++ internal/mcpfront.go | 34 ++- internal/oauthsession/session.go | 21 +- internal/server/auth_handlers.go | 44 +-- internal/server/auth_handlers_test.go | 96 ++++-- internal/server/http_test.go | 27 +- internal/server/middleware.go | 18 +- 47 files changed, 2221 insertions(+), 653 deletions(-) delete mode 100644 internal/googleauth/google.go delete mode 100644 internal/googleauth/google_test.go create mode 100644 internal/idp/azure.go create mode 100644 internal/idp/azure_test.go create mode 100644 internal/idp/factory.go create mode 100644 internal/idp/factory_test.go create mode 100644 internal/idp/github.go create mode 100644 internal/idp/github_test.go create mode 100644 internal/idp/google.go create mode 100644 internal/idp/google_test.go create mode 100644 internal/idp/oidc.go create mode 100644 internal/idp/oidc_test.go create mode 100644 internal/idp/provider.go create mode 100644 internal/idp/provider_test.go diff --git a/cmd/mcp-front/main.go b/cmd/mcp-front/main.go index 7fd26a3..1c1a910 100644 --- a/cmd/mcp-front/main.go +++ b/cmd/mcp-front/main.go @@ -22,17 +22,20 @@ func generateDefaultConfig(path string) error { "addr": ":8080", "name": "mcp-front", "auth": map[string]any{ - "kind": "oauth", - "issuer": "https://mcp.yourcompany.com", - "allowedDomains": []string{"yourcompany.com"}, - "allowedOrigins": []string{"https://claude.ai"}, - "tokenTtl": "24h", - "storage": "memory", - "googleClientId": map[string]string{"$env": "GOOGLE_CLIENT_ID"}, - "googleClientSecret": map[string]string{"$env": "GOOGLE_CLIENT_SECRET"}, - "googleRedirectUri": "https://mcp.yourcompany.com/oauth/callback", - "jwtSecret": map[string]string{"$env": "JWT_SECRET"}, - "encryptionKey": map[string]string{"$env": "ENCRYPTION_KEY"}, + "kind": "oauth", + "issuer": "https://mcp.yourcompany.com", + "allowedDomains": []string{"yourcompany.com"}, + "allowedOrigins": []string{"https://claude.ai"}, + "tokenTtl": "24h", + "storage": "memory", + "idp": map[string]any{ + "provider": "google", + "clientId": map[string]string{"$env": "GOOGLE_CLIENT_ID"}, + "clientSecret": map[string]string{"$env": "GOOGLE_CLIENT_SECRET"}, + "redirectUri": "https://mcp.yourcompany.com/oauth/callback", + }, + "jwtSecret": map[string]string{"$env": "JWT_SECRET"}, + "encryptionKey": map[string]string{"$env": "ENCRYPTION_KEY"}, }, }, "mcpServers": map[string]any{ diff --git a/config-admin-example.json b/config-admin-example.json index e0fd169..b35934c 100644 --- a/config-admin-example.json +++ b/config-admin-example.json @@ -7,13 +7,16 @@ "auth": { "kind": "oauth", "issuer": "https://mcp.example.com", + "idp": { + "provider": "google", + "clientId": {"$env": "GOOGLE_CLIENT_ID"}, + "clientSecret": {"$env": "GOOGLE_CLIENT_SECRET"}, + "redirectUri": "https://mcp.example.com/oauth/callback" + }, "allowedDomains": ["example.com"], "allowedOrigins": ["https://claude.ai"], "tokenTtl": "1h", "storage": "memory", - "googleClientId": {"$env": "GOOGLE_CLIENT_ID"}, - "googleClientSecret": {"$env": "GOOGLE_CLIENT_SECRET"}, - "googleRedirectUri": "https://mcp.example.com/oauth/callback", "jwtSecret": {"$env": "JWT_SECRET"}, "encryptionKey": {"$env": "ENCRYPTION_KEY"} }, diff --git a/config-inline-example.json b/config-inline-example.json index 44038cd..cfa2466 100644 --- a/config-inline-example.json +++ b/config-inline-example.json @@ -7,13 +7,16 @@ "auth": { "kind": "oauth", "issuer": "https://mcp.example.com", + "idp": { + "provider": "google", + "clientId": {"$env": "GOOGLE_CLIENT_ID"}, + "clientSecret": {"$env": "GOOGLE_CLIENT_SECRET"}, + "redirectUri": "https://mcp.example.com/oauth/callback" + }, "allowedDomains": ["example.com"], "allowedOrigins": ["https://claude.ai"], "tokenTtl": "1h", "storage": "memory", - "googleClientId": {"$env": "GOOGLE_CLIENT_ID"}, - "googleClientSecret": {"$env": "GOOGLE_CLIENT_SECRET"}, - "googleRedirectUri": "https://mcp.example.com/oauth/callback", "jwtSecret": {"$env": "JWT_SECRET"}, "encryptionKey": {"$env": "ENCRYPTION_KEY"} } diff --git a/config-inline-test.json b/config-inline-test.json index 1d005ba..9f7ac5a 100644 --- a/config-inline-test.json +++ b/config-inline-test.json @@ -7,13 +7,16 @@ "auth": { "kind": "oauth", "issuer": "http://localhost:8080", + "idp": { + "provider": "google", + "clientId": {"$env": "GOOGLE_CLIENT_ID"}, + "clientSecret": {"$env": "GOOGLE_CLIENT_SECRET"}, + "redirectUri": "http://localhost:8080/oauth/callback" + }, "allowedDomains": ["gmail.com"], "allowedOrigins": ["https://claude.ai"], "tokenTtl": "1h", "storage": "memory", - "googleClientId": {"$env": "GOOGLE_CLIENT_ID"}, - "googleClientSecret": {"$env": "GOOGLE_CLIENT_SECRET"}, - "googleRedirectUri": "http://localhost:8080/oauth/callback", "jwtSecret": {"$env": "JWT_SECRET"}, "encryptionKey": {"$env": "ENCRYPTION_KEY"} } diff --git a/config-oauth-firestore.example.json b/config-oauth-firestore.example.json index 72697a6..8099a34 100644 --- a/config-oauth-firestore.example.json +++ b/config-oauth-firestore.example.json @@ -8,13 +8,16 @@ "kind": "oauth", "issuer": {"$env": "OAUTH_ISSUER"}, "gcpProject": {"$env": "GCP_PROJECT"}, + "idp": { + "provider": "google", + "clientId": {"$env": "GOOGLE_CLIENT_ID"}, + "clientSecret": {"$env": "GOOGLE_CLIENT_SECRET"}, + "redirectUri": {"$env": "GOOGLE_REDIRECT_URI"} + }, "allowedDomains": ["yourcompany.com", "contractors.yourcompany.com"], "allowedOrigins": ["https://claude.ai", "https://yourcompany.com"], "tokenTtl": "24h", "storage": "firestore", - "googleClientId": {"$env": "GOOGLE_CLIENT_ID"}, - "googleClientSecret": {"$env": "GOOGLE_CLIENT_SECRET"}, - "googleRedirectUri": {"$env": "GOOGLE_REDIRECT_URI"}, "jwtSecret": {"$env": "JWT_SECRET"}, "encryptionKey": {"$env": "ENCRYPTION_KEY"} } diff --git a/config-oauth.example.json b/config-oauth.example.json index 8b2b1ff..caa5c96 100644 --- a/config-oauth.example.json +++ b/config-oauth.example.json @@ -8,13 +8,16 @@ "kind": "oauth", "issuer": {"$env": "OAUTH_ISSUER"}, "gcpProject": {"$env": "GCP_PROJECT"}, + "idp": { + "provider": "google", + "clientId": {"$env": "GOOGLE_CLIENT_ID"}, + "clientSecret": {"$env": "GOOGLE_CLIENT_SECRET"}, + "redirectUri": {"$env": "GOOGLE_REDIRECT_URI"} + }, "allowedDomains": ["yourcompany.com", "contractors.yourcompany.com"], "allowedOrigins": ["https://claude.ai", "https://yourcompany.com"], "tokenTtl": "24h", "storage": "memory", - "googleClientId": {"$env": "GOOGLE_CLIENT_ID"}, - "googleClientSecret": {"$env": "GOOGLE_CLIENT_SECRET"}, - "googleRedirectUri": {"$env": "GOOGLE_REDIRECT_URI"}, "jwtSecret": {"$env": "JWT_SECRET"}, "encryptionKey": {"$env": "ENCRYPTION_KEY"} } diff --git a/config-oauth.json b/config-oauth.json index dec7a3d..e0c9853 100644 --- a/config-oauth.json +++ b/config-oauth.json @@ -8,13 +8,16 @@ "kind": "oauth", "issuer": "https://mcp-internal.yourcompany.org", "gcpProject": {"$env": "GCP_PROJECT"}, + "idp": { + "provider": "google", + "clientId": {"$env": "GOOGLE_CLIENT_ID"}, + "clientSecret": {"$env": "GOOGLE_CLIENT_SECRET"}, + "redirectUri": "https://mcp-internal.yourcompany.org/oauth/callback" + }, "allowedDomains": ["yourcompany.com"], "allowedOrigins": ["https://claude.ai"], "tokenTtl": "24h", "storage": "memory", - "googleClientId": {"$env": "GOOGLE_CLIENT_ID"}, - "googleClientSecret": {"$env": "GOOGLE_CLIENT_SECRET"}, - "googleRedirectUri": "https://mcp-internal.yourcompany.org/oauth/callback", "jwtSecret": {"$env": "JWT_SECRET"}, "encryptionKey": {"$env": "ENCRYPTION_KEY"} } diff --git a/config-user-tokens-example.json b/config-user-tokens-example.json index bd93f36..4c4b42f 100644 --- a/config-user-tokens-example.json +++ b/config-user-tokens-example.json @@ -8,13 +8,16 @@ "kind": "oauth", "issuer": {"$env": "OAUTH_ISSUER"}, "gcpProject": {"$env": "GCP_PROJECT"}, + "idp": { + "provider": "google", + "clientId": {"$env": "GOOGLE_CLIENT_ID"}, + "clientSecret": {"$env": "GOOGLE_CLIENT_SECRET"}, + "redirectUri": {"$env": "GOOGLE_REDIRECT_URI"} + }, "allowedDomains": ["yourcompany.com"], "allowedOrigins": ["https://claude.ai"], "tokenTtl": "24h", "storage": "memory", - "googleClientId": {"$env": "GOOGLE_CLIENT_ID"}, - "googleClientSecret": {"$env": "GOOGLE_CLIENT_SECRET"}, - "googleRedirectUri": {"$env": "GOOGLE_REDIRECT_URI"}, "jwtSecret": {"$env": "JWT_SECRET"}, "encryptionKey": {"$env": "ENCRYPTION_KEY"} } diff --git a/docs-site/src/content/docs/configuration.md b/docs-site/src/content/docs/configuration.md index 1351707..7a2f908 100644 --- a/docs-site/src/content/docs/configuration.md +++ b/docs-site/src/content/docs/configuration.md @@ -49,13 +49,16 @@ For production, use OAuth with Google. Claude redirects users to Google for auth "auth": { "kind": "oauth", "issuer": "https://mcp.company.com", + "idp": { + "provider": "google", + "clientId": { "$env": "GOOGLE_CLIENT_ID" }, + "clientSecret": { "$env": "GOOGLE_CLIENT_SECRET" }, + "redirectUri": "https://mcp.company.com/oauth/callback" + }, "allowedDomains": ["company.com"], "allowedOrigins": ["https://claude.ai"], "tokenTtl": "24h", "storage": "memory", - "googleClientId": { "$env": "GOOGLE_CLIENT_ID" }, - "googleClientSecret": { "$env": "GOOGLE_CLIENT_SECRET" }, - "googleRedirectUri": "https://mcp.company.com/oauth/callback", "jwtSecret": { "$env": "JWT_SECRET" }, "encryptionKey": { "$env": "ENCRYPTION_KEY" } } @@ -66,7 +69,7 @@ The `issuer` should match your `baseURL`. `allowedDomains` restricts access to s `tokenTtl` controls how long JWT tokens are valid. Shorter times are more secure but require more frequent logins. -Security requirements: `googleClientSecret`, `jwtSecret`, and `encryptionKey` must be environment variables. The JWT secret must be at least 32 bytes. The encryption key must be exactly 32 bytes. +Security requirements: `idp.clientSecret`, `jwtSecret`, and `encryptionKey` must be environment variables. The JWT secret must be at least 32 bytes. The encryption key must be exactly 32 bytes. For production, set `storage` to "firestore" and add `gcpProject`, `firestoreDatabase`, and `firestoreCollection` fields. @@ -345,13 +348,16 @@ Set `MCP_FRONT_ENV=development` when testing OAuth locally. It allows http:// UR "auth": { "kind": "oauth", "issuer": "https://mcp.company.com", + "idp": { + "provider": "google", + "clientId": { "$env": "GOOGLE_CLIENT_ID" }, + "clientSecret": { "$env": "GOOGLE_CLIENT_SECRET" }, + "redirectUri": "https://mcp.company.com/oauth/callback" + }, "allowedDomains": ["company.com"], "allowedOrigins": ["https://claude.ai"], "tokenTtl": "4h", "storage": "firestore", - "googleClientId": { "$env": "GOOGLE_CLIENT_ID" }, - "googleClientSecret": { "$env": "GOOGLE_CLIENT_SECRET" }, - "googleRedirectUri": "https://mcp.company.com/oauth/callback", "jwtSecret": { "$env": "JWT_SECRET" }, "encryptionKey": { "$env": "ENCRYPTION_KEY" }, "gcpProject": { "$env": "GOOGLE_CLOUD_PROJECT" }, diff --git a/docs-site/src/content/docs/examples/oauth-google.md b/docs-site/src/content/docs/examples/oauth-google.md index d4598eb..d35aea1 100644 --- a/docs-site/src/content/docs/examples/oauth-google.md +++ b/docs-site/src/content/docs/examples/oauth-google.md @@ -39,15 +39,26 @@ Create `config.json`: ```json { - "version": "1.0", + "version": "v0.0.1-DEV_EDITION_EXPECT_CHANGES", "proxy": { "name": "Company MCP Proxy", - "baseUrl": "https://mcp.company.com", + "baseURL": "https://mcp.company.com", "addr": ":8080", "auth": { "kind": "oauth", "issuer": "https://mcp.company.com", - "allowedDomains": ["company.com"] + "idp": { + "provider": "google", + "clientId": { "$env": "GOOGLE_CLIENT_ID" }, + "clientSecret": { "$env": "GOOGLE_CLIENT_SECRET" }, + "redirectUri": "https://mcp.company.com/oauth/callback" + }, + "allowedDomains": ["company.com"], + "allowedOrigins": ["https://claude.ai"], + "tokenTtl": "24h", + "storage": "memory", + "jwtSecret": { "$env": "JWT_SECRET" }, + "encryptionKey": { "$env": "ENCRYPTION_KEY" } } }, "mcpServers": { @@ -80,6 +91,7 @@ Create `config.json`: export GOOGLE_CLIENT_ID="your-client-id.apps.googleusercontent.com" export GOOGLE_CLIENT_SECRET="your-client-secret" export JWT_SECRET=$(openssl rand -base64 32) +export ENCRYPTION_KEY=$(openssl rand -base64 32) ``` ## 4. Run with Docker @@ -89,6 +101,7 @@ docker run -p 8080:8080 \ -e GOOGLE_CLIENT_ID \ -e GOOGLE_CLIENT_SECRET \ -e JWT_SECRET \ + -e ENCRYPTION_KEY \ -v $(pwd)/config.json:/config.json \ ghcr.io/dgellow/mcp-front:latest ``` diff --git a/docs-site/src/content/docs/index.mdx b/docs-site/src/content/docs/index.mdx index c7275ee..f1d79fb 100644 --- a/docs-site/src/content/docs/index.mdx +++ b/docs-site/src/content/docs/index.mdx @@ -41,9 +41,13 @@ Claude redirects users to Google for authentication, and MCP Front validates the "auth": { "kind": "oauth", "issuer": "https://mcp.company.com", + "idp": { + "provider": "google", + "clientId": { "$env": "GOOGLE_CLIENT_ID" }, + "clientSecret": { "$env": "GOOGLE_CLIENT_SECRET" }, + "redirectUri": "https://mcp.company.com/oauth/callback" + }, "allowedDomains": ["company.com"], - "googleClientId": { "$env": "GOOGLE_CLIENT_ID" }, - "googleClientSecret": { "$env": "GOOGLE_CLIENT_SECRET" }, "jwtSecret": { "$env": "JWT_SECRET" }, "encryptionKey": { "$env": "ENCRYPTION_KEY" } } diff --git a/integration/config/config.oauth-integration-test.json b/integration/config/config.oauth-integration-test.json index 0a40a0f..bb24a06 100644 --- a/integration/config/config.oauth-integration-test.json +++ b/integration/config/config.oauth-integration-test.json @@ -8,6 +8,12 @@ "kind": "oauth", "issuer": "http://localhost:8080", "gcpProject": "test-project", + "idp": { + "provider": "google", + "clientId": "test-client-id", + "clientSecret": "test-client-secret-for-integration-testing", + "redirectUri": "http://localhost:8080/oauth/callback" + }, "allowedDomains": [ "test.com" ], @@ -16,9 +22,6 @@ ], "tokenTtl": "1h", "storage": "memory", - "googleClientId": "test-client-id", - "googleClientSecret": "test-client-secret-for-integration-testing", - "googleRedirectUri": "http://localhost:8080/oauth/callback", "jwtSecret": "test-jwt-secret-for-integration-testing-32-chars-long", "encryptionKey": "test-encryption-key-32-bytes-aes" } diff --git a/integration/config/config.oauth-rfc8707-test.json b/integration/config/config.oauth-rfc8707-test.json index f169a64..25ceb1d 100644 --- a/integration/config/config.oauth-rfc8707-test.json +++ b/integration/config/config.oauth-rfc8707-test.json @@ -8,13 +8,16 @@ "kind": "oauth", "issuer": "http://localhost:8080", "gcpProject": "test-project", + "idp": { + "provider": "google", + "clientId": {"$env": "GOOGLE_CLIENT_ID"}, + "clientSecret": {"$env": "GOOGLE_CLIENT_SECRET"}, + "redirectUri": "http://localhost:8080/oauth/callback" + }, "allowedDomains": ["test.com"], "allowedOrigins": ["https://claude.ai"], "tokenTtl": "1h", "storage": "memory", - "googleClientId": {"$env": "GOOGLE_CLIENT_ID"}, - "googleClientSecret": {"$env": "GOOGLE_CLIENT_SECRET"}, - "googleRedirectUri": "http://localhost:8080/oauth/callback", "jwtSecret": {"$env": "JWT_SECRET"}, "encryptionKey": {"$env": "ENCRYPTION_KEY"} } diff --git a/integration/config/config.oauth-service-integration-test.json b/integration/config/config.oauth-service-integration-test.json index e46b23c..09ca0a9 100644 --- a/integration/config/config.oauth-service-integration-test.json +++ b/integration/config/config.oauth-service-integration-test.json @@ -8,13 +8,16 @@ "kind": "oauth", "issuer": "http://localhost:8080", "gcpProject": "test-project", + "idp": { + "provider": "google", + "clientId": {"$env": "GOOGLE_CLIENT_ID"}, + "clientSecret": {"$env": "GOOGLE_CLIENT_SECRET"}, + "redirectUri": "http://localhost:8080/oauth/callback" + }, "allowedDomains": ["test.com"], "allowedOrigins": ["https://claude.ai"], "tokenTtl": "1h", "storage": "memory", - "googleClientId": {"$env": "GOOGLE_CLIENT_ID"}, - "googleClientSecret": {"$env": "GOOGLE_CLIENT_SECRET"}, - "googleRedirectUri": "http://localhost:8080/oauth/callback", "jwtSecret": {"$env": "JWT_SECRET"}, "encryptionKey": {"$env": "ENCRYPTION_KEY"} } diff --git a/integration/config/config.oauth-service-test.json b/integration/config/config.oauth-service-test.json index 9d2f6d4..04071c8 100644 --- a/integration/config/config.oauth-service-test.json +++ b/integration/config/config.oauth-service-test.json @@ -8,13 +8,16 @@ "kind": "oauth", "issuer": "http://localhost:8080", "gcpProject": "test-project", + "idp": { + "provider": "google", + "clientId": {"$env": "GOOGLE_CLIENT_ID"}, + "clientSecret": {"$env": "GOOGLE_CLIENT_SECRET"}, + "redirectUri": "http://localhost:8080/oauth/callback" + }, "allowedDomains": ["test.com"], "allowedOrigins": ["https://claude.ai"], "tokenTtl": "1h", "storage": "memory", - "googleClientId": {"$env": "GOOGLE_CLIENT_ID"}, - "googleClientSecret": {"$env": "GOOGLE_CLIENT_SECRET"}, - "googleRedirectUri": "http://localhost:8080/oauth/callback", "jwtSecret": {"$env": "JWT_SECRET"}, "encryptionKey": {"$env": "ENCRYPTION_KEY"} } diff --git a/integration/config/config.oauth-test.json b/integration/config/config.oauth-test.json index d08faca..399e6d2 100644 --- a/integration/config/config.oauth-test.json +++ b/integration/config/config.oauth-test.json @@ -8,13 +8,16 @@ "kind": "oauth", "issuer": "http://localhost:8080", "gcpProject": "test-project", + "idp": { + "provider": "google", + "clientId": {"$env": "GOOGLE_CLIENT_ID"}, + "clientSecret": {"$env": "GOOGLE_CLIENT_SECRET"}, + "redirectUri": "http://localhost:8080/oauth/callback" + }, "allowedDomains": ["test.com", "stainless.com", "claude.ai"], "allowedOrigins": ["https://claude.ai"], "tokenTtl": "1h", "storage": "memory", - "googleClientId": {"$env": "GOOGLE_CLIENT_ID"}, - "googleClientSecret": {"$env": "GOOGLE_CLIENT_SECRET"}, - "googleRedirectUri": "http://localhost:8080/oauth/callback", "jwtSecret": {"$env": "JWT_SECRET"}, "encryptionKey": {"$env": "ENCRYPTION_KEY"} } @@ -30,4 +33,4 @@ ] } } -} \ No newline at end of file +} diff --git a/integration/config/config.oauth-token-test.json b/integration/config/config.oauth-token-test.json index 5d39ce4..874a1c7 100644 --- a/integration/config/config.oauth-token-test.json +++ b/integration/config/config.oauth-token-test.json @@ -8,13 +8,16 @@ "kind": "oauth", "issuer": "http://localhost:8080", "gcpProject": "test-project", + "idp": { + "provider": "google", + "clientId": {"$env": "GOOGLE_CLIENT_ID"}, + "clientSecret": {"$env": "GOOGLE_CLIENT_SECRET"}, + "redirectUri": "http://localhost:8080/oauth/callback" + }, "allowedDomains": ["test.com"], "allowedOrigins": ["https://claude.ai"], "tokenTtl": "1h", "storage": "memory", - "googleClientId": {"$env": "GOOGLE_CLIENT_ID"}, - "googleClientSecret": {"$env": "GOOGLE_CLIENT_SECRET"}, - "googleRedirectUri": "http://localhost:8080/oauth/callback", "jwtSecret": {"$env": "JWT_SECRET"}, "encryptionKey": {"$env": "ENCRYPTION_KEY"} } diff --git a/integration/config/config.oauth-usertoken-tools-test.json b/integration/config/config.oauth-usertoken-tools-test.json index 9d2f6d4..9fafd1d 100644 --- a/integration/config/config.oauth-usertoken-tools-test.json +++ b/integration/config/config.oauth-usertoken-tools-test.json @@ -8,13 +8,16 @@ "kind": "oauth", "issuer": "http://localhost:8080", "gcpProject": "test-project", + "idp": { + "provider": "google", + "clientId": {"$env": "GOOGLE_CLIENT_ID"}, + "clientSecret": {"$env": "GOOGLE_CLIENT_SECRET"}, + "redirectUri": "http://localhost:8080/oauth/callback" + }, "allowedDomains": ["test.com"], "allowedOrigins": ["https://claude.ai"], "tokenTtl": "1h", "storage": "memory", - "googleClientId": {"$env": "GOOGLE_CLIENT_ID"}, - "googleClientSecret": {"$env": "GOOGLE_CLIENT_SECRET"}, - "googleRedirectUri": "http://localhost:8080/oauth/callback", "jwtSecret": {"$env": "JWT_SECRET"}, "encryptionKey": {"$env": "ENCRYPTION_KEY"} } @@ -44,4 +47,4 @@ } } } -} \ No newline at end of file +} diff --git a/internal/browserauth/session.go b/internal/browserauth/session.go index a1876d9..1563f5b 100644 --- a/internal/browserauth/session.go +++ b/internal/browserauth/session.go @@ -4,8 +4,9 @@ import "time" // SessionCookie represents the data stored in encrypted browser session cookies type SessionCookie struct { - Email string `json:"email"` - Expires time.Time `json:"expires"` + Email string `json:"email"` + Provider string `json:"provider"` // IDP that authenticated this user (e.g., "google", "azure", "github") + Expires time.Time `json:"expires"` } // AuthorizationState represents the OAuth authorization code flow state parameter diff --git a/internal/browserauth/session_test.go b/internal/browserauth/session_test.go index f7e15de..dace5bb 100644 --- a/internal/browserauth/session_test.go +++ b/internal/browserauth/session_test.go @@ -11,8 +11,9 @@ import ( func TestSessionCookie_MarshalUnmarshal(t *testing.T) { original := SessionCookie{ - Email: "user@example.com", - Expires: time.Now().Add(24 * time.Hour).Truncate(time.Second), + Email: "user@example.com", + Provider: "google", + Expires: time.Now().Add(24 * time.Hour).Truncate(time.Second), } // Marshal to JSON @@ -26,14 +27,16 @@ func TestSessionCookie_MarshalUnmarshal(t *testing.T) { // Truncate for comparison (JSON time serialization) assert.Equal(t, original.Email, unmarshaled.Email) + assert.Equal(t, original.Provider, unmarshaled.Provider) assert.WithinDuration(t, original.Expires, unmarshaled.Expires, time.Second) } func TestSessionCookie_Expiry(t *testing.T) { t.Run("not expired", func(t *testing.T) { session := SessionCookie{ - Email: "user@example.com", - Expires: time.Now().Add(1 * time.Hour), + Email: "user@example.com", + Provider: "google", + Expires: time.Now().Add(1 * time.Hour), } assert.True(t, session.Expires.After(time.Now())) @@ -41,8 +44,9 @@ func TestSessionCookie_Expiry(t *testing.T) { t.Run("expired", func(t *testing.T) { session := SessionCookie{ - Email: "user@example.com", - Expires: time.Now().Add(-1 * time.Hour), + Email: "user@example.com", + Provider: "google", + Expires: time.Now().Add(-1 * time.Hour), } assert.True(t, session.Expires.Before(time.Now())) diff --git a/internal/config/load.go b/internal/config/load.go index 326cec1..768ebf0 100644 --- a/internal/config/load.go +++ b/internal/config/load.go @@ -58,11 +58,11 @@ func validateRawConfig(rawConfig map[string]any) error { if proxy, ok := rawConfig["proxy"].(map[string]any); ok { if auth, ok := proxy["auth"].(map[string]any); ok { if kind, ok := auth["kind"].(string); ok && kind == "oauth" { + // Validate top-level auth secrets secrets := []struct { name string required bool }{ - {"googleClientSecret", true}, {"jwtSecret", true}, {"encryptionKey", true}, // Always required for OAuth } @@ -88,6 +88,20 @@ func validateRawConfig(rawConfig map[string]any) error { } } } + + // Validate IDP secret (clientSecret must be env ref) + if idp, ok := auth["idp"].(map[string]any); ok { + if clientSecret, exists := idp["clientSecret"]; exists { + if _, isString := clientSecret.(string); isString { + return fmt.Errorf("idp.clientSecret must use environment variable reference for security") + } + if refMap, isMap := clientSecret.(map[string]any); isMap { + if _, hasEnv := refMap["$env"]; !hasEnv { + return fmt.Errorf("idp.clientSecret must use {\"$env\": \"VAR_NAME\"} format") + } + } + } + } } } } @@ -148,24 +162,54 @@ func validateOAuthConfig(oauth *OAuthAuthConfig) error { if oauth.Issuer == "" { return fmt.Errorf("issuer is required") } - if oauth.GoogleClientID == "" { - return fmt.Errorf("googleClientId is required") + + // Validate IDP configuration + if oauth.IDP.Provider == "" { + return fmt.Errorf("idp.provider is required (google, azure, github, or oidc)") + } + if oauth.IDP.ClientID == "" { + return fmt.Errorf("idp.clientId is required") } - if oauth.GoogleClientSecret == "" { - return fmt.Errorf("googleClientSecret is required") + if oauth.IDP.ClientSecret == "" { + return fmt.Errorf("idp.clientSecret is required") } - if oauth.GoogleRedirectURI == "" { - return fmt.Errorf("googleRedirectUri is required") + if oauth.IDP.RedirectURI == "" { + return fmt.Errorf("idp.redirectUri is required") } + + // Provider-specific validation + switch oauth.IDP.Provider { + case "google": + // No additional validation needed + case "azure": + if oauth.IDP.TenantID == "" { + return fmt.Errorf("idp.tenantId is required for Azure AD") + } + case "github": + // No additional validation needed + case "oidc": + // Either discoveryUrl or all manual endpoints required + if oauth.IDP.DiscoveryURL == "" { + if oauth.IDP.AuthorizationURL == "" || oauth.IDP.TokenURL == "" || oauth.IDP.UserInfoURL == "" { + return fmt.Errorf("idp.discoveryUrl or all of (authorizationUrl, tokenUrl, userInfoUrl) required for OIDC") + } + } + default: + return fmt.Errorf("unsupported idp.provider: %s (must be google, azure, github, or oidc)", oauth.IDP.Provider) + } + if len(oauth.JWTSecret) < 32 { return fmt.Errorf("jwtSecret must be at least 32 characters (got %d). Generate with: openssl rand -base64 32", len(oauth.JWTSecret)) } if len(oauth.EncryptionKey) != 32 { return fmt.Errorf("encryptionKey must be exactly 32 characters (got %d). Generate with: openssl rand -base64 32 | head -c 32", len(oauth.EncryptionKey)) } - if len(oauth.AllowedDomains) == 0 { - return fmt.Errorf("at least one allowed domain is required") + + // Domain or org validation - at least one access control mechanism required + if len(oauth.AllowedDomains) == 0 && len(oauth.IDP.AllowedOrgs) == 0 { + return fmt.Errorf("at least one of allowedDomains or idp.allowedOrgs is required") } + if oauth.Storage == "firestore" { if oauth.GCPProject == "" { return fmt.Errorf("gcpProject is required when using firestore storage") diff --git a/internal/config/load_test.go b/internal/config/load_test.go index a79b8e3..4fb5659 100644 --- a/internal/config/load_test.go +++ b/internal/config/load_test.go @@ -41,17 +41,20 @@ func TestValidateConfig_UserTokensRequireOAuth(t *testing.T) { BaseURL: "https://test.example.com", Addr: ":8080", Auth: &OAuthAuthConfig{ - Kind: "oauth", - Issuer: "https://auth.example.com", - GoogleClientID: "test-client", - GoogleClientSecret: "test-secret", - GoogleRedirectURI: "https://test.example.com/callback", - JWTSecret: "test-jwt-secret-must-be-32-bytes-long", - EncryptionKey: "test-encryption-key-32-bytes-ok!", - AllowedDomains: []string{"example.com"}, - AllowedOrigins: []string{"https://test.example.com"}, - TokenTTL: time.Hour, - RefreshTokenTTL: 30 * 24 * time.Hour, + Kind: "oauth", + Issuer: "https://auth.example.com", + IDP: IDPConfig{ + Provider: "google", + ClientID: "test-client", + ClientSecret: "test-secret", + RedirectURI: "https://test.example.com/callback", + }, + JWTSecret: "test-jwt-secret-must-be-32-bytes-long", + EncryptionKey: "test-encryption-key-32-bytes-ok!", + AllowedDomains: []string{"example.com"}, + AllowedOrigins: []string{"https://test.example.com"}, + TokenTTL: time.Hour, + RefreshTokenTTL: 30 * 24 * time.Hour, }, }, MCPServers: map[string]*MCPClientConfig{ @@ -98,17 +101,20 @@ func TestValidateConfig_SessionConfig(t *testing.T) { BaseURL: "https://test.example.com", Addr: ":8080", Auth: &OAuthAuthConfig{ - Kind: "oauth", - Issuer: "https://auth.example.com", - GoogleClientID: "test-client", - GoogleClientSecret: "test-secret", - GoogleRedirectURI: "https://test.example.com/callback", - JWTSecret: "test-jwt-secret-must-be-32-bytes-long", - EncryptionKey: "test-encryption-key-32-bytes-ok!", - AllowedDomains: []string{"example.com"}, - AllowedOrigins: []string{"https://test.example.com"}, - TokenTTL: time.Hour, - RefreshTokenTTL: 30 * 24 * time.Hour, + Kind: "oauth", + Issuer: "https://auth.example.com", + IDP: IDPConfig{ + Provider: "google", + ClientID: "test-client", + ClientSecret: "test-secret", + RedirectURI: "https://test.example.com/callback", + }, + JWTSecret: "test-jwt-secret-must-be-32-bytes-long", + EncryptionKey: "test-encryption-key-32-bytes-ok!", + AllowedDomains: []string{"example.com"}, + AllowedOrigins: []string{"https://test.example.com"}, + TokenTTL: time.Hour, + RefreshTokenTTL: 30 * 24 * time.Hour, }, Sessions: &SessionConfig{ Timeout: 10 * time.Minute, @@ -128,17 +134,20 @@ func TestValidateConfig_SessionConfig(t *testing.T) { BaseURL: "https://test.example.com", Addr: ":8080", Auth: &OAuthAuthConfig{ - Kind: "oauth", - Issuer: "https://auth.example.com", - GoogleClientID: "test-client", - GoogleClientSecret: "test-secret", - GoogleRedirectURI: "https://test.example.com/callback", - JWTSecret: "test-jwt-secret-must-be-32-bytes-long", - EncryptionKey: "test-encryption-key-32-bytes-ok!", - AllowedDomains: []string{"example.com"}, - AllowedOrigins: []string{"https://test.example.com"}, - TokenTTL: time.Hour, - RefreshTokenTTL: 30 * 24 * time.Hour, + Kind: "oauth", + Issuer: "https://auth.example.com", + IDP: IDPConfig{ + Provider: "google", + ClientID: "test-client", + ClientSecret: "test-secret", + RedirectURI: "https://test.example.com/callback", + }, + JWTSecret: "test-jwt-secret-must-be-32-bytes-long", + EncryptionKey: "test-encryption-key-32-bytes-ok!", + AllowedDomains: []string{"example.com"}, + AllowedOrigins: []string{"https://test.example.com"}, + TokenTTL: time.Hour, + RefreshTokenTTL: 30 * 24 * time.Hour, }, Sessions: &SessionConfig{ Timeout: -1 * time.Minute, @@ -156,17 +165,20 @@ func TestValidateConfig_SessionConfig(t *testing.T) { BaseURL: "https://test.example.com", Addr: ":8080", Auth: &OAuthAuthConfig{ - Kind: "oauth", - Issuer: "https://auth.example.com", - GoogleClientID: "test-client", - GoogleClientSecret: "test-secret", - GoogleRedirectURI: "https://test.example.com/callback", - JWTSecret: "test-jwt-secret-must-be-32-bytes-long", - EncryptionKey: "test-encryption-key-32-bytes-ok!", - AllowedDomains: []string{"example.com"}, - AllowedOrigins: []string{"https://test.example.com"}, - TokenTTL: time.Hour, - RefreshTokenTTL: 30 * 24 * time.Hour, + Kind: "oauth", + Issuer: "https://auth.example.com", + IDP: IDPConfig{ + Provider: "google", + ClientID: "test-client", + ClientSecret: "test-secret", + RedirectURI: "https://test.example.com/callback", + }, + JWTSecret: "test-jwt-secret-must-be-32-bytes-long", + EncryptionKey: "test-encryption-key-32-bytes-ok!", + AllowedDomains: []string{"example.com"}, + AllowedOrigins: []string{"https://test.example.com"}, + TokenTTL: time.Hour, + RefreshTokenTTL: 30 * 24 * time.Hour, }, Sessions: &SessionConfig{ Timeout: 10 * time.Minute, @@ -184,17 +196,20 @@ func TestValidateConfig_SessionConfig(t *testing.T) { BaseURL: "https://test.example.com", Addr: ":8080", Auth: &OAuthAuthConfig{ - Kind: "oauth", - Issuer: "https://auth.example.com", - GoogleClientID: "test-client", - GoogleClientSecret: "test-secret", - GoogleRedirectURI: "https://test.example.com/callback", - JWTSecret: "test-jwt-secret-must-be-32-bytes-long", - EncryptionKey: "test-encryption-key-32-bytes-ok!", - AllowedDomains: []string{"example.com"}, - AllowedOrigins: []string{"https://test.example.com"}, - TokenTTL: time.Hour, - RefreshTokenTTL: 30 * 24 * time.Hour, + Kind: "oauth", + Issuer: "https://auth.example.com", + IDP: IDPConfig{ + Provider: "google", + ClientID: "test-client", + ClientSecret: "test-secret", + RedirectURI: "https://test.example.com/callback", + }, + JWTSecret: "test-jwt-secret-must-be-32-bytes-long", + EncryptionKey: "test-encryption-key-32-bytes-ok!", + AllowedDomains: []string{"example.com"}, + AllowedOrigins: []string{"https://test.example.com"}, + TokenTTL: time.Hour, + RefreshTokenTTL: 30 * 24 * time.Hour, }, Sessions: &SessionConfig{ Timeout: 0, diff --git a/internal/config/types.go b/internal/config/types.go index 515f202..337f16d 100644 --- a/internal/config/types.go +++ b/internal/config/types.go @@ -200,12 +200,39 @@ type AdminConfig struct { AdminEmails []string `json:"adminEmails"` } +// IDPConfig represents identity provider configuration. +type IDPConfig struct { + // Provider type: "google", "azure", "github", or "oidc" + Provider string `json:"provider"` + + // OAuth client configuration + ClientID string `json:"clientId"` + ClientSecret Secret `json:"clientSecret"` + RedirectURI string `json:"redirectUri"` + + // For OIDC: discovery URL or manual endpoint configuration + DiscoveryURL string `json:"discoveryUrl,omitempty"` + AuthorizationURL string `json:"authorizationUrl,omitempty"` + TokenURL string `json:"tokenUrl,omitempty"` + UserInfoURL string `json:"userInfoUrl,omitempty"` + + // Custom scopes (optional, defaults per provider) + Scopes []string `json:"scopes,omitempty"` + + // Azure-specific: tenant ID + TenantID string `json:"tenantId,omitempty"` + + // GitHub-specific: allowed organizations + AllowedOrgs []string `json:"allowedOrgs,omitempty"` +} + // OAuthAuthConfig represents OAuth 2.0 configuration with resolved values type OAuthAuthConfig struct { Kind AuthKind `json:"kind"` Issuer string `json:"issuer"` GCPProject string `json:"gcpProject"` - AllowedDomains []string `json:"allowedDomains"` // For Google OAuth email validation + IDP IDPConfig `json:"idp"` + AllowedDomains []string `json:"allowedDomains"` // For domain-based access control AllowedOrigins []string `json:"allowedOrigins"` // For CORS validation TokenTTL time.Duration `json:"tokenTtl"` RefreshTokenTTL time.Duration `json:"refreshTokenTtl"` @@ -213,9 +240,6 @@ type OAuthAuthConfig struct { Storage string `json:"storage"` // "memory" or "firestore" FirestoreDatabase string `json:"firestoreDatabase,omitempty"` // Optional: Firestore database name FirestoreCollection string `json:"firestoreCollection,omitempty"` // Optional: Firestore collection name - GoogleClientID string `json:"googleClientId"` - GoogleClientSecret Secret `json:"googleClientSecret"` - GoogleRedirectURI string `json:"googleRedirectUri"` JWTSecret Secret `json:"jwtSecret"` EncryptionKey Secret `json:"encryptionKey"` // DangerouslyAcceptIssuerAudience allows tokens with just the base issuer as audience diff --git a/internal/config/unmarshal.go b/internal/config/unmarshal.go index eb20321..0a03878 100644 --- a/internal/config/unmarshal.go +++ b/internal/config/unmarshal.go @@ -118,6 +118,7 @@ func (o *OAuthAuthConfig) UnmarshalJSON(data []byte) error { Kind AuthKind `json:"kind"` Issuer json.RawMessage `json:"issuer"` GCPProject json.RawMessage `json:"gcpProject"` + IDP json.RawMessage `json:"idp"` AllowedDomains []string `json:"allowedDomains"` AllowedOrigins []string `json:"allowedOrigins"` TokenTTL string `json:"tokenTtl"` @@ -126,9 +127,6 @@ func (o *OAuthAuthConfig) UnmarshalJSON(data []byte) error { Storage string `json:"storage"` FirestoreDatabase string `json:"firestoreDatabase,omitempty"` FirestoreCollection string `json:"firestoreCollection,omitempty"` - GoogleClientID json.RawMessage `json:"googleClientId"` - GoogleClientSecret json.RawMessage `json:"googleClientSecret"` - GoogleRedirectURI json.RawMessage `json:"googleRedirectUri"` JWTSecret json.RawMessage `json:"jwtSecret"` EncryptionKey json.RawMessage `json:"encryptionKey,omitempty"` DangerouslyAcceptIssuerAudience bool `json:"dangerouslyAcceptIssuerAudience,omitempty"` @@ -202,38 +200,13 @@ func (o *OAuthAuthConfig) UnmarshalJSON(data []byte) error { o.GCPProject = parsed.value } - if raw.GoogleClientID != nil { - parsed, err := ParseConfigValue(raw.GoogleClientID) + // Parse IDP config + if raw.IDP != nil { + idp, err := parseIDPConfig(raw.IDP) if err != nil { - return fmt.Errorf("parsing googleClientId: %w", err) + return fmt.Errorf("parsing idp: %w", err) } - if parsed.needsUserToken { - return fmt.Errorf("googleClientId cannot be a user token reference") - } - o.GoogleClientID = parsed.value - } - - if raw.GoogleRedirectURI != nil { - parsed, err := ParseConfigValue(raw.GoogleRedirectURI) - if err != nil { - return fmt.Errorf("parsing googleRedirectUri: %w", err) - } - if parsed.needsUserToken { - return fmt.Errorf("googleRedirectUri cannot be a user token reference") - } - o.GoogleRedirectURI = parsed.value - } - - // Parse secret fields - if raw.GoogleClientSecret != nil { - parsed, err := ParseConfigValue(raw.GoogleClientSecret) - if err != nil { - return fmt.Errorf("parsing googleClientSecret: %w", err) - } - if parsed.needsUserToken { - return fmt.Errorf("googleClientSecret cannot be a user token reference") - } - o.GoogleClientSecret = Secret(parsed.value) + o.IDP = *idp } if raw.JWTSecret != nil { @@ -274,6 +247,87 @@ func (o *OAuthAuthConfig) UnmarshalJSON(data []byte) error { return nil } +// parseIDPConfig parses the IDP configuration with env var references +func parseIDPConfig(data json.RawMessage) (*IDPConfig, error) { + type rawIDP struct { + Provider string `json:"provider"` + ClientID json.RawMessage `json:"clientId"` + ClientSecret json.RawMessage `json:"clientSecret"` + RedirectURI json.RawMessage `json:"redirectUri"` + DiscoveryURL json.RawMessage `json:"discoveryUrl,omitempty"` + AuthorizationURL json.RawMessage `json:"authorizationUrl,omitempty"` + TokenURL json.RawMessage `json:"tokenUrl,omitempty"` + UserInfoURL json.RawMessage `json:"userInfoUrl,omitempty"` + Scopes []string `json:"scopes,omitempty"` + TenantID json.RawMessage `json:"tenantId,omitempty"` + AllowedOrgs []string `json:"allowedOrgs,omitempty"` + } + + var raw rawIDP + if err := json.Unmarshal(data, &raw); err != nil { + return nil, err + } + + idp := &IDPConfig{ + Provider: raw.Provider, + Scopes: raw.Scopes, + AllowedOrgs: raw.AllowedOrgs, + } + + // Helper to parse a field that cannot be a user token reference + parseField := func(data json.RawMessage, fieldName string) (string, error) { + if data == nil { + return "", nil + } + parsed, err := ParseConfigValue(data) + if err != nil { + return "", fmt.Errorf("parsing %s: %w", fieldName, err) + } + if parsed.needsUserToken { + return "", fmt.Errorf("%s cannot be a user token reference", fieldName) + } + return parsed.value, nil + } + + var err error + + if idp.ClientID, err = parseField(raw.ClientID, "clientId"); err != nil { + return nil, err + } + + var clientSecret string + if clientSecret, err = parseField(raw.ClientSecret, "clientSecret"); err != nil { + return nil, err + } + idp.ClientSecret = Secret(clientSecret) + + if idp.RedirectURI, err = parseField(raw.RedirectURI, "redirectUri"); err != nil { + return nil, err + } + + if idp.DiscoveryURL, err = parseField(raw.DiscoveryURL, "discoveryUrl"); err != nil { + return nil, err + } + + if idp.AuthorizationURL, err = parseField(raw.AuthorizationURL, "authorizationUrl"); err != nil { + return nil, err + } + + if idp.TokenURL, err = parseField(raw.TokenURL, "tokenUrl"); err != nil { + return nil, err + } + + if idp.UserInfoURL, err = parseField(raw.UserInfoURL, "userInfoUrl"); err != nil { + return nil, err + } + + if idp.TenantID, err = parseField(raw.TenantID, "tenantId"); err != nil { + return nil, err + } + + return idp, nil +} + // UnmarshalJSON implements custom unmarshaling for ProxyConfig func (p *ProxyConfig) UnmarshalJSON(data []byte) error { // Use a raw type to parse references diff --git a/internal/config/unmarshal_test.go b/internal/config/unmarshal_test.go index dc0682a..bc1db55 100644 --- a/internal/config/unmarshal_test.go +++ b/internal/config/unmarshal_test.go @@ -255,9 +255,12 @@ func TestOAuthAuthConfig_UnmarshalJSON(t *testing.T) { "allowedOrigins": ["https://claude.ai", "https://example.com"], "tokenTtl": "1h", "storage": "firestore", - "googleClientId": "test-client-id", - "googleClientSecret": {"$env": "CLIENT_SECRET"}, - "googleRedirectUri": "https://example.com/callback", + "idp": { + "provider": "google", + "clientId": "test-client-id", + "clientSecret": {"$env": "CLIENT_SECRET"}, + "redirectUri": "https://example.com/callback" + }, "jwtSecret": {"$env": "JWT_SECRET"}, "encryptionKey": {"$env": "ENCRYPTION_KEY"} }` @@ -271,8 +274,10 @@ func TestOAuthAuthConfig_UnmarshalJSON(t *testing.T) { assert.Equal(t, "test-project", config.GCPProject) assert.Equal(t, []string{"example.com"}, config.AllowedDomains) assert.Equal(t, []string{"https://claude.ai", "https://example.com"}, config.AllowedOrigins) - assert.Equal(t, "test-client-id", config.GoogleClientID) - assert.Equal(t, Secret("test-secret-value"), config.GoogleClientSecret) + assert.Equal(t, "google", config.IDP.Provider) + assert.Equal(t, "test-client-id", config.IDP.ClientID) + assert.Equal(t, Secret("test-secret-value"), config.IDP.ClientSecret) + assert.Equal(t, "https://example.com/callback", config.IDP.RedirectURI) assert.Equal(t, Secret("this-is-a-very-long-jwt-secret-key"), config.JWTSecret) assert.Equal(t, Secret("exactly-32-bytes-long-encryptkey"), config.EncryptionKey) } @@ -406,9 +411,12 @@ func TestProxyConfig_SessionConfigIntegration(t *testing.T) { "auth": { "kind": "oauth", "issuer": "https://auth.example.com", - "googleClientId": "test-client", - "googleClientSecret": "test-secret", - "googleRedirectUri": "https://test.example.com/callback", + "idp": { + "provider": "google", + "clientId": "test-client", + "clientSecret": "test-secret", + "redirectUri": "https://test.example.com/callback" + }, "jwtSecret": "test-jwt-secret-must-be-32-bytes-long", "encryptionKey": "test-encryption-key-32-bytes-ok!", "allowedDomains": ["example.com"], diff --git a/internal/config/validation.go b/internal/config/validation.go index 40612d1..490a123 100644 --- a/internal/config/validation.go +++ b/internal/config/validation.go @@ -146,9 +146,6 @@ func validateAuthStructure(auth map[string]any, result *ValidationResult) { hint string }{ {"issuer", ""}, - {"googleClientId", ""}, - {"googleClientSecret", ""}, - {"googleRedirectUri", ""}, {"jwtSecret", "Hint: Must be at least 32 bytes long for HMAC-SHA256"}, {"encryptionKey", "Hint: Must be exactly 32 bytes for AES-256-GCM encryption"}, } @@ -164,12 +161,18 @@ func validateAuthStructure(auth map[string]any, result *ValidationResult) { }) } } - if domains, ok := auth["allowedDomains"].([]any); !ok || len(domains) == 0 { + + // Validate IDP configuration + idp, hasIDP := auth["idp"].(map[string]any) + if !hasIDP { result.Errors = append(result.Errors, ValidationError{ - Path: "proxy.auth.allowedDomains", - Message: "at least one allowed domain is required for OAuth", + Path: "proxy.auth.idp", + Message: "idp configuration is required for OAuth", }) + } else { + validateIDPStructure(idp, result) } + if origins, ok := auth["allowedOrigins"].([]any); !ok || len(origins) == 0 { result.Errors = append(result.Errors, ValidationError{ Path: "proxy.auth.allowedOrigins", @@ -184,6 +187,74 @@ func validateAuthStructure(auth map[string]any, result *ValidationResult) { } } +// validateIDPStructure checks identity provider configuration +func validateIDPStructure(idp map[string]any, result *ValidationResult) { + provider, ok := idp["provider"].(string) + if !ok { + result.Errors = append(result.Errors, ValidationError{ + Path: "proxy.auth.idp.provider", + Message: "provider is required. Options: google, azure, github, oidc", + }) + return + } + + // Check required fields for all providers + if _, ok := idp["clientId"]; !ok { + result.Errors = append(result.Errors, ValidationError{ + Path: "proxy.auth.idp.clientId", + Message: "clientId is required for IDP configuration", + }) + } + if _, ok := idp["clientSecret"]; !ok { + result.Errors = append(result.Errors, ValidationError{ + Path: "proxy.auth.idp.clientSecret", + Message: "clientSecret is required for IDP configuration", + }) + } + if _, ok := idp["redirectUri"]; !ok { + result.Errors = append(result.Errors, ValidationError{ + Path: "proxy.auth.idp.redirectUri", + Message: "redirectUri is required for IDP configuration", + }) + } + + // Provider-specific validation + switch provider { + case "google", "github": + // No additional required fields + case "azure": + if _, ok := idp["tenantId"]; !ok { + result.Errors = append(result.Errors, ValidationError{ + Path: "proxy.auth.idp.tenantId", + Message: "tenantId is required for Azure AD provider", + }) + } + case "oidc": + // Either discoveryUrl or manual endpoints required + hasDiscovery := false + if _, ok := idp["discoveryUrl"]; ok { + hasDiscovery = true + } + if !hasDiscovery { + // Check for manual endpoints + requiredEndpoints := []string{"authorizationUrl", "tokenUrl", "userInfoUrl"} + for _, endpoint := range requiredEndpoints { + if _, ok := idp[endpoint]; !ok { + result.Errors = append(result.Errors, ValidationError{ + Path: "proxy.auth.idp." + endpoint, + Message: fmt.Sprintf("%s is required for OIDC provider when discoveryUrl is not provided", endpoint), + }) + } + } + } + default: + result.Errors = append(result.Errors, ValidationError{ + Path: "proxy.auth.idp.provider", + Message: fmt.Sprintf("unknown provider '%s' - supported providers: google, azure, github, oidc", provider), + }) + } +} + // validateAdminStructure checks admin configuration structure func validateAdminStructure(admin map[string]any, result *ValidationResult) { enabled, ok := admin["enabled"].(bool) diff --git a/internal/config/validation_test.go b/internal/config/validation_test.go index 910bb00..a8b5694 100644 --- a/internal/config/validation_test.go +++ b/internal/config/validation_test.go @@ -51,9 +51,12 @@ func TestValidateFile(t *testing.T) { "auth": { "kind": "oauth", "issuer": "https://example.com", - "googleClientId": {"$env": "CLIENT_ID"}, - "googleClientSecret": {"$env": "CLIENT_SECRET"}, - "googleRedirectUri": "https://example.com/callback", + "idp": { + "provider": "google", + "clientId": {"$env": "CLIENT_ID"}, + "clientSecret": {"$env": "CLIENT_SECRET"}, + "redirectUri": "https://example.com/callback" + }, "jwtSecret": {"$env": "JWT_SECRET"}, "encryptionKey": {"$env": "ENCRYPTION_KEY"}, "allowedDomains": ["example.com"], @@ -225,9 +228,12 @@ func TestValidateFile(t *testing.T) { "auth": { "kind": "oauth", "issuer": "https://example.com", - "googleClientId": "id", - "googleClientSecret": "secret", - "googleRedirectUri": "https://example.com/callback", + "idp": { + "provider": "google", + "clientId": "id", + "clientSecret": "secret", + "redirectUri": "https://example.com/callback" + }, "jwtSecret": "secret", "encryptionKey": "key", "allowedDomains": ["example.com"], @@ -260,15 +266,12 @@ func TestValidateFile(t *testing.T) { }`, wantErrors: []string{ "issuer is required for OAuth", - "googleClientId is required for OAuth", - "googleClientSecret is required for OAuth", - "googleRedirectUri is required for OAuth", "jwtSecret is required for OAuth. Hint: Must be at least 32 bytes long for HMAC-SHA256", "encryptionKey is required for OAuth. Hint: Must be exactly 32 bytes for AES-256-GCM encryption", - "at least one allowed domain is required for OAuth", + "idp configuration is required for OAuth", "at least one allowed origin is required for OAuth (CORS configuration)", }, - wantErrCount: 8, + wantErrCount: 5, }, { name: "valid_manual_user_authentication", @@ -280,9 +283,12 @@ func TestValidateFile(t *testing.T) { "auth": { "kind": "oauth", "issuer": "https://example.com", - "googleClientId": "id", - "googleClientSecret": "secret", - "googleRedirectUri": "https://example.com/callback", + "idp": { + "provider": "google", + "clientId": "id", + "clientSecret": "secret", + "redirectUri": "https://example.com/callback" + }, "jwtSecret": "secret123456789012345678901234567890", "encryptionKey": "key12345678901234567890123456789", "allowedDomains": ["example.com"], @@ -315,9 +321,12 @@ func TestValidateFile(t *testing.T) { "auth": { "kind": "oauth", "issuer": "https://example.com", - "googleClientId": "id", - "googleClientSecret": "secret", - "googleRedirectUri": "https://example.com/callback", + "idp": { + "provider": "google", + "clientId": "id", + "clientSecret": "secret", + "redirectUri": "https://example.com/callback" + }, "jwtSecret": "secret123456789012345678901234567890", "encryptionKey": "key12345678901234567890123456789", "allowedDomains": ["example.com"], @@ -354,9 +363,12 @@ func TestValidateFile(t *testing.T) { "auth": { "kind": "oauth", "issuer": "https://example.com", - "googleClientId": "id", - "googleClientSecret": "secret", - "googleRedirectUri": "https://example.com/callback", + "idp": { + "provider": "google", + "clientId": "id", + "clientSecret": "secret", + "redirectUri": "https://example.com/callback" + }, "jwtSecret": "secret123456789012345678901234567890", "encryptionKey": "key12345678901234567890123456789", "allowedDomains": ["example.com"], @@ -515,9 +527,12 @@ func TestValidateFile_ImprovedErrorMessages(t *testing.T) { "auth": { "kind": "oauth", "issuer": "https://example.com", - "googleClientId": "id", - "googleClientSecret": "secret", - "googleRedirectUri": "https://example.com/callback", + "idp": { + "provider": "google", + "clientId": "id", + "clientSecret": "secret", + "redirectUri": "https://example.com/callback" + }, "jwtSecret": "secret123456789012345678901234567890", "encryptionKey": "key12345678901234567890123456789", "allowedDomains": ["example.com"], diff --git a/internal/googleauth/google.go b/internal/googleauth/google.go deleted file mode 100644 index 1667a49..0000000 --- a/internal/googleauth/google.go +++ /dev/null @@ -1,121 +0,0 @@ -package googleauth - -import ( - "context" - "encoding/json" - "fmt" - "net/http" - "os" - "slices" - "strings" - - "github.com/dgellow/mcp-front/internal/config" - emailutil "github.com/dgellow/mcp-front/internal/emailutil" - "golang.org/x/oauth2" - "golang.org/x/oauth2/google" -) - -// UserInfo represents Google user information -type UserInfo struct { - Email string `json:"email"` - HostedDomain string `json:"hd"` - Name string `json:"name"` - Picture string `json:"picture"` - VerifiedEmail bool `json:"verified_email"` -} - -// GoogleAuthURL generates a Google OAuth authorization URL -func GoogleAuthURL(oauthConfig config.OAuthAuthConfig, state string) string { - googleOAuth := newGoogleOAuth2Config(oauthConfig) - return googleOAuth.AuthCodeURL(state, - oauth2.AccessTypeOffline, - oauth2.ApprovalForce, - ) -} - -// ExchangeCodeForToken exchanges the authorization code for a token -func ExchangeCodeForToken(ctx context.Context, oauthConfig config.OAuthAuthConfig, code string) (*oauth2.Token, error) { - googleOAuth := newGoogleOAuth2Config(oauthConfig) - return googleOAuth.Exchange(ctx, code) -} - -// ValidateUser validates the Google OAuth token and checks domain membership -func ValidateUser(ctx context.Context, oauthConfig config.OAuthAuthConfig, token *oauth2.Token) (UserInfo, error) { - googleOAuth := newGoogleOAuth2Config(oauthConfig) - client := googleOAuth.Client(ctx, token) - userInfoURL := "https://www.googleapis.com/oauth2/v2/userinfo" - if customURL := os.Getenv("GOOGLE_USERINFO_URL"); customURL != "" { - userInfoURL = customURL - } - resp, err := client.Get(userInfoURL) - if err != nil { - return UserInfo{}, fmt.Errorf("failed to get user info: %w", err) - } - defer resp.Body.Close() - - if resp.StatusCode != http.StatusOK { - return UserInfo{}, fmt.Errorf("failed to get user info: status %d", resp.StatusCode) - } - - var userInfo UserInfo - if err := json.NewDecoder(resp.Body).Decode(&userInfo); err != nil { - return UserInfo{}, fmt.Errorf("failed to decode user info: %w", err) - } - - // Validate domain if configured - if len(oauthConfig.AllowedDomains) > 0 { - userDomain := emailutil.ExtractDomain(userInfo.Email) - if !slices.Contains(oauthConfig.AllowedDomains, userDomain) { - return UserInfo{}, fmt.Errorf("domain '%s' is not allowed. Contact your administrator", userDomain) - } - } - - return userInfo, nil -} - -// ParseClientRequest parses MCP client registration metadata -func ParseClientRequest(metadata map[string]any) ([]string, []string, error) { - // Extract redirect URIs - redirectURIs := []string{} - if uris, ok := metadata["redirect_uris"].([]any); ok { - for _, uri := range uris { - if uriStr, ok := uri.(string); ok { - redirectURIs = append(redirectURIs, uriStr) - } - } - } - - if len(redirectURIs) == 0 { - return nil, nil, fmt.Errorf("no valid redirect URIs provided") - } - - // Extract scopes, default to read/write if not provided - scopes := []string{"read", "write"} // Default MCP scopes - if clientScopes, ok := metadata["scope"].(string); ok { - if strings.TrimSpace(clientScopes) != "" { - scopes = strings.Fields(clientScopes) - } - } - - return redirectURIs, scopes, nil -} - -// newGoogleOAuth2Config creates the OAuth2 config from our Config -func newGoogleOAuth2Config(oauthConfig config.OAuthAuthConfig) oauth2.Config { - // Use custom OAuth endpoints if provided (for testing) - endpoint := google.Endpoint - if authURL := os.Getenv("GOOGLE_OAUTH_AUTH_URL"); authURL != "" { - endpoint.AuthURL = authURL - } - if tokenURL := os.Getenv("GOOGLE_OAUTH_TOKEN_URL"); tokenURL != "" { - endpoint.TokenURL = tokenURL - } - - return oauth2.Config{ - ClientID: oauthConfig.GoogleClientID, - ClientSecret: string(oauthConfig.GoogleClientSecret), - RedirectURL: oauthConfig.GoogleRedirectURI, - Scopes: []string{"openid", "profile", "email"}, - Endpoint: endpoint, - } -} diff --git a/internal/googleauth/google_test.go b/internal/googleauth/google_test.go deleted file mode 100644 index 3f72d4a..0000000 --- a/internal/googleauth/google_test.go +++ /dev/null @@ -1,240 +0,0 @@ -package googleauth - -import ( - "context" - "encoding/json" - "net/http" - "net/http/httptest" - "testing" - "time" - - "github.com/dgellow/mcp-front/internal/config" - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" - "golang.org/x/oauth2" -) - -func TestGoogleAuthURL(t *testing.T) { - oauthConfig := config.OAuthAuthConfig{ - GoogleClientID: "test-client-id", - GoogleClientSecret: config.Secret("test-client-secret"), - GoogleRedirectURI: "https://test.example.com/oauth/callback", - } - - state := "test-state-parameter" - authURL := GoogleAuthURL(oauthConfig, state) - - // Verify URL structure - assert.Contains(t, authURL, "https://accounts.google.com/o/oauth2/auth") - assert.Contains(t, authURL, "client_id=test-client-id") - assert.Contains(t, authURL, "redirect_uri=https%3A%2F%2Ftest.example.com%2Foauth%2Fcallback") - assert.Contains(t, authURL, "state=test-state-parameter") - assert.Contains(t, authURL, "access_type=offline") - assert.Contains(t, authURL, "prompt=consent") - assert.Contains(t, authURL, "scope=openid+profile+email") -} - -func TestExchangeCodeForToken(t *testing.T) { - // Create a mock OAuth token endpoint - tokenServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - assert.Equal(t, "POST", r.Method) - assert.Equal(t, "/token", r.URL.Path) - - // Parse form data - err := r.ParseForm() - require.NoError(t, err) - - assert.Equal(t, "test-code", r.FormValue("code")) - assert.Equal(t, "test-client-id", r.FormValue("client_id")) - assert.Equal(t, "test-client-secret", r.FormValue("client_secret")) - assert.Equal(t, "https://test.example.com/oauth/callback", r.FormValue("redirect_uri")) - assert.Equal(t, "authorization_code", r.FormValue("grant_type")) - - // Return mock token response - response := map[string]any{ - "access_token": "mock-access-token", - "refresh_token": "mock-refresh-token", - "token_type": "Bearer", - "expires_in": 3600, - } - w.Header().Set("Content-Type", "application/json") - if err := json.NewEncoder(w).Encode(response); err != nil { - t.Errorf("failed to encode response: %v", err) - } - })) - defer tokenServer.Close() - - // Set environment variable for custom token URL - t.Setenv("GOOGLE_OAUTH_TOKEN_URL", tokenServer.URL+"/token") - - oauthConfig := config.OAuthAuthConfig{ - GoogleClientID: "test-client-id", - GoogleClientSecret: config.Secret("test-client-secret"), - GoogleRedirectURI: "https://test.example.com/oauth/callback", - } - - token, err := ExchangeCodeForToken(context.Background(), oauthConfig, "test-code") - require.NoError(t, err) - require.NotNil(t, token) - - assert.Equal(t, "mock-access-token", token.AccessToken) - assert.Equal(t, "mock-refresh-token", token.RefreshToken) - assert.Equal(t, "Bearer", token.TokenType) - assert.WithinDuration(t, time.Now().Add(3600*time.Second), token.Expiry, 5*time.Second) -} - -func TestValidateUser(t *testing.T) { - t.Run("valid user in allowed domain", func(t *testing.T) { - // Create mock user info endpoint - userInfoServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - assert.Equal(t, "GET", r.Method) - assert.Contains(t, r.Header.Get("Authorization"), "Bearer mock-token") - - response := UserInfo{ - Email: "user@example.com", - HostedDomain: "example.com", - Name: "Test User", - Picture: "https://example.com/pic.jpg", - VerifiedEmail: true, - } - w.Header().Set("Content-Type", "application/json") - if err := json.NewEncoder(w).Encode(response); err != nil { - t.Errorf("failed to encode response: %v", err) - } - })) - defer userInfoServer.Close() - - t.Setenv("GOOGLE_USERINFO_URL", userInfoServer.URL) - - oauthConfig := config.OAuthAuthConfig{ - AllowedDomains: []string{"example.com", "test.com"}, - } - - token := &oauth2.Token{AccessToken: "mock-token"} - userInfo, err := ValidateUser(context.Background(), oauthConfig, token) - - require.NoError(t, err) - assert.Equal(t, "user@example.com", userInfo.Email) - assert.Equal(t, "example.com", userInfo.HostedDomain) - assert.Equal(t, "Test User", userInfo.Name) - assert.True(t, userInfo.VerifiedEmail) - }) - - t.Run("user from disallowed domain", func(t *testing.T) { - userInfoServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - response := UserInfo{ - Email: "user@unauthorized.com", - HostedDomain: "unauthorized.com", - VerifiedEmail: true, - } - w.Header().Set("Content-Type", "application/json") - if err := json.NewEncoder(w).Encode(response); err != nil { - t.Errorf("failed to encode response: %v", err) - } - })) - defer userInfoServer.Close() - - t.Setenv("GOOGLE_USERINFO_URL", userInfoServer.URL) - - oauthConfig := config.OAuthAuthConfig{ - AllowedDomains: []string{"example.com", "test.com"}, - } - - token := &oauth2.Token{AccessToken: "mock-token"} - _, err := ValidateUser(context.Background(), oauthConfig, token) - - require.Error(t, err) - assert.Contains(t, err.Error(), "domain 'unauthorized.com' is not allowed") - }) - - t.Run("no domain restrictions", func(t *testing.T) { - userInfoServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - response := UserInfo{ - Email: "user@anydomain.com", - VerifiedEmail: true, - } - w.Header().Set("Content-Type", "application/json") - if err := json.NewEncoder(w).Encode(response); err != nil { - t.Errorf("failed to encode response: %v", err) - } - })) - defer userInfoServer.Close() - - t.Setenv("GOOGLE_USERINFO_URL", userInfoServer.URL) - - oauthConfig := config.OAuthAuthConfig{ - AllowedDomains: []string{}, // Empty means allow all - } - - token := &oauth2.Token{AccessToken: "mock-token"} - userInfo, err := ValidateUser(context.Background(), oauthConfig, token) - - require.NoError(t, err) - assert.Equal(t, "user@anydomain.com", userInfo.Email) - }) -} - -func TestParseClientRequest(t *testing.T) { - t.Run("valid request with redirect URIs and scopes", func(t *testing.T) { - metadata := map[string]any{ - "redirect_uris": []any{ - "https://example.com/callback1", - "https://example.com/callback2", - }, - "scope": "read write admin", - } - - redirectURIs, scopes, err := ParseClientRequest(metadata) - require.NoError(t, err) - - assert.Equal(t, []string{ - "https://example.com/callback1", - "https://example.com/callback2", - }, redirectURIs) - assert.Equal(t, []string{"read", "write", "admin"}, scopes) - }) - - t.Run("default scopes when not provided", func(t *testing.T) { - metadata := map[string]any{ - "redirect_uris": []any{"https://example.com/callback"}, - } - - redirectURIs, scopes, err := ParseClientRequest(metadata) - require.NoError(t, err) - - assert.Equal(t, []string{"https://example.com/callback"}, redirectURIs) - assert.Equal(t, []string{"read", "write"}, scopes, "Should default to read/write") - }) - - t.Run("empty scope string uses default", func(t *testing.T) { - metadata := map[string]any{ - "redirect_uris": []any{"https://example.com/callback"}, - "scope": " ", // Whitespace only - } - - _, scopes, err := ParseClientRequest(metadata) - require.NoError(t, err) - - assert.Equal(t, []string{"read", "write"}, scopes) - }) - - t.Run("missing redirect URIs", func(t *testing.T) { - metadata := map[string]any{ - "scope": "read write", - } - - _, _, err := ParseClientRequest(metadata) - assert.Error(t, err) - assert.Contains(t, err.Error(), "no valid redirect URIs") - }) - - t.Run("empty redirect URIs array", func(t *testing.T) { - metadata := map[string]any{ - "redirect_uris": []any{}, - } - - _, _, err := ParseClientRequest(metadata) - assert.Error(t, err) - assert.Contains(t, err.Error(), "no valid redirect URIs") - }) -} diff --git a/internal/idp/azure.go b/internal/idp/azure.go new file mode 100644 index 0000000..ccd0682 --- /dev/null +++ b/internal/idp/azure.go @@ -0,0 +1,25 @@ +package idp + +import "fmt" + +// NewAzureProvider creates an Azure AD provider using OIDC discovery. +// Azure AD is OIDC-compliant, so we use the generic OIDC provider with Azure's tenant-specific discovery URL. +func NewAzureProvider(tenantID, clientID, clientSecret, redirectURI string) (*OIDCProvider, error) { + if tenantID == "" { + return nil, fmt.Errorf("tenantId is required for Azure AD") + } + + discoveryURL := fmt.Sprintf( + "https://login.microsoftonline.com/%s/v2.0/.well-known/openid-configuration", + tenantID, + ) + + return NewOIDCProvider(OIDCConfig{ + ProviderType: "azure", + DiscoveryURL: discoveryURL, + ClientID: clientID, + ClientSecret: clientSecret, + RedirectURI: redirectURI, + Scopes: []string{"openid", "email", "profile"}, + }) +} diff --git a/internal/idp/azure_test.go b/internal/idp/azure_test.go new file mode 100644 index 0000000..d163ce0 --- /dev/null +++ b/internal/idp/azure_test.go @@ -0,0 +1,15 @@ +package idp + +import ( + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestNewAzureProvider_MissingTenantID(t *testing.T) { + _, err := NewAzureProvider("", "client-id", "client-secret", "https://example.com/callback") + + require.Error(t, err) + assert.Contains(t, err.Error(), "tenantId is required") +} diff --git a/internal/idp/factory.go b/internal/idp/factory.go new file mode 100644 index 0000000..2032dd0 --- /dev/null +++ b/internal/idp/factory.go @@ -0,0 +1,51 @@ +package idp + +import ( + "fmt" + + "github.com/dgellow/mcp-front/internal/config" +) + +// NewProvider creates a Provider based on the IDPConfig. +func NewProvider(cfg config.IDPConfig) (Provider, error) { + switch cfg.Provider { + case "google": + return NewGoogleProvider( + cfg.ClientID, + string(cfg.ClientSecret), + cfg.RedirectURI, + ), nil + + case "azure": + return NewAzureProvider( + cfg.TenantID, + cfg.ClientID, + string(cfg.ClientSecret), + cfg.RedirectURI, + ) + + case "github": + return NewGitHubProvider( + cfg.ClientID, + string(cfg.ClientSecret), + cfg.RedirectURI, + cfg.AllowedOrgs, + ), nil + + case "oidc": + return NewOIDCProvider(OIDCConfig{ + ProviderType: "oidc", + DiscoveryURL: cfg.DiscoveryURL, + AuthorizationURL: cfg.AuthorizationURL, + TokenURL: cfg.TokenURL, + UserInfoURL: cfg.UserInfoURL, + ClientID: cfg.ClientID, + ClientSecret: string(cfg.ClientSecret), + RedirectURI: cfg.RedirectURI, + Scopes: cfg.Scopes, + }) + + default: + return nil, fmt.Errorf("unknown provider type: %s", cfg.Provider) + } +} diff --git a/internal/idp/factory_test.go b/internal/idp/factory_test.go new file mode 100644 index 0000000..9b38681 --- /dev/null +++ b/internal/idp/factory_test.go @@ -0,0 +1,105 @@ +package idp + +import ( + "testing" + + "github.com/dgellow/mcp-front/internal/config" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestNewProvider(t *testing.T) { + tests := []struct { + name string + cfg config.IDPConfig + wantType string + wantErr bool + errContains string + skipCreation bool + }{ + { + name: "google_provider", + cfg: config.IDPConfig{ + Provider: "google", + ClientID: "test-client-id", + ClientSecret: config.Secret("test-client-secret"), + RedirectURI: "https://example.com/callback", + }, + wantType: "google", + wantErr: false, + }, + { + name: "github_provider", + cfg: config.IDPConfig{ + Provider: "github", + ClientID: "test-client-id", + ClientSecret: config.Secret("test-client-secret"), + RedirectURI: "https://example.com/callback", + }, + wantType: "github", + wantErr: false, + }, + { + name: "azure_provider_missing_tenant", + cfg: config.IDPConfig{ + Provider: "azure", + ClientID: "test-client-id", + ClientSecret: config.Secret("test-client-secret"), + RedirectURI: "https://example.com/callback", + }, + wantErr: true, + errContains: "tenantId is required", + }, + { + name: "oidc_provider_missing_endpoints", + cfg: config.IDPConfig{ + Provider: "oidc", + ClientID: "test-client-id", + ClientSecret: config.Secret("test-client-secret"), + RedirectURI: "https://example.com/callback", + }, + wantErr: true, + errContains: "discoveryUrl or all endpoints", + }, + { + name: "oidc_provider_with_direct_endpoints", + cfg: config.IDPConfig{ + Provider: "oidc", + ClientID: "test-client-id", + ClientSecret: config.Secret("test-client-secret"), + RedirectURI: "https://example.com/callback", + AuthorizationURL: "https://idp.example.com/authorize", + TokenURL: "https://idp.example.com/token", + UserInfoURL: "https://idp.example.com/userinfo", + }, + wantType: "oidc", + wantErr: false, + }, + { + name: "unknown_provider", + cfg: config.IDPConfig{ + Provider: "unknown", + }, + wantErr: true, + errContains: "unknown provider type", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + provider, err := NewProvider(tt.cfg) + + if tt.wantErr { + require.Error(t, err) + if tt.errContains != "" { + assert.Contains(t, err.Error(), tt.errContains) + } + return + } + + require.NoError(t, err) + require.NotNil(t, provider) + assert.Equal(t, tt.wantType, provider.Type()) + }) + } +} diff --git a/internal/idp/github.go b/internal/idp/github.go new file mode 100644 index 0000000..1503d65 --- /dev/null +++ b/internal/idp/github.go @@ -0,0 +1,217 @@ +package idp + +import ( + "context" + "encoding/json" + "fmt" + "net/http" + + emailutil "github.com/dgellow/mcp-front/internal/emailutil" + "golang.org/x/oauth2" + "golang.org/x/oauth2/github" +) + +// GitHubProvider implements the Provider interface for GitHub OAuth. +// GitHub uses OAuth 2.0 (not OIDC) and has its own API for user info and org membership. +type GitHubProvider struct { + config oauth2.Config + apiBaseURL string // defaults to https://api.github.com, can be overridden for testing + allowedOrgs []string // organizations users must be members of (empty = no restriction) +} + +// githubUserResponse represents GitHub's user API response. +type githubUserResponse struct { + ID int64 `json:"id"` + Login string `json:"login"` + Email string `json:"email"` + Name string `json:"name"` + AvatarURL string `json:"avatar_url"` +} + +// githubEmailResponse represents an email from GitHub's emails API. +type githubEmailResponse struct { + Email string `json:"email"` + Primary bool `json:"primary"` + Verified bool `json:"verified"` +} + +// githubOrgResponse represents an org from GitHub's orgs API. +type githubOrgResponse struct { + Login string `json:"login"` +} + +// NewGitHubProvider creates a new GitHub OAuth provider. +// allowedOrgs specifies organizations users must be members of (empty = no restriction). +func NewGitHubProvider(clientID, clientSecret, redirectURI string, allowedOrgs []string) *GitHubProvider { + return &GitHubProvider{ + config: oauth2.Config{ + ClientID: clientID, + ClientSecret: clientSecret, + RedirectURL: redirectURI, + Scopes: []string{"user:email", "read:org"}, + Endpoint: github.Endpoint, + }, + apiBaseURL: "https://api.github.com", + allowedOrgs: allowedOrgs, + } +} + +// Type returns the provider type. +func (p *GitHubProvider) Type() string { + return "github" +} + +// AuthURL generates the authorization URL. +func (p *GitHubProvider) AuthURL(state string) string { + return p.config.AuthCodeURL(state) +} + +// ExchangeCode exchanges an authorization code for tokens. +func (p *GitHubProvider) ExchangeCode(ctx context.Context, code string) (*oauth2.Token, error) { + return p.config.Exchange(ctx, code) +} + +// UserInfo fetches user information from GitHub's API. +// Validates organization membership if allowedOrgs was configured at construction. +// TODO: Consider caching org membership to reduce API calls. +func (p *GitHubProvider) UserInfo(ctx context.Context, token *oauth2.Token, allowedDomains []string) (*UserInfo, error) { + client := p.config.Client(ctx, token) + + // Fetch user profile + user, err := p.fetchUser(client) + if err != nil { + return nil, err + } + + // Fetch primary email if not in profile + // GitHub only shows verified emails in user profile, so if email is present it's verified + email := user.Email + emailVerified := email != "" + if email == "" { + primaryEmail, verified, err := p.fetchPrimaryEmail(client) + if err != nil { + return nil, fmt.Errorf("failed to get user email: %w", err) + } + email = primaryEmail + emailVerified = verified + } + + domain := emailutil.ExtractDomain(email) + + // Validate domain if configured + if err := ValidateDomain(domain, allowedDomains); err != nil { + return nil, err + } + + // Fetch organizations only if org validation is configured + var orgs []string + if len(p.allowedOrgs) > 0 { + orgs, err = p.fetchOrganizations(client) + if err != nil { + return nil, fmt.Errorf("failed to get user organizations: %w", err) + } + + // Validate org membership + hasAllowedOrg := false + for _, org := range orgs { + for _, allowed := range p.allowedOrgs { + if org == allowed { + hasAllowedOrg = true + break + } + } + if hasAllowedOrg { + break + } + } + if !hasAllowedOrg { + return nil, fmt.Errorf("user is not a member of any allowed organization. Contact your administrator") + } + } + + return &UserInfo{ + ProviderType: "github", + Subject: fmt.Sprintf("%d", user.ID), + Email: email, + EmailVerified: emailVerified, + Name: user.Name, + Picture: user.AvatarURL, + Domain: domain, + Organizations: orgs, + }, nil +} + +func (p *GitHubProvider) fetchUser(client *http.Client) (*githubUserResponse, error) { + resp, err := client.Get(p.apiBaseURL + "/user") + if err != nil { + return nil, fmt.Errorf("failed to get user: %w", err) + } + defer resp.Body.Close() + + if resp.StatusCode != http.StatusOK { + return nil, fmt.Errorf("failed to get user: status %d", resp.StatusCode) + } + + var user githubUserResponse + if err := json.NewDecoder(resp.Body).Decode(&user); err != nil { + return nil, fmt.Errorf("failed to decode user: %w", err) + } + + return &user, nil +} + +func (p *GitHubProvider) fetchPrimaryEmail(client *http.Client) (string, bool, error) { + resp, err := client.Get(p.apiBaseURL + "/user/emails") + if err != nil { + return "", false, fmt.Errorf("failed to get emails: %w", err) + } + defer resp.Body.Close() + + if resp.StatusCode != http.StatusOK { + return "", false, fmt.Errorf("failed to get emails: status %d", resp.StatusCode) + } + + var emails []githubEmailResponse + if err := json.NewDecoder(resp.Body).Decode(&emails); err != nil { + return "", false, fmt.Errorf("failed to decode emails: %w", err) + } + + for _, email := range emails { + if email.Primary && email.Verified { + return email.Email, true, nil + } + } + + // Fallback to first verified email + for _, email := range emails { + if email.Verified { + return email.Email, true, nil + } + } + + return "", false, fmt.Errorf("no verified email found") +} + +func (p *GitHubProvider) fetchOrganizations(client *http.Client) ([]string, error) { + resp, err := client.Get(p.apiBaseURL + "/user/orgs") + if err != nil { + return nil, fmt.Errorf("failed to get organizations: %w", err) + } + defer resp.Body.Close() + + if resp.StatusCode != http.StatusOK { + return nil, fmt.Errorf("failed to get organizations: status %d", resp.StatusCode) + } + + var orgs []githubOrgResponse + if err := json.NewDecoder(resp.Body).Decode(&orgs); err != nil { + return nil, fmt.Errorf("failed to decode organizations: %w", err) + } + + orgNames := make([]string, len(orgs)) + for i, org := range orgs { + orgNames[i] = org.Login + } + + return orgNames, nil +} diff --git a/internal/idp/github_test.go b/internal/idp/github_test.go new file mode 100644 index 0000000..71650de --- /dev/null +++ b/internal/idp/github_test.go @@ -0,0 +1,284 @@ +package idp + +import ( + "context" + "encoding/json" + "net/http" + "net/http/httptest" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "golang.org/x/oauth2" +) + +func TestGitHubProvider_Type(t *testing.T) { + provider := NewGitHubProvider("client-id", "client-secret", "https://example.com/callback", nil) + assert.Equal(t, "github", provider.Type()) +} + +func TestGitHubProvider_AuthURL(t *testing.T) { + provider := NewGitHubProvider("client-id", "client-secret", "https://example.com/callback", nil) + + authURL := provider.AuthURL("test-state") + + assert.Contains(t, authURL, "github.com") + assert.Contains(t, authURL, "state=test-state") + assert.Contains(t, authURL, "client_id=client-id") +} + +func TestGitHubProvider_UserInfo(t *testing.T) { + tests := []struct { + name string + userResp githubUserResponse + emailsResp []githubEmailResponse + orgsResp []githubOrgResponse + allowedDomains []string + allowedOrgs []string + wantErr bool + errContains string + expectedEmail string + expectedEmailVerified bool + expectedDomain string + expectedOrgs []string + }{ + { + name: "user_with_public_email", + userResp: githubUserResponse{ + ID: 12345, + Login: "testuser", + Email: "user@company.com", + Name: "Test User", + AvatarURL: "https://github.com/avatar.jpg", + }, + expectedEmail: "user@company.com", + expectedEmailVerified: true, // Public emails in GitHub profile are verified + expectedDomain: "company.com", + expectedOrgs: nil, // Orgs not fetched when allowedOrgs is empty + }, + { + name: "user_without_public_email_fetches_from_api", + userResp: githubUserResponse{ + ID: 12345, + Login: "testuser", + Name: "Test User", + }, + emailsResp: []githubEmailResponse{ + {Email: "secondary@other.com", Primary: false, Verified: true}, + {Email: "primary@company.com", Primary: true, Verified: true}, + }, + expectedEmail: "primary@company.com", + expectedEmailVerified: true, + expectedDomain: "company.com", + expectedOrgs: nil, // Orgs not fetched when allowedOrgs is empty + }, + { + name: "user_with_unverified_primary_falls_back_to_verified", + userResp: githubUserResponse{ + ID: 12345, + Login: "testuser", + }, + emailsResp: []githubEmailResponse{ + {Email: "primary@company.com", Primary: true, Verified: false}, + {Email: "verified@company.com", Primary: false, Verified: true}, + }, + expectedEmail: "verified@company.com", + expectedEmailVerified: true, + expectedDomain: "company.com", + expectedOrgs: nil, // Orgs not fetched when allowedOrgs is empty + }, + { + name: "domain_validation_success", + userResp: githubUserResponse{ + ID: 12345, + Login: "testuser", + Email: "user@company.com", + }, + allowedDomains: []string{"company.com"}, + expectedEmail: "user@company.com", + expectedEmailVerified: true, + expectedDomain: "company.com", + expectedOrgs: nil, // Orgs not fetched when allowedOrgs is empty + }, + { + name: "domain_validation_failure", + userResp: githubUserResponse{ + ID: 12345, + Login: "testuser", + Email: "user@other.com", + }, + allowedDomains: []string{"company.com"}, + wantErr: true, + errContains: "domain 'other.com' is not allowed", + }, + { + name: "org_validation_success", + userResp: githubUserResponse{ + ID: 12345, + Login: "testuser", + Email: "user@gmail.com", + }, + orgsResp: []githubOrgResponse{{Login: "allowed-org"}, {Login: "other-org"}}, + allowedOrgs: []string{"allowed-org"}, + expectedEmail: "user@gmail.com", + expectedEmailVerified: true, + expectedDomain: "gmail.com", + expectedOrgs: []string{"allowed-org", "other-org"}, + }, + { + name: "org_validation_failure", + userResp: githubUserResponse{ + ID: 12345, + Login: "testuser", + Email: "user@gmail.com", + }, + orgsResp: []githubOrgResponse{{Login: "other-org"}}, + allowedOrgs: []string{"required-org"}, + wantErr: true, + errContains: "not a member of any allowed organization", + }, + { + name: "user_with_no_orgs_restriction", + userResp: githubUserResponse{ + ID: 12345, + Login: "testuser", + Email: "user@gmail.com", + }, + expectedEmail: "user@gmail.com", + expectedEmailVerified: true, + expectedDomain: "gmail.com", + expectedOrgs: nil, // Orgs not fetched when allowedOrgs is empty + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "application/json") + + switch r.URL.Path { + case "/user": + err := json.NewEncoder(w).Encode(tt.userResp) + require.NoError(t, err) + case "/user/emails": + err := json.NewEncoder(w).Encode(tt.emailsResp) + require.NoError(t, err) + case "/user/orgs": + err := json.NewEncoder(w).Encode(tt.orgsResp) + require.NoError(t, err) + default: + w.WriteHeader(http.StatusNotFound) + } + })) + defer server.Close() + + // Create provider with test server endpoints and allowedOrgs + provider := &GitHubProvider{ + config: oauth2.Config{ + ClientID: "test-client", + ClientSecret: "test-secret", + RedirectURL: "https://example.com/callback", + Scopes: []string{"user:email", "read:org"}, + Endpoint: oauth2.Endpoint{ + AuthURL: server.URL + "/authorize", + TokenURL: server.URL + "/token", + }, + }, + apiBaseURL: server.URL, + allowedOrgs: tt.allowedOrgs, + } + + token := &oauth2.Token{AccessToken: "test-token"} + userInfo, err := provider.UserInfo(context.Background(), token, tt.allowedDomains) + + if tt.wantErr { + require.Error(t, err) + if tt.errContains != "" { + assert.Contains(t, err.Error(), tt.errContains) + } + return + } + + require.NoError(t, err) + require.NotNil(t, userInfo) + assert.Equal(t, "github", userInfo.ProviderType) + assert.Equal(t, tt.expectedEmail, userInfo.Email) + assert.Equal(t, tt.expectedEmailVerified, userInfo.EmailVerified) + assert.Equal(t, tt.expectedDomain, userInfo.Domain) + assert.Equal(t, tt.expectedOrgs, userInfo.Organizations) + }) + } +} + +func TestGitHubProvider_UserInfo_APIErrors(t *testing.T) { + tests := []struct { + name string + userStatus int + errContains string + }{ + { + name: "user_api_error", + userStatus: http.StatusInternalServerError, + errContains: "status 500", + }, + { + name: "user_unauthorized", + userStatus: http.StatusUnauthorized, + errContains: "status 401", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(tt.userStatus) + })) + defer server.Close() + + provider := &GitHubProvider{ + config: oauth2.Config{ + ClientID: "test-client", + ClientSecret: "test-secret", + }, + apiBaseURL: server.URL, + } + + token := &oauth2.Token{AccessToken: "test-token"} + _, err := provider.UserInfo(context.Background(), token, nil) + + require.Error(t, err) + assert.Contains(t, err.Error(), tt.errContains) + }) + } +} + +func TestGitHubProvider_UserInfo_NoVerifiedEmail(t *testing.T) { + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "application/json") + + switch r.URL.Path { + case "/user": + err := json.NewEncoder(w).Encode(githubUserResponse{ID: 123, Login: "test"}) + require.NoError(t, err) + case "/user/emails": + err := json.NewEncoder(w).Encode([]githubEmailResponse{ + {Email: "unverified@example.com", Primary: true, Verified: false}, + }) + require.NoError(t, err) + default: + w.WriteHeader(http.StatusNotFound) + } + })) + defer server.Close() + + provider := &GitHubProvider{ + config: oauth2.Config{ClientID: "test"}, + apiBaseURL: server.URL, + } + + token := &oauth2.Token{AccessToken: "test-token"} + _, err := provider.UserInfo(context.Background(), token, nil) + + require.Error(t, err) + assert.Contains(t, err.Error(), "no verified email") +} diff --git a/internal/idp/google.go b/internal/idp/google.go new file mode 100644 index 0000000..87909a3 --- /dev/null +++ b/internal/idp/google.go @@ -0,0 +1,103 @@ +package idp + +import ( + "context" + "encoding/json" + "fmt" + "net/http" + + emailutil "github.com/dgellow/mcp-front/internal/emailutil" + "golang.org/x/oauth2" + "golang.org/x/oauth2/google" +) + +// GoogleProvider implements the Provider interface for Google OAuth. +// Google has specific quirks like `hd` for hosted domain and `verified_email` field. +type GoogleProvider struct { + config oauth2.Config + userInfoURL string +} + +// googleUserInfoResponse represents Google's userinfo response. +// Note: Google uses `hd` for hosted domain and `verified_email` instead of OIDC standard `email_verified`. +type googleUserInfoResponse struct { + Sub string `json:"sub"` + Email string `json:"email"` + VerifiedEmail bool `json:"verified_email"` + Name string `json:"name"` + Picture string `json:"picture"` + HostedDomain string `json:"hd"` +} + +// NewGoogleProvider creates a new Google OAuth provider. +func NewGoogleProvider(clientID, clientSecret, redirectURI string) *GoogleProvider { + return &GoogleProvider{ + config: oauth2.Config{ + ClientID: clientID, + ClientSecret: clientSecret, + RedirectURL: redirectURI, + Scopes: []string{"openid", "profile", "email"}, + Endpoint: google.Endpoint, + }, + userInfoURL: "https://www.googleapis.com/oauth2/v2/userinfo", + } +} + +// Type returns the provider type. +func (p *GoogleProvider) Type() string { + return "google" +} + +// AuthURL generates the authorization URL. +func (p *GoogleProvider) AuthURL(state string) string { + return p.config.AuthCodeURL(state, + oauth2.AccessTypeOffline, + oauth2.ApprovalForce, + ) +} + +// ExchangeCode exchanges an authorization code for tokens. +func (p *GoogleProvider) ExchangeCode(ctx context.Context, code string) (*oauth2.Token, error) { + return p.config.Exchange(ctx, code) +} + +// UserInfo fetches user information from Google's userinfo endpoint. +func (p *GoogleProvider) UserInfo(ctx context.Context, token *oauth2.Token, allowedDomains []string) (*UserInfo, error) { + client := p.config.Client(ctx, token) + + resp, err := client.Get(p.userInfoURL) + if err != nil { + return nil, fmt.Errorf("failed to get user info: %w", err) + } + defer resp.Body.Close() + + if resp.StatusCode != http.StatusOK { + return nil, fmt.Errorf("failed to get user info: status %d", resp.StatusCode) + } + + var googleUser googleUserInfoResponse + if err := json.NewDecoder(resp.Body).Decode(&googleUser); err != nil { + return nil, fmt.Errorf("failed to decode user info: %w", err) + } + + // Use Google's hosted domain if available, otherwise derive from email + domain := googleUser.HostedDomain + if domain == "" { + domain = emailutil.ExtractDomain(googleUser.Email) + } + + // Validate domain if configured + if err := ValidateDomain(domain, allowedDomains); err != nil { + return nil, err + } + + return &UserInfo{ + ProviderType: "google", + Subject: googleUser.Sub, + Email: googleUser.Email, + EmailVerified: googleUser.VerifiedEmail, + Name: googleUser.Name, + Picture: googleUser.Picture, + Domain: domain, + }, nil +} diff --git a/internal/idp/google_test.go b/internal/idp/google_test.go new file mode 100644 index 0000000..691012a --- /dev/null +++ b/internal/idp/google_test.go @@ -0,0 +1,149 @@ +package idp + +import ( + "context" + "encoding/json" + "net/http" + "net/http/httptest" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "golang.org/x/oauth2" +) + +func TestGoogleProvider_Type(t *testing.T) { + provider := NewGoogleProvider("client-id", "client-secret", "https://example.com/callback") + assert.Equal(t, "google", provider.Type()) +} + +func TestGoogleProvider_AuthURL(t *testing.T) { + provider := NewGoogleProvider("client-id", "client-secret", "https://example.com/callback") + + authURL := provider.AuthURL("test-state") + + assert.Contains(t, authURL, "accounts.google.com") + assert.Contains(t, authURL, "state=test-state") + assert.Contains(t, authURL, "client_id=client-id") + assert.Contains(t, authURL, "redirect_uri=") + assert.Contains(t, authURL, "access_type=offline") +} + +func TestGoogleProvider_UserInfo(t *testing.T) { + tests := []struct { + name string + userInfoResp googleUserInfoResponse + allowedDomains []string + wantErr bool + errContains string + expectedDomain string + expectedSubject string + }{ + { + name: "valid_user_with_hosted_domain", + userInfoResp: googleUserInfoResponse{ + Sub: "12345", + Email: "user@company.com", + VerifiedEmail: true, + Name: "Test User", + Picture: "https://example.com/photo.jpg", + HostedDomain: "company.com", + }, + allowedDomains: []string{"company.com"}, + wantErr: false, + expectedDomain: "company.com", + expectedSubject: "12345", + }, + { + name: "valid_user_without_hosted_domain_derives_from_email", + userInfoResp: googleUserInfoResponse{ + Sub: "12345", + Email: "user@gmail.com", + VerifiedEmail: true, + Name: "Test User", + }, + allowedDomains: nil, + wantErr: false, + expectedDomain: "gmail.com", + expectedSubject: "12345", + }, + { + name: "domain_not_allowed", + userInfoResp: googleUserInfoResponse{ + Sub: "12345", + Email: "user@other.com", + VerifiedEmail: true, + Name: "Test User", + HostedDomain: "other.com", + }, + allowedDomains: []string{"company.com"}, + wantErr: true, + errContains: "domain 'other.com' is not allowed", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "application/json") + err := json.NewEncoder(w).Encode(tt.userInfoResp) + require.NoError(t, err) + })) + defer server.Close() + + provider := &GoogleProvider{ + config: oauth2.Config{ + ClientID: "test-client", + ClientSecret: "test-secret", + RedirectURL: "https://example.com/callback", + Scopes: []string{"openid", "profile", "email"}, + Endpoint: oauth2.Endpoint{ + AuthURL: server.URL + "/authorize", + TokenURL: server.URL + "/token", + }, + }, + userInfoURL: server.URL, + } + token := &oauth2.Token{AccessToken: "test-token"} + + userInfo, err := provider.UserInfo(context.Background(), token, tt.allowedDomains) + + if tt.wantErr { + require.Error(t, err) + if tt.errContains != "" { + assert.Contains(t, err.Error(), tt.errContains) + } + return + } + + require.NoError(t, err) + require.NotNil(t, userInfo) + assert.Equal(t, "google", userInfo.ProviderType) + assert.Equal(t, tt.expectedSubject, userInfo.Subject) + assert.Equal(t, tt.expectedDomain, userInfo.Domain) + assert.Equal(t, tt.userInfoResp.Email, userInfo.Email) + assert.Equal(t, tt.userInfoResp.VerifiedEmail, userInfo.EmailVerified) + }) + } +} + +func TestGoogleProvider_UserInfo_ServerError(t *testing.T) { + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusInternalServerError) + })) + defer server.Close() + + provider := &GoogleProvider{ + config: oauth2.Config{ + ClientID: "test-client", + ClientSecret: "test-secret", + }, + userInfoURL: server.URL, + } + token := &oauth2.Token{AccessToken: "test-token"} + + _, err := provider.UserInfo(context.Background(), token, nil) + + require.Error(t, err) + assert.Contains(t, err.Error(), "status 500") +} diff --git a/internal/idp/oidc.go b/internal/idp/oidc.go new file mode 100644 index 0000000..0d3dbf6 --- /dev/null +++ b/internal/idp/oidc.go @@ -0,0 +1,183 @@ +package idp + +import ( + "context" + "encoding/json" + "fmt" + "net/http" + "time" + + emailutil "github.com/dgellow/mcp-front/internal/emailutil" + "golang.org/x/oauth2" +) + +// OIDCConfig configures a generic OIDC provider. +type OIDCConfig struct { + // ProviderType identifies this provider (e.g., "oidc", "google", "azure"). + ProviderType string + + // Discovery URL for OIDC discovery (optional if endpoints are provided directly). + DiscoveryURL string + + // Direct endpoint configuration (used if DiscoveryURL is not set). + AuthorizationURL string + TokenURL string + UserInfoURL string + + // OAuth client configuration. + ClientID string + ClientSecret string + RedirectURI string + Scopes []string +} + +// OIDCProvider implements the Provider interface for OIDC-compliant identity providers. +type OIDCProvider struct { + providerType string + config oauth2.Config + userInfoURL string +} + +// oidcDiscoveryDocument represents the OIDC discovery document. +type oidcDiscoveryDocument struct { + AuthorizationEndpoint string `json:"authorization_endpoint"` + TokenEndpoint string `json:"token_endpoint"` + UserInfoEndpoint string `json:"userinfo_endpoint"` + Issuer string `json:"issuer"` +} + +// oidcUserInfoResponse represents the standard OIDC userinfo response. +type oidcUserInfoResponse struct { + Sub string `json:"sub"` + Email string `json:"email"` + EmailVerified bool `json:"email_verified"` + Name string `json:"name"` + Picture string `json:"picture"` +} + +// NewOIDCProvider creates a new OIDC provider. +// TODO: Add OIDC discovery caching to avoid repeated network calls. +func NewOIDCProvider(cfg OIDCConfig) (*OIDCProvider, error) { + var authURL, tokenURL, userInfoURL string + + if cfg.DiscoveryURL != "" { + discovery, err := fetchOIDCDiscovery(cfg.DiscoveryURL) + if err != nil { + return nil, fmt.Errorf("failed to fetch OIDC discovery: %w", err) + } + authURL = discovery.AuthorizationEndpoint + tokenURL = discovery.TokenEndpoint + userInfoURL = discovery.UserInfoEndpoint + } else { + if cfg.AuthorizationURL == "" || cfg.TokenURL == "" || cfg.UserInfoURL == "" { + return nil, fmt.Errorf("either discoveryUrl or all endpoints (authorizationUrl, tokenUrl, userInfoUrl) must be provided") + } + authURL = cfg.AuthorizationURL + tokenURL = cfg.TokenURL + userInfoURL = cfg.UserInfoURL + } + + scopes := cfg.Scopes + if len(scopes) == 0 { + scopes = []string{"openid", "email", "profile"} + } + + providerType := cfg.ProviderType + if providerType == "" { + providerType = "oidc" + } + + return &OIDCProvider{ + providerType: providerType, + config: oauth2.Config{ + ClientID: cfg.ClientID, + ClientSecret: cfg.ClientSecret, + RedirectURL: cfg.RedirectURI, + Scopes: scopes, + Endpoint: oauth2.Endpoint{ + AuthURL: authURL, + TokenURL: tokenURL, + }, + }, + userInfoURL: userInfoURL, + }, nil +} + +func fetchOIDCDiscovery(discoveryURL string) (*oidcDiscoveryDocument, error) { + client := &http.Client{Timeout: 10 * time.Second} + resp, err := client.Get(discoveryURL) + if err != nil { + return nil, fmt.Errorf("failed to fetch discovery document: %w", err) + } + defer resp.Body.Close() + + if resp.StatusCode != http.StatusOK { + return nil, fmt.Errorf("discovery endpoint returned status %d", resp.StatusCode) + } + + var discovery oidcDiscoveryDocument + if err := json.NewDecoder(resp.Body).Decode(&discovery); err != nil { + return nil, fmt.Errorf("failed to decode discovery document: %w", err) + } + + if discovery.AuthorizationEndpoint == "" || discovery.TokenEndpoint == "" || discovery.UserInfoEndpoint == "" { + return nil, fmt.Errorf("discovery document missing required endpoints") + } + + return &discovery, nil +} + +// Type returns the provider type. +func (p *OIDCProvider) Type() string { + return p.providerType +} + +// AuthURL generates the authorization URL. +func (p *OIDCProvider) AuthURL(state string) string { + return p.config.AuthCodeURL(state, + oauth2.AccessTypeOffline, + oauth2.ApprovalForce, + ) +} + +// ExchangeCode exchanges an authorization code for tokens. +func (p *OIDCProvider) ExchangeCode(ctx context.Context, code string) (*oauth2.Token, error) { + return p.config.Exchange(ctx, code) +} + +// UserInfo fetches user information from the OIDC userinfo endpoint. +// TODO: Add ID token validation as optimization (avoids network call). +func (p *OIDCProvider) UserInfo(ctx context.Context, token *oauth2.Token, allowedDomains []string) (*UserInfo, error) { + client := p.config.Client(ctx, token) + resp, err := client.Get(p.userInfoURL) + if err != nil { + return nil, fmt.Errorf("failed to get user info: %w", err) + } + defer resp.Body.Close() + + if resp.StatusCode != http.StatusOK { + return nil, fmt.Errorf("failed to get user info: status %d", resp.StatusCode) + } + + var userInfoResp oidcUserInfoResponse + if err := json.NewDecoder(resp.Body).Decode(&userInfoResp); err != nil { + return nil, fmt.Errorf("failed to decode user info: %w", err) + } + + domain := emailutil.ExtractDomain(userInfoResp.Email) + + // Validate domain if configured + if err := ValidateDomain(domain, allowedDomains); err != nil { + return nil, err + } + + return &UserInfo{ + ProviderType: p.providerType, + Subject: userInfoResp.Sub, + Email: userInfoResp.Email, + EmailVerified: userInfoResp.EmailVerified, + Name: userInfoResp.Name, + Picture: userInfoResp.Picture, + Domain: domain, + }, nil +} diff --git a/internal/idp/oidc_test.go b/internal/idp/oidc_test.go new file mode 100644 index 0000000..b88d1dd --- /dev/null +++ b/internal/idp/oidc_test.go @@ -0,0 +1,218 @@ +package idp + +import ( + "context" + "encoding/json" + "net/http" + "net/http/httptest" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "golang.org/x/oauth2" +) + +func TestNewOIDCProvider_WithDirectEndpoints(t *testing.T) { + provider, err := NewOIDCProvider(OIDCConfig{ + ProviderType: "custom", + AuthorizationURL: "https://idp.example.com/authorize", + TokenURL: "https://idp.example.com/token", + UserInfoURL: "https://idp.example.com/userinfo", + ClientID: "client-id", + ClientSecret: "client-secret", + RedirectURI: "https://example.com/callback", + }) + + require.NoError(t, err) + require.NotNil(t, provider) + assert.Equal(t, "custom", provider.Type()) +} + +func TestNewOIDCProvider_WithDiscovery(t *testing.T) { + discovery := oidcDiscoveryDocument{ + Issuer: "https://idp.example.com", + AuthorizationEndpoint: "https://idp.example.com/authorize", + TokenEndpoint: "https://idp.example.com/token", + UserInfoEndpoint: "https://idp.example.com/userinfo", + } + + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "application/json") + err := json.NewEncoder(w).Encode(discovery) + require.NoError(t, err) + })) + defer server.Close() + + provider, err := NewOIDCProvider(OIDCConfig{ + DiscoveryURL: server.URL, + ClientID: "client-id", + ClientSecret: "client-secret", + RedirectURI: "https://example.com/callback", + }) + + require.NoError(t, err) + require.NotNil(t, provider) + assert.Equal(t, "oidc", provider.Type()) +} + +func TestNewOIDCProvider_MissingEndpoints(t *testing.T) { + _, err := NewOIDCProvider(OIDCConfig{ + ClientID: "client-id", + ClientSecret: "client-secret", + RedirectURI: "https://example.com/callback", + }) + + require.Error(t, err) + assert.Contains(t, err.Error(), "discoveryUrl or all endpoints") +} + +func TestNewOIDCProvider_PartialEndpoints(t *testing.T) { + _, err := NewOIDCProvider(OIDCConfig{ + AuthorizationURL: "https://idp.example.com/authorize", + ClientID: "client-id", + ClientSecret: "client-secret", + RedirectURI: "https://example.com/callback", + }) + + require.Error(t, err) + assert.Contains(t, err.Error(), "discoveryUrl or all endpoints") +} + +func TestNewOIDCProvider_DiscoveryMissingEndpoints(t *testing.T) { + discovery := oidcDiscoveryDocument{ + Issuer: "https://idp.example.com", + } + + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "application/json") + err := json.NewEncoder(w).Encode(discovery) + require.NoError(t, err) + })) + defer server.Close() + + _, err := NewOIDCProvider(OIDCConfig{ + DiscoveryURL: server.URL, + ClientID: "client-id", + ClientSecret: "client-secret", + RedirectURI: "https://example.com/callback", + }) + + require.Error(t, err) + assert.Contains(t, err.Error(), "missing required endpoints") +} + +func TestOIDCProvider_AuthURL(t *testing.T) { + provider, err := NewOIDCProvider(OIDCConfig{ + AuthorizationURL: "https://idp.example.com/authorize", + TokenURL: "https://idp.example.com/token", + UserInfoURL: "https://idp.example.com/userinfo", + ClientID: "client-id", + ClientSecret: "client-secret", + RedirectURI: "https://example.com/callback", + }) + require.NoError(t, err) + + authURL := provider.AuthURL("test-state") + + assert.Contains(t, authURL, "https://idp.example.com/authorize") + assert.Contains(t, authURL, "state=test-state") + assert.Contains(t, authURL, "client_id=client-id") +} + +func TestOIDCProvider_UserInfo(t *testing.T) { + userInfoServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "application/json") + resp := oidcUserInfoResponse{ + Sub: "12345", + Email: "user@example.com", + EmailVerified: true, + Name: "Test User", + Picture: "https://example.com/photo.jpg", + } + err := json.NewEncoder(w).Encode(resp) + require.NoError(t, err) + })) + defer userInfoServer.Close() + + provider, err := NewOIDCProvider(OIDCConfig{ + AuthorizationURL: "https://idp.example.com/authorize", + TokenURL: "https://idp.example.com/token", + UserInfoURL: userInfoServer.URL, + ClientID: "client-id", + ClientSecret: "client-secret", + RedirectURI: "https://example.com/callback", + }) + require.NoError(t, err) + + token := &oauth2.Token{AccessToken: "test-token"} + userInfo, err := provider.UserInfo(context.Background(), token, nil) + + require.NoError(t, err) + require.NotNil(t, userInfo) + assert.Equal(t, "oidc", userInfo.ProviderType) + assert.Equal(t, "12345", userInfo.Subject) + assert.Equal(t, "user@example.com", userInfo.Email) + assert.Equal(t, "example.com", userInfo.Domain) + assert.True(t, userInfo.EmailVerified) +} + +func TestOIDCProvider_UserInfo_DomainValidation(t *testing.T) { + userInfoServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "application/json") + resp := oidcUserInfoResponse{ + Sub: "12345", + Email: "user@other.com", + } + err := json.NewEncoder(w).Encode(resp) + require.NoError(t, err) + })) + defer userInfoServer.Close() + + provider, err := NewOIDCProvider(OIDCConfig{ + AuthorizationURL: "https://idp.example.com/authorize", + TokenURL: "https://idp.example.com/token", + UserInfoURL: userInfoServer.URL, + ClientID: "client-id", + ClientSecret: "client-secret", + RedirectURI: "https://example.com/callback", + }) + require.NoError(t, err) + + token := &oauth2.Token{AccessToken: "test-token"} + _, err = provider.UserInfo(context.Background(), token, []string{"example.com"}) + + require.Error(t, err) + assert.Contains(t, err.Error(), "domain 'other.com' is not allowed") +} + +func TestOIDCProvider_DefaultScopes(t *testing.T) { + provider, err := NewOIDCProvider(OIDCConfig{ + AuthorizationURL: "https://idp.example.com/authorize", + TokenURL: "https://idp.example.com/token", + UserInfoURL: "https://idp.example.com/userinfo", + ClientID: "client-id", + ClientSecret: "client-secret", + RedirectURI: "https://example.com/callback", + }) + require.NoError(t, err) + + authURL := provider.AuthURL("test-state") + assert.Contains(t, authURL, "scope=openid") +} + +func TestOIDCProvider_CustomScopes(t *testing.T) { + provider, err := NewOIDCProvider(OIDCConfig{ + AuthorizationURL: "https://idp.example.com/authorize", + TokenURL: "https://idp.example.com/token", + UserInfoURL: "https://idp.example.com/userinfo", + ClientID: "client-id", + ClientSecret: "client-secret", + RedirectURI: "https://example.com/callback", + Scopes: []string{"openid", "custom_scope"}, + }) + require.NoError(t, err) + + authURL := provider.AuthURL("test-state") + assert.Contains(t, authURL, "scope=openid") + assert.Contains(t, authURL, "custom_scope") +} diff --git a/internal/idp/provider.go b/internal/idp/provider.go new file mode 100644 index 0000000..da10e1b --- /dev/null +++ b/internal/idp/provider.go @@ -0,0 +1,80 @@ +package idp + +import ( + "context" + "fmt" + "slices" + "strings" + + "golang.org/x/oauth2" +) + +// UserInfo represents user information from any identity provider. +// ProviderType is included for multi-IDP readiness. +type UserInfo struct { + ProviderType string `json:"provider_type"` + Subject string `json:"sub"` + Email string `json:"email"` + EmailVerified bool `json:"email_verified"` + Name string `json:"name"` + Picture string `json:"picture"` + Domain string `json:"domain"` + Organizations []string `json:"organizations,omitempty"` +} + +// Provider abstracts identity provider operations. +type Provider interface { + // Type returns the provider type identifier (e.g., "google", "azure", "github", "oidc"). + Type() string + + // AuthURL generates the authorization URL for the OAuth flow. + AuthURL(state string) string + + // ExchangeCode exchanges an authorization code for tokens. + ExchangeCode(ctx context.Context, code string) (*oauth2.Token, error) + + // UserInfo fetches user information and validates access. + // allowedDomains is used for domain-based access control. + // Provider-specific access control (e.g., GitHub org membership) is configured at construction. + UserInfo(ctx context.Context, token *oauth2.Token, allowedDomains []string) (*UserInfo, error) +} + +// ValidateDomain checks if the domain is in the allowed list. +// Returns nil if allowedDomains is empty (no restriction) or domain is allowed. +func ValidateDomain(domain string, allowedDomains []string) error { + if len(allowedDomains) == 0 { + return nil + } + if !slices.Contains(allowedDomains, domain) { + return fmt.Errorf("domain '%s' is not allowed. Contact your administrator", domain) + } + return nil +} + +// ParseClientRequest parses MCP client registration metadata. +// This is provider-agnostic as it deals with MCP client registration, not IDP. +func ParseClientRequest(metadata map[string]any) (redirectURIs []string, scopes []string, err error) { + // Extract redirect URIs + redirectURIs = []string{} + if uris, ok := metadata["redirect_uris"].([]any); ok { + for _, uri := range uris { + if uriStr, ok := uri.(string); ok { + redirectURIs = append(redirectURIs, uriStr) + } + } + } + + if len(redirectURIs) == 0 { + return nil, nil, fmt.Errorf("no valid redirect URIs provided") + } + + // Extract scopes, default to read/write if not provided + scopes = []string{"read", "write"} // Default MCP scopes + if clientScopes, ok := metadata["scope"].(string); ok { + if strings.TrimSpace(clientScopes) != "" { + scopes = strings.Fields(clientScopes) + } + } + + return redirectURIs, scopes, nil +} diff --git a/internal/idp/provider_test.go b/internal/idp/provider_test.go new file mode 100644 index 0000000..b4827d6 --- /dev/null +++ b/internal/idp/provider_test.go @@ -0,0 +1,113 @@ +package idp + +import ( + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestParseClientRequest(t *testing.T) { + tests := []struct { + name string + metadata map[string]any + wantRedirectURIs []string + wantScopes []string + wantErr bool + errContains string + }{ + { + name: "valid_with_single_redirect_uri", + metadata: map[string]any{ + "redirect_uris": []any{"https://example.com/callback"}, + }, + wantRedirectURIs: []string{"https://example.com/callback"}, + wantScopes: []string{"read", "write"}, + wantErr: false, + }, + { + name: "valid_with_multiple_redirect_uris", + metadata: map[string]any{ + "redirect_uris": []any{ + "https://example.com/callback", + "https://example.com/callback2", + }, + }, + wantRedirectURIs: []string{ + "https://example.com/callback", + "https://example.com/callback2", + }, + wantScopes: []string{"read", "write"}, + wantErr: false, + }, + { + name: "valid_with_custom_scopes", + metadata: map[string]any{ + "redirect_uris": []any{"https://example.com/callback"}, + "scope": "openid profile email", + }, + wantRedirectURIs: []string{"https://example.com/callback"}, + wantScopes: []string{"openid", "profile", "email"}, + wantErr: false, + }, + { + name: "valid_with_empty_scope_uses_default", + metadata: map[string]any{ + "redirect_uris": []any{"https://example.com/callback"}, + "scope": " ", + }, + wantRedirectURIs: []string{"https://example.com/callback"}, + wantScopes: []string{"read", "write"}, + wantErr: false, + }, + { + name: "missing_redirect_uris", + metadata: map[string]any{}, + wantErr: true, + errContains: "no valid redirect URIs", + }, + { + name: "empty_redirect_uris", + metadata: map[string]any{ + "redirect_uris": []any{}, + }, + wantErr: true, + errContains: "no valid redirect URIs", + }, + { + name: "redirect_uris_wrong_type", + metadata: map[string]any{ + "redirect_uris": "https://example.com/callback", + }, + wantErr: true, + errContains: "no valid redirect URIs", + }, + { + name: "redirect_uri_non_string_elements_ignored", + metadata: map[string]any{ + "redirect_uris": []any{123, "https://example.com/callback", nil}, + }, + wantRedirectURIs: []string{"https://example.com/callback"}, + wantScopes: []string{"read", "write"}, + wantErr: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + redirectURIs, scopes, err := ParseClientRequest(tt.metadata) + + if tt.wantErr { + require.Error(t, err) + if tt.errContains != "" { + assert.Contains(t, err.Error(), tt.errContains) + } + return + } + + require.NoError(t, err) + assert.Equal(t, tt.wantRedirectURIs, redirectURIs) + assert.Equal(t, tt.wantScopes, scopes) + }) + } +} diff --git a/internal/mcpfront.go b/internal/mcpfront.go index e6f5b80..ed0d8f0 100644 --- a/internal/mcpfront.go +++ b/internal/mcpfront.go @@ -14,6 +14,7 @@ import ( "github.com/dgellow/mcp-front/internal/client" "github.com/dgellow/mcp-front/internal/config" "github.com/dgellow/mcp-front/internal/crypto" + "github.com/dgellow/mcp-front/internal/idp" "github.com/dgellow/mcp-front/internal/inline" "github.com/dgellow/mcp-front/internal/log" "github.com/dgellow/mcp-front/internal/oauth" @@ -52,7 +53,7 @@ func NewMCPFront(ctx context.Context, cfg config.Config) (*MCPFront, error) { } // Setup authentication (OAuth components and service client) - oauthProvider, sessionEncryptor, authConfig, serviceOAuthClient, err := setupAuthentication(ctx, cfg, store) + oauthProvider, idpProvider, sessionEncryptor, authConfig, serviceOAuthClient, err := setupAuthentication(ctx, cfg, store) if err != nil { return nil, fmt.Errorf("failed to setup authentication: %w", err) } @@ -98,6 +99,7 @@ func NewMCPFront(ctx context.Context, cfg config.Config) (*MCPFront, error) { cfg, store, oauthProvider, + idpProvider, sessionEncryptor, authConfig, serviceOAuthClient, @@ -227,32 +229,42 @@ func setupStorage(ctx context.Context, cfg config.Config) (storage.Storage, erro } // setupAuthentication creates individual OAuth components using clean constructors -func setupAuthentication(ctx context.Context, cfg config.Config, store storage.Storage) (fosite.OAuth2Provider, crypto.Encryptor, config.OAuthAuthConfig, *auth.ServiceOAuthClient, error) { +func setupAuthentication(ctx context.Context, cfg config.Config, store storage.Storage) (fosite.OAuth2Provider, idp.Provider, crypto.Encryptor, config.OAuthAuthConfig, *auth.ServiceOAuthClient, error) { oauthAuth := cfg.Proxy.Auth if oauthAuth == nil { // OAuth not configured - return nil, nil, config.OAuthAuthConfig{}, nil, nil + return nil, nil, nil, config.OAuthAuthConfig{}, nil, nil } log.LogDebug("initializing OAuth components") + // Create identity provider + idpProvider, err := idp.NewProvider(oauthAuth.IDP) + if err != nil { + return nil, nil, nil, config.OAuthAuthConfig{}, nil, fmt.Errorf("failed to create identity provider: %w", err) + } + + log.LogInfoWithFields("mcpfront", "Identity provider configured", map[string]any{ + "type": idpProvider.Type(), + }) + // Generate or validate JWT secret using clean constructor jwtSecret, err := oauth.GenerateJWTSecret(string(oauthAuth.JWTSecret)) if err != nil { - return nil, nil, config.OAuthAuthConfig{}, nil, fmt.Errorf("failed to setup JWT secret: %w", err) + return nil, nil, nil, config.OAuthAuthConfig{}, nil, fmt.Errorf("failed to setup JWT secret: %w", err) } // Create session encryptor using clean constructor encryptionKey := []byte(oauthAuth.EncryptionKey) sessionEncryptor, err := oauth.NewSessionEncryptor(encryptionKey) if err != nil { - return nil, nil, config.OAuthAuthConfig{}, nil, fmt.Errorf("failed to create session encryptor: %w", err) + return nil, nil, nil, config.OAuthAuthConfig{}, nil, fmt.Errorf("failed to create session encryptor: %w", err) } // Create OAuth provider using clean constructor oauthProvider, err := oauth.NewOAuthProvider(*oauthAuth, store, jwtSecret) if err != nil { - return nil, nil, config.OAuthAuthConfig{}, nil, fmt.Errorf("failed to create OAuth provider: %w", err) + return nil, nil, nil, config.OAuthAuthConfig{}, nil, fmt.Errorf("failed to create OAuth provider: %w", err) } // Create OAuth client for service authentication and token refresh @@ -279,7 +291,7 @@ func setupAuthentication(ctx context.Context, cfg config.Config, store storage.S } } - return oauthProvider, sessionEncryptor, *oauthAuth, serviceOAuthClient, nil + return oauthProvider, idpProvider, sessionEncryptor, *oauthAuth, serviceOAuthClient, nil } // buildHTTPHandler creates the complete HTTP handler with all routing and middleware @@ -287,6 +299,7 @@ func buildHTTPHandler( cfg config.Config, storage storage.Storage, oauthProvider fosite.OAuth2Provider, + idpProvider idp.Provider, sessionEncryptor crypto.Encryptor, authConfig config.OAuthAuthConfig, serviceOAuthClient *auth.ServiceOAuthClient, @@ -337,6 +350,7 @@ func buildHTTPHandler( authHandlers := server.NewAuthHandlers( oauthProvider, authConfig, + idpProvider, storage, sessionEncryptor, cfg.MCPServers, @@ -352,7 +366,7 @@ func buildHTTPHandler( // Returns 404 by default, or base issuer metadata if dangerouslyAcceptIssuerAudience is enabled mux.Handle(route("/.well-known/oauth-protected-resource"), server.ChainMiddleware(http.HandlerFunc(authHandlers.ProtectedResourceMetadataHandler), oauthMiddleware...)) mux.Handle(route("/authorize"), server.ChainMiddleware(http.HandlerFunc(authHandlers.AuthorizeHandler), oauthMiddleware...)) - mux.Handle(route("/oauth/callback"), server.ChainMiddleware(http.HandlerFunc(authHandlers.GoogleCallbackHandler), oauthMiddleware...)) + mux.Handle(route("/oauth/callback"), server.ChainMiddleware(http.HandlerFunc(authHandlers.IDPCallbackHandler), oauthMiddleware...)) mux.Handle(route("/token"), server.ChainMiddleware(http.HandlerFunc(authHandlers.TokenHandler), oauthMiddleware...)) mux.Handle(route("/register"), server.ChainMiddleware(http.HandlerFunc(authHandlers.RegisterHandler), oauthMiddleware...)) mux.Handle(route("/clients/{client_id}"), server.ChainMiddleware(http.HandlerFunc(authHandlers.ClientMetadataHandler), oauthMiddleware...)) @@ -361,7 +375,7 @@ func buildHTTPHandler( tokenMiddleware := []server.MiddlewareFunc{ corsMiddleware, tokenLogger, - server.NewBrowserSSOMiddleware(authConfig, sessionEncryptor, browserStateToken), + server.NewBrowserSSOMiddleware(authConfig, idpProvider, sessionEncryptor, browserStateToken), mcpRecover, } @@ -475,7 +489,7 @@ func buildHTTPHandler( // Add browser SSO if OAuth is enabled if oauthProvider != nil { // Reuse the same browserStateToken created earlier for consistency - adminMiddleware = append(adminMiddleware, server.NewBrowserSSOMiddleware(authConfig, sessionEncryptor, browserStateToken)) + adminMiddleware = append(adminMiddleware, server.NewBrowserSSOMiddleware(authConfig, idpProvider, sessionEncryptor, browserStateToken)) } // Add admin check middleware diff --git a/internal/oauthsession/session.go b/internal/oauthsession/session.go index f8f093f..b8cf56e 100644 --- a/internal/oauthsession/session.go +++ b/internal/oauthsession/session.go @@ -1,14 +1,31 @@ package oauthsession import ( - "github.com/dgellow/mcp-front/internal/googleauth" + "time" + + "github.com/dgellow/mcp-front/internal/idp" "github.com/ory/fosite" ) // Session extends DefaultSession with user information type Session struct { *fosite.DefaultSession - UserInfo googleauth.UserInfo `json:"user_info"` + UserInfo idp.UserInfo `json:"user_info"` +} + +// NewSession creates a new session with user info +func NewSession(userInfo idp.UserInfo) *Session { + return &Session{ + DefaultSession: &fosite.DefaultSession{ + ExpiresAt: map[fosite.TokenType]time.Time{ + fosite.AccessToken: time.Now().Add(time.Hour), + fosite.RefreshToken: time.Now().Add(24 * time.Hour), + }, + Username: userInfo.Email, + Subject: userInfo.Email, + }, + UserInfo: userInfo, + } } // Clone implements fosite.Session diff --git a/internal/server/auth_handlers.go b/internal/server/auth_handlers.go index 23a98d5..f8396ba 100644 --- a/internal/server/auth_handlers.go +++ b/internal/server/auth_handlers.go @@ -15,7 +15,7 @@ import ( "github.com/dgellow/mcp-front/internal/config" "github.com/dgellow/mcp-front/internal/crypto" "github.com/dgellow/mcp-front/internal/envutil" - "github.com/dgellow/mcp-front/internal/googleauth" + "github.com/dgellow/mcp-front/internal/idp" jsonwriter "github.com/dgellow/mcp-front/internal/json" "github.com/dgellow/mcp-front/internal/log" "github.com/dgellow/mcp-front/internal/oauth" @@ -28,6 +28,7 @@ import ( type AuthHandlers struct { oauthProvider fosite.OAuth2Provider authConfig config.OAuthAuthConfig + idpProvider idp.Provider storage storage.Storage sessionEncryptor crypto.Encryptor mcpServers map[string]*config.MCPClientConfig @@ -37,18 +38,19 @@ type AuthHandlers struct { // UpstreamOAuthState stores OAuth state during upstream authentication flow (MCP host → mcp-front) type UpstreamOAuthState struct { - UserInfo googleauth.UserInfo `json:"user_info"` - ClientID string `json:"client_id"` - RedirectURI string `json:"redirect_uri"` - Scopes []string `json:"scopes"` - State string `json:"state"` - ResponseType string `json:"response_type"` + UserInfo idp.UserInfo `json:"user_info"` + ClientID string `json:"client_id"` + RedirectURI string `json:"redirect_uri"` + Scopes []string `json:"scopes"` + State string `json:"state"` + ResponseType string `json:"response_type"` } // NewAuthHandlers creates new auth handlers with dependency injection func NewAuthHandlers( oauthProvider fosite.OAuth2Provider, authConfig config.OAuthAuthConfig, + idpProvider idp.Provider, storage storage.Storage, sessionEncryptor crypto.Encryptor, mcpServers map[string]*config.MCPClientConfig, @@ -57,6 +59,7 @@ func NewAuthHandlers( return &AuthHandlers{ oauthProvider: oauthProvider, authConfig: authConfig, + idpProvider: idpProvider, storage: storage, sessionEncryptor: sessionEncryptor, mcpServers: mcpServers, @@ -261,12 +264,12 @@ func (h *AuthHandlers) AuthorizeHandler(w http.ResponseWriter, r *http.Request) state := ar.GetState() h.storage.StoreAuthorizeRequest(state, ar) - authURL := googleauth.GoogleAuthURL(h.authConfig, state) + authURL := h.idpProvider.AuthURL(state) http.Redirect(w, r, authURL, http.StatusFound) } -// GoogleCallbackHandler handles the callback from Google OAuth -func (h *AuthHandlers) GoogleCallbackHandler(w http.ResponseWriter, r *http.Request) { +// IDPCallbackHandler handles the callback from the identity provider +func (h *AuthHandlers) IDPCallbackHandler(w http.ResponseWriter, r *http.Request) { ctx := r.Context() state := r.URL.Query().Get("state") @@ -274,7 +277,7 @@ func (h *AuthHandlers) GoogleCallbackHandler(w http.ResponseWriter, r *http.Requ if errMsg := r.URL.Query().Get("error"); errMsg != "" { errDesc := r.URL.Query().Get("error_description") - log.LogError("Google OAuth error: %s - %s", errMsg, errDesc) + log.LogError("OAuth error: %s - %s", errMsg, errDesc) jsonwriter.WriteBadRequest(w, fmt.Sprintf("Authentication failed: %s", errMsg)) return } @@ -315,7 +318,7 @@ func (h *AuthHandlers) GoogleCallbackHandler(w http.ResponseWriter, r *http.Requ ctx, cancel := context.WithTimeout(ctx, 30*time.Second) defer cancel() - token, err := googleauth.ExchangeCodeForToken(ctx, h.authConfig, code) + token, err := h.idpProvider.ExchangeCode(ctx, code) if err != nil { log.LogError("Failed to exchange code: %v", err) if !isBrowserFlow && ar != nil { @@ -326,8 +329,8 @@ func (h *AuthHandlers) GoogleCallbackHandler(w http.ResponseWriter, r *http.Requ return } - // Validate user - userInfo, err := googleauth.ValidateUser(ctx, h.authConfig, token) + // Validate user and fetch user info + userInfo, err := h.idpProvider.UserInfo(ctx, token, h.authConfig.AllowedDomains) if err != nil { log.LogError("User validation failed: %v", err) if !isBrowserFlow && ar != nil { @@ -354,8 +357,9 @@ func (h *AuthHandlers) GoogleCallbackHandler(w http.ResponseWriter, r *http.Requ sessionDuration := 24 * time.Hour sessionData := browserauth.SessionCookie{ - Email: userInfo.Email, - Expires: time.Now().Add(sessionDuration), + Email: userInfo.Email, + Provider: userInfo.ProviderType, + Expires: time.Now().Add(sessionDuration), } // Marshal session data to JSON @@ -408,7 +412,7 @@ func (h *AuthHandlers) GoogleCallbackHandler(w http.ResponseWriter, r *http.Requ } if needsServiceAuth { - stateData, err := h.signUpstreamOAuthState(ar, userInfo) + stateData, err := h.signUpstreamOAuthState(ar, *userInfo) if err != nil { log.LogError("Failed to sign OAuth state: %v", err) h.oauthProvider.WriteAuthorizeError(ctx, w, ar, fosite.ErrServerError.WithHint("Failed to prepare service authentication")) @@ -429,7 +433,7 @@ func (h *AuthHandlers) GoogleCallbackHandler(w http.ResponseWriter, r *http.Requ fosite.RefreshToken: time.Now().Add(h.authConfig.RefreshTokenTTL), }, }, - UserInfo: userInfo, + UserInfo: *userInfo, } // Accept the authorization request @@ -511,7 +515,7 @@ func (h *AuthHandlers) RegisterHandler(w http.ResponseWriter, r *http.Request) { } // Parse client request - redirectURIs, scopes, err := googleauth.ParseClientRequest(metadata) + redirectURIs, scopes, err := idp.ParseClientRequest(metadata) if err != nil { log.LogError("Client request parsing error: %v", err) jsonwriter.WriteBadRequest(w, err.Error()) @@ -560,7 +564,7 @@ func (h *AuthHandlers) RegisterHandler(w http.ResponseWriter, r *http.Request) { } // signUpstreamOAuthState signs upstream OAuth state for secure storage -func (h *AuthHandlers) signUpstreamOAuthState(ar fosite.AuthorizeRequester, userInfo googleauth.UserInfo) (string, error) { +func (h *AuthHandlers) signUpstreamOAuthState(ar fosite.AuthorizeRequester, userInfo idp.UserInfo) (string, error) { state := UpstreamOAuthState{ UserInfo: userInfo, ClientID: ar.GetClient().GetID(), diff --git a/internal/server/auth_handlers_test.go b/internal/server/auth_handlers_test.go index 6e2e98f..fcc6f98 100644 --- a/internal/server/auth_handlers_test.go +++ b/internal/server/auth_handlers_test.go @@ -1,6 +1,7 @@ package server import ( + "context" "encoding/json" "net/http" "net/http/httptest" @@ -12,12 +13,40 @@ import ( "github.com/dgellow/mcp-front/internal/browserauth" "github.com/dgellow/mcp-front/internal/config" "github.com/dgellow/mcp-front/internal/crypto" + "github.com/dgellow/mcp-front/internal/idp" "github.com/dgellow/mcp-front/internal/oauth" "github.com/dgellow/mcp-front/internal/storage" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" + "golang.org/x/oauth2" ) +// mockIDPProvider is a mock IDP provider for testing +type mockIDPProvider struct{} + +func (m *mockIDPProvider) Type() string { + return "mock" +} + +func (m *mockIDPProvider) AuthURL(state string) string { + return "https://auth.example.com/authorize?state=" + state +} + +func (m *mockIDPProvider) ExchangeCode(ctx context.Context, code string) (*oauth2.Token, error) { + return &oauth2.Token{AccessToken: "test-token"}, nil +} + +func (m *mockIDPProvider) UserInfo(ctx context.Context, token *oauth2.Token, allowedDomains []string) (*idp.UserInfo, error) { + return &idp.UserInfo{ + ProviderType: "mock", + Subject: "123", + Email: "test@example.com", + EmailVerified: true, + Name: "Test User", + Domain: "example.com", + }, nil +} + func TestAuthenticationBoundaries(t *testing.T) { tests := []struct { name string @@ -47,18 +76,21 @@ func TestAuthenticationBoundaries(t *testing.T) { // Setup test OAuth configuration oauthConfig := config.OAuthAuthConfig{ - Kind: config.AuthKindOAuth, - Issuer: "https://test.example.com", - GoogleClientID: "test-client-id", - GoogleClientSecret: config.Secret("test-client-secret"), - GoogleRedirectURI: "https://test.example.com/oauth/callback", - JWTSecret: config.Secret(strings.Repeat("a", 32)), - EncryptionKey: config.Secret(strings.Repeat("b", 32)), - TokenTTL: time.Hour, - RefreshTokenTTL: 30 * 24 * time.Hour, - Storage: "memory", - AllowedDomains: []string{"example.com"}, - AllowedOrigins: []string{"https://test.example.com"}, + Kind: config.AuthKindOAuth, + Issuer: "https://test.example.com", + IDP: config.IDPConfig{ + Provider: "google", + ClientID: "test-client-id", + ClientSecret: config.Secret("test-client-secret"), + RedirectURI: "https://test.example.com/oauth/callback", + }, + JWTSecret: config.Secret(strings.Repeat("a", 32)), + EncryptionKey: config.Secret(strings.Repeat("b", 32)), + TokenTTL: time.Hour, + RefreshTokenTTL: 30 * 24 * time.Hour, + Storage: "memory", + AllowedDomains: []string{"example.com"}, + AllowedOrigins: []string{"https://test.example.com"}, } // Create storage @@ -77,10 +109,14 @@ func TestAuthenticationBoundaries(t *testing.T) { // Create service OAuth client serviceOAuthClient := auth.NewServiceOAuthClient(store, "https://test.example.com", []byte(strings.Repeat("k", 32))) + // Create mock IDP provider for testing + mockIDP := &mockIDPProvider{} + // Create handlers authHandlers := NewAuthHandlers( oauthProvider, oauthConfig, + mockIDP, store, sessionEncryptor, map[string]*config.MCPClientConfig{}, @@ -103,7 +139,7 @@ func TestAuthenticationBoundaries(t *testing.T) { // Protected endpoints tokenMiddleware := []MiddlewareFunc{ corsMiddleware, - NewBrowserSSOMiddleware(oauthConfig, sessionEncryptor, &browserStateToken), + NewBrowserSSOMiddleware(oauthConfig, mockIDP, sessionEncryptor, &browserStateToken), } mux.Handle("/my/tokens", ChainMiddleware( @@ -155,8 +191,9 @@ func TestAuthenticationBoundaries(t *testing.T) { if tt.expectAuth { // Create session data sessionData := browserauth.SessionCookie{ - Email: "test@example.com", - Expires: time.Now().Add(24 * time.Hour), + Email: "test@example.com", + Provider: "mock", + Expires: time.Now().Add(24 * time.Hour), } jsonData, err := json.Marshal(sessionData) require.NoError(t, err) @@ -186,18 +223,21 @@ func TestAuthenticationBoundaries(t *testing.T) { func TestOAuthEndpointHandlers(t *testing.T) { oauthConfig := config.OAuthAuthConfig{ - Kind: config.AuthKindOAuth, - Issuer: "https://test.example.com", - GoogleClientID: "test-client-id", - GoogleClientSecret: config.Secret("test-client-secret"), - GoogleRedirectURI: "https://test.example.com/oauth/callback", - JWTSecret: config.Secret(strings.Repeat("a", 32)), - EncryptionKey: config.Secret(strings.Repeat("b", 32)), - TokenTTL: time.Hour, - RefreshTokenTTL: 30 * 24 * time.Hour, - Storage: "memory", - AllowedDomains: []string{"example.com"}, - AllowedOrigins: []string{"https://test.example.com"}, + Kind: config.AuthKindOAuth, + Issuer: "https://test.example.com", + IDP: config.IDPConfig{ + Provider: "google", + ClientID: "test-client-id", + ClientSecret: config.Secret("test-client-secret"), + RedirectURI: "https://test.example.com/oauth/callback", + }, + JWTSecret: config.Secret(strings.Repeat("a", 32)), + EncryptionKey: config.Secret(strings.Repeat("b", 32)), + TokenTTL: time.Hour, + RefreshTokenTTL: 30 * 24 * time.Hour, + Storage: "memory", + AllowedDomains: []string{"example.com"}, + AllowedOrigins: []string{"https://test.example.com"}, } store := storage.NewMemoryStorage() @@ -208,10 +248,12 @@ func TestOAuthEndpointHandlers(t *testing.T) { sessionEncryptor, err := oauth.NewSessionEncryptor([]byte(oauthConfig.EncryptionKey)) require.NoError(t, err) serviceOAuthClient := auth.NewServiceOAuthClient(store, "https://test.example.com", []byte(strings.Repeat("k", 32))) + mockIDP := &mockIDPProvider{} authHandlers := NewAuthHandlers( oauthProvider, oauthConfig, + mockIDP, store, sessionEncryptor, map[string]*config.MCPClientConfig{}, diff --git a/internal/server/http_test.go b/internal/server/http_test.go index 822b13b..955c9da 100644 --- a/internal/server/http_test.go +++ b/internal/server/http_test.go @@ -37,17 +37,20 @@ func TestHealthEndpoint(t *testing.T) { func TestOAuthEndpointsCORS(t *testing.T) { // Setup OAuth config oauthConfig := config.OAuthAuthConfig{ - Kind: config.AuthKindOAuth, - Issuer: "https://test.example.com", - GoogleClientID: "test-client-id", - GoogleClientSecret: config.Secret("test-client-secret"), - GoogleRedirectURI: "https://test.example.com/oauth/callback", - JWTSecret: config.Secret(strings.Repeat("a", 32)), - EncryptionKey: config.Secret(strings.Repeat("b", 32)), - TokenTTL: time.Hour, - RefreshTokenTTL: 30 * 24 * time.Hour, - Storage: "memory", - AllowedOrigins: []string{"http://localhost:6274"}, + Kind: config.AuthKindOAuth, + Issuer: "https://test.example.com", + IDP: config.IDPConfig{ + Provider: "google", + ClientID: "test-client-id", + ClientSecret: config.Secret("test-client-secret"), + RedirectURI: "https://test.example.com/oauth/callback", + }, + JWTSecret: config.Secret(strings.Repeat("a", 32)), + EncryptionKey: config.Secret(strings.Repeat("b", 32)), + TokenTTL: time.Hour, + RefreshTokenTTL: 30 * 24 * time.Hour, + Storage: "memory", + AllowedOrigins: []string{"http://localhost:6274"}, } store := storage.NewMemoryStorage() @@ -58,10 +61,12 @@ func TestOAuthEndpointsCORS(t *testing.T) { sessionEncryptor, err := oauth.NewSessionEncryptor([]byte(oauthConfig.EncryptionKey)) require.NoError(t, err) serviceOAuthClient := auth.NewServiceOAuthClient(store, "https://test.example.com", []byte(strings.Repeat("k", 32))) + mockIDP := &mockIDPProvider{} authHandlers := NewAuthHandlers( oauthProvider, oauthConfig, + mockIDP, store, sessionEncryptor, map[string]*config.MCPClientConfig{}, diff --git a/internal/server/middleware.go b/internal/server/middleware.go index bb04163..0e53ba0 100644 --- a/internal/server/middleware.go +++ b/internal/server/middleware.go @@ -14,7 +14,7 @@ import ( "github.com/dgellow/mcp-front/internal/config" "github.com/dgellow/mcp-front/internal/cookie" "github.com/dgellow/mcp-front/internal/crypto" - "github.com/dgellow/mcp-front/internal/googleauth" + "github.com/dgellow/mcp-front/internal/idp" jsonwriter "github.com/dgellow/mcp-front/internal/json" "github.com/dgellow/mcp-front/internal/log" "github.com/dgellow/mcp-front/internal/oauth" @@ -289,7 +289,7 @@ func adminMiddleware(adminConfig *config.AdminConfig, store storage.Storage) Mid } // NewBrowserSSOMiddleware creates middleware for browser-based SSO authentication -func NewBrowserSSOMiddleware(authConfig config.OAuthAuthConfig, sessionEncryptor crypto.Encryptor, browserStateToken *crypto.TokenSigner) MiddlewareFunc { +func NewBrowserSSOMiddleware(authConfig config.OAuthAuthConfig, idpProvider idp.Provider, sessionEncryptor crypto.Encryptor, browserStateToken *crypto.TokenSigner) MiddlewareFunc { return func(next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { // Check for session cookie @@ -301,8 +301,8 @@ func NewBrowserSSOMiddleware(authConfig config.OAuthAuthConfig, sessionEncryptor jsonwriter.WriteInternalServerError(w, "Failed to generate authentication state") return } - googleURL := googleauth.GoogleAuthURL(authConfig, state) - http.Redirect(w, r, googleURL, http.StatusFound) + authURL := idpProvider.AuthURL(state) + http.Redirect(w, r, authURL, http.StatusFound) return } @@ -313,8 +313,8 @@ func NewBrowserSSOMiddleware(authConfig config.OAuthAuthConfig, sessionEncryptor log.LogDebug("Invalid session cookie: %v", err) cookie.ClearSession(w) // Clear bad cookie state := generateBrowserState(browserStateToken, r.URL.String()) - googleURL := googleauth.GoogleAuthURL(authConfig, state) - http.Redirect(w, r, googleURL, http.StatusFound) + authURL := idpProvider.AuthURL(state) + http.Redirect(w, r, authURL, http.StatusFound) return } @@ -331,14 +331,14 @@ func NewBrowserSSOMiddleware(authConfig config.OAuthAuthConfig, sessionEncryptor if time.Now().After(sessionData.Expires) { log.LogDebug("Session expired for user %s", sessionData.Email) cookie.ClearSession(w) - // Redirect directly to Google OAuth + // Redirect directly to OAuth state := generateBrowserState(browserStateToken, r.URL.String()) if state == "" { jsonwriter.WriteInternalServerError(w, "Failed to generate authentication state") return } - googleURL := googleauth.GoogleAuthURL(authConfig, state) - http.Redirect(w, r, googleURL, http.StatusFound) + authURL := idpProvider.AuthURL(state) + http.Redirect(w, r, authURL, http.StatusFound) return } From 9ccd1ff23edff41909aeb8215053bb3dbab6f643 Mon Sep 17 00:00:00 2001 From: Claude Date: Tue, 2 Dec 2025 00:37:37 +0000 Subject: [PATCH 2/7] Move allowedDomains to construction time for consistency Previously allowedDomains was passed at call time to UserInfo(), while allowedOrgs was configured at construction. This inconsistency made the interface harder to understand and didn't follow the principle that access control should be configured when the provider is created. Changes: - Provider.UserInfo() now takes only context and token - All providers store allowedDomains internally - Factory accepts allowedDomains parameter - All tests updated to reflect new interface --- internal/idp/azure.go | 15 ++++++++------- internal/idp/azure_test.go | 2 +- internal/idp/factory.go | 7 ++++++- internal/idp/factory_test.go | 2 +- internal/idp/github.go | 21 +++++++++++---------- internal/idp/github_test.go | 15 ++++++++------- internal/idp/google.go | 14 ++++++++------ internal/idp/google_test.go | 11 ++++++----- internal/idp/oidc.go | 17 +++++++++++------ internal/idp/oidc_test.go | 5 +++-- internal/idp/provider.go | 6 +++--- internal/mcpfront.go | 2 +- internal/server/auth_handlers.go | 2 +- internal/server/auth_handlers_test.go | 2 +- 14 files changed, 69 insertions(+), 52 deletions(-) diff --git a/internal/idp/azure.go b/internal/idp/azure.go index ccd0682..b364649 100644 --- a/internal/idp/azure.go +++ b/internal/idp/azure.go @@ -4,7 +4,7 @@ import "fmt" // NewAzureProvider creates an Azure AD provider using OIDC discovery. // Azure AD is OIDC-compliant, so we use the generic OIDC provider with Azure's tenant-specific discovery URL. -func NewAzureProvider(tenantID, clientID, clientSecret, redirectURI string) (*OIDCProvider, error) { +func NewAzureProvider(tenantID, clientID, clientSecret, redirectURI string, allowedDomains []string) (*OIDCProvider, error) { if tenantID == "" { return nil, fmt.Errorf("tenantId is required for Azure AD") } @@ -15,11 +15,12 @@ func NewAzureProvider(tenantID, clientID, clientSecret, redirectURI string) (*OI ) return NewOIDCProvider(OIDCConfig{ - ProviderType: "azure", - DiscoveryURL: discoveryURL, - ClientID: clientID, - ClientSecret: clientSecret, - RedirectURI: redirectURI, - Scopes: []string{"openid", "email", "profile"}, + ProviderType: "azure", + DiscoveryURL: discoveryURL, + ClientID: clientID, + ClientSecret: clientSecret, + RedirectURI: redirectURI, + Scopes: []string{"openid", "email", "profile"}, + AllowedDomains: allowedDomains, }) } diff --git a/internal/idp/azure_test.go b/internal/idp/azure_test.go index d163ce0..b4ad312 100644 --- a/internal/idp/azure_test.go +++ b/internal/idp/azure_test.go @@ -8,7 +8,7 @@ import ( ) func TestNewAzureProvider_MissingTenantID(t *testing.T) { - _, err := NewAzureProvider("", "client-id", "client-secret", "https://example.com/callback") + _, err := NewAzureProvider("", "client-id", "client-secret", "https://example.com/callback", nil) require.Error(t, err) assert.Contains(t, err.Error(), "tenantId is required") diff --git a/internal/idp/factory.go b/internal/idp/factory.go index 2032dd0..1ece21f 100644 --- a/internal/idp/factory.go +++ b/internal/idp/factory.go @@ -7,13 +7,15 @@ import ( ) // NewProvider creates a Provider based on the IDPConfig. -func NewProvider(cfg config.IDPConfig) (Provider, error) { +// allowedDomains configures domain-based access control for all provider types. +func NewProvider(cfg config.IDPConfig, allowedDomains []string) (Provider, error) { switch cfg.Provider { case "google": return NewGoogleProvider( cfg.ClientID, string(cfg.ClientSecret), cfg.RedirectURI, + allowedDomains, ), nil case "azure": @@ -22,6 +24,7 @@ func NewProvider(cfg config.IDPConfig) (Provider, error) { cfg.ClientID, string(cfg.ClientSecret), cfg.RedirectURI, + allowedDomains, ) case "github": @@ -29,6 +32,7 @@ func NewProvider(cfg config.IDPConfig) (Provider, error) { cfg.ClientID, string(cfg.ClientSecret), cfg.RedirectURI, + allowedDomains, cfg.AllowedOrgs, ), nil @@ -43,6 +47,7 @@ func NewProvider(cfg config.IDPConfig) (Provider, error) { ClientSecret: string(cfg.ClientSecret), RedirectURI: cfg.RedirectURI, Scopes: cfg.Scopes, + AllowedDomains: allowedDomains, }) default: diff --git a/internal/idp/factory_test.go b/internal/idp/factory_test.go index 9b38681..11065f7 100644 --- a/internal/idp/factory_test.go +++ b/internal/idp/factory_test.go @@ -87,7 +87,7 @@ func TestNewProvider(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - provider, err := NewProvider(tt.cfg) + provider, err := NewProvider(tt.cfg, nil) if tt.wantErr { require.Error(t, err) diff --git a/internal/idp/github.go b/internal/idp/github.go index 1503d65..bc055e2 100644 --- a/internal/idp/github.go +++ b/internal/idp/github.go @@ -14,9 +14,10 @@ import ( // GitHubProvider implements the Provider interface for GitHub OAuth. // GitHub uses OAuth 2.0 (not OIDC) and has its own API for user info and org membership. type GitHubProvider struct { - config oauth2.Config - apiBaseURL string // defaults to https://api.github.com, can be overridden for testing - allowedOrgs []string // organizations users must be members of (empty = no restriction) + config oauth2.Config + apiBaseURL string // defaults to https://api.github.com, can be overridden for testing + allowedDomains []string // email domains users must belong to (empty = no restriction) + allowedOrgs []string // organizations users must be members of (empty = no restriction) } // githubUserResponse represents GitHub's user API response. @@ -41,8 +42,7 @@ type githubOrgResponse struct { } // NewGitHubProvider creates a new GitHub OAuth provider. -// allowedOrgs specifies organizations users must be members of (empty = no restriction). -func NewGitHubProvider(clientID, clientSecret, redirectURI string, allowedOrgs []string) *GitHubProvider { +func NewGitHubProvider(clientID, clientSecret, redirectURI string, allowedDomains, allowedOrgs []string) *GitHubProvider { return &GitHubProvider{ config: oauth2.Config{ ClientID: clientID, @@ -51,8 +51,9 @@ func NewGitHubProvider(clientID, clientSecret, redirectURI string, allowedOrgs [ Scopes: []string{"user:email", "read:org"}, Endpoint: github.Endpoint, }, - apiBaseURL: "https://api.github.com", - allowedOrgs: allowedOrgs, + apiBaseURL: "https://api.github.com", + allowedDomains: allowedDomains, + allowedOrgs: allowedOrgs, } } @@ -72,9 +73,9 @@ func (p *GitHubProvider) ExchangeCode(ctx context.Context, code string) (*oauth2 } // UserInfo fetches user information from GitHub's API. -// Validates organization membership if allowedOrgs was configured at construction. +// Validates domain and organization membership based on construction-time config. // TODO: Consider caching org membership to reduce API calls. -func (p *GitHubProvider) UserInfo(ctx context.Context, token *oauth2.Token, allowedDomains []string) (*UserInfo, error) { +func (p *GitHubProvider) UserInfo(ctx context.Context, token *oauth2.Token) (*UserInfo, error) { client := p.config.Client(ctx, token) // Fetch user profile @@ -99,7 +100,7 @@ func (p *GitHubProvider) UserInfo(ctx context.Context, token *oauth2.Token, allo domain := emailutil.ExtractDomain(email) // Validate domain if configured - if err := ValidateDomain(domain, allowedDomains); err != nil { + if err := ValidateDomain(domain, p.allowedDomains); err != nil { return nil, err } diff --git a/internal/idp/github_test.go b/internal/idp/github_test.go index 71650de..f42b5ad 100644 --- a/internal/idp/github_test.go +++ b/internal/idp/github_test.go @@ -13,12 +13,12 @@ import ( ) func TestGitHubProvider_Type(t *testing.T) { - provider := NewGitHubProvider("client-id", "client-secret", "https://example.com/callback", nil) + provider := NewGitHubProvider("client-id", "client-secret", "https://example.com/callback", nil, nil) assert.Equal(t, "github", provider.Type()) } func TestGitHubProvider_AuthURL(t *testing.T) { - provider := NewGitHubProvider("client-id", "client-secret", "https://example.com/callback", nil) + provider := NewGitHubProvider("client-id", "client-secret", "https://example.com/callback", nil, nil) authURL := provider.AuthURL("test-state") @@ -184,12 +184,13 @@ func TestGitHubProvider_UserInfo(t *testing.T) { TokenURL: server.URL + "/token", }, }, - apiBaseURL: server.URL, - allowedOrgs: tt.allowedOrgs, + apiBaseURL: server.URL, + allowedDomains: tt.allowedDomains, + allowedOrgs: tt.allowedOrgs, } token := &oauth2.Token{AccessToken: "test-token"} - userInfo, err := provider.UserInfo(context.Background(), token, tt.allowedDomains) + userInfo, err := provider.UserInfo(context.Background(), token) if tt.wantErr { require.Error(t, err) @@ -244,7 +245,7 @@ func TestGitHubProvider_UserInfo_APIErrors(t *testing.T) { } token := &oauth2.Token{AccessToken: "test-token"} - _, err := provider.UserInfo(context.Background(), token, nil) + _, err := provider.UserInfo(context.Background(), token) require.Error(t, err) assert.Contains(t, err.Error(), tt.errContains) @@ -277,7 +278,7 @@ func TestGitHubProvider_UserInfo_NoVerifiedEmail(t *testing.T) { } token := &oauth2.Token{AccessToken: "test-token"} - _, err := provider.UserInfo(context.Background(), token, nil) + _, err := provider.UserInfo(context.Background(), token) require.Error(t, err) assert.Contains(t, err.Error(), "no verified email") diff --git a/internal/idp/google.go b/internal/idp/google.go index 87909a3..6d017cb 100644 --- a/internal/idp/google.go +++ b/internal/idp/google.go @@ -14,8 +14,9 @@ import ( // GoogleProvider implements the Provider interface for Google OAuth. // Google has specific quirks like `hd` for hosted domain and `verified_email` field. type GoogleProvider struct { - config oauth2.Config - userInfoURL string + config oauth2.Config + userInfoURL string + allowedDomains []string } // googleUserInfoResponse represents Google's userinfo response. @@ -30,7 +31,7 @@ type googleUserInfoResponse struct { } // NewGoogleProvider creates a new Google OAuth provider. -func NewGoogleProvider(clientID, clientSecret, redirectURI string) *GoogleProvider { +func NewGoogleProvider(clientID, clientSecret, redirectURI string, allowedDomains []string) *GoogleProvider { return &GoogleProvider{ config: oauth2.Config{ ClientID: clientID, @@ -39,7 +40,8 @@ func NewGoogleProvider(clientID, clientSecret, redirectURI string) *GoogleProvid Scopes: []string{"openid", "profile", "email"}, Endpoint: google.Endpoint, }, - userInfoURL: "https://www.googleapis.com/oauth2/v2/userinfo", + userInfoURL: "https://www.googleapis.com/oauth2/v2/userinfo", + allowedDomains: allowedDomains, } } @@ -62,7 +64,7 @@ func (p *GoogleProvider) ExchangeCode(ctx context.Context, code string) (*oauth2 } // UserInfo fetches user information from Google's userinfo endpoint. -func (p *GoogleProvider) UserInfo(ctx context.Context, token *oauth2.Token, allowedDomains []string) (*UserInfo, error) { +func (p *GoogleProvider) UserInfo(ctx context.Context, token *oauth2.Token) (*UserInfo, error) { client := p.config.Client(ctx, token) resp, err := client.Get(p.userInfoURL) @@ -87,7 +89,7 @@ func (p *GoogleProvider) UserInfo(ctx context.Context, token *oauth2.Token, allo } // Validate domain if configured - if err := ValidateDomain(domain, allowedDomains); err != nil { + if err := ValidateDomain(domain, p.allowedDomains); err != nil { return nil, err } diff --git a/internal/idp/google_test.go b/internal/idp/google_test.go index 691012a..e091373 100644 --- a/internal/idp/google_test.go +++ b/internal/idp/google_test.go @@ -13,12 +13,12 @@ import ( ) func TestGoogleProvider_Type(t *testing.T) { - provider := NewGoogleProvider("client-id", "client-secret", "https://example.com/callback") + provider := NewGoogleProvider("client-id", "client-secret", "https://example.com/callback", nil) assert.Equal(t, "google", provider.Type()) } func TestGoogleProvider_AuthURL(t *testing.T) { - provider := NewGoogleProvider("client-id", "client-secret", "https://example.com/callback") + provider := NewGoogleProvider("client-id", "client-secret", "https://example.com/callback", nil) authURL := provider.AuthURL("test-state") @@ -102,11 +102,12 @@ func TestGoogleProvider_UserInfo(t *testing.T) { TokenURL: server.URL + "/token", }, }, - userInfoURL: server.URL, + userInfoURL: server.URL, + allowedDomains: tt.allowedDomains, } token := &oauth2.Token{AccessToken: "test-token"} - userInfo, err := provider.UserInfo(context.Background(), token, tt.allowedDomains) + userInfo, err := provider.UserInfo(context.Background(), token) if tt.wantErr { require.Error(t, err) @@ -142,7 +143,7 @@ func TestGoogleProvider_UserInfo_ServerError(t *testing.T) { } token := &oauth2.Token{AccessToken: "test-token"} - _, err := provider.UserInfo(context.Background(), token, nil) + _, err := provider.UserInfo(context.Background(), token) require.Error(t, err) assert.Contains(t, err.Error(), "status 500") diff --git a/internal/idp/oidc.go b/internal/idp/oidc.go index 0d3dbf6..8bb4fa0 100644 --- a/internal/idp/oidc.go +++ b/internal/idp/oidc.go @@ -29,13 +29,17 @@ type OIDCConfig struct { ClientSecret string RedirectURI string Scopes []string + + // Access control. + AllowedDomains []string } // OIDCProvider implements the Provider interface for OIDC-compliant identity providers. type OIDCProvider struct { - providerType string - config oauth2.Config - userInfoURL string + providerType string + config oauth2.Config + userInfoURL string + allowedDomains []string } // oidcDiscoveryDocument represents the OIDC discovery document. @@ -99,7 +103,8 @@ func NewOIDCProvider(cfg OIDCConfig) (*OIDCProvider, error) { TokenURL: tokenURL, }, }, - userInfoURL: userInfoURL, + userInfoURL: userInfoURL, + allowedDomains: cfg.AllowedDomains, }, nil } @@ -147,7 +152,7 @@ func (p *OIDCProvider) ExchangeCode(ctx context.Context, code string) (*oauth2.T // UserInfo fetches user information from the OIDC userinfo endpoint. // TODO: Add ID token validation as optimization (avoids network call). -func (p *OIDCProvider) UserInfo(ctx context.Context, token *oauth2.Token, allowedDomains []string) (*UserInfo, error) { +func (p *OIDCProvider) UserInfo(ctx context.Context, token *oauth2.Token) (*UserInfo, error) { client := p.config.Client(ctx, token) resp, err := client.Get(p.userInfoURL) if err != nil { @@ -167,7 +172,7 @@ func (p *OIDCProvider) UserInfo(ctx context.Context, token *oauth2.Token, allowe domain := emailutil.ExtractDomain(userInfoResp.Email) // Validate domain if configured - if err := ValidateDomain(domain, allowedDomains); err != nil { + if err := ValidateDomain(domain, p.allowedDomains); err != nil { return nil, err } diff --git a/internal/idp/oidc_test.go b/internal/idp/oidc_test.go index b88d1dd..9d93c1c 100644 --- a/internal/idp/oidc_test.go +++ b/internal/idp/oidc_test.go @@ -145,7 +145,7 @@ func TestOIDCProvider_UserInfo(t *testing.T) { require.NoError(t, err) token := &oauth2.Token{AccessToken: "test-token"} - userInfo, err := provider.UserInfo(context.Background(), token, nil) + userInfo, err := provider.UserInfo(context.Background(), token) require.NoError(t, err) require.NotNil(t, userInfo) @@ -175,11 +175,12 @@ func TestOIDCProvider_UserInfo_DomainValidation(t *testing.T) { ClientID: "client-id", ClientSecret: "client-secret", RedirectURI: "https://example.com/callback", + AllowedDomains: []string{"example.com"}, }) require.NoError(t, err) token := &oauth2.Token{AccessToken: "test-token"} - _, err = provider.UserInfo(context.Background(), token, []string{"example.com"}) + _, err = provider.UserInfo(context.Background(), token) require.Error(t, err) assert.Contains(t, err.Error(), "domain 'other.com' is not allowed") diff --git a/internal/idp/provider.go b/internal/idp/provider.go index da10e1b..d01c908 100644 --- a/internal/idp/provider.go +++ b/internal/idp/provider.go @@ -23,6 +23,7 @@ type UserInfo struct { } // Provider abstracts identity provider operations. +// Access control (allowed domains, orgs) is configured at construction time. type Provider interface { // Type returns the provider type identifier (e.g., "google", "azure", "github", "oidc"). Type() string @@ -34,9 +35,8 @@ type Provider interface { ExchangeCode(ctx context.Context, code string) (*oauth2.Token, error) // UserInfo fetches user information and validates access. - // allowedDomains is used for domain-based access control. - // Provider-specific access control (e.g., GitHub org membership) is configured at construction. - UserInfo(ctx context.Context, token *oauth2.Token, allowedDomains []string) (*UserInfo, error) + // Returns error if user doesn't meet access requirements (domain, org membership). + UserInfo(ctx context.Context, token *oauth2.Token) (*UserInfo, error) } // ValidateDomain checks if the domain is in the allowed list. diff --git a/internal/mcpfront.go b/internal/mcpfront.go index ed0d8f0..eb97bf4 100644 --- a/internal/mcpfront.go +++ b/internal/mcpfront.go @@ -239,7 +239,7 @@ func setupAuthentication(ctx context.Context, cfg config.Config, store storage.S log.LogDebug("initializing OAuth components") // Create identity provider - idpProvider, err := idp.NewProvider(oauthAuth.IDP) + idpProvider, err := idp.NewProvider(oauthAuth.IDP, oauthAuth.AllowedDomains) if err != nil { return nil, nil, nil, config.OAuthAuthConfig{}, nil, fmt.Errorf("failed to create identity provider: %w", err) } diff --git a/internal/server/auth_handlers.go b/internal/server/auth_handlers.go index f8396ba..fcde3ae 100644 --- a/internal/server/auth_handlers.go +++ b/internal/server/auth_handlers.go @@ -330,7 +330,7 @@ func (h *AuthHandlers) IDPCallbackHandler(w http.ResponseWriter, r *http.Request } // Validate user and fetch user info - userInfo, err := h.idpProvider.UserInfo(ctx, token, h.authConfig.AllowedDomains) + userInfo, err := h.idpProvider.UserInfo(ctx, token) if err != nil { log.LogError("User validation failed: %v", err) if !isBrowserFlow && ar != nil { diff --git a/internal/server/auth_handlers_test.go b/internal/server/auth_handlers_test.go index fcc6f98..e0e3a2f 100644 --- a/internal/server/auth_handlers_test.go +++ b/internal/server/auth_handlers_test.go @@ -36,7 +36,7 @@ func (m *mockIDPProvider) ExchangeCode(ctx context.Context, code string) (*oauth return &oauth2.Token{AccessToken: "test-token"}, nil } -func (m *mockIDPProvider) UserInfo(ctx context.Context, token *oauth2.Token, allowedDomains []string) (*idp.UserInfo, error) { +func (m *mockIDPProvider) UserInfo(ctx context.Context, token *oauth2.Token) (*idp.UserInfo, error) { return &idp.UserInfo{ ProviderType: "mock", Subject: "123", From 0e099bd411a1f02b4af19fbfeb31ecd32a5ef7f6 Mon Sep 17 00:00:00 2001 From: Sam El-Borai Date: Thu, 12 Feb 2026 10:16:54 +0100 Subject: [PATCH 3/7] Separate authentication from authorization in IDP providers MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Providers now only fetch identity — access control (domain/org checks) is centralized in a single validateAccess method on the auth handler. This means IDP errors (unreachable provider) produce ErrServerError while policy rejections produce ErrAccessDenied, and adding new access rules no longer requires touching every provider. GitHub always fetches orgs now (scope was already requested unconditionally). UserInfo struct renamed to Identity to avoid collision with the method name. ParseClientRequest relocated to oauth package as ParseClientRegistration. Deleted deprecated ProtectedResourceMetadata and inlined the workaround logic into the handler that used it. --- internal/idp/azure.go | 15 ++- internal/idp/azure_test.go | 2 +- internal/idp/factory.go | 8 +- internal/idp/factory_test.go | 2 +- internal/idp/github.go | 54 ++-------- internal/idp/github_test.go | 101 ++++-------------- internal/idp/google.go | 21 ++-- internal/idp/google_test.go | 53 +++------ internal/idp/oidc.go | 23 ++-- internal/idp/oidc_test.go | 44 ++------ internal/idp/provider.go | 57 ++-------- internal/mcpfront.go | 2 +- internal/oauth/client_registration.go | 32 ++++++ .../client_registration_test.go} | 6 +- internal/oauth/metadata.go | 28 +---- internal/oauth/metadata_test.go | 68 ------------ internal/oauth/provider.go | 4 +- internal/oauthsession/session.go | 21 +--- internal/server/auth_handlers.go | 87 +++++++++++---- internal/server/auth_handlers_test.go | 89 ++++++++++++++- 20 files changed, 279 insertions(+), 438 deletions(-) create mode 100644 internal/oauth/client_registration.go rename internal/{idp/provider_test.go => oauth/client_registration_test.go} (95%) diff --git a/internal/idp/azure.go b/internal/idp/azure.go index b364649..ccd0682 100644 --- a/internal/idp/azure.go +++ b/internal/idp/azure.go @@ -4,7 +4,7 @@ import "fmt" // NewAzureProvider creates an Azure AD provider using OIDC discovery. // Azure AD is OIDC-compliant, so we use the generic OIDC provider with Azure's tenant-specific discovery URL. -func NewAzureProvider(tenantID, clientID, clientSecret, redirectURI string, allowedDomains []string) (*OIDCProvider, error) { +func NewAzureProvider(tenantID, clientID, clientSecret, redirectURI string) (*OIDCProvider, error) { if tenantID == "" { return nil, fmt.Errorf("tenantId is required for Azure AD") } @@ -15,12 +15,11 @@ func NewAzureProvider(tenantID, clientID, clientSecret, redirectURI string, allo ) return NewOIDCProvider(OIDCConfig{ - ProviderType: "azure", - DiscoveryURL: discoveryURL, - ClientID: clientID, - ClientSecret: clientSecret, - RedirectURI: redirectURI, - Scopes: []string{"openid", "email", "profile"}, - AllowedDomains: allowedDomains, + ProviderType: "azure", + DiscoveryURL: discoveryURL, + ClientID: clientID, + ClientSecret: clientSecret, + RedirectURI: redirectURI, + Scopes: []string{"openid", "email", "profile"}, }) } diff --git a/internal/idp/azure_test.go b/internal/idp/azure_test.go index b4ad312..d163ce0 100644 --- a/internal/idp/azure_test.go +++ b/internal/idp/azure_test.go @@ -8,7 +8,7 @@ import ( ) func TestNewAzureProvider_MissingTenantID(t *testing.T) { - _, err := NewAzureProvider("", "client-id", "client-secret", "https://example.com/callback", nil) + _, err := NewAzureProvider("", "client-id", "client-secret", "https://example.com/callback") require.Error(t, err) assert.Contains(t, err.Error(), "tenantId is required") diff --git a/internal/idp/factory.go b/internal/idp/factory.go index 1ece21f..60004e1 100644 --- a/internal/idp/factory.go +++ b/internal/idp/factory.go @@ -7,15 +7,13 @@ import ( ) // NewProvider creates a Provider based on the IDPConfig. -// allowedDomains configures domain-based access control for all provider types. -func NewProvider(cfg config.IDPConfig, allowedDomains []string) (Provider, error) { +func NewProvider(cfg config.IDPConfig) (Provider, error) { switch cfg.Provider { case "google": return NewGoogleProvider( cfg.ClientID, string(cfg.ClientSecret), cfg.RedirectURI, - allowedDomains, ), nil case "azure": @@ -24,7 +22,6 @@ func NewProvider(cfg config.IDPConfig, allowedDomains []string) (Provider, error cfg.ClientID, string(cfg.ClientSecret), cfg.RedirectURI, - allowedDomains, ) case "github": @@ -32,8 +29,6 @@ func NewProvider(cfg config.IDPConfig, allowedDomains []string) (Provider, error cfg.ClientID, string(cfg.ClientSecret), cfg.RedirectURI, - allowedDomains, - cfg.AllowedOrgs, ), nil case "oidc": @@ -47,7 +42,6 @@ func NewProvider(cfg config.IDPConfig, allowedDomains []string) (Provider, error ClientSecret: string(cfg.ClientSecret), RedirectURI: cfg.RedirectURI, Scopes: cfg.Scopes, - AllowedDomains: allowedDomains, }) default: diff --git a/internal/idp/factory_test.go b/internal/idp/factory_test.go index 11065f7..9b38681 100644 --- a/internal/idp/factory_test.go +++ b/internal/idp/factory_test.go @@ -87,7 +87,7 @@ func TestNewProvider(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - provider, err := NewProvider(tt.cfg, nil) + provider, err := NewProvider(tt.cfg) if tt.wantErr { require.Error(t, err) diff --git a/internal/idp/github.go b/internal/idp/github.go index bc055e2..e737cd1 100644 --- a/internal/idp/github.go +++ b/internal/idp/github.go @@ -14,10 +14,8 @@ import ( // GitHubProvider implements the Provider interface for GitHub OAuth. // GitHub uses OAuth 2.0 (not OIDC) and has its own API for user info and org membership. type GitHubProvider struct { - config oauth2.Config - apiBaseURL string // defaults to https://api.github.com, can be overridden for testing - allowedDomains []string // email domains users must belong to (empty = no restriction) - allowedOrgs []string // organizations users must be members of (empty = no restriction) + config oauth2.Config + apiBaseURL string // defaults to https://api.github.com, can be overridden for testing } // githubUserResponse represents GitHub's user API response. @@ -42,7 +40,7 @@ type githubOrgResponse struct { } // NewGitHubProvider creates a new GitHub OAuth provider. -func NewGitHubProvider(clientID, clientSecret, redirectURI string, allowedDomains, allowedOrgs []string) *GitHubProvider { +func NewGitHubProvider(clientID, clientSecret, redirectURI string) *GitHubProvider { return &GitHubProvider{ config: oauth2.Config{ ClientID: clientID, @@ -51,9 +49,7 @@ func NewGitHubProvider(clientID, clientSecret, redirectURI string, allowedDomain Scopes: []string{"user:email", "read:org"}, Endpoint: github.Endpoint, }, - apiBaseURL: "https://api.github.com", - allowedDomains: allowedDomains, - allowedOrgs: allowedOrgs, + apiBaseURL: "https://api.github.com", } } @@ -72,13 +68,11 @@ func (p *GitHubProvider) ExchangeCode(ctx context.Context, code string) (*oauth2 return p.config.Exchange(ctx, code) } -// UserInfo fetches user information from GitHub's API. -// Validates domain and organization membership based on construction-time config. -// TODO: Consider caching org membership to reduce API calls. -func (p *GitHubProvider) UserInfo(ctx context.Context, token *oauth2.Token) (*UserInfo, error) { +// UserInfo fetches user identity from GitHub's API. +// Always fetches organizations so the authorization layer can check membership. +func (p *GitHubProvider) UserInfo(ctx context.Context, token *oauth2.Token) (*Identity, error) { client := p.config.Client(ctx, token) - // Fetch user profile user, err := p.fetchUser(client) if err != nil { return nil, err @@ -99,38 +93,12 @@ func (p *GitHubProvider) UserInfo(ctx context.Context, token *oauth2.Token) (*Us domain := emailutil.ExtractDomain(email) - // Validate domain if configured - if err := ValidateDomain(domain, p.allowedDomains); err != nil { - return nil, err - } - - // Fetch organizations only if org validation is configured - var orgs []string - if len(p.allowedOrgs) > 0 { - orgs, err = p.fetchOrganizations(client) - if err != nil { - return nil, fmt.Errorf("failed to get user organizations: %w", err) - } - - // Validate org membership - hasAllowedOrg := false - for _, org := range orgs { - for _, allowed := range p.allowedOrgs { - if org == allowed { - hasAllowedOrg = true - break - } - } - if hasAllowedOrg { - break - } - } - if !hasAllowedOrg { - return nil, fmt.Errorf("user is not a member of any allowed organization. Contact your administrator") - } + orgs, err := p.fetchOrganizations(client) + if err != nil { + return nil, fmt.Errorf("failed to get user organizations: %w", err) } - return &UserInfo{ + return &Identity{ ProviderType: "github", Subject: fmt.Sprintf("%d", user.ID), Email: email, diff --git a/internal/idp/github_test.go b/internal/idp/github_test.go index f42b5ad..e532ef3 100644 --- a/internal/idp/github_test.go +++ b/internal/idp/github_test.go @@ -13,12 +13,12 @@ import ( ) func TestGitHubProvider_Type(t *testing.T) { - provider := NewGitHubProvider("client-id", "client-secret", "https://example.com/callback", nil, nil) + provider := NewGitHubProvider("client-id", "client-secret", "https://example.com/callback") assert.Equal(t, "github", provider.Type()) } func TestGitHubProvider_AuthURL(t *testing.T) { - provider := NewGitHubProvider("client-id", "client-secret", "https://example.com/callback", nil, nil) + provider := NewGitHubProvider("client-id", "client-secret", "https://example.com/callback") authURL := provider.AuthURL("test-state") @@ -33,10 +33,6 @@ func TestGitHubProvider_UserInfo(t *testing.T) { userResp githubUserResponse emailsResp []githubEmailResponse orgsResp []githubOrgResponse - allowedDomains []string - allowedOrgs []string - wantErr bool - errContains string expectedEmail string expectedEmailVerified bool expectedDomain string @@ -51,10 +47,11 @@ func TestGitHubProvider_UserInfo(t *testing.T) { Name: "Test User", AvatarURL: "https://github.com/avatar.jpg", }, + orgsResp: []githubOrgResponse{{Login: "my-org"}}, expectedEmail: "user@company.com", - expectedEmailVerified: true, // Public emails in GitHub profile are verified + expectedEmailVerified: true, expectedDomain: "company.com", - expectedOrgs: nil, // Orgs not fetched when allowedOrgs is empty + expectedOrgs: []string{"my-org"}, }, { name: "user_without_public_email_fetches_from_api", @@ -67,10 +64,11 @@ func TestGitHubProvider_UserInfo(t *testing.T) { {Email: "secondary@other.com", Primary: false, Verified: true}, {Email: "primary@company.com", Primary: true, Verified: true}, }, + orgsResp: []githubOrgResponse{}, expectedEmail: "primary@company.com", expectedEmailVerified: true, expectedDomain: "company.com", - expectedOrgs: nil, // Orgs not fetched when allowedOrgs is empty + expectedOrgs: []string{}, }, { name: "user_with_unverified_primary_falls_back_to_verified", @@ -82,72 +80,24 @@ func TestGitHubProvider_UserInfo(t *testing.T) { {Email: "primary@company.com", Primary: true, Verified: false}, {Email: "verified@company.com", Primary: false, Verified: true}, }, + orgsResp: []githubOrgResponse{}, expectedEmail: "verified@company.com", expectedEmailVerified: true, expectedDomain: "company.com", - expectedOrgs: nil, // Orgs not fetched when allowedOrgs is empty - }, - { - name: "domain_validation_success", - userResp: githubUserResponse{ - ID: 12345, - Login: "testuser", - Email: "user@company.com", - }, - allowedDomains: []string{"company.com"}, - expectedEmail: "user@company.com", - expectedEmailVerified: true, - expectedDomain: "company.com", - expectedOrgs: nil, // Orgs not fetched when allowedOrgs is empty + expectedOrgs: []string{}, }, { - name: "domain_validation_failure", - userResp: githubUserResponse{ - ID: 12345, - Login: "testuser", - Email: "user@other.com", - }, - allowedDomains: []string{"company.com"}, - wantErr: true, - errContains: "domain 'other.com' is not allowed", - }, - { - name: "org_validation_success", + name: "orgs_always_populated", userResp: githubUserResponse{ ID: 12345, Login: "testuser", Email: "user@gmail.com", }, - orgsResp: []githubOrgResponse{{Login: "allowed-org"}, {Login: "other-org"}}, - allowedOrgs: []string{"allowed-org"}, + orgsResp: []githubOrgResponse{{Login: "org-a"}, {Login: "org-b"}}, expectedEmail: "user@gmail.com", expectedEmailVerified: true, expectedDomain: "gmail.com", - expectedOrgs: []string{"allowed-org", "other-org"}, - }, - { - name: "org_validation_failure", - userResp: githubUserResponse{ - ID: 12345, - Login: "testuser", - Email: "user@gmail.com", - }, - orgsResp: []githubOrgResponse{{Login: "other-org"}}, - allowedOrgs: []string{"required-org"}, - wantErr: true, - errContains: "not a member of any allowed organization", - }, - { - name: "user_with_no_orgs_restriction", - userResp: githubUserResponse{ - ID: 12345, - Login: "testuser", - Email: "user@gmail.com", - }, - expectedEmail: "user@gmail.com", - expectedEmailVerified: true, - expectedDomain: "gmail.com", - expectedOrgs: nil, // Orgs not fetched when allowedOrgs is empty + expectedOrgs: []string{"org-a", "org-b"}, }, } @@ -172,7 +122,6 @@ func TestGitHubProvider_UserInfo(t *testing.T) { })) defer server.Close() - // Create provider with test server endpoints and allowedOrgs provider := &GitHubProvider{ config: oauth2.Config{ ClientID: "test-client", @@ -184,29 +133,19 @@ func TestGitHubProvider_UserInfo(t *testing.T) { TokenURL: server.URL + "/token", }, }, - apiBaseURL: server.URL, - allowedDomains: tt.allowedDomains, - allowedOrgs: tt.allowedOrgs, + apiBaseURL: server.URL, } token := &oauth2.Token{AccessToken: "test-token"} - userInfo, err := provider.UserInfo(context.Background(), token) - - if tt.wantErr { - require.Error(t, err) - if tt.errContains != "" { - assert.Contains(t, err.Error(), tt.errContains) - } - return - } + identity, err := provider.UserInfo(context.Background(), token) require.NoError(t, err) - require.NotNil(t, userInfo) - assert.Equal(t, "github", userInfo.ProviderType) - assert.Equal(t, tt.expectedEmail, userInfo.Email) - assert.Equal(t, tt.expectedEmailVerified, userInfo.EmailVerified) - assert.Equal(t, tt.expectedDomain, userInfo.Domain) - assert.Equal(t, tt.expectedOrgs, userInfo.Organizations) + require.NotNil(t, identity) + assert.Equal(t, "github", identity.ProviderType) + assert.Equal(t, tt.expectedEmail, identity.Email) + assert.Equal(t, tt.expectedEmailVerified, identity.EmailVerified) + assert.Equal(t, tt.expectedDomain, identity.Domain) + assert.Equal(t, tt.expectedOrgs, identity.Organizations) }) } } diff --git a/internal/idp/google.go b/internal/idp/google.go index 6d017cb..b05ca68 100644 --- a/internal/idp/google.go +++ b/internal/idp/google.go @@ -14,9 +14,8 @@ import ( // GoogleProvider implements the Provider interface for Google OAuth. // Google has specific quirks like `hd` for hosted domain and `verified_email` field. type GoogleProvider struct { - config oauth2.Config - userInfoURL string - allowedDomains []string + config oauth2.Config + userInfoURL string } // googleUserInfoResponse represents Google's userinfo response. @@ -31,7 +30,7 @@ type googleUserInfoResponse struct { } // NewGoogleProvider creates a new Google OAuth provider. -func NewGoogleProvider(clientID, clientSecret, redirectURI string, allowedDomains []string) *GoogleProvider { +func NewGoogleProvider(clientID, clientSecret, redirectURI string) *GoogleProvider { return &GoogleProvider{ config: oauth2.Config{ ClientID: clientID, @@ -40,8 +39,7 @@ func NewGoogleProvider(clientID, clientSecret, redirectURI string, allowedDomain Scopes: []string{"openid", "profile", "email"}, Endpoint: google.Endpoint, }, - userInfoURL: "https://www.googleapis.com/oauth2/v2/userinfo", - allowedDomains: allowedDomains, + userInfoURL: "https://www.googleapis.com/oauth2/v2/userinfo", } } @@ -63,8 +61,8 @@ func (p *GoogleProvider) ExchangeCode(ctx context.Context, code string) (*oauth2 return p.config.Exchange(ctx, code) } -// UserInfo fetches user information from Google's userinfo endpoint. -func (p *GoogleProvider) UserInfo(ctx context.Context, token *oauth2.Token) (*UserInfo, error) { +// UserInfo fetches user identity from Google's userinfo endpoint. +func (p *GoogleProvider) UserInfo(ctx context.Context, token *oauth2.Token) (*Identity, error) { client := p.config.Client(ctx, token) resp, err := client.Get(p.userInfoURL) @@ -88,12 +86,7 @@ func (p *GoogleProvider) UserInfo(ctx context.Context, token *oauth2.Token) (*Us domain = emailutil.ExtractDomain(googleUser.Email) } - // Validate domain if configured - if err := ValidateDomain(domain, p.allowedDomains); err != nil { - return nil, err - } - - return &UserInfo{ + return &Identity{ ProviderType: "google", Subject: googleUser.Sub, Email: googleUser.Email, diff --git a/internal/idp/google_test.go b/internal/idp/google_test.go index e091373..ff618fb 100644 --- a/internal/idp/google_test.go +++ b/internal/idp/google_test.go @@ -13,12 +13,12 @@ import ( ) func TestGoogleProvider_Type(t *testing.T) { - provider := NewGoogleProvider("client-id", "client-secret", "https://example.com/callback", nil) + provider := NewGoogleProvider("client-id", "client-secret", "https://example.com/callback") assert.Equal(t, "google", provider.Type()) } func TestGoogleProvider_AuthURL(t *testing.T) { - provider := NewGoogleProvider("client-id", "client-secret", "https://example.com/callback", nil) + provider := NewGoogleProvider("client-id", "client-secret", "https://example.com/callback") authURL := provider.AuthURL("test-state") @@ -33,14 +33,11 @@ func TestGoogleProvider_UserInfo(t *testing.T) { tests := []struct { name string userInfoResp googleUserInfoResponse - allowedDomains []string - wantErr bool - errContains string expectedDomain string expectedSubject string }{ { - name: "valid_user_with_hosted_domain", + name: "user_with_hosted_domain", userInfoResp: googleUserInfoResponse{ Sub: "12345", Email: "user@company.com", @@ -49,37 +46,20 @@ func TestGoogleProvider_UserInfo(t *testing.T) { Picture: "https://example.com/photo.jpg", HostedDomain: "company.com", }, - allowedDomains: []string{"company.com"}, - wantErr: false, expectedDomain: "company.com", expectedSubject: "12345", }, { - name: "valid_user_without_hosted_domain_derives_from_email", + name: "user_without_hosted_domain_derives_from_email", userInfoResp: googleUserInfoResponse{ Sub: "12345", Email: "user@gmail.com", VerifiedEmail: true, Name: "Test User", }, - allowedDomains: nil, - wantErr: false, expectedDomain: "gmail.com", expectedSubject: "12345", }, - { - name: "domain_not_allowed", - userInfoResp: googleUserInfoResponse{ - Sub: "12345", - Email: "user@other.com", - VerifiedEmail: true, - Name: "Test User", - HostedDomain: "other.com", - }, - allowedDomains: []string{"company.com"}, - wantErr: true, - errContains: "domain 'other.com' is not allowed", - }, } for _, tt := range tests { @@ -102,28 +82,19 @@ func TestGoogleProvider_UserInfo(t *testing.T) { TokenURL: server.URL + "/token", }, }, - userInfoURL: server.URL, - allowedDomains: tt.allowedDomains, + userInfoURL: server.URL, } token := &oauth2.Token{AccessToken: "test-token"} - userInfo, err := provider.UserInfo(context.Background(), token) - - if tt.wantErr { - require.Error(t, err) - if tt.errContains != "" { - assert.Contains(t, err.Error(), tt.errContains) - } - return - } + identity, err := provider.UserInfo(context.Background(), token) require.NoError(t, err) - require.NotNil(t, userInfo) - assert.Equal(t, "google", userInfo.ProviderType) - assert.Equal(t, tt.expectedSubject, userInfo.Subject) - assert.Equal(t, tt.expectedDomain, userInfo.Domain) - assert.Equal(t, tt.userInfoResp.Email, userInfo.Email) - assert.Equal(t, tt.userInfoResp.VerifiedEmail, userInfo.EmailVerified) + require.NotNil(t, identity) + assert.Equal(t, "google", identity.ProviderType) + assert.Equal(t, tt.expectedSubject, identity.Subject) + assert.Equal(t, tt.expectedDomain, identity.Domain) + assert.Equal(t, tt.userInfoResp.Email, identity.Email) + assert.Equal(t, tt.userInfoResp.VerifiedEmail, identity.EmailVerified) }) } } diff --git a/internal/idp/oidc.go b/internal/idp/oidc.go index 8bb4fa0..4d48877 100644 --- a/internal/idp/oidc.go +++ b/internal/idp/oidc.go @@ -30,16 +30,13 @@ type OIDCConfig struct { RedirectURI string Scopes []string - // Access control. - AllowedDomains []string } // OIDCProvider implements the Provider interface for OIDC-compliant identity providers. type OIDCProvider struct { - providerType string - config oauth2.Config - userInfoURL string - allowedDomains []string + providerType string + config oauth2.Config + userInfoURL string } // oidcDiscoveryDocument represents the OIDC discovery document. @@ -103,8 +100,7 @@ func NewOIDCProvider(cfg OIDCConfig) (*OIDCProvider, error) { TokenURL: tokenURL, }, }, - userInfoURL: userInfoURL, - allowedDomains: cfg.AllowedDomains, + userInfoURL: userInfoURL, }, nil } @@ -150,9 +146,9 @@ func (p *OIDCProvider) ExchangeCode(ctx context.Context, code string) (*oauth2.T return p.config.Exchange(ctx, code) } -// UserInfo fetches user information from the OIDC userinfo endpoint. +// UserInfo fetches user identity from the OIDC userinfo endpoint. // TODO: Add ID token validation as optimization (avoids network call). -func (p *OIDCProvider) UserInfo(ctx context.Context, token *oauth2.Token) (*UserInfo, error) { +func (p *OIDCProvider) UserInfo(ctx context.Context, token *oauth2.Token) (*Identity, error) { client := p.config.Client(ctx, token) resp, err := client.Get(p.userInfoURL) if err != nil { @@ -171,12 +167,7 @@ func (p *OIDCProvider) UserInfo(ctx context.Context, token *oauth2.Token) (*User domain := emailutil.ExtractDomain(userInfoResp.Email) - // Validate domain if configured - if err := ValidateDomain(domain, p.allowedDomains); err != nil { - return nil, err - } - - return &UserInfo{ + return &Identity{ ProviderType: p.providerType, Subject: userInfoResp.Sub, Email: userInfoResp.Email, diff --git a/internal/idp/oidc_test.go b/internal/idp/oidc_test.go index 9d93c1c..ced5b9f 100644 --- a/internal/idp/oidc_test.go +++ b/internal/idp/oidc_test.go @@ -145,45 +145,15 @@ func TestOIDCProvider_UserInfo(t *testing.T) { require.NoError(t, err) token := &oauth2.Token{AccessToken: "test-token"} - userInfo, err := provider.UserInfo(context.Background(), token) + identity, err := provider.UserInfo(context.Background(), token) require.NoError(t, err) - require.NotNil(t, userInfo) - assert.Equal(t, "oidc", userInfo.ProviderType) - assert.Equal(t, "12345", userInfo.Subject) - assert.Equal(t, "user@example.com", userInfo.Email) - assert.Equal(t, "example.com", userInfo.Domain) - assert.True(t, userInfo.EmailVerified) -} - -func TestOIDCProvider_UserInfo_DomainValidation(t *testing.T) { - userInfoServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - w.Header().Set("Content-Type", "application/json") - resp := oidcUserInfoResponse{ - Sub: "12345", - Email: "user@other.com", - } - err := json.NewEncoder(w).Encode(resp) - require.NoError(t, err) - })) - defer userInfoServer.Close() - - provider, err := NewOIDCProvider(OIDCConfig{ - AuthorizationURL: "https://idp.example.com/authorize", - TokenURL: "https://idp.example.com/token", - UserInfoURL: userInfoServer.URL, - ClientID: "client-id", - ClientSecret: "client-secret", - RedirectURI: "https://example.com/callback", - AllowedDomains: []string{"example.com"}, - }) - require.NoError(t, err) - - token := &oauth2.Token{AccessToken: "test-token"} - _, err = provider.UserInfo(context.Background(), token) - - require.Error(t, err) - assert.Contains(t, err.Error(), "domain 'other.com' is not allowed") + require.NotNil(t, identity) + assert.Equal(t, "oidc", identity.ProviderType) + assert.Equal(t, "12345", identity.Subject) + assert.Equal(t, "user@example.com", identity.Email) + assert.Equal(t, "example.com", identity.Domain) + assert.True(t, identity.EmailVerified) } func TestOIDCProvider_DefaultScopes(t *testing.T) { diff --git a/internal/idp/provider.go b/internal/idp/provider.go index d01c908..e700112 100644 --- a/internal/idp/provider.go +++ b/internal/idp/provider.go @@ -2,16 +2,14 @@ package idp import ( "context" - "fmt" - "slices" - "strings" "golang.org/x/oauth2" ) -// UserInfo represents user information from any identity provider. -// ProviderType is included for multi-IDP readiness. -type UserInfo struct { +// Identity represents user identity as reported by an identity provider. +// Providers populate this with identity information only — access control +// (domain, org checks) is handled by the authorization layer. +type Identity struct { ProviderType string `json:"provider_type"` Subject string `json:"sub"` Email string `json:"email"` @@ -23,7 +21,6 @@ type UserInfo struct { } // Provider abstracts identity provider operations. -// Access control (allowed domains, orgs) is configured at construction time. type Provider interface { // Type returns the provider type identifier (e.g., "google", "azure", "github", "oidc"). Type() string @@ -34,47 +31,7 @@ type Provider interface { // ExchangeCode exchanges an authorization code for tokens. ExchangeCode(ctx context.Context, code string) (*oauth2.Token, error) - // UserInfo fetches user information and validates access. - // Returns error if user doesn't meet access requirements (domain, org membership). - UserInfo(ctx context.Context, token *oauth2.Token) (*UserInfo, error) -} - -// ValidateDomain checks if the domain is in the allowed list. -// Returns nil if allowedDomains is empty (no restriction) or domain is allowed. -func ValidateDomain(domain string, allowedDomains []string) error { - if len(allowedDomains) == 0 { - return nil - } - if !slices.Contains(allowedDomains, domain) { - return fmt.Errorf("domain '%s' is not allowed. Contact your administrator", domain) - } - return nil -} - -// ParseClientRequest parses MCP client registration metadata. -// This is provider-agnostic as it deals with MCP client registration, not IDP. -func ParseClientRequest(metadata map[string]any) (redirectURIs []string, scopes []string, err error) { - // Extract redirect URIs - redirectURIs = []string{} - if uris, ok := metadata["redirect_uris"].([]any); ok { - for _, uri := range uris { - if uriStr, ok := uri.(string); ok { - redirectURIs = append(redirectURIs, uriStr) - } - } - } - - if len(redirectURIs) == 0 { - return nil, nil, fmt.Errorf("no valid redirect URIs provided") - } - - // Extract scopes, default to read/write if not provided - scopes = []string{"read", "write"} // Default MCP scopes - if clientScopes, ok := metadata["scope"].(string); ok { - if strings.TrimSpace(clientScopes) != "" { - scopes = strings.Fields(clientScopes) - } - } - - return redirectURIs, scopes, nil + // UserInfo fetches user identity from the provider. + // Returns identity information only — no access control validation. + UserInfo(ctx context.Context, token *oauth2.Token) (*Identity, error) } diff --git a/internal/mcpfront.go b/internal/mcpfront.go index eb97bf4..ed0d8f0 100644 --- a/internal/mcpfront.go +++ b/internal/mcpfront.go @@ -239,7 +239,7 @@ func setupAuthentication(ctx context.Context, cfg config.Config, store storage.S log.LogDebug("initializing OAuth components") // Create identity provider - idpProvider, err := idp.NewProvider(oauthAuth.IDP, oauthAuth.AllowedDomains) + idpProvider, err := idp.NewProvider(oauthAuth.IDP) if err != nil { return nil, nil, nil, config.OAuthAuthConfig{}, nil, fmt.Errorf("failed to create identity provider: %w", err) } diff --git a/internal/oauth/client_registration.go b/internal/oauth/client_registration.go new file mode 100644 index 0000000..75510df --- /dev/null +++ b/internal/oauth/client_registration.go @@ -0,0 +1,32 @@ +package oauth + +import ( + "fmt" + "strings" +) + +// ParseClientRegistration parses MCP client registration metadata. +// This is provider-agnostic as it deals with MCP client registration, not IDP. +func ParseClientRegistration(metadata map[string]any) (redirectURIs []string, scopes []string, err error) { + redirectURIs = []string{} + if uris, ok := metadata["redirect_uris"].([]any); ok { + for _, uri := range uris { + if uriStr, ok := uri.(string); ok { + redirectURIs = append(redirectURIs, uriStr) + } + } + } + + if len(redirectURIs) == 0 { + return nil, nil, fmt.Errorf("no valid redirect URIs provided") + } + + scopes = []string{"read", "write"} + if clientScopes, ok := metadata["scope"].(string); ok { + if strings.TrimSpace(clientScopes) != "" { + scopes = strings.Fields(clientScopes) + } + } + + return redirectURIs, scopes, nil +} diff --git a/internal/idp/provider_test.go b/internal/oauth/client_registration_test.go similarity index 95% rename from internal/idp/provider_test.go rename to internal/oauth/client_registration_test.go index b4827d6..4aacd68 100644 --- a/internal/idp/provider_test.go +++ b/internal/oauth/client_registration_test.go @@ -1,4 +1,4 @@ -package idp +package oauth import ( "testing" @@ -7,7 +7,7 @@ import ( "github.com/stretchr/testify/require" ) -func TestParseClientRequest(t *testing.T) { +func TestParseClientRegistration(t *testing.T) { tests := []struct { name string metadata map[string]any @@ -95,7 +95,7 @@ func TestParseClientRequest(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - redirectURIs, scopes, err := ParseClientRequest(tt.metadata) + redirectURIs, scopes, err := ParseClientRegistration(tt.metadata) if tt.wantErr { require.Error(t, err) diff --git a/internal/oauth/metadata.go b/internal/oauth/metadata.go index e750415..078906e 100644 --- a/internal/oauth/metadata.go +++ b/internal/oauth/metadata.go @@ -53,29 +53,9 @@ func AuthorizationServerMetadata(issuer string) (map[string]any, error) { }, nil } -// ProtectedResourceMetadata builds OAuth 2.0 Protected Resource Metadata per RFC 9728 -// https://datatracker.ietf.org/doc/html/rfc9728 -// -// Deprecated: Use ServiceProtectedResourceMetadata for per-service metadata endpoints. -// This function returns the base issuer as the resource, which doesn't support -// per-service audience validation required by RFC 8707. -func ProtectedResourceMetadata(issuer string) (map[string]any, error) { - authzServerURL, err := urlutil.JoinPath(issuer, ".well-known", "oauth-authorization-server") - if err != nil { - return nil, err - } - - return map[string]any{ - "resource": issuer, - "authorization_servers": []string{ - issuer, - }, - "_links": map[string]any{ - "oauth-authorization-server": map[string]string{ - "href": authzServerURL, - }, - }, - }, nil +// AuthorizationServerMetadataURI returns the well-known URI for the authorization server metadata. +func AuthorizationServerMetadataURI(issuer string) (string, error) { + return urlutil.JoinPath(issuer, ".well-known", "oauth-authorization-server") } // ServiceProtectedResourceMetadata builds OAuth 2.0 Protected Resource Metadata per RFC 9728 @@ -92,7 +72,7 @@ func ServiceProtectedResourceMetadata(issuer string, serviceName string) (map[st return nil, err } - authzServerURL, err := urlutil.JoinPath(issuer, ".well-known", "oauth-authorization-server") + authzServerURL, err := AuthorizationServerMetadataURI(issuer) if err != nil { return nil, err } diff --git a/internal/oauth/metadata_test.go b/internal/oauth/metadata_test.go index db9e8df..4367c09 100644 --- a/internal/oauth/metadata_test.go +++ b/internal/oauth/metadata_test.go @@ -83,74 +83,6 @@ func TestAuthorizationServerMetadata(t *testing.T) { } } -func TestProtectedResourceMetadata(t *testing.T) { - tests := []struct { - name string - issuer string - wantErr bool - }{ - { - name: "valid issuer", - issuer: "https://example.com", - wantErr: false, - }, - { - name: "issuer with path", - issuer: "https://example.com/oauth", - wantErr: false, - }, - { - name: "invalid issuer", - issuer: "://invalid", - wantErr: true, - }, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - metadata, err := ProtectedResourceMetadata(tt.issuer) - if (err != nil) != tt.wantErr { - t.Errorf("ProtectedResourceMetadata() error = %v, wantErr %v", err, tt.wantErr) - return - } - - if tt.wantErr { - return - } - - // Verify resource field - if metadata["resource"] != tt.issuer { - t.Errorf("resource = %v, want %v", metadata["resource"], tt.issuer) - } - - // Verify authorization_servers array - authzServers, ok := metadata["authorization_servers"].([]string) - if !ok || len(authzServers) == 0 { - t.Error("authorization_servers is missing or empty") - } - - if authzServers[0] != tt.issuer { - t.Errorf("authorization_servers[0] = %v, want %v", authzServers[0], tt.issuer) - } - - // Verify _links structure - links, ok := metadata["_links"].(map[string]any) - if !ok { - t.Error("_links is missing or wrong type") - } - - authzServerLink, ok := links["oauth-authorization-server"].(map[string]string) - if !ok { - t.Error("oauth-authorization-server link is missing or wrong type") - } - - if authzServerLink["href"] == "" { - t.Error("oauth-authorization-server href is empty") - } - }) - } -} - func TestServiceProtectedResourceMetadata(t *testing.T) { tests := []struct { name string diff --git a/internal/oauth/provider.go b/internal/oauth/provider.go index 8ceca50..62ffebc 100644 --- a/internal/oauth/provider.go +++ b/internal/oauth/provider.go @@ -181,8 +181,8 @@ func NewValidateTokenMiddleware(provider fosite.OAuth2Provider, issuer string, a var userEmail string if accessRequest != nil { if reqSession, ok := accessRequest.GetSession().(*oauthsession.Session); ok { - if reqSession.UserInfo.Email != "" { - userEmail = reqSession.UserInfo.Email + if reqSession.Identity.Email != "" { + userEmail = reqSession.Identity.Email } } } diff --git a/internal/oauthsession/session.go b/internal/oauthsession/session.go index b8cf56e..5b7b5b8 100644 --- a/internal/oauthsession/session.go +++ b/internal/oauthsession/session.go @@ -1,8 +1,6 @@ package oauthsession import ( - "time" - "github.com/dgellow/mcp-front/internal/idp" "github.com/ory/fosite" ) @@ -10,28 +8,13 @@ import ( // Session extends DefaultSession with user information type Session struct { *fosite.DefaultSession - UserInfo idp.UserInfo `json:"user_info"` -} - -// NewSession creates a new session with user info -func NewSession(userInfo idp.UserInfo) *Session { - return &Session{ - DefaultSession: &fosite.DefaultSession{ - ExpiresAt: map[fosite.TokenType]time.Time{ - fosite.AccessToken: time.Now().Add(time.Hour), - fosite.RefreshToken: time.Now().Add(24 * time.Hour), - }, - Username: userInfo.Email, - Subject: userInfo.Email, - }, - UserInfo: userInfo, - } + Identity idp.Identity `json:"identity"` } // Clone implements fosite.Session func (s *Session) Clone() fosite.Session { return &Session{ DefaultSession: s.DefaultSession.Clone().(*fosite.DefaultSession), - UserInfo: s.UserInfo, + Identity: s.Identity, } } diff --git a/internal/server/auth_handlers.go b/internal/server/auth_handlers.go index fcde3ae..4ff73fb 100644 --- a/internal/server/auth_handlers.go +++ b/internal/server/auth_handlers.go @@ -7,6 +7,7 @@ import ( "fmt" "net/http" "net/url" + "slices" "strings" "time" @@ -38,7 +39,7 @@ type AuthHandlers struct { // UpstreamOAuthState stores OAuth state during upstream authentication flow (MCP host → mcp-front) type UpstreamOAuthState struct { - UserInfo idp.UserInfo `json:"user_info"` + Identity idp.Identity `json:"identity"` ClientID string `json:"client_id"` RedirectURI string `json:"redirect_uri"` Scopes []string `json:"scopes"` @@ -101,18 +102,31 @@ func (h *AuthHandlers) ProtectedResourceMetadataHandler(w http.ResponseWriter, r return } - // Workaround mode: return base issuer metadata for broken clients + // Workaround mode: return base issuer metadata for broken clients. + // This intentionally uses the issuer as the resource (no per-service scoping) + // because broken clients don't implement RFC 8707 resource indicators. + issuer := h.authConfig.Issuer log.LogWarnWithFields("oauth", "Serving base protected resource metadata (dangerouslyAcceptIssuerAudience enabled)", map[string]any{ - "issuer": h.authConfig.Issuer, + "issuer": issuer, }) - metadata, err := oauth.ProtectedResourceMetadata(h.authConfig.Issuer) + authzServerURL, err := oauth.AuthorizationServerMetadataURI(issuer) if err != nil { log.LogError("Failed to build protected resource metadata: %v", err) jsonwriter.WriteInternalServerError(w, "Internal server error") return } + metadata := map[string]any{ + "resource": issuer, + "authorization_servers": []string{issuer}, + "_links": map[string]any{ + "oauth-authorization-server": map[string]string{ + "href": authzServerURL, + }, + }, + } + w.Header().Set("Content-Type", "application/json") if err := json.NewEncoder(w).Encode(metadata); err != nil { log.LogError("Failed to encode protected resource metadata: %v", err) @@ -329,10 +343,21 @@ func (h *AuthHandlers) IDPCallbackHandler(w http.ResponseWriter, r *http.Request return } - // Validate user and fetch user info - userInfo, err := h.idpProvider.UserInfo(ctx, token) + // Fetch user identity from IDP + identity, err := h.idpProvider.UserInfo(ctx, token) if err != nil { - log.LogError("User validation failed: %v", err) + log.LogError("Failed to fetch user identity: %v", err) + if !isBrowserFlow && ar != nil { + h.oauthProvider.WriteAuthorizeError(ctx, w, ar, fosite.ErrServerError.WithHint("Failed to fetch user identity")) + } else { + jsonwriter.WriteInternalServerError(w, "Authentication failed") + } + return + } + + // Validate access (domain/org restrictions) + if err := h.validateAccess(identity); err != nil { + log.LogError("Access denied: %v", err) if !isBrowserFlow && ar != nil { h.oauthProvider.WriteAuthorizeError(ctx, w, ar, fosite.ErrAccessDenied.WithHint(err.Error())) } else { @@ -341,12 +366,12 @@ func (h *AuthHandlers) IDPCallbackHandler(w http.ResponseWriter, r *http.Request return } - log.Logf("User authenticated: %s", userInfo.Email) + log.Logf("User authenticated: %s", identity.Email) // Store user in database - if err := h.storage.UpsertUser(ctx, userInfo.Email); err != nil { + if err := h.storage.UpsertUser(ctx, identity.Email); err != nil { log.LogWarnWithFields("auth", "Failed to track user", map[string]any{ - "email": userInfo.Email, + "email": identity.Email, "error": err.Error(), }) } @@ -357,8 +382,8 @@ func (h *AuthHandlers) IDPCallbackHandler(w http.ResponseWriter, r *http.Request sessionDuration := 24 * time.Hour sessionData := browserauth.SessionCookie{ - Email: userInfo.Email, - Provider: userInfo.ProviderType, + Email: identity.Email, + Provider: identity.ProviderType, Expires: time.Now().Add(sessionDuration), } @@ -390,7 +415,7 @@ func (h *AuthHandlers) IDPCallbackHandler(w http.ResponseWriter, r *http.Request }) log.LogInfoWithFields("auth", "Browser SSO session created", map[string]any{ - "user": userInfo.Email, + "user": identity.Email, "duration": sessionDuration, "returnURL": returnURL, }) @@ -412,7 +437,7 @@ func (h *AuthHandlers) IDPCallbackHandler(w http.ResponseWriter, r *http.Request } if needsServiceAuth { - stateData, err := h.signUpstreamOAuthState(ar, *userInfo) + stateData, err := h.signUpstreamOAuthState(ar, *identity) if err != nil { log.LogError("Failed to sign OAuth state: %v", err) h.oauthProvider.WriteAuthorizeError(ctx, w, ar, fosite.ErrServerError.WithHint("Failed to prepare service authentication")) @@ -433,7 +458,7 @@ func (h *AuthHandlers) IDPCallbackHandler(w http.ResponseWriter, r *http.Request fosite.RefreshToken: time.Now().Add(h.authConfig.RefreshTokenTTL), }, }, - UserInfo: *userInfo, + Identity: *identity, } // Accept the authorization request @@ -515,7 +540,7 @@ func (h *AuthHandlers) RegisterHandler(w http.ResponseWriter, r *http.Request) { } // Parse client request - redirectURIs, scopes, err := idp.ParseClientRequest(metadata) + redirectURIs, scopes, err := oauth.ParseClientRegistration(metadata) if err != nil { log.LogError("Client request parsing error: %v", err) jsonwriter.WriteBadRequest(w, err.Error()) @@ -564,9 +589,9 @@ func (h *AuthHandlers) RegisterHandler(w http.ResponseWriter, r *http.Request) { } // signUpstreamOAuthState signs upstream OAuth state for secure storage -func (h *AuthHandlers) signUpstreamOAuthState(ar fosite.AuthorizeRequester, userInfo idp.UserInfo) (string, error) { +func (h *AuthHandlers) signUpstreamOAuthState(ar fosite.AuthorizeRequester, identity idp.Identity) (string, error) { state := UpstreamOAuthState{ - UserInfo: userInfo, + Identity: identity, ClientID: ar.GetClient().GetID(), RedirectURI: ar.GetRedirectURI().String(), Scopes: ar.GetRequestedScopes(), @@ -586,6 +611,28 @@ func (h *AuthHandlers) verifyUpstreamOAuthState(signedState string) (*UpstreamOA return &state, nil } +// validateAccess checks whether an authenticated identity meets the configured access policy. +func (h *AuthHandlers) validateAccess(identity *idp.Identity) error { + if len(h.authConfig.AllowedDomains) > 0 && + !slices.Contains(h.authConfig.AllowedDomains, identity.Domain) { + return fmt.Errorf("domain '%s' is not allowed. Contact your administrator", identity.Domain) + } + if len(h.authConfig.IDP.AllowedOrgs) > 0 && + !hasOverlap(identity.Organizations, h.authConfig.IDP.AllowedOrgs) { + return fmt.Errorf("user is not a member of any allowed organization. Contact your administrator") + } + return nil +} + +func hasOverlap(a, b []string) bool { + for _, x := range a { + if slices.Contains(b, x) { + return true + } + } + return false +} + // ServiceSelectionHandler shows the interstitial page for selecting services to connect func (h *AuthHandlers) ServiceSelectionHandler(w http.ResponseWriter, r *http.Request) { if r.Method != http.MethodGet { @@ -606,7 +653,7 @@ func (h *AuthHandlers) ServiceSelectionHandler(w http.ResponseWriter, r *http.Re return } - userEmail := upstreamOAuthState.UserInfo.Email + userEmail := upstreamOAuthState.Identity.Email // Prepare template data returnURL := fmt.Sprintf("/oauth/services?state=%s", url.QueryEscape(signedState)) @@ -748,7 +795,7 @@ func (h *AuthHandlers) CompleteOAuthHandler(w http.ResponseWriter, r *http.Reque fosite.RefreshToken: time.Now().Add(h.authConfig.RefreshTokenTTL), }, }, - UserInfo: upstreamOAuthState.UserInfo, + Identity: upstreamOAuthState.Identity, } ar.SetSession(session) diff --git a/internal/server/auth_handlers_test.go b/internal/server/auth_handlers_test.go index e0e3a2f..23db193 100644 --- a/internal/server/auth_handlers_test.go +++ b/internal/server/auth_handlers_test.go @@ -36,8 +36,8 @@ func (m *mockIDPProvider) ExchangeCode(ctx context.Context, code string) (*oauth return &oauth2.Token{AccessToken: "test-token"}, nil } -func (m *mockIDPProvider) UserInfo(ctx context.Context, token *oauth2.Token) (*idp.UserInfo, error) { - return &idp.UserInfo{ +func (m *mockIDPProvider) UserInfo(ctx context.Context, token *oauth2.Token) (*idp.Identity, error) { + return &idp.Identity{ ProviderType: "mock", Subject: "123", Email: "test@example.com", @@ -420,3 +420,88 @@ func TestBearerTokenAuth(t *testing.T) { }) } } + +func TestValidateAccess(t *testing.T) { + tests := []struct { + name string + allowedDomains []string + allowedOrgs []string + identity *idp.Identity + wantErr bool + errContains string + }{ + { + name: "no_restrictions", + allowedDomains: nil, + allowedOrgs: nil, + identity: &idp.Identity{Domain: "any.com", Organizations: []string{"any-org"}}, + wantErr: false, + }, + { + name: "domain_allowed", + allowedDomains: []string{"company.com"}, + identity: &idp.Identity{Domain: "company.com"}, + wantErr: false, + }, + { + name: "domain_rejected", + allowedDomains: []string{"company.com"}, + identity: &idp.Identity{Domain: "other.com"}, + wantErr: true, + errContains: "domain 'other.com' is not allowed", + }, + { + name: "org_allowed", + allowedOrgs: []string{"allowed-org"}, + identity: &idp.Identity{Domain: "any.com", Organizations: []string{"allowed-org", "other-org"}}, + wantErr: false, + }, + { + name: "org_rejected", + allowedOrgs: []string{"required-org"}, + identity: &idp.Identity{Domain: "any.com", Organizations: []string{"other-org"}}, + wantErr: true, + errContains: "not a member of any allowed organization", + }, + { + name: "domain_and_org_both_pass", + allowedDomains: []string{"company.com"}, + allowedOrgs: []string{"my-org"}, + identity: &idp.Identity{Domain: "company.com", Organizations: []string{"my-org"}}, + wantErr: false, + }, + { + name: "domain_fails_before_org_check", + allowedDomains: []string{"company.com"}, + allowedOrgs: []string{"my-org"}, + identity: &idp.Identity{Domain: "other.com", Organizations: []string{"my-org"}}, + wantErr: true, + errContains: "domain 'other.com' is not allowed", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + h := &AuthHandlers{ + authConfig: config.OAuthAuthConfig{ + AllowedDomains: tt.allowedDomains, + IDP: config.IDPConfig{ + AllowedOrgs: tt.allowedOrgs, + }, + }, + } + + err := h.validateAccess(tt.identity) + + if tt.wantErr { + require.Error(t, err) + if tt.errContains != "" { + assert.Contains(t, err.Error(), tt.errContains) + } + return + } + + require.NoError(t, err) + }) + } +} From 586d04718b7175d5af5d70b8f60f596b94efb17e Mon Sep 17 00:00:00 2001 From: Sam El-Borai Date: Thu, 12 Feb 2026 10:39:15 +0100 Subject: [PATCH 4/7] Codebase quality pass: fix bugs, unify patterns, reduce complexity Fix duplicate isStdioServer with diverging implementations by adding IsStdio() method on MCPClientConfig. Unify CSRF strategy across admin and token handlers using stateless HMAC-based tokens (removes sync.Map approach). Make StoreAuthorizeRequest return an error so Firestore failures surface clearly. Use cookie.SetSession consistently (fixes SameSite mismatch from Strict to Lax for OAuth callbacks). Split http.go into focused files (user_token_service.go, session_handler.go). Extract handleBrowserCallback and handleOAuthClientCallback from 190-line IDPCallbackHandler. Add GetUser to Storage interface for O(1) admin checks. Consolidate browserauth + oauthsession into internal/session, move envutil.IsDev into config package. Replace hardcoded Firestore collection names with constants. Delete unused GetErrorName. Update CLAUDE.md to reference ./scripts/ instead of make. --- CLAUDE.md | 9 +- internal/adminauth/admin.go | 10 +- internal/browserauth/session.go | 16 -- .../{envutil/envutil.go => config/env.go} | 2 +- internal/config/types.go | 5 + internal/cookie/cookie.go | 6 +- internal/crypto/csrf.go | 61 +++++ internal/jsonrpc/errors.go | 21 -- internal/mcpfront.go | 8 +- internal/oauth/provider.go | 13 +- internal/oauthsession/session.go | 20 -- internal/server/admin_handlers.go | 71 +---- internal/server/admin_handlers_test.go | 16 +- internal/server/auth_handlers.go | 114 ++++---- internal/server/auth_handlers_test.go | 8 +- internal/server/http.go | 259 ------------------ internal/server/mcp_handler.go | 8 +- internal/server/middleware.go | 6 +- internal/server/session_handler.go | 167 +++++++++++ internal/server/token_handlers.go | 31 +-- internal/server/user_token_service.go | 102 +++++++ internal/session/session.go | 35 +++ .../{browserauth => session}/session_test.go | 25 +- internal/storage/firestore.go | 47 +++- internal/storage/memory.go | 16 +- internal/storage/storage.go | 3 +- 26 files changed, 542 insertions(+), 537 deletions(-) delete mode 100644 internal/browserauth/session.go rename internal/{envutil/envutil.go => config/env.go} (94%) create mode 100644 internal/crypto/csrf.go delete mode 100644 internal/oauthsession/session.go create mode 100644 internal/server/session_handler.go create mode 100644 internal/server/user_token_service.go create mode 100644 internal/session/session.go rename internal/{browserauth => session}/session_test.go (73%) diff --git a/CLAUDE.md b/CLAUDE.md index ec5900f..e0c17dd 100644 --- a/CLAUDE.md +++ b/CLAUDE.md @@ -245,15 +245,16 @@ When refactoring for better design: ```bash # Build everything -make build +./scripts/build # Format everything -make format +./scripts/format # Lint everything -make lint +./scripts/lint # Test mcp-front +./scripts/test go test ./internal/... -v go test ./integration -v @@ -261,7 +262,7 @@ go test ./integration -v ./mcp-front -config config.json # Start docs dev server -make doc +./scripts/docs ``` ## Documentation Site Guidelines diff --git a/internal/adminauth/admin.go b/internal/adminauth/admin.go index 5f35c18..977c1a5 100644 --- a/internal/adminauth/admin.go +++ b/internal/adminauth/admin.go @@ -22,8 +22,16 @@ func IsAdmin(ctx context.Context, email string, adminConfig *config.AdminConfig, return true } - // Check if user is a promoted admin in storage + // Check if user is a promoted admin in storage. + // Try exact match first (O(1)), fall back to case-insensitive scan + // since emails may be stored with different casing. if store != nil { + user, err := store.GetUser(ctx, normalizedEmail) + if err == nil { + return user.IsAdmin + } + + // Fallback: case-insensitive scan for emails stored with different casing users, err := store.GetAllUsers(ctx) if err == nil { for _, user := range users { diff --git a/internal/browserauth/session.go b/internal/browserauth/session.go deleted file mode 100644 index 1563f5b..0000000 --- a/internal/browserauth/session.go +++ /dev/null @@ -1,16 +0,0 @@ -package browserauth - -import "time" - -// SessionCookie represents the data stored in encrypted browser session cookies -type SessionCookie struct { - Email string `json:"email"` - Provider string `json:"provider"` // IDP that authenticated this user (e.g., "google", "azure", "github") - Expires time.Time `json:"expires"` -} - -// AuthorizationState represents the OAuth authorization code flow state parameter -type AuthorizationState struct { - Nonce string `json:"nonce"` - ReturnURL string `json:"return_url"` -} diff --git a/internal/envutil/envutil.go b/internal/config/env.go similarity index 94% rename from internal/envutil/envutil.go rename to internal/config/env.go index ef5acfc..6297707 100644 --- a/internal/envutil/envutil.go +++ b/internal/config/env.go @@ -1,4 +1,4 @@ -package envutil +package config import ( "os" diff --git a/internal/config/types.go b/internal/config/types.go index 337f16d..87d714e 100644 --- a/internal/config/types.go +++ b/internal/config/types.go @@ -187,6 +187,11 @@ type MCPClientConfig struct { InlineConfig json.RawMessage `json:"inline,omitempty"` } +// IsStdio returns true if this is a stdio-based MCP server +func (c *MCPClientConfig) IsStdio() bool { + return c.TransportType == MCPClientTypeStdio +} + // SessionConfig represents session management configuration type SessionConfig struct { Timeout time.Duration diff --git a/internal/cookie/cookie.go b/internal/cookie/cookie.go index 3be28d4..0c38ea8 100644 --- a/internal/cookie/cookie.go +++ b/internal/cookie/cookie.go @@ -4,7 +4,7 @@ import ( "net/http" "time" - "github.com/dgellow/mcp-front/internal/envutil" + "github.com/dgellow/mcp-front/internal/config" "github.com/dgellow/mcp-front/internal/log" ) @@ -16,7 +16,7 @@ const ( // SetSession sets a session cookie with appropriate security settings func SetSession(w http.ResponseWriter, value string, maxAge time.Duration) { - secure := !envutil.IsDev() + secure := !config.IsDev() http.SetCookie(w, &http.Cookie{ Name: SessionCookie, Value: value, @@ -41,7 +41,7 @@ func SetCSRF(w http.ResponseWriter, value string) { Value: value, Path: "/", HttpOnly: false, // CSRF tokens need to be readable by JavaScript - Secure: !envutil.IsDev(), + Secure: !config.IsDev(), SameSite: http.SameSiteStrictMode, MaxAge: int((24 * time.Hour).Seconds()), // 24 hours }) diff --git a/internal/crypto/csrf.go b/internal/crypto/csrf.go new file mode 100644 index 0000000..53a9dbf --- /dev/null +++ b/internal/crypto/csrf.go @@ -0,0 +1,61 @@ +package crypto + +import ( + "fmt" + "strconv" + "strings" + "time" +) + +// CSRFProtection provides stateless HMAC-based CSRF token generation and validation. +// Tokens are self-contained: nonce:timestamp:signature, with configurable expiry. +type CSRFProtection struct { + signingKey []byte + ttl time.Duration +} + +// NewCSRFProtection creates a new CSRF protection instance +func NewCSRFProtection(signingKey []byte, ttl time.Duration) CSRFProtection { + return CSRFProtection{ + signingKey: signingKey, + ttl: ttl, + } +} + +// Generate creates a new CSRF token +func (c *CSRFProtection) Generate() (string, error) { + nonce := GenerateSecureToken() + if nonce == "" { + return "", fmt.Errorf("failed to generate nonce") + } + + timestamp := strconv.FormatInt(time.Now().Unix(), 10) + data := nonce + ":" + timestamp + signature := SignData(data, c.signingKey) + + return fmt.Sprintf("%s:%s:%s", nonce, timestamp, signature), nil +} + +// Validate checks if a CSRF token is valid and not expired +func (c *CSRFProtection) Validate(token string) bool { + parts := strings.SplitN(token, ":", 3) + if len(parts) != 3 { + return false + } + + nonce := parts[0] + timestampStr := parts[1] + signature := parts[2] + + timestamp, err := strconv.ParseInt(timestampStr, 10, 64) + if err != nil { + return false + } + + if time.Since(time.Unix(timestamp, 0)) > c.ttl { + return false + } + + data := nonce + ":" + timestampStr + return ValidateSignedData(data, signature, c.signingKey) +} diff --git a/internal/jsonrpc/errors.go b/internal/jsonrpc/errors.go index ee3196c..efdf152 100644 --- a/internal/jsonrpc/errors.go +++ b/internal/jsonrpc/errors.go @@ -47,24 +47,3 @@ func NewStandardError(code int) *Error { } } -// GetErrorName returns a human-readable name for standard error codes -func GetErrorName(code int) string { - switch code { - case ParseError: - return "PARSE_ERROR" - case InvalidRequest: - return "INVALID_REQUEST" - case MethodNotFound: - return "METHOD_NOT_FOUND" - case InvalidParams: - return "INVALID_PARAMS" - case InternalError: - return "INTERNAL_ERROR" - default: - // Check if it's an MCP-specific error code - if code >= -32099 && code <= -32000 { - return "SERVER_ERROR" - } - return "UNKNOWN_ERROR" - } -} diff --git a/internal/mcpfront.go b/internal/mcpfront.go index ed0d8f0..4d890a8 100644 --- a/internal/mcpfront.go +++ b/internal/mcpfront.go @@ -380,7 +380,7 @@ func buildHTTPHandler( } // Create token handlers - tokenHandlers := server.NewTokenHandlers(storage, cfg.MCPServers, true, serviceOAuthClient) + tokenHandlers := server.NewTokenHandlers(storage, cfg.MCPServers, true, serviceOAuthClient, []byte(authConfig.EncryptionKey)) // Token management UI endpoints mux.Handle(route("/my/tokens"), server.ChainMiddleware(http.HandlerFunc(tokenHandlers.ListTokensHandler), tokenMiddleware...)) @@ -421,7 +421,7 @@ func buildHTTPHandler( } } else { // For stdio/SSE servers - if isStdioServer(serverConfig) { + if serverConfig.IsStdio() { sseServer, mcpServer, err = buildStdioSSEServer(serverName, baseURL, sessionManager) if err != nil { return nil, fmt.Errorf("failed to create SSE server for %s: %w", serverName, err) @@ -606,7 +606,3 @@ func buildStdioSSEServer(serverName, baseURL string, sessionManager *client.Stdi return sseServer, mcpServer, nil } -// isStdioServer checks if this is a stdio-based server -func isStdioServer(cfg *config.MCPClientConfig) bool { - return cfg.TransportType == config.MCPClientTypeStdio -} diff --git a/internal/oauth/provider.go b/internal/oauth/provider.go index 62ffebc..09eb568 100644 --- a/internal/oauth/provider.go +++ b/internal/oauth/provider.go @@ -12,10 +12,9 @@ import ( "github.com/dgellow/mcp-front/internal/config" "github.com/dgellow/mcp-front/internal/crypto" - "github.com/dgellow/mcp-front/internal/envutil" jsonwriter "github.com/dgellow/mcp-front/internal/json" "github.com/dgellow/mcp-front/internal/log" - "github.com/dgellow/mcp-front/internal/oauthsession" + "github.com/dgellow/mcp-front/internal/session" "github.com/dgellow/mcp-front/internal/storage" "github.com/dgellow/mcp-front/internal/urlutil" "github.com/ory/fosite" @@ -48,8 +47,8 @@ func NewOAuthProvider(oauthConfig config.OAuthAuthConfig, store storage.Storage, // Determine min parameter entropy based on environment minEntropy := 8 // Production default - enforce secure state parameters (8+ chars) - log.Logf("OAuth provider initialization - MCP_FRONT_ENV=%s, isDevelopmentMode=%v", os.Getenv("MCP_FRONT_ENV"), envutil.IsDev()) - if envutil.IsDev() { + log.Logf("OAuth provider initialization - MCP_FRONT_ENV=%s, isDevelopmentMode=%v", os.Getenv("MCP_FRONT_ENV"), config.IsDev()) + if config.IsDev() { minEntropy = 0 // Development mode - allow empty state parameters log.LogWarn("Development mode enabled - OAuth security checks relaxed (state parameter entropy: %d)", minEntropy) } @@ -158,8 +157,8 @@ func NewValidateTokenMiddleware(provider fosite.OAuth2Provider, issuer string, a // - This is documented fosite behavior, not a bug // - The actual session data must be retrieved from the returned AccessRequester // See: https://github.com/ory/fosite/issues/256 - session := &oauthsession.Session{DefaultSession: &fosite.DefaultSession{}} - _, accessRequest, err := provider.IntrospectToken(ctx, token, fosite.AccessToken, session) + oauthSession := &session.OAuthSession{DefaultSession: &fosite.DefaultSession{}} + _, accessRequest, err := provider.IntrospectToken(ctx, token, fosite.AccessToken, oauthSession) if err != nil { jsonwriter.WriteUnauthorizedRFC9728(w, "Invalid or expired token", metadataURI) return @@ -180,7 +179,7 @@ func NewValidateTokenMiddleware(provider fosite.OAuth2Provider, issuer string, a // This is the correct way to retrieve session data after token introspection var userEmail string if accessRequest != nil { - if reqSession, ok := accessRequest.GetSession().(*oauthsession.Session); ok { + if reqSession, ok := accessRequest.GetSession().(*session.OAuthSession); ok { if reqSession.Identity.Email != "" { userEmail = reqSession.Identity.Email } diff --git a/internal/oauthsession/session.go b/internal/oauthsession/session.go deleted file mode 100644 index 5b7b5b8..0000000 --- a/internal/oauthsession/session.go +++ /dev/null @@ -1,20 +0,0 @@ -package oauthsession - -import ( - "github.com/dgellow/mcp-front/internal/idp" - "github.com/ory/fosite" -) - -// Session extends DefaultSession with user information -type Session struct { - *fosite.DefaultSession - Identity idp.Identity `json:"identity"` -} - -// Clone implements fosite.Session -func (s *Session) Clone() fosite.Session { - return &Session{ - DefaultSession: s.DefaultSession.Clone().(*fosite.DefaultSession), - Identity: s.Identity, - } -} diff --git a/internal/server/admin_handlers.go b/internal/server/admin_handlers.go index b76db11..4c1c20d 100644 --- a/internal/server/admin_handlers.go +++ b/internal/server/admin_handlers.go @@ -4,8 +4,6 @@ import ( "fmt" "net/http" "net/url" - "strconv" - "strings" "time" "github.com/dgellow/mcp-front/internal/adminauth" @@ -23,7 +21,7 @@ type AdminHandlers struct { storage storage.Storage config config.Config sessionManager *client.StdioSessionManager - encryptionKey []byte // For HMAC-based CSRF tokens + csrf crypto.CSRFProtection } // NewAdminHandlers creates a new admin handlers instance @@ -32,67 +30,10 @@ func NewAdminHandlers(storage storage.Storage, config config.Config, sessionMana storage: storage, config: config, sessionManager: sessionManager, - encryptionKey: []byte(encryptionKey), + csrf: crypto.NewCSRFProtection([]byte(encryptionKey), 15*time.Minute), } } -// generateCSRFToken creates a new HMAC-based CSRF token -func (h *AdminHandlers) generateCSRFToken() (string, error) { - // Generate random nonce - nonce := crypto.GenerateSecureToken() - if nonce == "" { - return "", fmt.Errorf("failed to generate nonce") - } - - // Add timestamp (Unix seconds) - timestamp := strconv.FormatInt(time.Now().Unix(), 10) - - // Create data to sign: nonce:timestamp - data := nonce + ":" + timestamp - - // Sign with HMAC - signature := crypto.SignData(data, h.encryptionKey) - - // Return format: nonce:timestamp:signature - return fmt.Sprintf("%s:%s:%s", nonce, timestamp, signature), nil -} - -// validateCSRFToken checks if a CSRF token is valid -func (h *AdminHandlers) validateCSRFToken(token string) bool { - // Parse token format: nonce:timestamp:signature - parts := strings.SplitN(token, ":", 3) - if len(parts) != 3 { - log.LogDebug("Invalid CSRF token format") - return false - } - - nonce := parts[0] - timestampStr := parts[1] - signature := parts[2] - - // Verify timestamp (15 minute expiry) - timestamp, err := strconv.ParseInt(timestampStr, 10, 64) - if err != nil { - log.LogDebug("Invalid CSRF token timestamp: %v", err) - return false - } - - now := time.Now().Unix() - if now-timestamp > 900 { // 15 minutes - log.LogDebug("CSRF token expired") - return false - } - - // Verify HMAC signature - data := nonce + ":" + timestampStr - if !crypto.ValidateSignedData(data, signature, h.encryptionKey) { - log.LogDebug("Invalid CSRF token signature") - return false - } - - return true -} - // DashboardHandler shows the admin dashboard func (h *AdminHandlers) DashboardHandler(w http.ResponseWriter, r *http.Request) { // Only accept GET @@ -152,7 +93,7 @@ func (h *AdminHandlers) DashboardHandler(w http.ResponseWriter, r *http.Request) currentLogLevel := log.GetLogLevel() // Generate CSRF token - csrfToken, err := h.generateCSRFToken() + csrfToken, err := h.csrf.Generate() if err != nil { log.LogErrorWithFields("admin", "Failed to generate CSRF token", map[string]any{ "error": err.Error(), @@ -209,7 +150,7 @@ func (h *AdminHandlers) UserActionHandler(w http.ResponseWriter, r *http.Request } // Validate CSRF - if !h.validateCSRFToken(r.FormValue("csrf_token")) { + if !h.csrf.Validate(r.FormValue("csrf_token")) { jsonwriter.WriteForbidden(w, "Invalid CSRF token") return } @@ -379,7 +320,7 @@ func (h *AdminHandlers) SessionActionHandler(w http.ResponseWriter, r *http.Requ } // Validate CSRF - if !h.validateCSRFToken(r.FormValue("csrf_token")) { + if !h.csrf.Validate(r.FormValue("csrf_token")) { jsonwriter.WriteForbidden(w, "Invalid CSRF token") return } @@ -472,7 +413,7 @@ func (h *AdminHandlers) LoggingActionHandler(w http.ResponseWriter, r *http.Requ } // Validate CSRF - if !h.validateCSRFToken(r.FormValue("csrf_token")) { + if !h.csrf.Validate(r.FormValue("csrf_token")) { jsonwriter.WriteForbidden(w, "Invalid CSRF token") return } diff --git a/internal/server/admin_handlers_test.go b/internal/server/admin_handlers_test.go index b3587fa..a2aa8ee 100644 --- a/internal/server/admin_handlers_test.go +++ b/internal/server/admin_handlers_test.go @@ -27,7 +27,7 @@ func TestAdminHandlers_CSRF(t *testing.T) { t.Run("generate and validate CSRF token", func(t *testing.T) { // Generate token - token, err := handlers.generateCSRFToken() + token, err := handlers.csrf.Generate() if err != nil { t.Fatalf("Failed to generate CSRF token: %v", err) } @@ -38,13 +38,13 @@ func TestAdminHandlers_CSRF(t *testing.T) { } // Token should be valid immediately - if !handlers.validateCSRFToken(token) { + if !handlers.csrf.Validate(token) { t.Error("Token should be valid immediately after generation") } // Token should not be valid twice (though with HMAC it actually can be) // With HMAC-based tokens, they can be validated multiple times - if !handlers.validateCSRFToken(token) { + if !handlers.csrf.Validate(token) { t.Error("HMAC token should be valid on second validation") } }) @@ -58,7 +58,7 @@ func TestAdminHandlers_CSRF(t *testing.T) { } for _, token := range invalidTokens { - if handlers.validateCSRFToken(token) { + if handlers.csrf.Validate(token) { t.Errorf("Token '%s' should be invalid", token) } } @@ -69,7 +69,7 @@ func TestAdminHandlers_CSRF(t *testing.T) { // An expired token would have a timestamp from > 15 minutes ago expiredToken := "test-nonce:0:invalid-signature" - if handlers.validateCSRFToken(expiredToken) { + if handlers.csrf.Validate(expiredToken) { t.Error("Expired token should be invalid") } }) @@ -79,18 +79,18 @@ func TestAdminHandlers_CSRF(t *testing.T) { handlers2 := NewAdminHandlers(storage, cfg, sessionManager, "different-encryption-key-32bytes") // Generate token with first handler - token1, err := handlers.generateCSRFToken() + token1, err := handlers.csrf.Generate() if err != nil { t.Fatalf("Failed to generate token: %v", err) } // Token from handler1 should not validate with handler2 - if handlers2.validateCSRFToken(token1) { + if handlers2.csrf.Validate(token1) { t.Error("Token should not validate with different encryption key") } // But should still validate with original handler - if !handlers.validateCSRFToken(token1) { + if !handlers.csrf.Validate(token1) { t.Error("Token should validate with original handler") } }) diff --git a/internal/server/auth_handlers.go b/internal/server/auth_handlers.go index 4ff73fb..cdb94e8 100644 --- a/internal/server/auth_handlers.go +++ b/internal/server/auth_handlers.go @@ -12,15 +12,14 @@ import ( "time" "github.com/dgellow/mcp-front/internal/auth" - "github.com/dgellow/mcp-front/internal/browserauth" "github.com/dgellow/mcp-front/internal/config" + "github.com/dgellow/mcp-front/internal/cookie" "github.com/dgellow/mcp-front/internal/crypto" - "github.com/dgellow/mcp-front/internal/envutil" "github.com/dgellow/mcp-front/internal/idp" jsonwriter "github.com/dgellow/mcp-front/internal/json" "github.com/dgellow/mcp-front/internal/log" "github.com/dgellow/mcp-front/internal/oauth" - "github.com/dgellow/mcp-front/internal/oauthsession" + "github.com/dgellow/mcp-front/internal/session" "github.com/dgellow/mcp-front/internal/storage" "github.com/ory/fosite" ) @@ -224,7 +223,7 @@ func (h *AuthHandlers) AuthorizeHandler(w http.ResponseWriter, r *http.Request) // In development mode, generate a secure state parameter if missing // This works around bugs in OAuth clients that don't send state stateParam := r.URL.Query().Get("state") - if envutil.IsDev() && len(stateParam) == 0 { + if config.IsDev() && len(stateParam) == 0 { generatedState := crypto.GenerateSecureToken() log.LogWarn("Development mode: generating state parameter '%s' for buggy client", generatedState) q := r.URL.Query() @@ -276,13 +275,18 @@ func (h *AuthHandlers) AuthorizeHandler(w http.ResponseWriter, r *http.Request) } state := ar.GetState() - h.storage.StoreAuthorizeRequest(state, ar) + if err := h.storage.StoreAuthorizeRequest(state, ar); err != nil { + log.LogError("Failed to store authorize request: %v", err) + h.oauthProvider.WriteAuthorizeError(ctx, w, ar, fosite.ErrServerError.WithHint("Failed to store authorization request")) + return + } authURL := h.idpProvider.AuthURL(state) http.Redirect(w, r, authURL, http.StatusFound) } -// IDPCallbackHandler handles the callback from the identity provider +// IDPCallbackHandler handles the callback from the identity provider. +// It dispatches to handleBrowserCallback or handleOAuthClientCallback based on the flow type. func (h *AuthHandlers) IDPCallbackHandler(w http.ResponseWriter, r *http.Request) { ctx := r.Context() @@ -310,7 +314,7 @@ func (h *AuthHandlers) IDPCallbackHandler(w http.ResponseWriter, r *http.Request isBrowserFlow = true stateToken := strings.TrimPrefix(state, "browser:") - var browserState browserauth.AuthorizationState + var browserState session.AuthorizationState if err := h.oauthStateToken.Verify(stateToken, &browserState); err != nil { log.LogError("Invalid browser state: %v", err) jsonwriter.WriteBadRequest(w, "Invalid state parameter") @@ -318,7 +322,6 @@ func (h *AuthHandlers) IDPCallbackHandler(w http.ResponseWriter, r *http.Request } returnURL = browserState.ReturnURL } else { - // OAuth client flow - retrieve stored authorize request var found bool ar, found = h.storage.GetAuthorizeRequest(state) if !found { @@ -343,7 +346,6 @@ func (h *AuthHandlers) IDPCallbackHandler(w http.ResponseWriter, r *http.Request return } - // Fetch user identity from IDP identity, err := h.idpProvider.UserInfo(ctx, token) if err != nil { log.LogError("Failed to fetch user identity: %v", err) @@ -355,7 +357,6 @@ func (h *AuthHandlers) IDPCallbackHandler(w http.ResponseWriter, r *http.Request return } - // Validate access (domain/org restrictions) if err := h.validateAccess(identity); err != nil { log.LogError("Access denied: %v", err) if !isBrowserFlow && ar != nil { @@ -368,7 +369,6 @@ func (h *AuthHandlers) IDPCallbackHandler(w http.ResponseWriter, r *http.Request log.Logf("User authenticated: %s", identity.Email) - // Store user in database if err := h.storage.UpsertUser(ctx, identity.Email); err != nil { log.LogWarnWithFields("auth", "Failed to track user", map[string]any{ "email": identity.Email, @@ -377,55 +377,51 @@ func (h *AuthHandlers) IDPCallbackHandler(w http.ResponseWriter, r *http.Request } if isBrowserFlow { - // Browser SSO flow - set encrypted session cookie - // Browser sessions should last longer than API tokens for better UX - sessionDuration := 24 * time.Hour - - sessionData := browserauth.SessionCookie{ - Email: identity.Email, - Provider: identity.ProviderType, - Expires: time.Now().Add(sessionDuration), - } - - // Marshal session data to JSON - jsonData, err := json.Marshal(sessionData) - if err != nil { - log.LogError("Failed to marshal session data: %v", err) - jsonwriter.WriteInternalServerError(w, "Failed to create session") - return - } + h.handleBrowserCallback(w, r, identity, returnURL) + } else { + h.handleOAuthClientCallback(ctx, w, r, ar, identity) + } +} - // Encrypt session data - encryptedData, err := h.sessionEncryptor.Encrypt(string(jsonData)) - if err != nil { - log.LogError("Failed to encrypt session: %v", err) - jsonwriter.WriteInternalServerError(w, "Failed to create session") - return - } +// handleBrowserCallback handles the browser SSO callback flow: creates an encrypted +// session cookie and redirects to the return URL. +func (h *AuthHandlers) handleBrowserCallback(w http.ResponseWriter, r *http.Request, identity *idp.Identity, returnURL string) { + sessionDuration := 24 * time.Hour - // Set secure session cookie - http.SetCookie(w, &http.Cookie{ - Name: "mcp_session", - Value: encryptedData, - Path: "/", - HttpOnly: true, - Secure: !envutil.IsDev(), - SameSite: http.SameSiteStrictMode, - MaxAge: int(sessionDuration.Seconds()), - }) + sessionData := session.BrowserCookie{ + Email: identity.Email, + Provider: identity.ProviderType, + Expires: time.Now().Add(sessionDuration), + } - log.LogInfoWithFields("auth", "Browser SSO session created", map[string]any{ - "user": identity.Email, - "duration": sessionDuration, - "returnURL": returnURL, - }) + jsonData, err := json.Marshal(sessionData) + if err != nil { + log.LogError("Failed to marshal session data: %v", err) + jsonwriter.WriteInternalServerError(w, "Failed to create session") + return + } - // Redirect to return URL - http.Redirect(w, r, returnURL, http.StatusFound) + encryptedData, err := h.sessionEncryptor.Encrypt(string(jsonData)) + if err != nil { + log.LogError("Failed to encrypt session: %v", err) + jsonwriter.WriteInternalServerError(w, "Failed to create session") return } - // OAuth client flow - check if any services need OAuth + cookie.SetSession(w, encryptedData, sessionDuration) + + log.LogInfoWithFields("auth", "Browser SSO session created", map[string]any{ + "user": identity.Email, + "duration": sessionDuration, + "returnURL": returnURL, + }) + + http.Redirect(w, r, returnURL, http.StatusFound) +} + +// handleOAuthClientCallback handles the OAuth client callback flow: checks for +// service auth needs, creates a fosite session, and issues the authorize response. +func (h *AuthHandlers) handleOAuthClientCallback(ctx context.Context, w http.ResponseWriter, r *http.Request, ar fosite.AuthorizeRequester, identity *idp.Identity) { needsServiceAuth := false for _, serverConfig := range h.mcpServers { if serverConfig.RequiresUserToken && @@ -448,10 +444,7 @@ func (h *AuthHandlers) IDPCallbackHandler(w http.ResponseWriter, r *http.Request return } - // Create session for token issuance - // Note: Audience claims are stored in the authorize request (ar.GetGrantedAudience()) - // and will be automatically propagated to access tokens by fosite - session := &oauthsession.Session{ + session := &session.OAuthSession{ DefaultSession: &fosite.DefaultSession{ ExpiresAt: map[fosite.TokenType]time.Time{ fosite.AccessToken: time.Now().Add(h.authConfig.TokenTTL), @@ -461,7 +454,6 @@ func (h *AuthHandlers) IDPCallbackHandler(w http.ResponseWriter, r *http.Request Identity: *identity, } - // Accept the authorization request response, err := h.oauthProvider.NewAuthorizeResponse(ctx, ar, session) if err != nil { log.LogError("Authorize response error: %v", err) @@ -480,7 +472,7 @@ func (h *AuthHandlers) TokenHandler(w http.ResponseWriter, r *http.Request) { // Create session for the token exchange // Note: We create our custom Session type here, and fosite will populate it // with the session data from the authorization code during NewAccessRequest - session := &oauthsession.Session{DefaultSession: &fosite.DefaultSession{}} + session := &session.OAuthSession{DefaultSession: &fosite.DefaultSession{}} // Handle token request - this retrieves the session from the authorization code accessRequest, err := h.oauthProvider.NewAccessRequest(ctx, r, session) @@ -775,7 +767,7 @@ func (h *AuthHandlers) CompleteOAuthHandler(w http.ResponseWriter, r *http.Reque Client: client, RequestedScope: upstreamOAuthState.Scopes, GrantedScope: upstreamOAuthState.Scopes, - Session: &oauthsession.Session{DefaultSession: &fosite.DefaultSession{}}, + Session: &session.OAuthSession{DefaultSession: &fosite.DefaultSession{}}, }, } @@ -788,7 +780,7 @@ func (h *AuthHandlers) CompleteOAuthHandler(w http.ResponseWriter, r *http.Reque ar.RedirectURI = redirectURI // Create session with user info - session := &oauthsession.Session{ + session := &session.OAuthSession{ DefaultSession: &fosite.DefaultSession{ ExpiresAt: map[fosite.TokenType]time.Time{ fosite.AccessToken: time.Now().Add(h.authConfig.TokenTTL), diff --git a/internal/server/auth_handlers_test.go b/internal/server/auth_handlers_test.go index 23db193..a8cd625 100644 --- a/internal/server/auth_handlers_test.go +++ b/internal/server/auth_handlers_test.go @@ -10,11 +10,11 @@ import ( "time" "github.com/dgellow/mcp-front/internal/auth" - "github.com/dgellow/mcp-front/internal/browserauth" "github.com/dgellow/mcp-front/internal/config" "github.com/dgellow/mcp-front/internal/crypto" "github.com/dgellow/mcp-front/internal/idp" "github.com/dgellow/mcp-front/internal/oauth" + "github.com/dgellow/mcp-front/internal/session" "github.com/dgellow/mcp-front/internal/storage" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" @@ -123,7 +123,7 @@ func TestAuthenticationBoundaries(t *testing.T) { serviceOAuthClient, ) - tokenHandlers := NewTokenHandlers(store, map[string]*config.MCPClientConfig{}, true, serviceOAuthClient) + tokenHandlers := NewTokenHandlers(store, map[string]*config.MCPClientConfig{}, true, serviceOAuthClient, []byte(oauthConfig.EncryptionKey)) // Build mux with middlewares mux := http.NewServeMux() @@ -190,7 +190,7 @@ func TestAuthenticationBoundaries(t *testing.T) { // Test with valid session cookie (if auth is expected) if tt.expectAuth { // Create session data - sessionData := browserauth.SessionCookie{ + sessionData := session.BrowserCookie{ Email: "test@example.com", Provider: "mock", Expires: time.Now().Add(24 * time.Hour), @@ -314,7 +314,7 @@ func TestOAuthEndpointHandlers(t *testing.T) { req := httptest.NewRequest(http.MethodGet, "/oauth/services", nil) // Add valid session cookie - sessionData := browserauth.SessionCookie{ + sessionData := session.BrowserCookie{ Email: "test@example.com", Expires: time.Now().Add(24 * time.Hour), } diff --git a/internal/server/http.go b/internal/server/http.go index 3bbe7dc..6a0dc68 100644 --- a/internal/server/http.go +++ b/internal/server/http.go @@ -4,86 +4,10 @@ import ( "context" "errors" "net/http" - "strings" - "time" - "github.com/dgellow/mcp-front/internal/auth" - "github.com/dgellow/mcp-front/internal/client" - "github.com/dgellow/mcp-front/internal/config" "github.com/dgellow/mcp-front/internal/log" - "github.com/dgellow/mcp-front/internal/storage" - mcpserver "github.com/mark3labs/mcp-go/server" ) -// UserTokenService handles user token retrieval and OAuth refresh -type UserTokenService struct { - storage storage.Storage - serviceOAuthClient *auth.ServiceOAuthClient -} - -// NewUserTokenService creates a new user token service -func NewUserTokenService(storage storage.Storage, serviceOAuthClient *auth.ServiceOAuthClient) *UserTokenService { - return &UserTokenService{ - storage: storage, - serviceOAuthClient: serviceOAuthClient, - } -} - -// GetUserToken retrieves and formats a user token for a service, handling OAuth refresh. -// -// Token refresh strategy: Optimistic continuation on failure. -// If refresh fails, we log a warning and continue with the current token. The external -// service will reject the expired token with 401, giving the user a clear error. -// This is acceptable because: (1) refresh failures are rare (network issues, revoked -// tokens), and (2) forcing users to re-auth is better than silently hiding auth issues. -func (uts *UserTokenService) GetUserToken(ctx context.Context, userEmail, serviceName string, serviceConfig *config.MCPClientConfig) (string, error) { - storedToken, err := uts.storage.GetUserToken(ctx, userEmail, serviceName) - if err != nil { - return "", err - } - - switch storedToken.Type { - case storage.TokenTypeManual: - // Token is already in storedToken.Value, formatUserToken will handle it - break - case storage.TokenTypeOAuth: - if storedToken.OAuthData != nil && uts.serviceOAuthClient != nil { - if err := uts.serviceOAuthClient.RefreshToken(ctx, userEmail, serviceName, serviceConfig); err != nil { - log.LogWarnWithFields("user_token", "Failed to refresh OAuth token", map[string]any{ - "service": serviceName, - "user": userEmail, - "error": err.Error(), - }) - // Continue with current token - the service will handle auth failure - } else { - // Re-fetch the updated token after refresh - refreshedToken, err := uts.storage.GetUserToken(ctx, userEmail, serviceName) - if err != nil { - log.LogErrorWithFields("user_token", "Failed to fetch token after successful refresh", map[string]any{ - "service": serviceName, - "user": userEmail, - "error": err.Error(), - }) - // Continue with original token - the service will handle auth failure - } else { - storedToken = refreshedToken - var expiresAt time.Time - if refreshedToken.OAuthData != nil { - expiresAt = refreshedToken.OAuthData.ExpiresAt - } - log.LogInfoWithFields("user_token", "OAuth token refreshed and updated", map[string]any{ - "service": serviceName, - "user": userEmail, - "expiresAt": expiresAt, - }) - } - } - } - } - - return formatUserToken(storedToken, serviceConfig.UserAuthentication), nil -} - // HTTPServer manages the HTTP server lifecycle type HTTPServer struct { server *http.Server @@ -99,8 +23,6 @@ func NewHTTPServer(handler http.Handler, addr string) *HTTPServer { } } -// Handler builders and mux assembly - // HealthHandler handles health check requests type HealthHandler struct{} @@ -143,184 +65,3 @@ func (h *HTTPServer) Stop(ctx context.Context) error { }) return nil } - -// isStdioServer checks if this is a stdio-based server -func isStdioServer(config *config.MCPClientConfig) bool { - return config.Command != "" -} - -// formatUserToken formats a stored token according to the user authentication configuration -func formatUserToken(storedToken *storage.StoredToken, auth *config.UserAuthentication) string { - if storedToken == nil { - return "" - } - - if storedToken.Type == storage.TokenTypeOAuth && storedToken.OAuthData != nil { - token := storedToken.OAuthData.AccessToken - if auth.TokenFormat != "" && auth.TokenFormat != "{{token}}" { - return strings.ReplaceAll(auth.TokenFormat, "{{token}}", token) - } - return token - } - - token := storedToken.Value - if auth != nil && auth.TokenFormat != "" && auth.TokenFormat != "{{token}}" { - return strings.ReplaceAll(auth.TokenFormat, "{{token}}", token) - } - return token -} - -// SessionHandlerKey is the context key for session handlers -type SessionHandlerKey struct{} - -// SessionRequestHandler handles session-specific logic for a request -type SessionRequestHandler struct { - h *MCPHandler - userEmail string - config *config.MCPClientConfig - mcpServer *mcpserver.MCPServer // The shared MCP server -} - -// NewSessionRequestHandler creates a new session request handler with all dependencies -func NewSessionRequestHandler(h *MCPHandler, userEmail string, config *config.MCPClientConfig, mcpServer *mcpserver.MCPServer) *SessionRequestHandler { - return &SessionRequestHandler{ - h: h, - userEmail: userEmail, - config: config, - mcpServer: mcpServer, - } -} - -// GetUserEmail returns the user email for this session -func (s *SessionRequestHandler) GetUserEmail() string { - return s.userEmail -} - -// GetServerName returns the server name for this session -func (s *SessionRequestHandler) GetServerName() string { - return s.h.serverName -} - -// GetStorage returns the storage interface -func (s *SessionRequestHandler) GetStorage() storage.Storage { - return s.h.storage -} - -// HandleSessionRegistration handles the registration of a new session -func HandleSessionRegistration( - sessionCtx context.Context, - session mcpserver.ClientSession, - handler *SessionRequestHandler, - sessionManager *client.StdioSessionManager, -) { - // Create stdio process for this session - key := client.SessionKey{ - UserEmail: handler.userEmail, - ServerName: handler.h.serverName, - SessionID: session.SessionID(), - } - - log.LogDebugWithFields("server", "Registering session", map[string]any{ - "sessionID": session.SessionID(), - "server": handler.h.serverName, - "user": handler.userEmail, - }) - - log.LogTraceWithFields("server", "Session registration started", map[string]any{ - "sessionID": session.SessionID(), - "server": handler.h.serverName, - "user": handler.userEmail, - "requiresUserToken": handler.config.RequiresUserToken, - "transportType": handler.config.TransportType, - "command": handler.config.Command, - }) - - var userToken string - if handler.config.RequiresUserToken && handler.userEmail != "" && handler.h.storage != nil { - storedToken, err := handler.h.storage.GetUserToken(sessionCtx, handler.userEmail, handler.h.serverName) - if err != nil { - log.LogDebugWithFields("server", "No user token found", map[string]any{ - "server": handler.h.serverName, - "user": handler.userEmail, - }) - } else if storedToken != nil { - if handler.config.UserAuthentication != nil { - userToken = formatUserToken(storedToken, handler.config.UserAuthentication) - } else { - userToken = storedToken.Value - } - } - } - - stdioSession, err := sessionManager.GetOrCreateSession( - sessionCtx, - key, - handler.config, - handler.h.info, - handler.h.setupBaseURL, - userToken, - ) - if err != nil { - log.LogErrorWithFields("server", "Failed to create stdio session", map[string]any{ - "error": err.Error(), - "sessionID": session.SessionID(), - "server": handler.h.serverName, - "user": handler.userEmail, - }) - return - } - - // Discover and register capabilities from the stdio process - if err := stdioSession.DiscoverAndRegisterCapabilities( - sessionCtx, - handler.mcpServer, - handler.userEmail, - handler.config.RequiresUserToken, - handler.h.storage, - handler.h.serverName, - handler.h.setupBaseURL, - handler.config.UserAuthentication, - session, - ); err != nil { - log.LogErrorWithFields("server", "Failed to discover and register capabilities", map[string]any{ - "error": err.Error(), - "sessionID": session.SessionID(), - "server": handler.h.serverName, - "user": handler.userEmail, - }) - if err := sessionManager.RemoveSession(key); err != nil { - log.LogErrorWithFields("server", "Failed to remove session on capability failure", map[string]any{ - "sessionID": session.SessionID(), - "server": handler.h.serverName, - "user": handler.userEmail, - "error": err.Error(), - }) - } - return - } - - if handler.userEmail != "" { - if handler.h.storage != nil { - activeSession := storage.ActiveSession{ - SessionID: session.SessionID(), - UserEmail: handler.userEmail, - ServerName: handler.h.serverName, - Created: time.Now(), - LastActive: time.Now(), - } - if err := handler.h.storage.TrackSession(sessionCtx, activeSession); err != nil { - log.LogWarnWithFields("server", "Failed to track session", map[string]any{ - "error": err.Error(), - "sessionID": session.SessionID(), - "user": handler.userEmail, - }) - } - } - } - - log.LogInfoWithFields("server", "Session successfully created and connected", map[string]any{ - "sessionID": session.SessionID(), - "server": handler.h.serverName, - "user": handler.userEmail, - }) -} diff --git a/internal/server/mcp_handler.go b/internal/server/mcp_handler.go index 1f517ec..c0bc2ee 100644 --- a/internal/server/mcp_handler.go +++ b/internal/server/mcp_handler.go @@ -122,7 +122,7 @@ func (h *MCPHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) { log.LogInfoWithFields("mcp", "Handling message request", map[string]any{ "path": r.URL.Path, "server": h.serverName, - "isStdio": isStdioServer(serverConfig), + "isStdio": serverConfig.IsStdio(), "user": userEmail, "remoteAddr": r.RemoteAddr, "contentLength": r.ContentLength, @@ -133,7 +133,7 @@ func (h *MCPHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) { log.LogInfoWithFields("mcp", "Handling SSE request", map[string]any{ "path": r.URL.Path, "server": h.serverName, - "isStdio": isStdioServer(serverConfig), + "isStdio": serverConfig.IsStdio(), "user": userEmail, "remoteAddr": r.RemoteAddr, "userAgent": r.UserAgent(), @@ -168,7 +168,7 @@ func (h *MCPHandler) trackUserAccess(ctx context.Context, userEmail string) { func (h *MCPHandler) handleSSERequest(ctx context.Context, w http.ResponseWriter, r *http.Request, userEmail string, config *config.MCPClientConfig) { h.trackUserAccess(ctx, userEmail) - if !isStdioServer(config) { + if !config.IsStdio() { // For non-stdio servers, handle normally h.handleNonStdioSSERequest(ctx, w, r, userEmail, config) return @@ -206,7 +206,7 @@ func (h *MCPHandler) handleSSERequest(ctx context.Context, w http.ResponseWriter func (h *MCPHandler) handleMessageRequest(ctx context.Context, w http.ResponseWriter, r *http.Request, userEmail string, config *config.MCPClientConfig) { h.trackUserAccess(ctx, userEmail) - if isStdioServer(config) { + if config.IsStdio() { sessionID := r.URL.Query().Get("sessionId") if sessionID == "" { jsonrpc.WriteError(w, nil, jsonrpc.InvalidParams, "missing sessionId") diff --git a/internal/server/middleware.go b/internal/server/middleware.go index 0e53ba0..c207e44 100644 --- a/internal/server/middleware.go +++ b/internal/server/middleware.go @@ -10,7 +10,6 @@ import ( "time" "github.com/dgellow/mcp-front/internal/adminauth" - "github.com/dgellow/mcp-front/internal/browserauth" "github.com/dgellow/mcp-front/internal/config" "github.com/dgellow/mcp-front/internal/cookie" "github.com/dgellow/mcp-front/internal/crypto" @@ -19,6 +18,7 @@ import ( "github.com/dgellow/mcp-front/internal/log" "github.com/dgellow/mcp-front/internal/oauth" "github.com/dgellow/mcp-front/internal/servicecontext" + "github.com/dgellow/mcp-front/internal/session" "github.com/dgellow/mcp-front/internal/storage" "golang.org/x/crypto/bcrypt" ) @@ -319,7 +319,7 @@ func NewBrowserSSOMiddleware(authConfig config.OAuthAuthConfig, idpProvider idp. } // Parse session data - var sessionData browserauth.SessionCookie + var sessionData session.BrowserCookie if err := json.NewDecoder(strings.NewReader(decrypted)).Decode(&sessionData); err != nil { // Invalid format cookie.ClearSession(w) @@ -353,7 +353,7 @@ func NewBrowserSSOMiddleware(authConfig config.OAuthAuthConfig, idpProvider idp. // generateBrowserState creates a secure state parameter for browser SSO func generateBrowserState(browserStateToken *crypto.TokenSigner, returnURL string) string { - state := browserauth.AuthorizationState{ + state := session.AuthorizationState{ Nonce: crypto.GenerateSecureToken(), ReturnURL: returnURL, } diff --git a/internal/server/session_handler.go b/internal/server/session_handler.go new file mode 100644 index 0000000..58334f6 --- /dev/null +++ b/internal/server/session_handler.go @@ -0,0 +1,167 @@ +package server + +import ( + "context" + "time" + + "github.com/dgellow/mcp-front/internal/client" + "github.com/dgellow/mcp-front/internal/config" + "github.com/dgellow/mcp-front/internal/log" + "github.com/dgellow/mcp-front/internal/storage" + mcpserver "github.com/mark3labs/mcp-go/server" +) + +// SessionHandlerKey is the context key for session handlers +type SessionHandlerKey struct{} + +// SessionRequestHandler handles session-specific logic for a request +type SessionRequestHandler struct { + h *MCPHandler + userEmail string + config *config.MCPClientConfig + mcpServer *mcpserver.MCPServer // The shared MCP server +} + +// NewSessionRequestHandler creates a new session request handler with all dependencies +func NewSessionRequestHandler(h *MCPHandler, userEmail string, config *config.MCPClientConfig, mcpServer *mcpserver.MCPServer) *SessionRequestHandler { + return &SessionRequestHandler{ + h: h, + userEmail: userEmail, + config: config, + mcpServer: mcpServer, + } +} + +// GetUserEmail returns the user email for this session +func (s *SessionRequestHandler) GetUserEmail() string { + return s.userEmail +} + +// GetServerName returns the server name for this session +func (s *SessionRequestHandler) GetServerName() string { + return s.h.serverName +} + +// GetStorage returns the storage interface +func (s *SessionRequestHandler) GetStorage() storage.Storage { + return s.h.storage +} + +// HandleSessionRegistration handles the registration of a new session +func HandleSessionRegistration( + sessionCtx context.Context, + session mcpserver.ClientSession, + handler *SessionRequestHandler, + sessionManager *client.StdioSessionManager, +) { + // Create stdio process for this session + key := client.SessionKey{ + UserEmail: handler.userEmail, + ServerName: handler.h.serverName, + SessionID: session.SessionID(), + } + + log.LogDebugWithFields("server", "Registering session", map[string]any{ + "sessionID": session.SessionID(), + "server": handler.h.serverName, + "user": handler.userEmail, + }) + + log.LogTraceWithFields("server", "Session registration started", map[string]any{ + "sessionID": session.SessionID(), + "server": handler.h.serverName, + "user": handler.userEmail, + "requiresUserToken": handler.config.RequiresUserToken, + "transportType": handler.config.TransportType, + "command": handler.config.Command, + }) + + var userToken string + if handler.config.RequiresUserToken && handler.userEmail != "" && handler.h.storage != nil { + storedToken, err := handler.h.storage.GetUserToken(sessionCtx, handler.userEmail, handler.h.serverName) + if err != nil { + log.LogDebugWithFields("server", "No user token found", map[string]any{ + "server": handler.h.serverName, + "user": handler.userEmail, + }) + } else if storedToken != nil { + if handler.config.UserAuthentication != nil { + userToken = formatUserToken(storedToken, handler.config.UserAuthentication) + } else { + userToken = storedToken.Value + } + } + } + + stdioSession, err := sessionManager.GetOrCreateSession( + sessionCtx, + key, + handler.config, + handler.h.info, + handler.h.setupBaseURL, + userToken, + ) + if err != nil { + log.LogErrorWithFields("server", "Failed to create stdio session", map[string]any{ + "error": err.Error(), + "sessionID": session.SessionID(), + "server": handler.h.serverName, + "user": handler.userEmail, + }) + return + } + + // Discover and register capabilities from the stdio process + if err := stdioSession.DiscoverAndRegisterCapabilities( + sessionCtx, + handler.mcpServer, + handler.userEmail, + handler.config.RequiresUserToken, + handler.h.storage, + handler.h.serverName, + handler.h.setupBaseURL, + handler.config.UserAuthentication, + session, + ); err != nil { + log.LogErrorWithFields("server", "Failed to discover and register capabilities", map[string]any{ + "error": err.Error(), + "sessionID": session.SessionID(), + "server": handler.h.serverName, + "user": handler.userEmail, + }) + if err := sessionManager.RemoveSession(key); err != nil { + log.LogErrorWithFields("server", "Failed to remove session on capability failure", map[string]any{ + "sessionID": session.SessionID(), + "server": handler.h.serverName, + "user": handler.userEmail, + "error": err.Error(), + }) + } + return + } + + if handler.userEmail != "" { + if handler.h.storage != nil { + activeSession := storage.ActiveSession{ + SessionID: session.SessionID(), + UserEmail: handler.userEmail, + ServerName: handler.h.serverName, + Created: time.Now(), + LastActive: time.Now(), + } + if err := handler.h.storage.TrackSession(sessionCtx, activeSession); err != nil { + log.LogWarnWithFields("server", "Failed to track session", map[string]any{ + "error": err.Error(), + "sessionID": session.SessionID(), + "user": handler.userEmail, + }) + } + } + } + + log.LogInfoWithFields("server", "Session successfully created and connected", map[string]any{ + "sessionID": session.SessionID(), + "server": handler.h.serverName, + "user": handler.userEmail, + }) +} diff --git a/internal/server/token_handlers.go b/internal/server/token_handlers.go index b2cbdaa..cf33a00 100644 --- a/internal/server/token_handlers.go +++ b/internal/server/token_handlers.go @@ -4,7 +4,6 @@ import ( "fmt" "net/http" "strings" - "sync" "time" "github.com/dgellow/mcp-front/internal/auth" @@ -20,40 +19,22 @@ import ( type TokenHandlers struct { tokenStore storage.UserTokenStore mcpServers map[string]*config.MCPClientConfig - csrfTokens sync.Map // Thread-safe CSRF token storage + csrf crypto.CSRFProtection oauthEnabled bool serviceOAuthClient *auth.ServiceOAuthClient } // NewTokenHandlers creates a new token handlers instance -func NewTokenHandlers(tokenStore storage.UserTokenStore, mcpServers map[string]*config.MCPClientConfig, oauthEnabled bool, serviceOAuthClient *auth.ServiceOAuthClient) *TokenHandlers { +func NewTokenHandlers(tokenStore storage.UserTokenStore, mcpServers map[string]*config.MCPClientConfig, oauthEnabled bool, serviceOAuthClient *auth.ServiceOAuthClient, csrfKey []byte) *TokenHandlers { return &TokenHandlers{ tokenStore: tokenStore, mcpServers: mcpServers, oauthEnabled: oauthEnabled, serviceOAuthClient: serviceOAuthClient, + csrf: crypto.NewCSRFProtection(csrfKey, 15*time.Minute), } } -// generateCSRFToken creates a new CSRF token -func (h *TokenHandlers) generateCSRFToken() (string, error) { - token := crypto.GenerateSecureToken() - if token == "" { - return "", fmt.Errorf("failed to generate CSRF token") - } - h.csrfTokens.Store(token, true) - return token, nil -} - -// validateCSRFToken checks if a CSRF token is valid -func (h *TokenHandlers) validateCSRFToken(token string) bool { - if _, exists := h.csrfTokens.LoadAndDelete(token); exists { - // One-time use via LoadAndDelete - return true - } - return false -} - // ListTokensHandler shows the token management page func (h *TokenHandlers) ListTokensHandler(w http.ResponseWriter, r *http.Request) { // Only accept GET @@ -137,7 +118,7 @@ func (h *TokenHandlers) ListTokensHandler(w http.ResponseWriter, r *http.Request } // Generate CSRF token - csrfToken, err := h.generateCSRFToken() + csrfToken, err := h.csrf.Generate() if err != nil { log.LogErrorWithFields("token", "Failed to generate CSRF token", map[string]any{ "error": err.Error(), @@ -188,7 +169,7 @@ func (h *TokenHandlers) SetTokenHandler(w http.ResponseWriter, r *http.Request) // Validate CSRF token csrfToken := r.FormValue("csrf_token") - if !h.validateCSRFToken(csrfToken) { + if !h.csrf.Validate(csrfToken) { jsonwriter.WriteForbidden(w, "Invalid CSRF token") return } @@ -304,7 +285,7 @@ func (h *TokenHandlers) DeleteTokenHandler(w http.ResponseWriter, r *http.Reques // Validate CSRF token csrfToken := r.FormValue("csrf_token") - if !h.validateCSRFToken(csrfToken) { + if !h.csrf.Validate(csrfToken) { jsonwriter.WriteForbidden(w, "Invalid CSRF token") return } diff --git a/internal/server/user_token_service.go b/internal/server/user_token_service.go new file mode 100644 index 0000000..7e14d57 --- /dev/null +++ b/internal/server/user_token_service.go @@ -0,0 +1,102 @@ +package server + +import ( + "context" + "strings" + "time" + + "github.com/dgellow/mcp-front/internal/auth" + "github.com/dgellow/mcp-front/internal/config" + "github.com/dgellow/mcp-front/internal/log" + "github.com/dgellow/mcp-front/internal/storage" +) + +// UserTokenService handles user token retrieval and OAuth refresh +type UserTokenService struct { + storage storage.Storage + serviceOAuthClient *auth.ServiceOAuthClient +} + +// NewUserTokenService creates a new user token service +func NewUserTokenService(storage storage.Storage, serviceOAuthClient *auth.ServiceOAuthClient) *UserTokenService { + return &UserTokenService{ + storage: storage, + serviceOAuthClient: serviceOAuthClient, + } +} + +// GetUserToken retrieves and formats a user token for a service, handling OAuth refresh. +// +// Token refresh strategy: Optimistic continuation on failure. +// If refresh fails, we log a warning and continue with the current token. The external +// service will reject the expired token with 401, giving the user a clear error. +// This is acceptable because: (1) refresh failures are rare (network issues, revoked +// tokens), and (2) forcing users to re-auth is better than silently hiding auth issues. +func (uts *UserTokenService) GetUserToken(ctx context.Context, userEmail, serviceName string, serviceConfig *config.MCPClientConfig) (string, error) { + storedToken, err := uts.storage.GetUserToken(ctx, userEmail, serviceName) + if err != nil { + return "", err + } + + switch storedToken.Type { + case storage.TokenTypeManual: + // Token is already in storedToken.Value, formatUserToken will handle it + break + case storage.TokenTypeOAuth: + if storedToken.OAuthData != nil && uts.serviceOAuthClient != nil { + if err := uts.serviceOAuthClient.RefreshToken(ctx, userEmail, serviceName, serviceConfig); err != nil { + log.LogWarnWithFields("user_token", "Failed to refresh OAuth token", map[string]any{ + "service": serviceName, + "user": userEmail, + "error": err.Error(), + }) + // Continue with current token - the service will handle auth failure + } else { + // Re-fetch the updated token after refresh + refreshedToken, err := uts.storage.GetUserToken(ctx, userEmail, serviceName) + if err != nil { + log.LogErrorWithFields("user_token", "Failed to fetch token after successful refresh", map[string]any{ + "service": serviceName, + "user": userEmail, + "error": err.Error(), + }) + // Continue with original token - the service will handle auth failure + } else { + storedToken = refreshedToken + var expiresAt time.Time + if refreshedToken.OAuthData != nil { + expiresAt = refreshedToken.OAuthData.ExpiresAt + } + log.LogInfoWithFields("user_token", "OAuth token refreshed and updated", map[string]any{ + "service": serviceName, + "user": userEmail, + "expiresAt": expiresAt, + }) + } + } + } + } + + return formatUserToken(storedToken, serviceConfig.UserAuthentication), nil +} + +// formatUserToken formats a stored token according to the user authentication configuration +func formatUserToken(storedToken *storage.StoredToken, auth *config.UserAuthentication) string { + if storedToken == nil { + return "" + } + + if storedToken.Type == storage.TokenTypeOAuth && storedToken.OAuthData != nil { + token := storedToken.OAuthData.AccessToken + if auth.TokenFormat != "" && auth.TokenFormat != "{{token}}" { + return strings.ReplaceAll(auth.TokenFormat, "{{token}}", token) + } + return token + } + + token := storedToken.Value + if auth != nil && auth.TokenFormat != "" && auth.TokenFormat != "{{token}}" { + return strings.ReplaceAll(auth.TokenFormat, "{{token}}", token) + } + return token +} diff --git a/internal/session/session.go b/internal/session/session.go new file mode 100644 index 0000000..226b160 --- /dev/null +++ b/internal/session/session.go @@ -0,0 +1,35 @@ +package session + +import ( + "time" + + "github.com/dgellow/mcp-front/internal/idp" + "github.com/ory/fosite" +) + +// OAuthSession extends DefaultSession with user information for the OAuth flow +type OAuthSession struct { + *fosite.DefaultSession + Identity idp.Identity `json:"identity"` +} + +// Clone implements fosite.Session +func (s *OAuthSession) Clone() fosite.Session { + return &OAuthSession{ + DefaultSession: s.DefaultSession.Clone().(*fosite.DefaultSession), + Identity: s.Identity, + } +} + +// BrowserCookie represents the data stored in encrypted browser session cookies +type BrowserCookie struct { + Email string `json:"email"` + Provider string `json:"provider"` // IDP that authenticated this user (e.g., "google", "azure", "github") + Expires time.Time `json:"expires"` +} + +// AuthorizationState represents the OAuth authorization code flow state parameter +type AuthorizationState struct { + Nonce string `json:"nonce"` + ReturnURL string `json:"return_url"` +} diff --git a/internal/browserauth/session_test.go b/internal/session/session_test.go similarity index 73% rename from internal/browserauth/session_test.go rename to internal/session/session_test.go index dace5bb..0294040 100644 --- a/internal/browserauth/session_test.go +++ b/internal/session/session_test.go @@ -1,4 +1,4 @@ -package browserauth +package session import ( "encoding/json" @@ -9,47 +9,42 @@ import ( "github.com/stretchr/testify/require" ) -func TestSessionCookie_MarshalUnmarshal(t *testing.T) { - original := SessionCookie{ +func TestBrowserCookie_MarshalUnmarshal(t *testing.T) { + original := BrowserCookie{ Email: "user@example.com", Provider: "google", Expires: time.Now().Add(24 * time.Hour).Truncate(time.Second), } - // Marshal to JSON data, err := json.Marshal(original) require.NoError(t, err) - // Unmarshal back - var unmarshaled SessionCookie + var unmarshaled BrowserCookie err = json.Unmarshal(data, &unmarshaled) require.NoError(t, err) - // Truncate for comparison (JSON time serialization) assert.Equal(t, original.Email, unmarshaled.Email) assert.Equal(t, original.Provider, unmarshaled.Provider) assert.WithinDuration(t, original.Expires, unmarshaled.Expires, time.Second) } -func TestSessionCookie_Expiry(t *testing.T) { +func TestBrowserCookie_Expiry(t *testing.T) { t.Run("not expired", func(t *testing.T) { - session := SessionCookie{ + s := BrowserCookie{ Email: "user@example.com", Provider: "google", Expires: time.Now().Add(1 * time.Hour), } - - assert.True(t, session.Expires.After(time.Now())) + assert.True(t, s.Expires.After(time.Now())) }) t.Run("expired", func(t *testing.T) { - session := SessionCookie{ + s := BrowserCookie{ Email: "user@example.com", Provider: "google", Expires: time.Now().Add(-1 * time.Hour), } - - assert.True(t, session.Expires.Before(time.Now())) + assert.True(t, s.Expires.Before(time.Now())) }) } @@ -59,11 +54,9 @@ func TestAuthorizationState_MarshalUnmarshal(t *testing.T) { ReturnURL: "/my/tokens", } - // Marshal to JSON data, err := json.Marshal(original) require.NoError(t, err) - // Unmarshal back var unmarshaled AuthorizationState err = json.Unmarshal(data, &unmarshaled) require.NoError(t, err) diff --git a/internal/storage/firestore.go b/internal/storage/firestore.go index 1311ee6..0b49554 100644 --- a/internal/storage/firestore.go +++ b/internal/storage/firestore.go @@ -16,6 +16,11 @@ import ( "google.golang.org/grpc/status" ) +const ( + usersCollection = "mcp_front_users" + sessionsCollection = "mcp_front_sessions" +) + // FirestoreStorage implements OAuth client storage using Google Cloud Firestore. // // Error handling strategy: @@ -195,8 +200,9 @@ func (s *FirestoreStorage) loadClientsFromFirestore(ctx context.Context) error { } // StoreAuthorizeRequest stores an authorize request with state (in memory only - short-lived) -func (s *FirestoreStorage) StoreAuthorizeRequest(state string, req fosite.AuthorizeRequester) { +func (s *FirestoreStorage) StoreAuthorizeRequest(state string, req fosite.AuthorizeRequester) error { s.stateCache.Store(state, req) + return nil } // GetAuthorizeRequest retrieves an authorize request by state (one-time use) @@ -533,7 +539,7 @@ func (s *FirestoreStorage) UpsertUser(ctx context.Context, email string) error { } // Try to get existing user first - doc, err := s.client.Collection("mcp_front_users").Doc(email).Get(ctx) + doc, err := s.client.Collection(usersCollection).Doc(email).Get(ctx) if err == nil { // User exists, update LastSeen _, err = doc.Ref.Update(ctx, []firestore.Update{ @@ -547,16 +553,35 @@ func (s *FirestoreStorage) UpsertUser(ctx context.Context, email string) error { userDoc.FirstSeen = time.Now() userDoc.Enabled = true userDoc.IsAdmin = false - _, err = s.client.Collection("mcp_front_users").Doc(email).Set(ctx, userDoc) + _, err = s.client.Collection(usersCollection).Doc(email).Set(ctx, userDoc) return err } return err } +// GetUser returns a single user by email +func (s *FirestoreStorage) GetUser(ctx context.Context, email string) (*UserInfo, error) { + doc, err := s.client.Collection(usersCollection).Doc(email).Get(ctx) + if err != nil { + if status.Code(err) == codes.NotFound { + return nil, ErrUserNotFound + } + return nil, fmt.Errorf("failed to get user from Firestore: %w", err) + } + + var userDoc UserDoc + if err := doc.DataTo(&userDoc); err != nil { + return nil, fmt.Errorf("failed to unmarshal user: %w", err) + } + + user := UserInfo(userDoc) + return &user, nil +} + // GetAllUsers returns all users func (s *FirestoreStorage) GetAllUsers(ctx context.Context) ([]UserInfo, error) { - iter := s.client.Collection("mcp_front_users").Documents(ctx) + iter := s.client.Collection(usersCollection).Documents(ctx) defer iter.Stop() var users []UserInfo @@ -583,7 +608,7 @@ func (s *FirestoreStorage) GetAllUsers(ctx context.Context) ([]UserInfo, error) // UpdateUserStatus updates a user's enabled status func (s *FirestoreStorage) UpdateUserStatus(ctx context.Context, email string, enabled bool) error { - _, err := s.client.Collection("mcp_front_users").Doc(email).Update(ctx, []firestore.Update{ + _, err := s.client.Collection(usersCollection).Doc(email).Update(ctx, []firestore.Update{ {Path: "enabled", Value: enabled}, }) if status.Code(err) == codes.NotFound { @@ -595,7 +620,7 @@ func (s *FirestoreStorage) UpdateUserStatus(ctx context.Context, email string, e // DeleteUser removes a user from storage func (s *FirestoreStorage) DeleteUser(ctx context.Context, email string) error { // Delete user document - _, err := s.client.Collection("mcp_front_users").Doc(email).Delete(ctx) + _, err := s.client.Collection(usersCollection).Doc(email).Delete(ctx) if err != nil && status.Code(err) != codes.NotFound { return err } @@ -625,7 +650,7 @@ func (s *FirestoreStorage) DeleteUser(ctx context.Context, email string) error { // SetUserAdmin updates a user's admin status func (s *FirestoreStorage) SetUserAdmin(ctx context.Context, email string, isAdmin bool) error { - _, err := s.client.Collection("mcp_front_users").Doc(email).Update(ctx, []firestore.Update{ + _, err := s.client.Collection(usersCollection).Doc(email).Update(ctx, []firestore.Update{ {Path: "is_admin", Value: isAdmin}, }) if status.Code(err) == codes.NotFound { @@ -645,7 +670,7 @@ func (s *FirestoreStorage) TrackSession(ctx context.Context, session ActiveSessi } // Check if session exists - doc, err := s.client.Collection("mcp_front_sessions").Doc(session.SessionID).Get(ctx) + doc, err := s.client.Collection(sessionsCollection).Doc(session.SessionID).Get(ctx) if err == nil { // Session exists, update LastActive _, err = doc.Ref.Update(ctx, []firestore.Update{ @@ -657,7 +682,7 @@ func (s *FirestoreStorage) TrackSession(ctx context.Context, session ActiveSessi // Session doesn't exist, create new if status.Code(err) == codes.NotFound { sessionDoc.Created = time.Now() - _, err = s.client.Collection("mcp_front_sessions").Doc(session.SessionID).Set(ctx, sessionDoc) + _, err = s.client.Collection(sessionsCollection).Doc(session.SessionID).Set(ctx, sessionDoc) return err } @@ -666,7 +691,7 @@ func (s *FirestoreStorage) TrackSession(ctx context.Context, session ActiveSessi // GetActiveSessions returns all active sessions func (s *FirestoreStorage) GetActiveSessions(ctx context.Context) ([]ActiveSession, error) { - iter := s.client.Collection("mcp_front_sessions").Documents(ctx) + iter := s.client.Collection(sessionsCollection).Documents(ctx) defer iter.Stop() var sessions []ActiveSession @@ -693,7 +718,7 @@ func (s *FirestoreStorage) GetActiveSessions(ctx context.Context) ([]ActiveSessi // RevokeSession removes a session func (s *FirestoreStorage) RevokeSession(ctx context.Context, sessionID string) error { - _, err := s.client.Collection("mcp_front_sessions").Doc(sessionID).Delete(ctx) + _, err := s.client.Collection(sessionsCollection).Doc(sessionID).Delete(ctx) if err != nil && status.Code(err) != codes.NotFound { return err } diff --git a/internal/storage/memory.go b/internal/storage/memory.go index 6a1dc79..63c2c54 100644 --- a/internal/storage/memory.go +++ b/internal/storage/memory.go @@ -43,8 +43,9 @@ func NewMemoryStorage() *MemoryStorage { } // StoreAuthorizeRequest stores an authorize request with state -func (s *MemoryStorage) StoreAuthorizeRequest(state string, req fosite.AuthorizeRequester) { +func (s *MemoryStorage) StoreAuthorizeRequest(state string, req fosite.AuthorizeRequester) error { s.stateCache.Store(state, req) + return nil } // GetAuthorizeRequest retrieves an authorize request by state (one-time use) @@ -204,6 +205,19 @@ func (s *MemoryStorage) UpsertUser(ctx context.Context, email string) error { return nil } +// GetUser returns a single user by email +func (s *MemoryStorage) GetUser(ctx context.Context, email string) (*UserInfo, error) { + s.usersMutex.RLock() + defer s.usersMutex.RUnlock() + + user, exists := s.users[email] + if !exists { + return nil, ErrUserNotFound + } + userCopy := *user + return &userCopy, nil +} + // GetAllUsers returns all users func (s *MemoryStorage) GetAllUsers(ctx context.Context) ([]UserInfo, error) { s.usersMutex.RLock() diff --git a/internal/storage/storage.go b/internal/storage/storage.go index 1f9a780..e44cb02 100644 --- a/internal/storage/storage.go +++ b/internal/storage/storage.go @@ -76,7 +76,7 @@ type Storage interface { fosite.Storage // OAuth state management - StoreAuthorizeRequest(state string, req fosite.AuthorizeRequester) + StoreAuthorizeRequest(state string, req fosite.AuthorizeRequester) error GetAuthorizeRequest(state string) (fosite.AuthorizeRequester, bool) // OAuth client management @@ -89,6 +89,7 @@ type Storage interface { // User tracking (upserted when users access MCP endpoints) UpsertUser(ctx context.Context, email string) error + GetUser(ctx context.Context, email string) (*UserInfo, error) GetAllUsers(ctx context.Context) ([]UserInfo, error) UpdateUserStatus(ctx context.Context, email string, enabled bool) error DeleteUser(ctx context.Context, email string) error From e0a6415472a299b051b407d0b247f5817004aa35 Mon Sep 17 00:00:00 2001 From: Sam El-Borai Date: Thu, 12 Feb 2026 12:30:47 +0100 Subject: [PATCH 5/7] Migrate integration tests from mcp/postgres to Google MCP Toolbox Replace the removed mcp/postgres Docker image with Google's MCP Toolbox for Databases. Convert 8 static JSON test configs to programmatic generation via Go builder functions so the image reference lives in one place (ToolboxImage constant). Key changes: - GoogleProvider now accepts optional endpoint overrides for authorizationUrl, tokenUrl, and userInfoUrl, replacing the broken GOOGLE_OAUTH_* env vars that were never actually read - Integration tests use execute_sql tool instead of query (toolbox API) - Remove dead code: TestEnvironment, SetupTestEnvironment, execDockerCompose helpers - Pull toolbox image in TestMain to prevent first-test timeout - Update all example configs and --config-init default --- cmd/mcp-front/main.go | 18 +- config-oauth-firestore.example.json | 15 +- config-oauth.example.json | 15 +- config-oauth.json | 22 +- config-token.example.json | 9 +- config-user-tokens-example.json | 15 +- integration/base_path_test.go | 7 +- integration/basic_auth_test.go | 7 +- integration/config/config.base-path-test.json | 28 -- .../config/config.basic-auth-test.json | 34 --- integration/config/config.demo-token.json | 9 +- .../config/config.oauth-integration-test.json | 44 --- .../config/config.oauth-rfc8707-test.json | 5 +- ...config.oauth-service-integration-test.json | 5 +- .../config/config.oauth-service-test.json | 50 ---- integration/config/config.oauth-test.json | 36 --- .../config/config.oauth-token-test.json | 5 +- .../config.oauth-usertoken-tools-test.json | 50 ---- integration/config/config.test.json | 31 --- integration/integration_test.go | 10 +- integration/isolation_test.go | 50 ++-- integration/main_test.go | 9 + integration/oauth_test.go | 56 ++-- integration/security_test.go | 16 +- integration/test_utils.go | 253 ++++++++++++------ internal/idp/factory.go | 3 + internal/idp/google.go | 21 +- internal/idp/google_test.go | 4 +- 28 files changed, 378 insertions(+), 449 deletions(-) delete mode 100644 integration/config/config.base-path-test.json delete mode 100644 integration/config/config.basic-auth-test.json delete mode 100644 integration/config/config.oauth-integration-test.json delete mode 100644 integration/config/config.oauth-service-test.json delete mode 100644 integration/config/config.oauth-test.json delete mode 100644 integration/config/config.oauth-usertoken-tools-test.json delete mode 100644 integration/config/config.test.json diff --git a/cmd/mcp-front/main.go b/cmd/mcp-front/main.go index 1c1a910..2cdb86d 100644 --- a/cmd/mcp-front/main.go +++ b/cmd/mcp-front/main.go @@ -43,9 +43,21 @@ func generateDefaultConfig(path string) error { "transportType": "stdio", "command": "docker", "args": []any{ - "run", "--rm", "-i", - "mcp/postgres:latest", - map[string]string{"$env": "POSTGRES_URL"}, + "run", "--rm", "-i", "--network", "host", + "-e", "POSTGRES_HOST", + "-e", "POSTGRES_PORT", + "-e", "POSTGRES_DATABASE", + "-e", "POSTGRES_USER", + "-e", "POSTGRES_PASSWORD", + "us-central1-docker.pkg.dev/database-toolbox/toolbox/toolbox:latest", + "--stdio", "--prebuilt", "postgres", + }, + "env": map[string]any{ + "POSTGRES_HOST": map[string]string{"$env": "POSTGRES_HOST"}, + "POSTGRES_PORT": map[string]string{"$env": "POSTGRES_PORT"}, + "POSTGRES_DATABASE": map[string]string{"$env": "POSTGRES_DATABASE"}, + "POSTGRES_USER": map[string]string{"$env": "POSTGRES_USER"}, + "POSTGRES_PASSWORD": map[string]string{"$env": "POSTGRES_PASSWORD"}, }, }, }, diff --git a/config-oauth-firestore.example.json b/config-oauth-firestore.example.json index 8099a34..aaa7d4d 100644 --- a/config-oauth-firestore.example.json +++ b/config-oauth-firestore.example.json @@ -28,11 +28,20 @@ "command": "docker", "args": [ "run", "--rm", "-i", "--network", "host", - "mcp/postgres", - {"$env": "DATABASE_URL"} + "-e", "POSTGRES_HOST", + "-e", "POSTGRES_PORT", + "-e", "POSTGRES_DATABASE", + "-e", "POSTGRES_USER", + "-e", "POSTGRES_PASSWORD", + "us-central1-docker.pkg.dev/database-toolbox/toolbox/toolbox:latest", + "--stdio", "--prebuilt", "postgres" ], "env": { - "PGPASSWORD": {"$env": "POSTGRES_PASSWORD"} + "POSTGRES_HOST": {"$env": "POSTGRES_HOST"}, + "POSTGRES_PORT": {"$env": "POSTGRES_PORT"}, + "POSTGRES_DATABASE": {"$env": "POSTGRES_DATABASE"}, + "POSTGRES_USER": {"$env": "POSTGRES_USER"}, + "POSTGRES_PASSWORD": {"$env": "POSTGRES_PASSWORD"} } }, "notion": { diff --git a/config-oauth.example.json b/config-oauth.example.json index caa5c96..80b522d 100644 --- a/config-oauth.example.json +++ b/config-oauth.example.json @@ -28,11 +28,20 @@ "command": "docker", "args": [ "run", "--rm", "-i", "--network", "host", - "mcp/postgres", - {"$env": "DATABASE_URL"} + "-e", "POSTGRES_HOST", + "-e", "POSTGRES_PORT", + "-e", "POSTGRES_DATABASE", + "-e", "POSTGRES_USER", + "-e", "POSTGRES_PASSWORD", + "us-central1-docker.pkg.dev/database-toolbox/toolbox/toolbox:latest", + "--stdio", "--prebuilt", "postgres" ], "env": { - "PGPASSWORD": {"$env": "POSTGRES_PASSWORD"} + "POSTGRES_HOST": {"$env": "POSTGRES_HOST"}, + "POSTGRES_PORT": {"$env": "POSTGRES_PORT"}, + "POSTGRES_DATABASE": {"$env": "POSTGRES_DATABASE"}, + "POSTGRES_USER": {"$env": "POSTGRES_USER"}, + "POSTGRES_PASSWORD": {"$env": "POSTGRES_PASSWORD"} } }, "notion": { diff --git a/config-oauth.json b/config-oauth.json index e0c9853..a49ba37 100644 --- a/config-oauth.json +++ b/config-oauth.json @@ -25,12 +25,24 @@ "mcpServers": { "postgres": { "transportType": "stdio", - "command": "docker", + "command": "docker", "args": [ - "run", "--rm", "-i", - "mcp/postgres:latest", - "postgresql://user:password@localhost:5432/database" - ] + "run", "--rm", "-i", "--network", "host", + "-e", "POSTGRES_HOST", + "-e", "POSTGRES_PORT", + "-e", "POSTGRES_DATABASE", + "-e", "POSTGRES_USER", + "-e", "POSTGRES_PASSWORD", + "us-central1-docker.pkg.dev/database-toolbox/toolbox/toolbox:latest", + "--stdio", "--prebuilt", "postgres" + ], + "env": { + "POSTGRES_HOST": {"$env": "POSTGRES_HOST"}, + "POSTGRES_PORT": {"$env": "POSTGRES_PORT"}, + "POSTGRES_DATABASE": {"$env": "POSTGRES_DATABASE"}, + "POSTGRES_USER": {"$env": "POSTGRES_USER"}, + "POSTGRES_PASSWORD": {"$env": "POSTGRES_PASSWORD"} + } }, "notion": { "transportType": "stdio", diff --git a/config-token.example.json b/config-token.example.json index b13702e..fc07d09 100644 --- a/config-token.example.json +++ b/config-token.example.json @@ -11,8 +11,13 @@ "command": "docker", "args": [ "run", "--rm", "-i", "--network", "host", - "mcp/postgres", - "postgresql://testuser:testpass@localhost:5432/testdb" + "-e", "POSTGRES_HOST=localhost", + "-e", "POSTGRES_PORT=5432", + "-e", "POSTGRES_DATABASE=testdb", + "-e", "POSTGRES_USER=testuser", + "-e", "POSTGRES_PASSWORD=testpass", + "us-central1-docker.pkg.dev/database-toolbox/toolbox/toolbox:latest", + "--stdio", "--prebuilt", "postgres" ], "serviceAuths": [ { diff --git a/config-user-tokens-example.json b/config-user-tokens-example.json index 4c4b42f..8abab60 100644 --- a/config-user-tokens-example.json +++ b/config-user-tokens-example.json @@ -60,11 +60,20 @@ "command": "docker", "args": [ "run", "--rm", "-i", "--network", "host", - "mcp/postgres", - {"$env": "DATABASE_URL"} + "-e", "POSTGRES_HOST", + "-e", "POSTGRES_PORT", + "-e", "POSTGRES_DATABASE", + "-e", "POSTGRES_USER", + "-e", "POSTGRES_PASSWORD", + "us-central1-docker.pkg.dev/database-toolbox/toolbox/toolbox:latest", + "--stdio", "--prebuilt", "postgres" ], "env": { - "PGPASSWORD": {"$env": "POSTGRES_PASSWORD"} + "POSTGRES_HOST": {"$env": "POSTGRES_HOST"}, + "POSTGRES_PORT": {"$env": "POSTGRES_PORT"}, + "POSTGRES_DATABASE": {"$env": "POSTGRES_DATABASE"}, + "POSTGRES_USER": {"$env": "POSTGRES_USER"}, + "POSTGRES_PASSWORD": {"$env": "POSTGRES_PASSWORD"} } } } diff --git a/integration/base_path_test.go b/integration/base_path_test.go index 88b8e13..ca0ec3c 100644 --- a/integration/base_path_test.go +++ b/integration/base_path_test.go @@ -11,7 +11,12 @@ import ( func TestBasePathRouting(t *testing.T) { waitForDB(t) - startMCPFront(t, "config/config.base-path-test.json") + cfg := buildTestConfig( + "http://localhost:8080/mcp-api", "mcp-front-base-path-test", + nil, + map[string]any{"postgres": testPostgresServer(withBearerTokens("test-token"))}, + ) + startMCPFront(t, writeTestConfig(t, cfg)) waitForMCPFront(t) initialContainers := getMCPContainers() diff --git a/integration/basic_auth_test.go b/integration/basic_auth_test.go index 1f31a91..4368f1a 100644 --- a/integration/basic_auth_test.go +++ b/integration/basic_auth_test.go @@ -10,8 +10,11 @@ import ( ) func TestBasicAuth(t *testing.T) { - // Start mcp-front with basic auth config - startMCPFront(t, "config/config.basic-auth-test.json", + cfg := buildTestConfig("http://localhost:8080", "mcp-front-basic-auth-test", + nil, + map[string]any{"postgres": testPostgresServer(withBasicAuth("admin", "ADMIN_PASSWORD"), withBasicAuth("user", "USER_PASSWORD"))}, + ) + startMCPFront(t, writeTestConfig(t, cfg), "ADMIN_PASSWORD=adminpass123", "USER_PASSWORD=userpass456", ) diff --git a/integration/config/config.base-path-test.json b/integration/config/config.base-path-test.json deleted file mode 100644 index 5b60230..0000000 --- a/integration/config/config.base-path-test.json +++ /dev/null @@ -1,28 +0,0 @@ -{ - "version": "v0.0.1-DEV_EDITION_EXPECT_CHANGES", - "proxy": { - "baseURL": "http://localhost:8080/mcp-api", - "addr": ":8080", - "name": "mcp-front-base-path-test" - }, - "mcpServers": { - "postgres": { - "transportType": "stdio", - "command": "docker", - "args": [ - "run", - "-i", - "--network", - "host", - "mcp/postgres", - "postgresql://testuser:testpass@localhost:15432/testdb" - ], - "serviceAuths": [ - { - "type": "bearer", - "tokens": ["test-token"] - } - ] - } - } -} diff --git a/integration/config/config.basic-auth-test.json b/integration/config/config.basic-auth-test.json deleted file mode 100644 index 8ef2e53..0000000 --- a/integration/config/config.basic-auth-test.json +++ /dev/null @@ -1,34 +0,0 @@ -{ - "version": "v0.0.1-DEV_EDITION_EXPECT_CHANGES", - "proxy": { - "baseURL": "http://localhost:8080", - "addr": ":8080", - "name": "mcp-front-basic-auth-test" - }, - "mcpServers": { - "postgres": { - "transportType": "stdio", - "command": "docker", - "args": [ - "run", - "-i", - "--network", - "host", - "mcp/postgres", - "postgresql://testuser:testpass@localhost:15432/testdb" - ], - "serviceAuths": [ - { - "type": "basic", - "username": "admin", - "password": {"$env": "ADMIN_PASSWORD"} - }, - { - "type": "basic", - "username": "user", - "password": {"$env": "USER_PASSWORD"} - } - ] - } - } -} \ No newline at end of file diff --git a/integration/config/config.demo-token.json b/integration/config/config.demo-token.json index a9b07b5..7610c11 100644 --- a/integration/config/config.demo-token.json +++ b/integration/config/config.demo-token.json @@ -11,8 +11,13 @@ "command": "docker", "args": [ "run", "--rm", "-i", "--network", "host", - "mcp/postgres", - "postgresql://testuser:testpass@localhost:15432/testdb" + "-e", "POSTGRES_HOST=localhost", + "-e", "POSTGRES_PORT=15432", + "-e", "POSTGRES_DATABASE=testdb", + "-e", "POSTGRES_USER=testuser", + "-e", "POSTGRES_PASSWORD=testpass", + "us-central1-docker.pkg.dev/database-toolbox/toolbox/toolbox:latest", + "--stdio", "--prebuilt", "postgres" ], "options": { "logEnabled": true diff --git a/integration/config/config.oauth-integration-test.json b/integration/config/config.oauth-integration-test.json deleted file mode 100644 index bb24a06..0000000 --- a/integration/config/config.oauth-integration-test.json +++ /dev/null @@ -1,44 +0,0 @@ -{ - "version": "v0.0.1-DEV_EDITION_EXPECT_CHANGES", - "proxy": { - "baseURL": "http://localhost:8080", - "addr": ":8080", - "name": "mcp-front-oauth-test", - "auth": { - "kind": "oauth", - "issuer": "http://localhost:8080", - "gcpProject": "test-project", - "idp": { - "provider": "google", - "clientId": "test-client-id", - "clientSecret": "test-client-secret-for-integration-testing", - "redirectUri": "http://localhost:8080/oauth/callback" - }, - "allowedDomains": [ - "test.com" - ], - "allowedOrigins": [ - "https://claude.ai" - ], - "tokenTtl": "1h", - "storage": "memory", - "jwtSecret": "test-jwt-secret-for-integration-testing-32-chars-long", - "encryptionKey": "test-encryption-key-32-bytes-aes" - } - }, - "mcpServers": { - "postgres": { - "transportType": "stdio", - "command": "docker", - "args": [ - "run", - "--rm", - "-i", - "--network", - "host", - "mcp/postgres", - "postgresql://testuser:testpass@localhost:15432/testdb" - ] - } - } -} \ No newline at end of file diff --git a/integration/config/config.oauth-rfc8707-test.json b/integration/config/config.oauth-rfc8707-test.json index 25ceb1d..75d1162 100644 --- a/integration/config/config.oauth-rfc8707-test.json +++ b/integration/config/config.oauth-rfc8707-test.json @@ -12,7 +12,10 @@ "provider": "google", "clientId": {"$env": "GOOGLE_CLIENT_ID"}, "clientSecret": {"$env": "GOOGLE_CLIENT_SECRET"}, - "redirectUri": "http://localhost:8080/oauth/callback" + "redirectUri": "http://localhost:8080/oauth/callback", + "authorizationUrl": "http://localhost:9090/auth", + "tokenUrl": "http://localhost:9090/token", + "userInfoUrl": "http://localhost:9090/userinfo" }, "allowedDomains": ["test.com"], "allowedOrigins": ["https://claude.ai"], diff --git a/integration/config/config.oauth-service-integration-test.json b/integration/config/config.oauth-service-integration-test.json index 09ca0a9..e1c4e61 100644 --- a/integration/config/config.oauth-service-integration-test.json +++ b/integration/config/config.oauth-service-integration-test.json @@ -12,7 +12,10 @@ "provider": "google", "clientId": {"$env": "GOOGLE_CLIENT_ID"}, "clientSecret": {"$env": "GOOGLE_CLIENT_SECRET"}, - "redirectUri": "http://localhost:8080/oauth/callback" + "redirectUri": "http://localhost:8080/oauth/callback", + "authorizationUrl": "http://localhost:9090/auth", + "tokenUrl": "http://localhost:9090/token", + "userInfoUrl": "http://localhost:9090/userinfo" }, "allowedDomains": ["test.com"], "allowedOrigins": ["https://claude.ai"], diff --git a/integration/config/config.oauth-service-test.json b/integration/config/config.oauth-service-test.json deleted file mode 100644 index 04071c8..0000000 --- a/integration/config/config.oauth-service-test.json +++ /dev/null @@ -1,50 +0,0 @@ -{ - "version": "v0.0.1-DEV_EDITION_EXPECT_CHANGES", - "proxy": { - "baseURL": "http://localhost:8080", - "addr": ":8080", - "name": "mcp-front-oauth-usertoken-test", - "auth": { - "kind": "oauth", - "issuer": "http://localhost:8080", - "gcpProject": "test-project", - "idp": { - "provider": "google", - "clientId": {"$env": "GOOGLE_CLIENT_ID"}, - "clientSecret": {"$env": "GOOGLE_CLIENT_SECRET"}, - "redirectUri": "http://localhost:8080/oauth/callback" - }, - "allowedDomains": ["test.com"], - "allowedOrigins": ["https://claude.ai"], - "tokenTtl": "1h", - "storage": "memory", - "jwtSecret": {"$env": "JWT_SECRET"}, - "encryptionKey": {"$env": "ENCRYPTION_KEY"} - } - }, - "mcpServers": { - "postgres": { - "transportType": "stdio", - "command": "docker", - "args": [ - "run", - "--rm", - "-i", - "--network", - "host", - "mcp/postgres", - "postgresql://testuser:testpass@localhost:15432/testdb" - ], - "env": { - "USER_TOKEN": {"$userToken": "{{token}}"} - }, - "requiresUserToken": true, - "userAuthentication": { - "type": "manual", - "displayName": "Test Service", - "instructions": "Enter your test token", - "helpUrl": "https://example.com/help" - } - } - } -} \ No newline at end of file diff --git a/integration/config/config.oauth-test.json b/integration/config/config.oauth-test.json deleted file mode 100644 index 399e6d2..0000000 --- a/integration/config/config.oauth-test.json +++ /dev/null @@ -1,36 +0,0 @@ -{ - "version": "v0.0.1-DEV_EDITION_EXPECT_CHANGES", - "proxy": { - "baseURL": "http://localhost:8080", - "addr": ":8080", - "name": "mcp-front-oauth-test", - "auth": { - "kind": "oauth", - "issuer": "http://localhost:8080", - "gcpProject": "test-project", - "idp": { - "provider": "google", - "clientId": {"$env": "GOOGLE_CLIENT_ID"}, - "clientSecret": {"$env": "GOOGLE_CLIENT_SECRET"}, - "redirectUri": "http://localhost:8080/oauth/callback" - }, - "allowedDomains": ["test.com", "stainless.com", "claude.ai"], - "allowedOrigins": ["https://claude.ai"], - "tokenTtl": "1h", - "storage": "memory", - "jwtSecret": {"$env": "JWT_SECRET"}, - "encryptionKey": {"$env": "ENCRYPTION_KEY"} - } - }, - "mcpServers": { - "postgres": { - "transportType": "stdio", - "command": "docker", - "args": [ - "run", "--rm", "-i", "--network", "host", - "mcp/postgres", - "postgresql://testuser:testpass@localhost:15432/testdb" - ] - } - } -} diff --git a/integration/config/config.oauth-token-test.json b/integration/config/config.oauth-token-test.json index 874a1c7..f3874b7 100644 --- a/integration/config/config.oauth-token-test.json +++ b/integration/config/config.oauth-token-test.json @@ -12,7 +12,10 @@ "provider": "google", "clientId": {"$env": "GOOGLE_CLIENT_ID"}, "clientSecret": {"$env": "GOOGLE_CLIENT_SECRET"}, - "redirectUri": "http://localhost:8080/oauth/callback" + "redirectUri": "http://localhost:8080/oauth/callback", + "authorizationUrl": "http://localhost:9090/auth", + "tokenUrl": "http://localhost:9090/token", + "userInfoUrl": "http://localhost:9090/userinfo" }, "allowedDomains": ["test.com"], "allowedOrigins": ["https://claude.ai"], diff --git a/integration/config/config.oauth-usertoken-tools-test.json b/integration/config/config.oauth-usertoken-tools-test.json deleted file mode 100644 index 9fafd1d..0000000 --- a/integration/config/config.oauth-usertoken-tools-test.json +++ /dev/null @@ -1,50 +0,0 @@ -{ - "version": "v0.0.1-DEV_EDITION_EXPECT_CHANGES", - "proxy": { - "baseURL": "http://localhost:8080", - "addr": ":8080", - "name": "mcp-front-oauth-usertoken-test", - "auth": { - "kind": "oauth", - "issuer": "http://localhost:8080", - "gcpProject": "test-project", - "idp": { - "provider": "google", - "clientId": {"$env": "GOOGLE_CLIENT_ID"}, - "clientSecret": {"$env": "GOOGLE_CLIENT_SECRET"}, - "redirectUri": "http://localhost:8080/oauth/callback" - }, - "allowedDomains": ["test.com"], - "allowedOrigins": ["https://claude.ai"], - "tokenTtl": "1h", - "storage": "memory", - "jwtSecret": {"$env": "JWT_SECRET"}, - "encryptionKey": {"$env": "ENCRYPTION_KEY"} - } - }, - "mcpServers": { - "postgres": { - "transportType": "stdio", - "command": "docker", - "args": [ - "run", - "--rm", - "-i", - "--network", - "host", - "mcp/postgres", - "postgresql://testuser:testpass@localhost:15432/testdb" - ], - "env": { - "USER_TOKEN": {"$userToken": "{{token}}"} - }, - "requiresUserToken": true, - "userAuthentication": { - "type": "manual", - "displayName": "Test Service", - "instructions": "Enter your test token", - "helpUrl": "https://example.com/help" - } - } - } -} diff --git a/integration/config/config.test.json b/integration/config/config.test.json deleted file mode 100644 index 773d78c..0000000 --- a/integration/config/config.test.json +++ /dev/null @@ -1,31 +0,0 @@ -{ - "version": "v0.0.1-DEV_EDITION_EXPECT_CHANGES", - "proxy": { - "baseURL": "http://localhost:8080", - "addr": ":8080", - "name": "mcp-front-test" - }, - "mcpServers": { - "postgres": { - "transportType": "stdio", - "command": "docker", - "args": [ - "run", - "-i", - "--network", - "host", - "mcp/postgres", - "postgresql://testuser:testpass@localhost:15432/testdb" - ], - "options": { - "logEnabled": true - }, - "serviceAuths": [ - { - "type": "bearer", - "tokens": ["test-token", "alt-test-token"] - } - ] - } - } -} \ No newline at end of file diff --git a/integration/integration_test.go b/integration/integration_test.go index 72ff8b4..2e46e84 100644 --- a/integration/integration_test.go +++ b/integration/integration_test.go @@ -15,11 +15,11 @@ func TestIntegration(t *testing.T) { waitForDB(t) trace(t, "Starting mcp-front") - startMCPFront(t, "config/config.test.json", - "GOOGLE_OAUTH_AUTH_URL=http://localhost:9090/auth", - "GOOGLE_OAUTH_TOKEN_URL=http://localhost:9090/token", - "GOOGLE_USERINFO_URL=http://localhost:9090/userinfo", + cfg := buildTestConfig("http://localhost:8080", "mcp-front-test", + nil, + map[string]any{"postgres": testPostgresServer(withBearerTokens("test-token", "alt-test-token"), withLogEnabled())}, ) + startMCPFront(t, writeTestConfig(t, cfg)) waitForMCPFront(t) trace(t, "mcp-front is ready") @@ -49,7 +49,7 @@ func TestIntegration(t *testing.T) { t.Log("Connected to MCP server with session") queryParams := map[string]any{ - "name": "query", + "name": "execute_sql", "arguments": map[string]any{ "sql": "SELECT COUNT(*) as user_count FROM users", }, diff --git a/integration/isolation_test.go b/integration/isolation_test.go index 41d5d0f..1036e91 100644 --- a/integration/isolation_test.go +++ b/integration/isolation_test.go @@ -18,12 +18,16 @@ func TestMultiUserSessionIsolation(t *testing.T) { // Start mcp-front with bearer token auth trace(t, "Starting mcp-front") - startMCPFront(t, "config/config.test.json") + cfg := buildTestConfig("http://localhost:8080", "mcp-front-test", + nil, + map[string]any{"postgres": testPostgresServer(withBearerTokens("test-token", "alt-test-token"), withLogEnabled())}, + ) + startMCPFront(t, writeTestConfig(t, cfg)) waitForMCPFront(t) // Get initial container count initialContainers := getMCPContainers() - t.Logf("Initial mcp/postgres containers: %d", len(initialContainers)) + t.Logf("Initial toolbox containers: %d", len(initialContainers)) // Create two clients with different auth tokens client1 := NewMCPSSEClient("http://localhost:8080") @@ -69,7 +73,7 @@ func TestMultiUserSessionIsolation(t *testing.T) { } query1Result, err := client1.SendMCPRequest("tools/call", map[string]any{ - "name": "query", + "name": "execute_sql", "arguments": map[string]any{ "sql": "SELECT 'user1-query1' as test_id, COUNT(*) as count FROM users", }, @@ -116,13 +120,13 @@ func TestMultiUserSessionIsolation(t *testing.T) { // Verify that client1 and client2 have different containers if client1Container != "" && client2Container != "" && client1Container == client2Container { t.Errorf("CRITICAL: Both users are using the same Docker container! Container ID: %s", client1Container) - t.Error("This indicates session isolation is NOT working - users are sharing the same mcp/postgres instance") + t.Error("This indicates session isolation is NOT working - users are sharing the same toolbox instance") } else if client1Container != "" && client2Container != "" { t.Logf("Confirmed different stdio processes: User1 container=%s, User2 container=%s", client1Container, client2Container) } query2Result, err := client2.SendMCPRequest("tools/call", map[string]any{ - "name": "query", + "name": "execute_sql", "arguments": map[string]any{ "sql": "SELECT 'user2-query1' as test_id, COUNT(*) as count FROM orders", }, @@ -135,7 +139,7 @@ func TestMultiUserSessionIsolation(t *testing.T) { // Step 3: First user sends another query t.Log("\nStep 3: First user sends another query") query3Result, err := client1.SendMCPRequest("tools/call", map[string]any{ - "name": "query", + "name": "execute_sql", "arguments": map[string]any{ "sql": "SELECT 'user1-query2' as test_id, current_timestamp as ts", }, @@ -148,7 +152,7 @@ func TestMultiUserSessionIsolation(t *testing.T) { // Step 4: First user sends another query t.Log("\nStep 4: First user sends another query") query4Result, err := client1.SendMCPRequest("tools/call", map[string]any{ - "name": "query", + "name": "execute_sql", "arguments": map[string]any{ "sql": "SELECT 'user1-query3' as test_id, version() as db_version", }, @@ -161,7 +165,7 @@ func TestMultiUserSessionIsolation(t *testing.T) { // Step 5: Second user sends a query t.Log("\nStep 5: Second user sends a query") query5Result, err := client2.SendMCPRequest("tools/call", map[string]any{ - "name": "query", + "name": "execute_sql", "arguments": map[string]any{ "sql": "SELECT 'user2-query2' as test_id, current_database() as db_name", }, @@ -208,12 +212,16 @@ func TestSessionCleanupAfterTimeout(t *testing.T) { // Start mcp-front with test timeout configuration trace(t, "Starting mcp-front with test session timeout") - startMCPFront(t, "config/config.test.json") + cfg := buildTestConfig("http://localhost:8080", "mcp-front-test", + nil, + map[string]any{"postgres": testPostgresServer(withBearerTokens("test-token", "alt-test-token"), withLogEnabled())}, + ) + startMCPFront(t, writeTestConfig(t, cfg)) waitForMCPFront(t) // Get initial container count initialContainers := getMCPContainers() - t.Logf("Initial mcp/postgres containers: %d", len(initialContainers)) + t.Logf("Initial toolbox containers: %d", len(initialContainers)) // Create a client and connect client := NewMCPSSEClient("http://localhost:8080") @@ -237,7 +245,7 @@ func TestSessionCleanupAfterTimeout(t *testing.T) { // Send a query to ensure session is active _, err = client.SendMCPRequest("tools/call", map[string]any{ - "name": "query", + "name": "execute_sql", "arguments": map[string]any{ "sql": "SELECT 'test' as test_id", }, @@ -277,12 +285,16 @@ func TestSessionTimerReset(t *testing.T) { // Start mcp-front with test timeout configuration trace(t, "Starting mcp-front with test session timeout") - startMCPFront(t, "config/config.test.json") + cfg := buildTestConfig("http://localhost:8080", "mcp-front-test", + nil, + map[string]any{"postgres": testPostgresServer(withBearerTokens("test-token", "alt-test-token"), withLogEnabled())}, + ) + startMCPFront(t, writeTestConfig(t, cfg)) waitForMCPFront(t) // Get initial container count initialContainers := getMCPContainers() - t.Logf("Initial mcp/postgres containers: %d", len(initialContainers)) + t.Logf("Initial toolbox containers: %d", len(initialContainers)) // Create a client and connect client := NewMCPSSEClient("http://localhost:8080") @@ -308,7 +320,7 @@ func TestSessionTimerReset(t *testing.T) { for i := range 3 { t.Logf("Sending keepalive query %d/3...", i+1) _, err := client.SendMCPRequest("tools/call", map[string]any{ - "name": "query", + "name": "execute_sql", "arguments": map[string]any{ "sql": "SELECT 'keepalive' as status, NOW() as timestamp", }, @@ -357,12 +369,16 @@ func TestMultiUserTimerIndependence(t *testing.T) { // Start mcp-front with test timeout configuration trace(t, "Starting mcp-front with test session timeout") - startMCPFront(t, "config/config.test.json") + cfg := buildTestConfig("http://localhost:8080", "mcp-front-test", + nil, + map[string]any{"postgres": testPostgresServer(withBearerTokens("test-token", "alt-test-token"), withLogEnabled())}, + ) + startMCPFront(t, writeTestConfig(t, cfg)) waitForMCPFront(t) // Get initial container count initialContainers := getMCPContainers() - t.Logf("Initial mcp/postgres containers: %d", len(initialContainers)) + t.Logf("Initial toolbox containers: %d", len(initialContainers)) // Create two clients client1 := NewMCPSSEClient("http://localhost:8080") @@ -411,7 +427,7 @@ func TestMultiUserTimerIndependence(t *testing.T) { for i := range 4 { time.Sleep(4 * time.Second) _, err := client2.SendMCPRequest("tools/call", map[string]any{ - "name": "query", + "name": "execute_sql", "arguments": map[string]any{ "sql": "SELECT 'client2-keepalive' as status", }, diff --git a/integration/main_test.go b/integration/main_test.go index 49eb876..5673290 100644 --- a/integration/main_test.go +++ b/integration/main_test.go @@ -22,6 +22,15 @@ func TestMain(m *testing.M) { os.Exit(1) } + // Pull the toolbox image so the first test doesn't timeout on image pull + fmt.Println("Pulling toolbox image...") + pullCmd := exec.Command("docker", "pull", ToolboxImage) + pullCmd.Stdout = os.Stdout + pullCmd.Stderr = os.Stderr + if err := pullCmd.Run(); err != nil { + fmt.Printf("Warning: failed to pull toolbox image: %v\n", err) + } + // Set up local log file for mcp-front output logFile := "mcp-front-test.log" os.Setenv("MCP_LOG_FILE", logFile) diff --git a/integration/oauth_test.go b/integration/oauth_test.go index 0b2949f..ed15878 100644 --- a/integration/oauth_test.go +++ b/integration/oauth_test.go @@ -26,16 +26,16 @@ import ( // TestBasicOAuthFlow tests the basic OAuth server functionality func TestBasicOAuthFlow(t *testing.T) { - // Start mcp-front with OAuth config - startMCPFront(t, "config/config.oauth-test.json", + cfg := buildTestConfig("http://localhost:8080", "mcp-front-oauth-test", + testOAuthConfigFromEnv(), + map[string]any{"postgres": testPostgresServer()}, + ) + startMCPFront(t, writeTestConfig(t, cfg), "JWT_SECRET=test-jwt-secret-32-bytes-exactly!", "ENCRYPTION_KEY=test-encryption-key-32-bytes-ok!", "GOOGLE_CLIENT_ID=test-client-id-for-oauth", "GOOGLE_CLIENT_SECRET=test-client-secret-for-oauth", "MCP_FRONT_ENV=development", - "GOOGLE_OAUTH_AUTH_URL=http://localhost:9090/auth", - "GOOGLE_OAUTH_TOKEN_URL=http://localhost:9090/token", - "GOOGLE_USERINFO_URL=http://localhost:9090/userinfo", ) // Wait for startup @@ -84,6 +84,12 @@ func TestBasicOAuthFlow(t *testing.T) { // TestJWTSecretValidation tests JWT secret length requirements func TestJWTSecretValidation(t *testing.T) { + oauthCfg := buildTestConfig("http://localhost:8080", "mcp-front-oauth-test", + testOAuthConfigFromEnv(), + map[string]any{"postgres": testPostgresServer()}, + ) + configPath := writeTestConfig(t, oauthCfg) + tests := []struct { name string secret string @@ -98,8 +104,7 @@ func TestJWTSecretValidation(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - // Start mcp-front with specific JWT secret - mcpCmd := exec.Command("../cmd/mcp-front/mcp-front", "-config", "config/config.oauth-test.json") + mcpCmd := exec.Command("../cmd/mcp-front/mcp-front", "-config", configPath) mcpCmd.Env = []string{ "PATH=" + os.Getenv("PATH"), "JWT_SECRET=" + tt.secret, @@ -799,15 +804,15 @@ func TestCORSHeaders(t *testing.T) { // TestToolAdvertisementWithUserTokens tests that tools are advertised even without user tokens // but fail gracefully when invoked without the required token, and succeed with the token func TestToolAdvertisementWithUserTokens(t *testing.T) { - // Start OAuth server with user token configuration - startMCPFront(t, "config/config.oauth-usertoken-tools-test.json", + cfg := buildTestConfig("http://localhost:8080", "mcp-front-oauth-usertoken-test", + testOAuthConfigFromEnv(), + map[string]any{"postgres": testPostgresServer(withUserToken())}, + ) + startMCPFront(t, writeTestConfig(t, cfg), "JWT_SECRET=demo-jwt-secret-32-bytes-exactly!", "ENCRYPTION_KEY=test-encryption-key-32-bytes-ok!", "GOOGLE_CLIENT_ID=test-client-id-oauth", "GOOGLE_CLIENT_SECRET=test-client-secret-oauth", - "GOOGLE_OAUTH_AUTH_URL=http://localhost:9090/auth", - "GOOGLE_OAUTH_TOKEN_URL=http://localhost:9090/token", - "GOOGLE_USERINFO_URL=http://localhost:9090/userinfo", "MCP_FRONT_ENV=development", "LOG_LEVEL=debug", ) @@ -851,7 +856,7 @@ func TestToolAdvertisementWithUserTokens(t *testing.T) { } } - assert.Contains(t, toolNames, "query", "Should have query tool") + assert.Contains(t, toolNames, "execute_sql", "Should have execute_sql tool") t.Logf("Successfully advertised tools without user token: %v", toolNames) }) @@ -867,7 +872,7 @@ func TestToolAdvertisementWithUserTokens(t *testing.T) { // Try to invoke a tool without user token queryParams := map[string]any{ - "name": "query", + "name": "execute_sql", "arguments": map[string]any{ "sql": "SELECT 1", }, @@ -1001,7 +1006,7 @@ func TestToolAdvertisementWithUserTokens(t *testing.T) { // Call the query tool with a simple query queryParams := map[string]any{ - "name": "query", + "name": "execute_sql", "arguments": map[string]any{ "sql": "SELECT 1 as test", }, @@ -1034,8 +1039,13 @@ func TestToolAdvertisementWithUserTokens(t *testing.T) { // Helper functions func startOAuthServer(t *testing.T, env map[string]string) *exec.Cmd { - // Start with OAuth config - mcpCmd := exec.Command("../cmd/mcp-front/mcp-front", "-config", "config/config.oauth-test.json") + cfg := buildTestConfig("http://localhost:8080", "mcp-front-oauth-test", + testOAuthConfigFromEnv(), + map[string]any{"postgres": testPostgresServer()}, + ) + configPath := writeTestConfig(t, cfg) + + mcpCmd := exec.Command("../cmd/mcp-front/mcp-front", "-config", configPath) // Set default environment mcpCmd.Env = []string{ @@ -1044,9 +1054,6 @@ func startOAuthServer(t *testing.T, env map[string]string) *exec.Cmd { "ENCRYPTION_KEY=test-encryption-key-32-bytes-ok!", "GOOGLE_CLIENT_ID=test-client-id-oauth", "GOOGLE_CLIENT_SECRET=test-client-secret-oauth", - "GOOGLE_OAUTH_AUTH_URL=http://localhost:9090/auth", - "GOOGLE_OAUTH_TOKEN_URL=http://localhost:9090/token", - "GOOGLE_USERINFO_URL=http://localhost:9090/userinfo", } // Override with provided env @@ -1085,9 +1092,6 @@ func startOAuthServerWithTokenConfig(t *testing.T) *exec.Cmd { "ENCRYPTION_KEY=test-encryption-key-32-bytes-ok!", "GOOGLE_CLIENT_ID=test-client-id-oauth", "GOOGLE_CLIENT_SECRET=test-client-secret-oauth", - "GOOGLE_OAUTH_AUTH_URL=http://localhost:9090/auth", - "GOOGLE_OAUTH_TOKEN_URL=http://localhost:9090/token", - "GOOGLE_USERINFO_URL=http://localhost:9090/userinfo", "MCP_FRONT_ENV=development", } @@ -1198,9 +1202,6 @@ func TestServiceOAuthIntegration(t *testing.T) { "GOOGLE_CLIENT_SECRET=test-client-secret-oauth", "TEST_SERVICE_CLIENT_ID=service-client-id", "TEST_SERVICE_CLIENT_SECRET=service-client-secret", - "GOOGLE_OAUTH_AUTH_URL=http://localhost:9090/auth", - "GOOGLE_OAUTH_TOKEN_URL=http://localhost:9090/token", - "GOOGLE_USERINFO_URL=http://localhost:9090/userinfo", "MCP_FRONT_ENV=development", ) @@ -1396,9 +1397,6 @@ func TestRFC8707ResourceIndicators(t *testing.T) { "GOOGLE_CLIENT_ID=test-client-id-for-oauth", "GOOGLE_CLIENT_SECRET=test-client-secret-for-oauth", "MCP_FRONT_ENV=development", - "GOOGLE_OAUTH_AUTH_URL=http://localhost:9090/auth", - "GOOGLE_OAUTH_TOKEN_URL=http://localhost:9090/token", - "GOOGLE_USERINFO_URL=http://localhost:9090/userinfo", ) waitForMCPFront(t) diff --git a/integration/security_test.go b/integration/security_test.go index 2305c92..013978d 100644 --- a/integration/security_test.go +++ b/integration/security_test.go @@ -15,7 +15,11 @@ func TestSecurityScenarios(t *testing.T) { waitForDB(t) // Start mcp-front - startMCPFront(t, "config/config.test.json") + cfg := buildTestConfig("http://localhost:8080", "mcp-front-test", + nil, + map[string]any{"postgres": testPostgresServer(withBearerTokens("test-token", "alt-test-token"), withLogEnabled())}, + ) + startMCPFront(t, writeTestConfig(t, cfg)) // Wait for server to be ready waitForMCPFront(t) @@ -78,7 +82,7 @@ func TestSecurityScenarios(t *testing.T) { }) t.Run("SQLInjectionAttempts", func(t *testing.T) { - t.Skip("Skipping SQL injection tests, it's not a responsibility of mcp-front to guard mcp/postgres") + t.Skip("Skipping SQL injection tests, it's not a responsibility of mcp-front to guard toolbox") client := NewMCPSSEClient("http://localhost:8080") _ = client.Authenticate() @@ -101,7 +105,7 @@ func TestSecurityScenarios(t *testing.T) { // Try to inject via the query parameter _, err := client.SendMCPRequest("tools/call", map[string]any{ - "name": "query", + "name": "execute_sql", "arguments": map[string]any{ "query": payload, }, @@ -297,7 +301,11 @@ func TestFailureScenarios(t *testing.T) { // Database is already started by TestMain, just wait for readiness waitForDB(t) - startMCPFront(t, "config/config.test.json") + cfg := buildTestConfig("http://localhost:8080", "mcp-front-test", + nil, + map[string]any{"postgres": testPostgresServer(withBearerTokens("test-token", "alt-test-token"), withLogEnabled())}, + ) + startMCPFront(t, writeTestConfig(t, cfg)) // Wait for server to be ready waitForMCPFront(t) diff --git a/integration/test_utils.go b/integration/test_utils.go index 7d21f61..bca2817 100644 --- a/integration/test_utils.go +++ b/integration/test_utils.go @@ -19,24 +19,172 @@ import ( "time" ) -// getDockerComposeCommand returns the appropriate docker compose command -func getDockerComposeCommand() string { - // Check if docker compose v2 is available - cmd := exec.Command("docker", "compose", "version") - if err := cmd.Run(); err == nil { - return "docker compose" +// ToolboxImage is the Docker image for the MCP Toolbox for Databases. +// Used as the MCP server backing integration tests. All test configs +// that reference a postgres MCP server should use this image. +const ToolboxImage = "us-central1-docker.pkg.dev/database-toolbox/toolbox/toolbox:latest" + +// testPostgresDockerArgs returns the Docker args for running the toolbox +// as a stdio MCP server against the test postgres database. +func testPostgresDockerArgs() []string { + return []string{ + "run", "--rm", "-i", "--network", "host", + "-e", "POSTGRES_HOST=localhost", + "-e", "POSTGRES_PORT=15432", + "-e", "POSTGRES_DATABASE=testdb", + "-e", "POSTGRES_USER=testuser", + "-e", "POSTGRES_PASSWORD=testpass", + ToolboxImage, + "--stdio", "--prebuilt", "postgres", + } +} + +// testPostgresServer returns an MCP server config for the test postgres database. +// Options can customize auth, logging, etc. +func testPostgresServer(opts ...serverOption) map[string]any { + args := make([]any, len(testPostgresDockerArgs())) + for i, a := range testPostgresDockerArgs() { + args[i] = a + } + s := map[string]any{ + "transportType": "stdio", + "command": "docker", + "args": args, + } + for _, opt := range opts { + opt(s) + } + return s +} + +type serverOption func(map[string]any) + +func withBearerTokens(tokens ...string) serverOption { + return func(s map[string]any) { + s["serviceAuths"] = []map[string]any{ + {"type": "bearer", "tokens": tokens}, + } + } +} + +func withBasicAuth(username, passwordEnvVar string) serverOption { + return func(s map[string]any) { + auths, _ := s["serviceAuths"].([]map[string]any) + auths = append(auths, map[string]any{ + "type": "basic", + "username": username, + "password": map[string]string{"$env": passwordEnvVar}, + }) + s["serviceAuths"] = auths + } +} + +func withLogEnabled() serverOption { + return func(s map[string]any) { + s["options"] = map[string]any{"logEnabled": true} + } +} + +func withUserToken() serverOption { + return func(s map[string]any) { + s["env"] = map[string]any{ + "USER_TOKEN": map[string]string{"$userToken": "{{token}}"}, + } + s["requiresUserToken"] = true + s["userAuthentication"] = map[string]any{ + "type": "manual", + "displayName": "Test Service", + "instructions": "Enter your test token", + "helpUrl": "https://example.com/help", + } + } +} + +// testOAuthConfig returns a standard OAuth auth config for testing. +// Uses hardcoded values suitable for integration tests with the fake GCP server. +func testOAuthConfig() map[string]any { + return map[string]any{ + "kind": "oauth", + "issuer": "http://localhost:8080", + "gcpProject": "test-project", + "idp": map[string]any{ + "provider": "google", + "clientId": "test-client-id", + "clientSecret": "test-client-secret-for-integration-testing", + "redirectUri": "http://localhost:8080/oauth/callback", + "authorizationUrl": "http://localhost:9090/auth", + "tokenUrl": "http://localhost:9090/token", + "userInfoUrl": "http://localhost:9090/userinfo", + }, + "allowedDomains": []string{"test.com"}, + "allowedOrigins": []string{"https://claude.ai"}, + "tokenTtl": "1h", + "storage": "memory", + "jwtSecret": "test-jwt-secret-for-integration-testing-32-chars-long", + "encryptionKey": "test-encryption-key-32-bytes-aes", + } +} + +// testOAuthConfigFromEnv returns an OAuth auth config that reads secrets from env vars. +func testOAuthConfigFromEnv() map[string]any { + return map[string]any{ + "kind": "oauth", + "issuer": "http://localhost:8080", + "gcpProject": "test-project", + "idp": map[string]any{ + "provider": "google", + "clientId": map[string]string{"$env": "GOOGLE_CLIENT_ID"}, + "clientSecret": map[string]string{"$env": "GOOGLE_CLIENT_SECRET"}, + "redirectUri": "http://localhost:8080/oauth/callback", + "authorizationUrl": "http://localhost:9090/auth", + "tokenUrl": "http://localhost:9090/token", + "userInfoUrl": "http://localhost:9090/userinfo", + }, + "allowedDomains": []string{"test.com", "stainless.com", "claude.ai"}, + "allowedOrigins": []string{"https://claude.ai"}, + "tokenTtl": "1h", + "storage": "memory", + "jwtSecret": map[string]string{"$env": "JWT_SECRET"}, + "encryptionKey": map[string]string{"$env": "ENCRYPTION_KEY"}, } - return "docker-compose" } -// execDockerCompose executes docker compose with the given arguments -func execDockerCompose(args ...string) *exec.Cmd { - dcCmd := getDockerComposeCommand() - if dcCmd == "docker compose" { - allArgs := append([]string{"compose"}, args...) - return exec.Command("docker", allArgs...) +// writeTestConfig writes a config map to a temporary JSON file and returns its path. +// The file is automatically cleaned up when the test finishes. +func writeTestConfig(t *testing.T, cfg map[string]any) string { + t.Helper() + data, err := json.MarshalIndent(cfg, "", " ") + if err != nil { + t.Fatalf("Failed to marshal test config: %v", err) + } + f, err := os.CreateTemp(t.TempDir(), "config-*.json") + if err != nil { + t.Fatalf("Failed to create temp config file: %v", err) + } + if _, err := f.Write(data); err != nil { + t.Fatalf("Failed to write temp config: %v", err) + } + if err := f.Close(); err != nil { + t.Fatalf("Failed to close temp config: %v", err) + } + return f.Name() +} + +// buildTestConfig builds a complete mcp-front config map. +func buildTestConfig(baseURL, name string, auth map[string]any, mcpServers map[string]any) map[string]any { + proxy := map[string]any{ + "baseURL": baseURL, + "addr": ":8080", + "name": name, + } + if auth != nil { + proxy["auth"] = auth + } + return map[string]any{ + "version": "v0.0.1-DEV_EDITION_EXPECT_CHANGES", + "proxy": proxy, + "mcpServers": mcpServers, } - return exec.Command("docker-compose", args...) } // MCPSSEClient simulates an MCP client for testing @@ -620,79 +768,6 @@ func (s *FakeServiceOAuthServer) Stop() error { return s.server.Shutdown(ctx) } -// TestEnvironment manages the complete test environment -type TestEnvironment struct { - dbCmd *exec.Cmd - mcpCmd *exec.Cmd - fakeGCP *FakeGCPServer - client *MCPSSEClient -} - -// SetupTestEnvironment creates and starts all components needed for testing -func SetupTestEnvironment(t *testing.T) *TestEnvironment { - env := &TestEnvironment{} - - // Start test database - t.Log("🚀 Starting test database...") - env.dbCmd = execDockerCompose("-f", "config/docker-compose.test.yml", "up", "-d") - if err := env.dbCmd.Run(); err != nil { - t.Fatalf("Failed to start test database: %v", err) - } - - time.Sleep(10 * time.Second) - - // Start mock GCP server - t.Log("🚀 Starting mock GCP server...") - env.fakeGCP = NewFakeGCPServer("9090") - if err := env.fakeGCP.Start(); err != nil { - t.Fatalf("Failed to start mock GCP server: %v", err) - } - - // Start mcp-front - t.Log("🚀 Starting mcp-front...") - env.mcpCmd = exec.Command("../cmd/mcp-front/mcp-front", "-config", "config/config.test.json") - - // Capture stderr to log file if MCP_LOG_FILE is set - if logFile := os.Getenv("MCP_LOG_FILE"); logFile != "" { - f, err := os.OpenFile(logFile, os.O_CREATE|os.O_WRONLY|os.O_APPEND, 0644) - if err == nil { - env.mcpCmd.Stderr = f - env.mcpCmd.Stdout = f - t.Cleanup(func() { f.Close() }) - } - } - - if err := env.mcpCmd.Start(); err != nil { - t.Fatalf("Failed to start mcp-front: %v", err) - } - - time.Sleep(15 * time.Second) - - // Create and authenticate client - env.client = NewMCPSSEClient("http://localhost:8080") - if err := env.client.Authenticate(); err != nil { - t.Fatalf("Authentication failed: %v", err) - } - - return env -} - -// Cleanup stops all test environment components -func (env *TestEnvironment) Cleanup() { - if env.mcpCmd != nil && env.mcpCmd.Process != nil { - _ = env.mcpCmd.Process.Kill() - } - - if env.fakeGCP != nil { - _ = env.fakeGCP.Stop() - } - - if env.dbCmd != nil { - downCmd := execDockerCompose("-f", "config/docker-compose.test.yml", "down", "-v") - _ = downCmd.Run() - } -} - // TestConfig holds all timeout configurations for integration tests type TestConfig struct { SessionTimeout string @@ -865,9 +940,9 @@ func waitForMCPFront(t *testing.T) { t.Fatal("mcp-front failed to become ready after 10 seconds") } -// getMCPContainers returns a list of running mcp/postgres container IDs +// getMCPContainers returns a list of running toolbox container IDs func getMCPContainers() []string { - cmd := exec.Command("docker", "ps", "--format", "{{.ID}}", "--filter", "ancestor=mcp/postgres") + cmd := exec.Command("docker", "ps", "--format", "{{.ID}}", "--filter", "ancestor="+ToolboxImage) output, err := cmd.Output() if err != nil { return nil diff --git a/internal/idp/factory.go b/internal/idp/factory.go index 60004e1..ff55d8d 100644 --- a/internal/idp/factory.go +++ b/internal/idp/factory.go @@ -14,6 +14,9 @@ func NewProvider(cfg config.IDPConfig) (Provider, error) { cfg.ClientID, string(cfg.ClientSecret), cfg.RedirectURI, + cfg.AuthorizationURL, + cfg.TokenURL, + cfg.UserInfoURL, ), nil case "azure": diff --git a/internal/idp/google.go b/internal/idp/google.go index b05ca68..82e86e2 100644 --- a/internal/idp/google.go +++ b/internal/idp/google.go @@ -30,16 +30,31 @@ type googleUserInfoResponse struct { } // NewGoogleProvider creates a new Google OAuth provider. -func NewGoogleProvider(clientID, clientSecret, redirectURI string) *GoogleProvider { +// Optional endpoint overrides (authorizationURL, tokenURL, userInfoURL) allow +// pointing at a non-Google server — useful for testing or corporate proxies. +func NewGoogleProvider(clientID, clientSecret, redirectURI, authorizationURL, tokenURL, userInfoURL string) *GoogleProvider { + endpoint := google.Endpoint + if authorizationURL != "" { + endpoint.AuthURL = authorizationURL + } + if tokenURL != "" { + endpoint.TokenURL = tokenURL + } + + uiURL := "https://www.googleapis.com/oauth2/v2/userinfo" + if userInfoURL != "" { + uiURL = userInfoURL + } + return &GoogleProvider{ config: oauth2.Config{ ClientID: clientID, ClientSecret: clientSecret, RedirectURL: redirectURI, Scopes: []string{"openid", "profile", "email"}, - Endpoint: google.Endpoint, + Endpoint: endpoint, }, - userInfoURL: "https://www.googleapis.com/oauth2/v2/userinfo", + userInfoURL: uiURL, } } diff --git a/internal/idp/google_test.go b/internal/idp/google_test.go index ff618fb..662c194 100644 --- a/internal/idp/google_test.go +++ b/internal/idp/google_test.go @@ -13,12 +13,12 @@ import ( ) func TestGoogleProvider_Type(t *testing.T) { - provider := NewGoogleProvider("client-id", "client-secret", "https://example.com/callback") + provider := NewGoogleProvider("client-id", "client-secret", "https://example.com/callback", "", "", "") assert.Equal(t, "google", provider.Type()) } func TestGoogleProvider_AuthURL(t *testing.T) { - provider := NewGoogleProvider("client-id", "client-secret", "https://example.com/callback") + provider := NewGoogleProvider("client-id", "client-secret", "https://example.com/callback", "", "", "") authURL := provider.AuthURL("test-state") From 9b37a218dea7a22c19aa31faa3c0558493e03639 Mon Sep 17 00:00:00 2001 From: Sam El-Borai Date: Thu, 12 Feb 2026 12:31:02 +0100 Subject: [PATCH 6/7] chore: formatting --- integration/security_test.go | 8 ++++---- integration/test_utils.go | 8 ++++---- internal/idp/oidc.go | 1 - internal/jsonrpc/errors.go | 1 - internal/mcpfront.go | 1 - 5 files changed, 8 insertions(+), 11 deletions(-) diff --git a/integration/security_test.go b/integration/security_test.go index 013978d..653ca14 100644 --- a/integration/security_test.go +++ b/integration/security_test.go @@ -302,10 +302,10 @@ func TestFailureScenarios(t *testing.T) { waitForDB(t) cfg := buildTestConfig("http://localhost:8080", "mcp-front-test", - nil, - map[string]any{"postgres": testPostgresServer(withBearerTokens("test-token", "alt-test-token"), withLogEnabled())}, - ) - startMCPFront(t, writeTestConfig(t, cfg)) + nil, + map[string]any{"postgres": testPostgresServer(withBearerTokens("test-token", "alt-test-token"), withLogEnabled())}, + ) + startMCPFront(t, writeTestConfig(t, cfg)) // Wait for server to be ready waitForMCPFront(t) diff --git a/integration/test_utils.go b/integration/test_utils.go index bca2817..6557f13 100644 --- a/integration/test_utils.go +++ b/integration/test_utils.go @@ -104,8 +104,8 @@ func withUserToken() serverOption { // Uses hardcoded values suitable for integration tests with the fake GCP server. func testOAuthConfig() map[string]any { return map[string]any{ - "kind": "oauth", - "issuer": "http://localhost:8080", + "kind": "oauth", + "issuer": "http://localhost:8080", "gcpProject": "test-project", "idp": map[string]any{ "provider": "google", @@ -128,8 +128,8 @@ func testOAuthConfig() map[string]any { // testOAuthConfigFromEnv returns an OAuth auth config that reads secrets from env vars. func testOAuthConfigFromEnv() map[string]any { return map[string]any{ - "kind": "oauth", - "issuer": "http://localhost:8080", + "kind": "oauth", + "issuer": "http://localhost:8080", "gcpProject": "test-project", "idp": map[string]any{ "provider": "google", diff --git a/internal/idp/oidc.go b/internal/idp/oidc.go index 4d48877..807ac05 100644 --- a/internal/idp/oidc.go +++ b/internal/idp/oidc.go @@ -29,7 +29,6 @@ type OIDCConfig struct { ClientSecret string RedirectURI string Scopes []string - } // OIDCProvider implements the Provider interface for OIDC-compliant identity providers. diff --git a/internal/jsonrpc/errors.go b/internal/jsonrpc/errors.go index efdf152..8fcfcf1 100644 --- a/internal/jsonrpc/errors.go +++ b/internal/jsonrpc/errors.go @@ -46,4 +46,3 @@ func NewStandardError(code int) *Error { Message: message, } } - diff --git a/internal/mcpfront.go b/internal/mcpfront.go index 4d890a8..ec70773 100644 --- a/internal/mcpfront.go +++ b/internal/mcpfront.go @@ -605,4 +605,3 @@ func buildStdioSSEServer(serverName, baseURL string, sessionManager *client.Stdi return sseServer, mcpServer, nil } - From 05078eaff9527519b0f7ce6cdec5e2009cd9f13c Mon Sep 17 00:00:00 2001 From: Sam El-Borai Date: Thu, 12 Feb 2026 23:09:01 +0100 Subject: [PATCH 7/7] Multi-IDP integration tests and integration directory reorganization MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Add endpoint overrides to GitHub and Azure providers so they can point at local fake servers during testing, matching the existing Google provider pattern. Create FakeGitHubServer (port 9092) and FakeOIDCServer (port 9093) alongside FakeGCPServer to simulate all four supported IDPs. New integration tests exercise the full OAuth flow for GitHub, OIDC, and Azure providers end-to-end (register client → authorize → IDP redirect → callback → token exchange → MCP tools/list). Also tests org denial for GitHub and domain denial across all three providers. Reorganize integration directory: split test_utils.go (993 lines) into test_helpers.go, test_clients.go, test_fakes.go. Split oauth_test.go (1576 lines) into oauth_flow_test.go, oauth_user_tokens_test.go, oauth_service_test.go, oauth_rfc8707_test.go, oauth_idp_test.go. --- integration/integration_test.go | 46 +- integration/main_test.go | 24 +- integration/oauth_flow_test.go | 789 +++++++++++++ integration/oauth_idp_test.go | 256 ++++ integration/oauth_rfc8707_test.go | 198 ++++ integration/oauth_service_test.go | 88 ++ integration/oauth_test.go | 1575 ------------------------- integration/oauth_user_tokens_test.go | 536 +++++++++ integration/streamable_client.go | 1 - integration/test_clients.go | 442 +++++++ integration/test_fakes.go | 355 ++++++ integration/test_helpers.go | 463 ++++++++ integration/test_utils.go | 992 ---------------- internal/idp/azure.go | 27 +- internal/idp/azure_test.go | 2 +- internal/idp/factory.go | 6 + internal/idp/github.go | 21 +- internal/idp/github_test.go | 4 +- 18 files changed, 3222 insertions(+), 2603 deletions(-) create mode 100644 integration/oauth_flow_test.go create mode 100644 integration/oauth_idp_test.go create mode 100644 integration/oauth_rfc8707_test.go create mode 100644 integration/oauth_service_test.go delete mode 100644 integration/oauth_test.go create mode 100644 integration/oauth_user_tokens_test.go delete mode 100644 integration/streamable_client.go create mode 100644 integration/test_clients.go create mode 100644 integration/test_fakes.go create mode 100644 integration/test_helpers.go delete mode 100644 integration/test_utils.go diff --git a/integration/integration_test.go b/integration/integration_test.go index 2e46e84..2bd6473 100644 --- a/integration/integration_test.go +++ b/integration/integration_test.go @@ -75,23 +75,31 @@ func TestIntegration(t *testing.T) { assert.NotEmpty(t, content, "Query result missing content") t.Log("Query executed successfully") - // Test resources list - resourcesResult, err := client.SendMCPRequest("resources/list", map[string]any{}) - require.NoError(t, err, "Failed to list resources") - - t.Logf("Resources response: %+v", resourcesResult) - - // Check for error in resources response - errorMap, hasError = resourcesResult["error"].(map[string]any) - assert.False(t, hasError, "Resources list returned error: %v", errorMap) - - // Verify we got resources - resultMap, ok = resourcesResult["result"].(map[string]any) - require.True(t, ok, "Expected result in resources response") - - resources, ok := resultMap["resources"].([]any) - require.True(t, ok, "Expected resources array in result") - assert.NotEmpty(t, resources, "Expected at least one resource") - t.Logf("Found %d resources", len(resources)) - + // Test tools list + toolsResult, err := client.SendMCPRequest("tools/list", map[string]any{}) + require.NoError(t, err, "Failed to list tools") + + t.Logf("Tools response: %+v", toolsResult) + + errorMap, hasError = toolsResult["error"].(map[string]any) + assert.False(t, hasError, "Tools list returned error: %v", errorMap) + + resultMap, ok = toolsResult["result"].(map[string]any) + require.True(t, ok, "Expected result in tools response") + + tools, ok := resultMap["tools"].([]any) + require.True(t, ok, "Expected tools array in result") + assert.NotEmpty(t, tools, "Expected at least one tool") + t.Logf("Found %d tools", len(tools)) + + // Verify execute_sql tool is present + var toolNames []string + for _, tool := range tools { + if toolMap, ok := tool.(map[string]any); ok { + if name, ok := toolMap["name"].(string); ok { + toolNames = append(toolNames, name) + } + } + } + assert.Contains(t, toolNames, "execute_sql", "Should have execute_sql tool") } diff --git a/integration/main_test.go b/integration/main_test.go index 5673290..39527d7 100644 --- a/integration/main_test.go +++ b/integration/main_test.go @@ -62,7 +62,7 @@ func TestMain(m *testing.M) { os.Exit(exitCode) }() - // Start fake GCP server for OAuth + // Start fake GCP server for OAuth (port 9090) fakeGCP := NewFakeGCPServer("9090") err := fakeGCP.Start() if err != nil { @@ -74,6 +74,28 @@ func TestMain(m *testing.M) { _ = fakeGCP.Stop() }() + // Start fake GitHub server (port 9092) + fakeGitHub := NewFakeGitHubServer("9092", []string{"test-org", "another-org"}) + if err := fakeGitHub.Start(); err != nil { + fmt.Printf("Failed to start fake GitHub server: %v\n", err) + exitCode = 1 + return + } + defer func() { + _ = fakeGitHub.Stop() + }() + + // Start fake OIDC server (port 9093) — used for both OIDC and Azure tests + fakeOIDC := NewFakeOIDCServer("9093") + if err := fakeOIDC.Start(); err != nil { + fmt.Printf("Failed to start fake OIDC server: %v\n", err) + exitCode = 1 + return + } + defer func() { + _ = fakeOIDC.Stop() + }() + // Wait for database to be ready fmt.Println("Waiting for database to be ready...") for i := range 30 { // Wait up to 30 seconds diff --git a/integration/oauth_flow_test.go b/integration/oauth_flow_test.go new file mode 100644 index 0000000..c6e8217 --- /dev/null +++ b/integration/oauth_flow_test.go @@ -0,0 +1,789 @@ +package integration + +import ( + "bufio" + "bytes" + "encoding/json" + "fmt" + "io" + "net/http" + "net/url" + "os" + "os/exec" + "regexp" + "strings" + "testing" + "time" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +// TestBasicOAuthFlow tests the basic OAuth server functionality +func TestBasicOAuthFlow(t *testing.T) { + cfg := buildTestConfig("http://localhost:8080", "mcp-front-oauth-test", + testOAuthConfigFromEnv(), + map[string]any{"postgres": testPostgresServer()}, + ) + startMCPFront(t, writeTestConfig(t, cfg), + "JWT_SECRET=test-jwt-secret-32-bytes-exactly!", + "ENCRYPTION_KEY=test-encryption-key-32-bytes-ok!", + "GOOGLE_CLIENT_ID=test-client-id-for-oauth", + "GOOGLE_CLIENT_SECRET=test-client-secret-for-oauth", + "MCP_FRONT_ENV=development", + ) + + // Wait for startup + waitForMCPFront(t) + + // Test OAuth discovery + resp, err := http.Get("http://localhost:8080/.well-known/oauth-authorization-server") + require.NoError(t, err, "Failed to get OAuth discovery") + defer resp.Body.Close() + + assert.Equal(t, 200, resp.StatusCode, "OAuth discovery failed") + + var discovery map[string]any + err = json.NewDecoder(resp.Body).Decode(&discovery) + require.NoError(t, err, "Failed to decode discovery") + + // Verify required endpoints + requiredEndpoints := []string{ + "issuer", + "authorization_endpoint", + "token_endpoint", + "registration_endpoint", + } + + for _, endpoint := range requiredEndpoints { + _, ok := discovery[endpoint] + assert.True(t, ok, "Missing required endpoint: %s", endpoint) + } + + // Verify client_secret_post is advertised + authMethods, ok := discovery["token_endpoint_auth_methods_supported"].([]any) + assert.True(t, ok, "token_endpoint_auth_methods_supported should be present") + + var hasNone, hasClientSecretPost bool + for _, method := range authMethods { + if method == "none" { + hasNone = true + } + if method == "client_secret_post" { + hasClientSecretPost = true + } + } + assert.True(t, hasNone, "Should support 'none' auth method for public clients") + assert.True(t, hasClientSecretPost, "Should support 'client_secret_post' auth method for confidential clients") +} + +// TestJWTSecretValidation tests JWT secret length requirements +func TestJWTSecretValidation(t *testing.T) { + oauthCfg := buildTestConfig("http://localhost:8080", "mcp-front-oauth-test", + testOAuthConfigFromEnv(), + map[string]any{"postgres": testPostgresServer()}, + ) + configPath := writeTestConfig(t, oauthCfg) + + tests := []struct { + name string + secret string + shouldFail bool + }{ + {"Short 3-byte secret", "123", true}, + {"Short 16-byte secret", "sixteen-byte-key", true}, + {"Valid 32-byte secret", "demo-jwt-secret-32-bytes-exactly!", false}, + {"Long 64-byte secret", "demo-jwt-secret-32-bytes-exactly!demo-jwt-secret-32-bytes-exactly!", false}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + + mcpCmd := exec.Command("../cmd/mcp-front/mcp-front", "-config", configPath) + mcpCmd.Env = []string{ + "PATH=" + os.Getenv("PATH"), + "JWT_SECRET=" + tt.secret, + "ENCRYPTION_KEY=test-encryption-key-32-bytes-ok!", + "GOOGLE_CLIENT_ID=test-client-id", + "GOOGLE_CLIENT_SECRET=test-client-secret", + "MCP_FRONT_ENV=development", + } + + // Capture stderr + stderrPipe, _ := mcpCmd.StderrPipe() + scanner := bufio.NewScanner(stderrPipe) + + if err := mcpCmd.Start(); err != nil { + t.Fatalf("Failed to start mcp-front: %v", err) + } + + // Read stderr to check for errors + errorFound := false + go func() { + for scanner.Scan() { + line := scanner.Text() + if contains(line, "JWT secret must be at least") { + errorFound = true + } + } + }() + + // Give it time to start or fail + time.Sleep(2 * time.Second) + + // Check if it's running + healthy := checkHealth() + + // Clean up + if mcpCmd.Process != nil { + _ = mcpCmd.Process.Kill() + _ = mcpCmd.Wait() + } + + if tt.shouldFail { + assert.False(t, healthy && !errorFound, "Expected failure with short JWT secret but server started successfully") + } else { + assert.True(t, healthy, "Expected success with valid JWT secret but server failed to start") + } + }) + } +} + +// TestClientRegistration tests dynamic client registration (RFC 7591) +func TestClientRegistration(t *testing.T) { + // Start OAuth server + mcpCmd := startOAuthServer(t, map[string]string{ + "MCP_FRONT_ENV": "development", + }) + defer stopServer(mcpCmd) + + if !waitForHealthCheck(30) { + t.Fatal("OAuth server failed to start") + } + + t.Run("PublicClientRegistration", func(t *testing.T) { + // Register a public client (no secret) + clientReq := map[string]any{ + "redirect_uris": []string{"http://127.0.0.1:6274/oauth/callback/debug"}, + "scope": "read write", + } + + body, _ := json.Marshal(clientReq) + resp, err := http.Post( + "http://localhost:8080/register", + "application/json", + bytes.NewBuffer(body), + ) + if err != nil { + t.Fatalf("Failed to register client: %v", err) + } + defer resp.Body.Close() + + if resp.StatusCode != 201 { + body, _ := io.ReadAll(resp.Body) + t.Fatalf("Client registration failed with status %d: %s", resp.StatusCode, string(body)) + } + + var clientResp map[string]any + if err := json.NewDecoder(resp.Body).Decode(&clientResp); err != nil { + t.Fatalf("Failed to decode response: %v", err) + } + + // Verify response + if clientResp["client_id"] == "" { + t.Error("Client ID should not be empty") + } + if clientResp["client_secret"] != nil { + t.Error("Public client should not have a secret") + } + if scope, ok := clientResp["scope"].(string); !ok || scope != "read write" { + t.Errorf("Expected scope 'read write' as string, got: %v", clientResp["scope"]) + } + }) + + t.Run("MultipleRegistrations", func(t *testing.T) { + // Register multiple clients and verify they get different IDs + var clientIDs []string + + for i := range 3 { + clientReq := map[string]any{ + "redirect_uris": []string{fmt.Sprintf("http://example.com/callback%d", i)}, + "scope": "read", + } + + body, _ := json.Marshal(clientReq) + resp, err := http.Post( + "http://localhost:8080/register", + "application/json", + bytes.NewBuffer(body), + ) + if err != nil { + t.Fatalf("Failed to register client %d: %v", i, err) + } + defer resp.Body.Close() + + var clientResp map[string]any + _ = json.NewDecoder(resp.Body).Decode(&clientResp) + clientIDs = append(clientIDs, clientResp["client_id"].(string)) + } + + // Verify all IDs are unique + for i := 0; i < len(clientIDs); i++ { + for j := i + 1; j < len(clientIDs); j++ { + if clientIDs[i] == clientIDs[j] { + t.Errorf("Client IDs should be unique, but got duplicate: %s", clientIDs[i]) + } + } + } + + }) + + t.Run("ConfidentialClientRegistration", func(t *testing.T) { + // Register a confidential client with client_secret_post + clientReq := map[string]any{ + "redirect_uris": []string{"https://example.com/callback"}, + "scope": "read write", + "token_endpoint_auth_method": "client_secret_post", + } + + body, _ := json.Marshal(clientReq) + resp, err := http.Post( + "http://localhost:8080/register", + "application/json", + bytes.NewBuffer(body), + ) + if err != nil { + t.Fatalf("Failed to register confidential client: %v", err) + } + defer resp.Body.Close() + + if resp.StatusCode != 201 { + body, _ := io.ReadAll(resp.Body) + t.Fatalf("Confidential client registration failed with status %d: %s", resp.StatusCode, string(body)) + } + + var clientResp map[string]any + if err := json.NewDecoder(resp.Body).Decode(&clientResp); err != nil { + t.Fatalf("Failed to decode response: %v", err) + } + + // Verify response includes client_secret + if clientResp["client_id"] == "" { + t.Error("Client ID should not be empty") + } + clientSecret, ok := clientResp["client_secret"].(string) + if !ok || clientSecret == "" { + t.Error("Confidential client should receive a client_secret") + } + // Verify secret has reasonable length (base64 of 32 bytes) + if len(clientSecret) < 40 { + t.Errorf("Client secret seems too short: %d chars", len(clientSecret)) + } + + tokenAuthMethod, ok := clientResp["token_endpoint_auth_method"].(string) + if !ok || tokenAuthMethod != "client_secret_post" { + t.Errorf("Expected token_endpoint_auth_method 'client_secret_post', got: %v", clientResp["token_endpoint_auth_method"]) + } + + // Verify scope is returned as string + if scope, ok := clientResp["scope"].(string); !ok || scope != "read write" { + t.Errorf("Expected scope 'read write' as string, got: %v", clientResp["scope"]) + } + }) + + t.Run("PublicVsConfidentialClients", func(t *testing.T) { + // Test that public clients don't get secrets and confidential ones do + + // First, create a public client + publicReq := map[string]any{ + "redirect_uris": []string{"https://public.example.com/callback"}, + "scope": "read", + // No token_endpoint_auth_method specified - defaults to "none" + } + + body, _ := json.Marshal(publicReq) + resp, err := http.Post( + "http://localhost:8080/register", + "application/json", + bytes.NewBuffer(body), + ) + require.NoError(t, err) + defer resp.Body.Close() + + var publicResp map[string]any + _ = json.NewDecoder(resp.Body).Decode(&publicResp) + + // Verify public client has no secret + if _, hasSecret := publicResp["client_secret"]; hasSecret { + t.Error("Public client should not have a secret") + } + if authMethod := publicResp["token_endpoint_auth_method"]; authMethod != "none" { + t.Errorf("Public client should have auth method 'none', got: %v", authMethod) + } + + // Now create a confidential client + confidentialReq := map[string]any{ + "redirect_uris": []string{"https://confidential.example.com/callback"}, + "scope": "read write", + "token_endpoint_auth_method": "client_secret_post", + } + + body, _ = json.Marshal(confidentialReq) + resp, err = http.Post( + "http://localhost:8080/register", + "application/json", + bytes.NewBuffer(body), + ) + require.NoError(t, err) + defer resp.Body.Close() + + var confResp map[string]any + _ = json.NewDecoder(resp.Body).Decode(&confResp) + + // Verify confidential client has a secret + if secret, ok := confResp["client_secret"].(string); !ok || secret == "" { + t.Error("Confidential client should have a secret") + } + if authMethod := confResp["token_endpoint_auth_method"]; authMethod != "client_secret_post" { + t.Errorf("Confidential client should have auth method 'client_secret_post', got: %v", authMethod) + } + }) +} + +// TestStateParameterHandling tests OAuth state parameter requirements +func TestStateParameterHandling(t *testing.T) { + tests := []struct { + name string + environment string + state string + expectError bool + }{ + {"Production without state", "production", "", true}, + {"Production with state", "production", "secure-random-state", false}, + {"Development without state", "development", "", false}, // Should auto-generate + {"Development with state", "development", "test-state", false}, + } + + for _, tt := range tests { + // capture range variable + t.Run(tt.name, func(t *testing.T) { + // Start server with specific environment + mcpCmd := startOAuthServer(t, map[string]string{ + "MCP_FRONT_ENV": tt.environment, + }) + defer stopServer(mcpCmd) + + if !waitForHealthCheck(10) { + t.Fatal("Server failed to start") + } + + // Register a client first + clientID := registerTestClient(t) + + // Create authorization request + params := url.Values{ + "response_type": {"code"}, + "client_id": {clientID}, + "redirect_uri": {"http://127.0.0.1:6274/oauth/callback"}, + "code_challenge": {"test-challenge"}, + "code_challenge_method": {"S256"}, + "scope": {"read write"}, + } + if tt.state != "" { + params.Set("state", tt.state) + } + + authURL := fmt.Sprintf("http://localhost:8080/authorize?%s", params.Encode()) + + // Use a client that doesn't follow redirects + client := &http.Client{ + CheckRedirect: func(req *http.Request, via []*http.Request) error { + return http.ErrUseLastResponse + }, + } + resp, err := client.Get(authURL) + if err != nil { + t.Fatalf("Authorization request failed: %v", err) + } + defer resp.Body.Close() + + if tt.expectError { + // OAuth errors are returned as redirects with error parameters + if resp.StatusCode == 302 || resp.StatusCode == 303 { + location := resp.Header.Get("Location") + if strings.Contains(location, "error=") { + } else { + t.Errorf("Expected error redirect for %s, got redirect without error", tt.name) + } + } else if resp.StatusCode >= 400 { + } else { + t.Errorf("Expected error for %s, got status %d", tt.name, resp.StatusCode) + } + } else { + if resp.StatusCode == 302 || resp.StatusCode == 303 { + location := resp.Header.Get("Location") + if strings.Contains(location, "error=") { + t.Errorf("Unexpected error redirect for %s: %s", tt.name, location) + } + } else if resp.StatusCode < 400 { + } else { + body, _ := io.ReadAll(resp.Body) + t.Errorf("Expected success for %s, got status %d: %s", tt.name, resp.StatusCode, string(body)) + } + } + }) + } +} + +// TestEnvironmentModes tests development vs production mode differences +func TestEnvironmentModes(t *testing.T) { + t.Run("DevelopmentMode", func(t *testing.T) { + mcpCmd := startOAuthServer(t, map[string]string{ + "MCP_FRONT_ENV": "development", + }) + defer stopServer(mcpCmd) + + if !waitForHealthCheck(30) { + t.Fatal("Server failed to start") + } + + // In development mode, missing state should be auto-generated + clientID := registerTestClient(t) + + params := url.Values{ + "response_type": {"code"}, + "client_id": {clientID}, + "redirect_uri": {"http://127.0.0.1:6274/oauth/callback"}, + "code_challenge": {"test-challenge"}, + "code_challenge_method": {"S256"}, + "scope": {"read"}, + // Intentionally omitting state parameter + } + + // Use a client that doesn't follow redirects + client := &http.Client{ + CheckRedirect: func(req *http.Request, via []*http.Request) error { + return http.ErrUseLastResponse + }, + } + resp, err := client.Get("http://localhost:8080/authorize?" + params.Encode()) + if err != nil { + t.Fatalf("Failed to make auth request: %v", err) + } + defer resp.Body.Close() + + // Should redirect (302) not error + if resp.StatusCode >= 400 && resp.StatusCode != 302 { + t.Errorf("Development mode should handle missing state, got status %d", resp.StatusCode) + } + }) + + t.Run("ProductionMode", func(t *testing.T) { + mcpCmd := startOAuthServer(t, map[string]string{ + "MCP_FRONT_ENV": "production", + }) + defer stopServer(mcpCmd) + + if !waitForHealthCheck(30) { + t.Fatal("Server failed to start") + } + + // In production mode, state should be required + clientID := registerTestClient(t) + + params := url.Values{ + "response_type": {"code"}, + "client_id": {clientID}, + "redirect_uri": {"http://127.0.0.1:6274/oauth/callback"}, + "code_challenge": {"test-challenge"}, + "code_challenge_method": {"S256"}, + "scope": {"read"}, + // Intentionally omitting state parameter + } + + // Use a client that doesn't follow redirects + client := &http.Client{ + CheckRedirect: func(req *http.Request, via []*http.Request) error { + return http.ErrUseLastResponse + }, + } + resp, err := client.Get("http://localhost:8080/authorize?" + params.Encode()) + if err != nil { + t.Fatalf("Failed to make auth request: %v", err) + } + defer resp.Body.Close() + + // Should error - OAuth errors are returned as redirects + if resp.StatusCode == 302 || resp.StatusCode == 303 { + location := resp.Header.Get("Location") + if strings.Contains(location, "error=") { + } else { + t.Errorf("Expected error redirect in production mode, got redirect without error") + } + } else if resp.StatusCode >= 400 { + } else { + t.Errorf("Production mode should require state parameter, got status %d", resp.StatusCode) + } + }) +} + +// TestOAuthEndpoints tests all OAuth endpoints comprehensively +func TestOAuthEndpoints(t *testing.T) { + mcpCmd := startOAuthServer(t, map[string]string{ + "MCP_FRONT_ENV": "development", + }) + defer stopServer(mcpCmd) + + if !waitForHealthCheck(10) { + t.Fatal("Server failed to start") + } + + t.Run("Discovery", func(t *testing.T) { + resp, err := http.Get("http://localhost:8080/.well-known/oauth-authorization-server") + if err != nil { + t.Fatalf("Discovery request failed: %v", err) + } + defer resp.Body.Close() + + if resp.StatusCode != 200 { + t.Fatalf("Discovery failed with status %d", resp.StatusCode) + } + + var discovery map[string]any + if err := json.NewDecoder(resp.Body).Decode(&discovery); err != nil { + t.Fatalf("Failed to decode discovery response: %v", err) + } + + // Verify all required fields + required := []string{ + "issuer", + "authorization_endpoint", + "token_endpoint", + "registration_endpoint", + "response_types_supported", + "grant_types_supported", + "code_challenge_methods_supported", + } + + for _, field := range required { + if _, ok := discovery[field]; !ok { + t.Errorf("Missing required discovery field: %s", field) + } + } + + }) + + t.Run("HealthCheck", func(t *testing.T) { + resp, err := http.Get("http://localhost:8080/health") + if err != nil { + t.Fatalf("Health check failed: %v", err) + } + defer resp.Body.Close() + + if resp.StatusCode != 200 { + t.Errorf("Health check should return 200, got %d", resp.StatusCode) + } + + var health map[string]string + if err := json.NewDecoder(resp.Body).Decode(&health); err != nil { + t.Fatalf("Failed to decode health response: %v", err) + } + if health["status"] != "ok" { + t.Errorf("Expected status 'ok', got '%s'", health["status"]) + } + + }) +} + +// TestCORSHeaders tests CORS headers for Claude.ai compatibility +func TestCORSHeaders(t *testing.T) { + mcpCmd := startOAuthServer(t, map[string]string{ + "MCP_FRONT_ENV": "development", + }) + defer stopServer(mcpCmd) + + if !waitForHealthCheck(10) { + t.Fatal("Server failed to start") + } + + // Test preflight request + req, _ := http.NewRequest("OPTIONS", "http://localhost:8080/register", nil) + req.Header.Set("Origin", "https://claude.ai") + req.Header.Set("Access-Control-Request-Method", "POST") + req.Header.Set("Access-Control-Request-Headers", "content-type") + + resp, err := http.DefaultClient.Do(req) + if err != nil { + t.Fatalf("Preflight request failed: %v", err) + } + defer resp.Body.Close() + + if resp.StatusCode != 200 { + t.Errorf("Preflight should return 200, got %d", resp.StatusCode) + } + + // Check CORS headers + expectedHeaders := map[string]string{ + "Access-Control-Allow-Origin": "https://claude.ai", + "Access-Control-Allow-Methods": "GET, POST, OPTIONS", + "Access-Control-Allow-Headers": "Content-Type, Authorization, Cache-Control, mcp-protocol-version", + } + + for header, expected := range expectedHeaders { + actual := resp.Header.Get(header) + if actual != expected { + t.Errorf("Expected %s: '%s', got '%s'", header, expected, actual) + } + } + +} + +// Shared helpers for OAuth tests + +func startOAuthServer(t *testing.T, env map[string]string) *exec.Cmd { + cfg := buildTestConfig("http://localhost:8080", "mcp-front-oauth-test", + testOAuthConfigFromEnv(), + map[string]any{"postgres": testPostgresServer()}, + ) + configPath := writeTestConfig(t, cfg) + + mcpCmd := exec.Command("../cmd/mcp-front/mcp-front", "-config", configPath) + + // Set default environment + mcpCmd.Env = []string{ + "PATH=" + os.Getenv("PATH"), + "JWT_SECRET=demo-jwt-secret-32-bytes-exactly!", + "ENCRYPTION_KEY=test-encryption-key-32-bytes-ok!", + "GOOGLE_CLIENT_ID=test-client-id-oauth", + "GOOGLE_CLIENT_SECRET=test-client-secret-oauth", + } + + // Override with provided env + for key, value := range env { + mcpCmd.Env = append(mcpCmd.Env, fmt.Sprintf("%s=%s", key, value)) + } + + // Capture stderr for debugging and also output to test log + var stderr bytes.Buffer + mcpCmd.Stderr = io.MultiWriter(&stderr, os.Stderr) + + if err := mcpCmd.Start(); err != nil { + t.Fatalf("Failed to start OAuth server: %v", err) + } + + // Give a moment for immediate failures + time.Sleep(100 * time.Millisecond) + + // Check if process died immediately + if mcpCmd.ProcessState != nil { + t.Fatalf("OAuth server died immediately: %s", stderr.String()) + } + + return mcpCmd +} + +// startOAuthServerWithTokenConfig starts the OAuth server with user token configuration +func startOAuthServerWithTokenConfig(t *testing.T) *exec.Cmd { + // Start with user token config + mcpCmd := exec.Command("../cmd/mcp-front/mcp-front", "-config", "config/config.oauth-token-test.json") + + // Set default environment + mcpCmd.Env = []string{ + "PATH=" + os.Getenv("PATH"), + "JWT_SECRET=demo-jwt-secret-32-bytes-exactly!", + "ENCRYPTION_KEY=test-encryption-key-32-bytes-ok!", + "GOOGLE_CLIENT_ID=test-client-id-oauth", + "GOOGLE_CLIENT_SECRET=test-client-secret-oauth", + "MCP_FRONT_ENV=development", + } + + // Capture stderr for debugging and also output to test log + var stderr bytes.Buffer + mcpCmd.Stderr = io.MultiWriter(&stderr, os.Stderr) + + if err := mcpCmd.Start(); err != nil { + t.Fatalf("Failed to start OAuth server: %v", err) + } + + // Give a moment for immediate failures + time.Sleep(100 * time.Millisecond) + + // Check if process died immediately + if mcpCmd.ProcessState != nil { + t.Fatalf("OAuth server died immediately: %s", stderr.String()) + } + + return mcpCmd +} + +func stopServer(cmd *exec.Cmd) { + if cmd != nil && cmd.Process != nil { + _ = cmd.Process.Kill() + _ = cmd.Wait() + // Give the OS time to release the port + time.Sleep(100 * time.Millisecond) + } +} + +func waitForHealthCheck(seconds int) bool { + for range seconds { + if checkHealth() { + return true + } + time.Sleep(1 * time.Second) + } + return false +} + +func checkHealth() bool { + resp, err := http.Get("http://localhost:8080/health") + if err == nil && resp.StatusCode == 200 { + resp.Body.Close() + return true + } + if resp != nil { + resp.Body.Close() + } + return false +} + +func registerTestClient(t *testing.T) string { + clientReq := map[string]any{ + "redirect_uris": []string{"http://127.0.0.1:6274/oauth/callback"}, + "scope": "openid email profile read write", + } + + body, _ := json.Marshal(clientReq) + resp, err := http.Post( + "http://localhost:8080/register", + "application/json", + bytes.NewBuffer(body), + ) + if err != nil { + t.Fatalf("Failed to register client: %v", err) + } + defer resp.Body.Close() + + if resp.StatusCode != 201 { + body, _ := io.ReadAll(resp.Body) + t.Fatalf("Client registration failed: %d - %s", resp.StatusCode, string(body)) + } + + var clientResp map[string]any + _ = json.NewDecoder(resp.Body).Decode(&clientResp) + return clientResp["client_id"].(string) +} + +// extractCSRFToken extracts the CSRF token from the HTML response +func extractCSRFToken(t *testing.T, html string) string { + // Look for + re := regexp.MustCompile(`]+name="csrf_token"[^>]+value="([^"]+)"`) + matches := re.FindStringSubmatch(html) + require.GreaterOrEqual(t, len(matches), 2, "CSRF token not found in response") + return matches[1] +} + +// contains is a simple helper to check if string contains substring +func contains(s, substr string) bool { + return strings.Contains(s, substr) +} diff --git a/integration/oauth_idp_test.go b/integration/oauth_idp_test.go new file mode 100644 index 0000000..86b2171 --- /dev/null +++ b/integration/oauth_idp_test.go @@ -0,0 +1,256 @@ +package integration + +import ( + "io" + "net/http" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +// idpEnvGitHub is the set of env vars needed for GitHub IDP integration tests. +var idpEnvGitHub = []string{ + "JWT_SECRET=test-jwt-secret-32-bytes-exactly!", + "ENCRYPTION_KEY=test-encryption-key-32-bytes-ok!", + "GITHUB_CLIENT_SECRET=test-github-client-secret", + "MCP_FRONT_ENV=development", +} + +// idpEnvOIDC is the set of env vars needed for OIDC IDP integration tests. +var idpEnvOIDC = []string{ + "JWT_SECRET=test-jwt-secret-32-bytes-exactly!", + "ENCRYPTION_KEY=test-encryption-key-32-bytes-ok!", + "OIDC_CLIENT_SECRET=test-oidc-client-secret", + "MCP_FRONT_ENV=development", +} + +// idpEnvAzure is the set of env vars needed for Azure IDP integration tests. +var idpEnvAzure = []string{ + "JWT_SECRET=test-jwt-secret-32-bytes-exactly!", + "ENCRYPTION_KEY=test-encryption-key-32-bytes-ok!", + "AZURE_CLIENT_SECRET=test-azure-client-secret", + "MCP_FRONT_ENV=development", +} + +// TestGitHubOAuthFlow tests the full OAuth flow using the GitHub IDP. +func TestGitHubOAuthFlow(t *testing.T) { + cfg := buildTestConfig("http://localhost:8080", "mcp-front-github-test", + testGitHubOAuthConfig("test-org"), + map[string]any{"postgres": testPostgresServer()}, + ) + startMCPFront(t, writeTestConfig(t, cfg), idpEnvGitHub...) + waitForMCPFront(t) + + accessToken := getOAuthAccessTokenForIDP(t, "http://localhost:8080/postgres", "localhost:9092") + + mcpClient := NewMCPSSEClient("http://localhost:8080") + mcpClient.SetAuthToken(accessToken) + + err := mcpClient.Connect() + require.NoError(t, err, "Should connect to postgres with GitHub-issued token") + defer mcpClient.Close() + + toolsResp, err := mcpClient.SendMCPRequest("tools/list", map[string]any{}) + require.NoError(t, err, "Should list tools with GitHub-issued token") + + resultMap, ok := toolsResp["result"].(map[string]any) + require.True(t, ok, "Expected result in tools response") + tools, ok := resultMap["tools"].([]any) + require.True(t, ok, "Expected tools array") + assert.NotEmpty(t, tools, "Should have tools") +} + +// TestGitHubOrgDenial verifies that users not in allowed orgs are denied. +func TestGitHubOrgDenial(t *testing.T) { + cfg := buildTestConfig("http://localhost:8080", "mcp-front-github-org-deny", + testGitHubOAuthConfig("org-that-doesnt-match"), + map[string]any{"postgres": testPostgresServer()}, + ) + startMCPFront(t, writeTestConfig(t, cfg), idpEnvGitHub...) + waitForMCPFront(t) + + // Register client and start auth flow + clientID := registerTestClient(t) + + client := &http.Client{ + CheckRedirect: func(req *http.Request, via []*http.Request) error { + return http.ErrUseLastResponse + }, + } + + // Step 1: Authorize + authResp, err := client.Get("http://localhost:8080/authorize?response_type=code&client_id=" + clientID + + "&redirect_uri=http://127.0.0.1:6274/oauth/callback&code_challenge=test-challenge&code_challenge_method=S256" + + "&scope=openid+email+profile&state=test-state&resource=http://localhost:8080/postgres") + require.NoError(t, err) + defer authResp.Body.Close() + require.Contains(t, []int{302, 303}, authResp.StatusCode) + + // Step 2: Follow to GitHub fake + location := authResp.Header.Get("Location") + idpResp, err := client.Get(location) + require.NoError(t, err) + defer idpResp.Body.Close() + + // Step 3: Follow callback — should get an error (access denied), not a code + callbackLocation := idpResp.Header.Get("Location") + callbackResp, err := client.Get(callbackLocation) + require.NoError(t, err) + defer callbackResp.Body.Close() + + // The callback should either: + // - Return a 403 directly, or + // - Redirect to the redirect_uri with an error parameter + if callbackResp.StatusCode == 302 || callbackResp.StatusCode == 303 { + finalLocation := callbackResp.Header.Get("Location") + assert.Contains(t, finalLocation, "error=", "Should redirect with error for org denial") + } else { + // Direct error response + body, _ := io.ReadAll(callbackResp.Body) + assert.Contains(t, string(body), "denied", "Should contain denial message") + } +} + +// TestOIDCOAuthFlow tests the full OAuth flow using a generic OIDC provider. +func TestOIDCOAuthFlow(t *testing.T) { + cfg := buildTestConfig("http://localhost:8080", "mcp-front-oidc-test", + testOIDCOAuthConfig(), + map[string]any{"postgres": testPostgresServer()}, + ) + startMCPFront(t, writeTestConfig(t, cfg), idpEnvOIDC...) + waitForMCPFront(t) + + accessToken := getOAuthAccessTokenForIDP(t, "http://localhost:8080/postgres", "localhost:9093") + + mcpClient := NewMCPSSEClient("http://localhost:8080") + mcpClient.SetAuthToken(accessToken) + + err := mcpClient.Connect() + require.NoError(t, err, "Should connect to postgres with OIDC-issued token") + defer mcpClient.Close() + + toolsResp, err := mcpClient.SendMCPRequest("tools/list", map[string]any{}) + require.NoError(t, err, "Should list tools with OIDC-issued token") + + resultMap, ok := toolsResp["result"].(map[string]any) + require.True(t, ok, "Expected result in tools response") + tools, ok := resultMap["tools"].([]any) + require.True(t, ok, "Expected tools array") + assert.NotEmpty(t, tools, "Should have tools") +} + +// TestAzureOAuthFlow tests the full OAuth flow using the Azure IDP (backed by the OIDC fake server). +func TestAzureOAuthFlow(t *testing.T) { + cfg := buildTestConfig("http://localhost:8080", "mcp-front-azure-test", + testAzureOAuthConfig(), + map[string]any{"postgres": testPostgresServer()}, + ) + startMCPFront(t, writeTestConfig(t, cfg), idpEnvAzure...) + waitForMCPFront(t) + + accessToken := getOAuthAccessTokenForIDP(t, "http://localhost:8080/postgres", "localhost:9093") + + mcpClient := NewMCPSSEClient("http://localhost:8080") + mcpClient.SetAuthToken(accessToken) + + err := mcpClient.Connect() + require.NoError(t, err, "Should connect to postgres with Azure-issued token") + defer mcpClient.Close() + + toolsResp, err := mcpClient.SendMCPRequest("tools/list", map[string]any{}) + require.NoError(t, err, "Should list tools with Azure-issued token") + + resultMap, ok := toolsResp["result"].(map[string]any) + require.True(t, ok, "Expected result in tools response") + tools, ok := resultMap["tools"].([]any) + require.True(t, ok, "Expected tools array") + assert.NotEmpty(t, tools, "Should have tools") +} + +// TestIDPDomainDenial verifies that domain restrictions are enforced across providers. +func TestIDPDomainDenial(t *testing.T) { + tests := []struct { + name string + authConfig map[string]any + env []string + idpHost string + }{ + { + name: "GitHub", + authConfig: func() map[string]any { + cfg := testGitHubOAuthConfig("test-org") + cfg["allowedDomains"] = []string{"wrong-domain.com"} + return cfg + }(), + env: idpEnvGitHub, + idpHost: "localhost:9092", + }, + { + name: "OIDC", + authConfig: func() map[string]any { + cfg := testOIDCOAuthConfig() + cfg["allowedDomains"] = []string{"wrong-domain.com"} + return cfg + }(), + env: idpEnvOIDC, + idpHost: "localhost:9093", + }, + { + name: "Azure", + authConfig: func() map[string]any { + cfg := testAzureOAuthConfig() + cfg["allowedDomains"] = []string{"wrong-domain.com"} + return cfg + }(), + env: idpEnvAzure, + idpHost: "localhost:9093", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + cfg := buildTestConfig("http://localhost:8080", "mcp-front-domain-deny-"+tt.name, + tt.authConfig, + map[string]any{"postgres": testPostgresServer()}, + ) + startMCPFront(t, writeTestConfig(t, cfg), tt.env...) + waitForMCPFront(t) + + clientID := registerTestClient(t) + + client := &http.Client{ + CheckRedirect: func(req *http.Request, via []*http.Request) error { + return http.ErrUseLastResponse + }, + } + + authResp, err := client.Get("http://localhost:8080/authorize?response_type=code&client_id=" + clientID + + "&redirect_uri=http://127.0.0.1:6274/oauth/callback&code_challenge=test-challenge&code_challenge_method=S256" + + "&scope=openid+email+profile&state=test-state&resource=http://localhost:8080/postgres") + require.NoError(t, err) + defer authResp.Body.Close() + require.Contains(t, []int{302, 303}, authResp.StatusCode) + + location := authResp.Header.Get("Location") + require.Contains(t, location, tt.idpHost, "Should redirect to expected IDP") + + idpResp, err := client.Get(location) + require.NoError(t, err) + defer idpResp.Body.Close() + + callbackLocation := idpResp.Header.Get("Location") + callbackResp, err := client.Get(callbackLocation) + require.NoError(t, err) + defer callbackResp.Body.Close() + + if callbackResp.StatusCode == 302 || callbackResp.StatusCode == 303 { + finalLocation := callbackResp.Header.Get("Location") + assert.Contains(t, finalLocation, "error=", "Should redirect with error for domain denial") + } else { + body, _ := io.ReadAll(callbackResp.Body) + assert.Contains(t, string(body), "denied", "Should contain denial message") + } + }) + } +} diff --git a/integration/oauth_rfc8707_test.go b/integration/oauth_rfc8707_test.go new file mode 100644 index 0000000..94c01f2 --- /dev/null +++ b/integration/oauth_rfc8707_test.go @@ -0,0 +1,198 @@ +package integration + +import ( + "crypto/sha256" + "encoding/base64" + "encoding/json" + "net/http" + "net/url" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +// TestRFC8707ResourceIndicators validates RFC 8707 resource indicator functionality +func TestRFC8707ResourceIndicators(t *testing.T) { + startMCPFront(t, "config/config.oauth-rfc8707-test.json", + "JWT_SECRET=test-jwt-secret-32-bytes-exactly!", + "ENCRYPTION_KEY=test-encryption-key-32-bytes-ok!", + "GOOGLE_CLIENT_ID=test-client-id-for-oauth", + "GOOGLE_CLIENT_SECRET=test-client-secret-for-oauth", + "MCP_FRONT_ENV=development", + ) + + waitForMCPFront(t) + + t.Run("BaseProtectedResourceMetadataReturns404", func(t *testing.T) { + // Base metadata endpoint should return 404, directing clients to per-service endpoints + resp, err := http.Get("http://localhost:8080/.well-known/oauth-protected-resource") + require.NoError(t, err) + defer resp.Body.Close() + + assert.Equal(t, 404, resp.StatusCode, "Base protected resource metadata endpoint should return 404") + + var errResp map[string]any + err = json.NewDecoder(resp.Body).Decode(&errResp) + require.NoError(t, err) + + assert.Contains(t, errResp["message"], "per-service", "Error message should direct to per-service endpoints") + }) + + t.Run("PerServiceProtectedResourceMetadataEndpoint", func(t *testing.T) { + // Per-service metadata endpoint should return service-specific resource URI + resp, err := http.Get("http://localhost:8080/.well-known/oauth-protected-resource/test-sse") + require.NoError(t, err) + defer resp.Body.Close() + + assert.Equal(t, 200, resp.StatusCode, "Per-service protected resource metadata endpoint should exist") + + var metadata map[string]any + err = json.NewDecoder(resp.Body).Decode(&metadata) + require.NoError(t, err) + + // Resource should be service-specific, not base URL + assert.Equal(t, "http://localhost:8080/test-sse", metadata["resource"], + "Resource should be service-specific URL") + + authzServers, ok := metadata["authorization_servers"].([]any) + require.True(t, ok, "Should have authorization_servers array") + require.NotEmpty(t, authzServers) + assert.Equal(t, "http://localhost:8080", authzServers[0], + "Authorization server should be base issuer") + }) + + t.Run("UnknownServiceReturns404", func(t *testing.T) { + resp, err := http.Get("http://localhost:8080/.well-known/oauth-protected-resource/nonexistent-service") + require.NoError(t, err) + defer resp.Body.Close() + + assert.Equal(t, 404, resp.StatusCode, "Unknown service should return 404") + }) + + t.Run("TokenWithResourceParameter", func(t *testing.T) { + clientID := registerTestClient(t) + + codeVerifier := "test-code-verifier-that-is-at-least-43-characters-long" + h := sha256.New() + h.Write([]byte(codeVerifier)) + codeChallenge := base64.RawURLEncoding.EncodeToString(h.Sum(nil)) + + authParams := url.Values{ + "response_type": {"code"}, + "client_id": {clientID}, + "redirect_uri": {"http://127.0.0.1:6274/oauth/callback"}, + "code_challenge": {codeChallenge}, + "code_challenge_method": {"S256"}, + "scope": {"openid email profile"}, + "state": {"test-state"}, + "resource": {"http://localhost:8080/test-sse"}, + } + + client := &http.Client{ + CheckRedirect: func(req *http.Request, via []*http.Request) error { + return http.ErrUseLastResponse + }, + } + + authResp, err := client.Get("http://localhost:8080/authorize?" + authParams.Encode()) + require.NoError(t, err) + defer authResp.Body.Close() + + assert.Contains(t, []int{302, 303}, authResp.StatusCode, "Should redirect to Google OAuth") + + location := authResp.Header.Get("Location") + googleResp, err := client.Get(location) + require.NoError(t, err) + defer googleResp.Body.Close() + + callbackLocation := googleResp.Header.Get("Location") + callbackResp, err := client.Get(callbackLocation) + require.NoError(t, err) + defer callbackResp.Body.Close() + + finalURL, err := url.Parse(callbackResp.Header.Get("Location")) + require.NoError(t, err) + authCode := finalURL.Query().Get("code") + require.NotEmpty(t, authCode, "Should have authorization code") + + tokenParams := url.Values{ + "grant_type": {"authorization_code"}, + "code": {authCode}, + "redirect_uri": {"http://127.0.0.1:6274/oauth/callback"}, + "client_id": {clientID}, + "code_verifier": {codeVerifier}, + } + + tokenResp, err := http.PostForm("http://localhost:8080/token", tokenParams) + require.NoError(t, err) + defer tokenResp.Body.Close() + + require.Equal(t, 200, tokenResp.StatusCode, "Token exchange should succeed") + + var tokenData map[string]any + err = json.NewDecoder(tokenResp.Body).Decode(&tokenData) + require.NoError(t, err) + + testSSEToken := tokenData["access_token"].(string) + require.NotEmpty(t, testSSEToken, "Should have access token") + + t.Logf("Got token with test-sse audience: %s", testSSEToken[:20]+"...") + + // Verify token works for test-sse (matching audience) + req, _ := http.NewRequest("GET", "http://localhost:8080/test-sse/sse", nil) + req.Header.Set("Authorization", "Bearer "+testSSEToken) + req.Header.Set("Accept", "text/event-stream") + + sseResp, err := client.Do(req) + require.NoError(t, err) + defer sseResp.Body.Close() + + assert.Equal(t, 200, sseResp.StatusCode, + "Token with test-sse audience should access /test-sse/sse") + + // Verify token does NOT work for test-streamable (wrong audience) + req, _ = http.NewRequest("GET", "http://localhost:8080/test-streamable/sse", nil) + req.Header.Set("Authorization", "Bearer "+testSSEToken) + req.Header.Set("Accept", "text/event-stream") + + streamableResp, err := client.Do(req) + require.NoError(t, err) + defer streamableResp.Body.Close() + + assert.Equal(t, 401, streamableResp.StatusCode, + "Token with test-sse audience should NOT access /test-streamable/sse") + + wwwAuth := streamableResp.Header.Get("WWW-Authenticate") + assert.Contains(t, wwwAuth, "Bearer resource_metadata=", + "401 response should include RFC 9728 WWW-Authenticate header") + // Per RFC 9728 Section 5.2, the metadata URI should be service-specific + assert.Contains(t, wwwAuth, "/.well-known/oauth-protected-resource/test-streamable", + "401 response should point to per-service metadata endpoint") + }) + + t.Run("401ResponseIncludesServiceSpecificMetadataURI", func(t *testing.T) { + // Request to a protected endpoint without token should get 401 + // with service-specific metadata URI in WWW-Authenticate header + client := &http.Client{ + CheckRedirect: func(req *http.Request, via []*http.Request) error { + return http.ErrUseLastResponse + }, + } + + req, _ := http.NewRequest("GET", "http://localhost:8080/test-sse/sse", nil) + req.Header.Set("Accept", "text/event-stream") + + resp, err := client.Do(req) + require.NoError(t, err) + defer resp.Body.Close() + + assert.Equal(t, 401, resp.StatusCode, "Request without token should return 401") + + wwwAuth := resp.Header.Get("WWW-Authenticate") + assert.Contains(t, wwwAuth, "Bearer resource_metadata=", + "401 response should include RFC 9728 WWW-Authenticate header") + assert.Contains(t, wwwAuth, "/.well-known/oauth-protected-resource/test-sse", + "401 response should point to test-sse specific metadata endpoint") + }) +} diff --git a/integration/oauth_service_test.go b/integration/oauth_service_test.go new file mode 100644 index 0000000..20a8ca1 --- /dev/null +++ b/integration/oauth_service_test.go @@ -0,0 +1,88 @@ +package integration + +import ( + "io" + "net/http" + "net/http/cookiejar" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +// TestServiceOAuthIntegration validates the complete OAuth flow for external services +func TestServiceOAuthIntegration(t *testing.T) { + // Start fake service OAuth provider on port 9091 + fakeService := NewFakeServiceOAuthServer("9091") + err := fakeService.Start() + require.NoError(t, err) + defer func() { _ = fakeService.Stop() }() + + // Start mcp-front with OAuth service configuration + startMCPFront(t, "config/config.oauth-service-integration-test.json", + "JWT_SECRET=demo-jwt-secret-32-bytes-exactly!", + "ENCRYPTION_KEY=test-encryption-key-32-bytes-ok!", + "GOOGLE_CLIENT_ID=test-client-id-oauth", + "GOOGLE_CLIENT_SECRET=test-client-secret-oauth", + "TEST_SERVICE_CLIENT_ID=service-client-id", + "TEST_SERVICE_CLIENT_SECRET=service-client-secret", + "MCP_FRONT_ENV=development", + ) + + if !waitForHealthCheck(30) { + t.Fatal("Server failed to start") + } + + // For this test, we use browser SSO instead of OAuth client flow + // This simulates a user in the browser connecting services + jar, _ := cookiejar.New(nil) + client := &http.Client{Jar: jar} + + // Complete Google OAuth to get browser session + // Access /my/tokens which triggers SSO flow + resp, err := client.Get("http://localhost:8080/my/tokens") + require.NoError(t, err) + defer resp.Body.Close() + + // Should have completed SSO and landed on /my/tokens + require.Equal(t, http.StatusOK, resp.StatusCode) + + t.Run("ServiceOAuthConnectFlow", func(t *testing.T) { + // User clicks "Connect" for the service + req, _ := http.NewRequest("GET", "http://localhost:8080/oauth/connect?service=test-service", nil) + + resp, err := client.Do(req) + require.NoError(t, err) + defer resp.Body.Close() + + // Should complete OAuth flow and redirect back with success + // The http.Client automatically follows redirects: + // 1. /oauth/connect → redirects to localhost:9091/oauth/authorize + // 2. Fake service → redirects to /oauth/callback/test-service?code=... + // 3. Callback → stores token, redirects to /my/tokens with success message + + body, _ := io.ReadAll(resp.Body) + bodyStr := string(body) + + // Final page should show success + assert.Contains(t, bodyStr, "Successfully connected", "Should show success message after OAuth flow") + assert.Contains(t, bodyStr, "Test OAuth Service", "Should mention service name") + }) + + t.Run("ConnectedServiceShownOnTokenPage", func(t *testing.T) { + // After OAuth connection, service should appear as connected + req, _ := http.NewRequest("GET", "http://localhost:8080/my/tokens", nil) + + resp, err := client.Do(req) + require.NoError(t, err) + defer resp.Body.Close() + + body, _ := io.ReadAll(resp.Body) + bodyStr := string(body) + + // Should show the service with connected status + assert.Contains(t, bodyStr, "Test OAuth Service") + // OAuth-connected services show disconnect button, not connect + assert.Contains(t, bodyStr, "Disconnect", "OAuth-connected service should show Disconnect button") + }) +} diff --git a/integration/oauth_test.go b/integration/oauth_test.go deleted file mode 100644 index ed15878..0000000 --- a/integration/oauth_test.go +++ /dev/null @@ -1,1575 +0,0 @@ -package integration - -import ( - "bufio" - "bytes" - "context" - "crypto/sha256" - "encoding/base64" - "encoding/json" - "fmt" - "io" - "net/http" - "net/http/cookiejar" - "net/url" - "os" - "os/exec" - "regexp" - "strings" - "testing" - "time" - - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/mock" - "github.com/stretchr/testify/require" -) - -// TestBasicOAuthFlow tests the basic OAuth server functionality -func TestBasicOAuthFlow(t *testing.T) { - cfg := buildTestConfig("http://localhost:8080", "mcp-front-oauth-test", - testOAuthConfigFromEnv(), - map[string]any{"postgres": testPostgresServer()}, - ) - startMCPFront(t, writeTestConfig(t, cfg), - "JWT_SECRET=test-jwt-secret-32-bytes-exactly!", - "ENCRYPTION_KEY=test-encryption-key-32-bytes-ok!", - "GOOGLE_CLIENT_ID=test-client-id-for-oauth", - "GOOGLE_CLIENT_SECRET=test-client-secret-for-oauth", - "MCP_FRONT_ENV=development", - ) - - // Wait for startup - waitForMCPFront(t) - - // Test OAuth discovery - resp, err := http.Get("http://localhost:8080/.well-known/oauth-authorization-server") - require.NoError(t, err, "Failed to get OAuth discovery") - defer resp.Body.Close() - - assert.Equal(t, 200, resp.StatusCode, "OAuth discovery failed") - - var discovery map[string]any - err = json.NewDecoder(resp.Body).Decode(&discovery) - require.NoError(t, err, "Failed to decode discovery") - - // Verify required endpoints - requiredEndpoints := []string{ - "issuer", - "authorization_endpoint", - "token_endpoint", - "registration_endpoint", - } - - for _, endpoint := range requiredEndpoints { - _, ok := discovery[endpoint] - assert.True(t, ok, "Missing required endpoint: %s", endpoint) - } - - // Verify client_secret_post is advertised - authMethods, ok := discovery["token_endpoint_auth_methods_supported"].([]any) - assert.True(t, ok, "token_endpoint_auth_methods_supported should be present") - - var hasNone, hasClientSecretPost bool - for _, method := range authMethods { - if method == "none" { - hasNone = true - } - if method == "client_secret_post" { - hasClientSecretPost = true - } - } - assert.True(t, hasNone, "Should support 'none' auth method for public clients") - assert.True(t, hasClientSecretPost, "Should support 'client_secret_post' auth method for confidential clients") -} - -// TestJWTSecretValidation tests JWT secret length requirements -func TestJWTSecretValidation(t *testing.T) { - oauthCfg := buildTestConfig("http://localhost:8080", "mcp-front-oauth-test", - testOAuthConfigFromEnv(), - map[string]any{"postgres": testPostgresServer()}, - ) - configPath := writeTestConfig(t, oauthCfg) - - tests := []struct { - name string - secret string - shouldFail bool - }{ - {"Short 3-byte secret", "123", true}, - {"Short 16-byte secret", "sixteen-byte-key", true}, - {"Valid 32-byte secret", "demo-jwt-secret-32-bytes-exactly!", false}, - {"Long 64-byte secret", "demo-jwt-secret-32-bytes-exactly!demo-jwt-secret-32-bytes-exactly!", false}, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - - mcpCmd := exec.Command("../cmd/mcp-front/mcp-front", "-config", configPath) - mcpCmd.Env = []string{ - "PATH=" + os.Getenv("PATH"), - "JWT_SECRET=" + tt.secret, - "ENCRYPTION_KEY=test-encryption-key-32-bytes-ok!", - "GOOGLE_CLIENT_ID=test-client-id", - "GOOGLE_CLIENT_SECRET=test-client-secret", - "MCP_FRONT_ENV=development", - } - - // Capture stderr - stderrPipe, _ := mcpCmd.StderrPipe() - scanner := bufio.NewScanner(stderrPipe) - - if err := mcpCmd.Start(); err != nil { - t.Fatalf("Failed to start mcp-front: %v", err) - } - - // Read stderr to check for errors - errorFound := false - go func() { - for scanner.Scan() { - line := scanner.Text() - if contains(line, "JWT secret must be at least") { - errorFound = true - } - } - }() - - // Give it time to start or fail - time.Sleep(2 * time.Second) - - // Check if it's running - healthy := checkHealth() - - // Clean up - if mcpCmd.Process != nil { - _ = mcpCmd.Process.Kill() - _ = mcpCmd.Wait() - } - - if tt.shouldFail { - assert.False(t, healthy && !errorFound, "Expected failure with short JWT secret but server started successfully") - } else { - assert.True(t, healthy, "Expected success with valid JWT secret but server failed to start") - } - }) - } -} - -// TestClientRegistration tests dynamic client registration (RFC 7591) -func TestClientRegistration(t *testing.T) { - // Start OAuth server - mcpCmd := startOAuthServer(t, map[string]string{ - "MCP_FRONT_ENV": "development", - }) - defer stopServer(mcpCmd) - - if !waitForHealthCheck(30) { - t.Fatal("OAuth server failed to start") - } - - t.Run("PublicClientRegistration", func(t *testing.T) { - // Register a public client (no secret) - clientReq := map[string]any{ - "redirect_uris": []string{"http://127.0.0.1:6274/oauth/callback/debug"}, - "scope": "read write", - } - - body, _ := json.Marshal(clientReq) - resp, err := http.Post( - "http://localhost:8080/register", - "application/json", - bytes.NewBuffer(body), - ) - if err != nil { - t.Fatalf("Failed to register client: %v", err) - } - defer resp.Body.Close() - - if resp.StatusCode != 201 { - body, _ := io.ReadAll(resp.Body) - t.Fatalf("Client registration failed with status %d: %s", resp.StatusCode, string(body)) - } - - var clientResp map[string]any - if err := json.NewDecoder(resp.Body).Decode(&clientResp); err != nil { - t.Fatalf("Failed to decode response: %v", err) - } - - // Verify response - if clientResp["client_id"] == "" { - t.Error("Client ID should not be empty") - } - if clientResp["client_secret"] != nil { - t.Error("Public client should not have a secret") - } - if scope, ok := clientResp["scope"].(string); !ok || scope != "read write" { - t.Errorf("Expected scope 'read write' as string, got: %v", clientResp["scope"]) - } - }) - - t.Run("MultipleRegistrations", func(t *testing.T) { - // Register multiple clients and verify they get different IDs - var clientIDs []string - - for i := range 3 { - clientReq := map[string]any{ - "redirect_uris": []string{fmt.Sprintf("http://example.com/callback%d", i)}, - "scope": "read", - } - - body, _ := json.Marshal(clientReq) - resp, err := http.Post( - "http://localhost:8080/register", - "application/json", - bytes.NewBuffer(body), - ) - if err != nil { - t.Fatalf("Failed to register client %d: %v", i, err) - } - defer resp.Body.Close() - - var clientResp map[string]any - _ = json.NewDecoder(resp.Body).Decode(&clientResp) - clientIDs = append(clientIDs, clientResp["client_id"].(string)) - } - - // Verify all IDs are unique - for i := 0; i < len(clientIDs); i++ { - for j := i + 1; j < len(clientIDs); j++ { - if clientIDs[i] == clientIDs[j] { - t.Errorf("Client IDs should be unique, but got duplicate: %s", clientIDs[i]) - } - } - } - - }) - - t.Run("ConfidentialClientRegistration", func(t *testing.T) { - // Register a confidential client with client_secret_post - clientReq := map[string]any{ - "redirect_uris": []string{"https://example.com/callback"}, - "scope": "read write", - "token_endpoint_auth_method": "client_secret_post", - } - - body, _ := json.Marshal(clientReq) - resp, err := http.Post( - "http://localhost:8080/register", - "application/json", - bytes.NewBuffer(body), - ) - if err != nil { - t.Fatalf("Failed to register confidential client: %v", err) - } - defer resp.Body.Close() - - if resp.StatusCode != 201 { - body, _ := io.ReadAll(resp.Body) - t.Fatalf("Confidential client registration failed with status %d: %s", resp.StatusCode, string(body)) - } - - var clientResp map[string]any - if err := json.NewDecoder(resp.Body).Decode(&clientResp); err != nil { - t.Fatalf("Failed to decode response: %v", err) - } - - // Verify response includes client_secret - if clientResp["client_id"] == "" { - t.Error("Client ID should not be empty") - } - clientSecret, ok := clientResp["client_secret"].(string) - if !ok || clientSecret == "" { - t.Error("Confidential client should receive a client_secret") - } - // Verify secret has reasonable length (base64 of 32 bytes) - if len(clientSecret) < 40 { - t.Errorf("Client secret seems too short: %d chars", len(clientSecret)) - } - - tokenAuthMethod, ok := clientResp["token_endpoint_auth_method"].(string) - if !ok || tokenAuthMethod != "client_secret_post" { - t.Errorf("Expected token_endpoint_auth_method 'client_secret_post', got: %v", clientResp["token_endpoint_auth_method"]) - } - - // Verify scope is returned as string - if scope, ok := clientResp["scope"].(string); !ok || scope != "read write" { - t.Errorf("Expected scope 'read write' as string, got: %v", clientResp["scope"]) - } - }) - - t.Run("PublicVsConfidentialClients", func(t *testing.T) { - // Test that public clients don't get secrets and confidential ones do - - // First, create a public client - publicReq := map[string]any{ - "redirect_uris": []string{"https://public.example.com/callback"}, - "scope": "read", - // No token_endpoint_auth_method specified - defaults to "none" - } - - body, _ := json.Marshal(publicReq) - resp, err := http.Post( - "http://localhost:8080/register", - "application/json", - bytes.NewBuffer(body), - ) - require.NoError(t, err) - defer resp.Body.Close() - - var publicResp map[string]any - _ = json.NewDecoder(resp.Body).Decode(&publicResp) - - // Verify public client has no secret - if _, hasSecret := publicResp["client_secret"]; hasSecret { - t.Error("Public client should not have a secret") - } - if authMethod := publicResp["token_endpoint_auth_method"]; authMethod != "none" { - t.Errorf("Public client should have auth method 'none', got: %v", authMethod) - } - - // Now create a confidential client - confidentialReq := map[string]any{ - "redirect_uris": []string{"https://confidential.example.com/callback"}, - "scope": "read write", - "token_endpoint_auth_method": "client_secret_post", - } - - body, _ = json.Marshal(confidentialReq) - resp, err = http.Post( - "http://localhost:8080/register", - "application/json", - bytes.NewBuffer(body), - ) - require.NoError(t, err) - defer resp.Body.Close() - - var confResp map[string]any - _ = json.NewDecoder(resp.Body).Decode(&confResp) - - // Verify confidential client has a secret - if secret, ok := confResp["client_secret"].(string); !ok || secret == "" { - t.Error("Confidential client should have a secret") - } - if authMethod := confResp["token_endpoint_auth_method"]; authMethod != "client_secret_post" { - t.Errorf("Confidential client should have auth method 'client_secret_post', got: %v", authMethod) - } - }) -} - -// TestUserTokenFlow tests the user token management functionality with browser-based SSO -// This test expects the /my/* routes to work with Google SSO (session-based auth), -// not Bearer token auth. -func TestUserTokenFlow(t *testing.T) { - // Start OAuth server with user token configuration - mcpCmd := startOAuthServerWithTokenConfig(t) - defer stopServer(mcpCmd) - - if !waitForHealthCheck(30) { - t.Fatal("Server failed to start") - } - - // Create a client with cookie jar to simulate browser behavior - jar, _ := cookiejar.New(nil) - client := &http.Client{ - Jar: jar, - CheckRedirect: func(req *http.Request, via []*http.Request) error { - // Allow up to 10 redirects - if len(via) >= 10 { - return fmt.Errorf("too many redirects") - } - return nil - }, - } - - t.Run("UnauthenticatedRedirectsToSSO", func(t *testing.T) { - // Create a client that doesn't follow redirects to test the initial redirect - noRedirectClient := &http.Client{ - CheckRedirect: func(req *http.Request, via []*http.Request) error { - return http.ErrUseLastResponse - }, - } - - // Try to access /my/tokens without authentication - resp, err := noRedirectClient.Get("http://localhost:8080/my/tokens") - require.NoError(t, err) - defer resp.Body.Close() - - // Should get a redirect response - assert.Equal(t, http.StatusFound, resp.StatusCode, "Should get redirect status") - - // Check the redirect location - location := resp.Header.Get("Location") - assert.Contains(t, location, "localhost:9090/auth", "Should redirect to Google OAuth") - assert.Contains(t, location, "client_id=", "Should include client_id") - assert.Contains(t, location, "redirect_uri=", "Should include redirect_uri") - // Extract and validate the state parameter - parsedURL, err := url.Parse(location) - require.NoError(t, err) - stateParam := parsedURL.Query().Get("state") - require.NotEmpty(t, stateParam, "State parameter should be present") - - // State format: "browser:" prefix followed by signed token - // We verify structure but not internal format (that's implementation detail) - assert.True(t, strings.HasPrefix(stateParam, "browser:"), "State should start with browser:") - assert.Greater(t, len(stateParam), len("browser:"), "State should have content after prefix") - - // Verify state contains signature (has dot separator indicating signed data) - stateContent := strings.TrimPrefix(stateParam, "browser:") - assert.Contains(t, stateContent, ".", "Signed state should contain signature separator") - }) - - t.Run("AuthenticatedUserCanAccessTokens", func(t *testing.T) { - // The client with cookie jar will automatically follow the full SSO flow: - // 1. GET /my/tokens -> redirect to Google OAuth - // 2. Google OAuth redirects to /oauth/callback with code - // 3. Callback sets session cookie and redirects to /my/tokens - // 4. Client follows redirect with cookie and gets the page - - resp, err := client.Get("http://localhost:8080/my/tokens") - require.NoError(t, err) - defer resp.Body.Close() - - // After following all redirects, we should be at /my/tokens with 200 OK - assert.Equal(t, http.StatusOK, resp.StatusCode, "Should access /my/tokens after SSO") - finalURL := resp.Request.URL.String() - assert.Contains(t, finalURL, "/my/tokens", "Should end up at /my/tokens after SSO") - - // Read response body - body, err := io.ReadAll(resp.Body) - require.NoError(t, err) - bodyStr := string(body) - - // Should show both services without tokens - assert.Contains(t, bodyStr, "Notion", "Expected Notion service in response") - assert.Contains(t, bodyStr, "GitHub", "Expected GitHub service in response") - }) - - t.Run("SetTokenWithValidation", func(t *testing.T) { - // Assume we're already authenticated from previous test - // Get CSRF token first - resp, err := client.Get("http://localhost:8080/my/tokens") - require.NoError(t, err) - defer resp.Body.Close() - - body, err := io.ReadAll(resp.Body) - require.NoError(t, err) - - // Extract CSRF token from response - csrfToken := extractCSRFToken(t, string(body)) - - // Try to set invalid Notion token - form := url.Values{ - "service": []string{"notion"}, - "token": []string{"invalid-token"}, - "csrf_token": []string{csrfToken}, - } - - req, _ := http.NewRequest("POST", "http://localhost:8080/my/tokens/set", strings.NewReader(form.Encode())) - req.Header.Set("Content-Type", "application/x-www-form-urlencoded") - - // Use custom client that doesn't follow redirects for this test - noRedirectClient := &http.Client{ - Jar: jar, // Use same cookie jar - CheckRedirect: func(req *http.Request, via []*http.Request) error { - return http.ErrUseLastResponse - }, - } - - resp, err = noRedirectClient.Do(req) - require.NoError(t, err) - defer resp.Body.Close() - - // Should redirect with error - assert.Equal(t, http.StatusSeeOther, resp.StatusCode, "Expected redirect") - location := resp.Header.Get("Location") - assert.Contains(t, location, "error", "Expected error in redirect") - - // Get new CSRF token - resp, err = client.Get("http://localhost:8080/my/tokens") - require.NoError(t, err) - defer resp.Body.Close() - - body, err = io.ReadAll(resp.Body) - require.NoError(t, err) - csrfToken = extractCSRFToken(t, string(body)) - - // Set valid Notion token (regex expects exactly 43 chars after "secret_") - form = url.Values{ - "service": {"notion"}, - "token": {"secret_1234567890123456789012345678901234567890123"}, - "csrf_token": {csrfToken}, - } - - req, _ = http.NewRequest("POST", "http://localhost:8080/my/tokens/set", strings.NewReader(form.Encode())) - req.Header.Set("Content-Type", "application/x-www-form-urlencoded") - - resp, err = noRedirectClient.Do(req) - require.NoError(t, err) - defer resp.Body.Close() - - // Should redirect with success - assert.Equal(t, http.StatusSeeOther, resp.StatusCode, "Expected redirect") - location = resp.Header.Get("Location") - assert.Contains(t, location, "success", "Expected success in redirect") - }) -} - -// TestStateParameterHandling tests OAuth state parameter requirements -func TestStateParameterHandling(t *testing.T) { - tests := []struct { - name string - environment string - state string - expectError bool - }{ - {"Production without state", "production", "", true}, - {"Production with state", "production", "secure-random-state", false}, - {"Development without state", "development", "", false}, // Should auto-generate - {"Development with state", "development", "test-state", false}, - } - - for _, tt := range tests { - // capture range variable - t.Run(tt.name, func(t *testing.T) { - // Start server with specific environment - mcpCmd := startOAuthServer(t, map[string]string{ - "MCP_FRONT_ENV": tt.environment, - }) - defer stopServer(mcpCmd) - - if !waitForHealthCheck(10) { - t.Fatal("Server failed to start") - } - - // Register a client first - clientID := registerTestClient(t) - - // Create authorization request - params := url.Values{ - "response_type": {"code"}, - "client_id": {clientID}, - "redirect_uri": {"http://127.0.0.1:6274/oauth/callback"}, - "code_challenge": {"test-challenge"}, - "code_challenge_method": {"S256"}, - "scope": {"read write"}, - } - if tt.state != "" { - params.Set("state", tt.state) - } - - authURL := fmt.Sprintf("http://localhost:8080/authorize?%s", params.Encode()) - - // Use a client that doesn't follow redirects - client := &http.Client{ - CheckRedirect: func(req *http.Request, via []*http.Request) error { - return http.ErrUseLastResponse - }, - } - resp, err := client.Get(authURL) - if err != nil { - t.Fatalf("Authorization request failed: %v", err) - } - defer resp.Body.Close() - - if tt.expectError { - // OAuth errors are returned as redirects with error parameters - if resp.StatusCode == 302 || resp.StatusCode == 303 { - location := resp.Header.Get("Location") - if strings.Contains(location, "error=") { - } else { - t.Errorf("Expected error redirect for %s, got redirect without error", tt.name) - } - } else if resp.StatusCode >= 400 { - } else { - t.Errorf("Expected error for %s, got status %d", tt.name, resp.StatusCode) - } - } else { - if resp.StatusCode == 302 || resp.StatusCode == 303 { - location := resp.Header.Get("Location") - if strings.Contains(location, "error=") { - t.Errorf("Unexpected error redirect for %s: %s", tt.name, location) - } - } else if resp.StatusCode < 400 { - } else { - body, _ := io.ReadAll(resp.Body) - t.Errorf("Expected success for %s, got status %d: %s", tt.name, resp.StatusCode, string(body)) - } - } - }) - } -} - -// TestEnvironmentModes tests development vs production mode differences -func TestEnvironmentModes(t *testing.T) { - t.Run("DevelopmentMode", func(t *testing.T) { - mcpCmd := startOAuthServer(t, map[string]string{ - "MCP_FRONT_ENV": "development", - }) - defer stopServer(mcpCmd) - - if !waitForHealthCheck(30) { - t.Fatal("Server failed to start") - } - - // In development mode, missing state should be auto-generated - clientID := registerTestClient(t) - - params := url.Values{ - "response_type": {"code"}, - "client_id": {clientID}, - "redirect_uri": {"http://127.0.0.1:6274/oauth/callback"}, - "code_challenge": {"test-challenge"}, - "code_challenge_method": {"S256"}, - "scope": {"read"}, - // Intentionally omitting state parameter - } - - // Use a client that doesn't follow redirects - client := &http.Client{ - CheckRedirect: func(req *http.Request, via []*http.Request) error { - return http.ErrUseLastResponse - }, - } - resp, err := client.Get("http://localhost:8080/authorize?" + params.Encode()) - if err != nil { - t.Fatalf("Failed to make auth request: %v", err) - } - defer resp.Body.Close() - - // Should redirect (302) not error - if resp.StatusCode >= 400 && resp.StatusCode != 302 { - t.Errorf("Development mode should handle missing state, got status %d", resp.StatusCode) - } - }) - - t.Run("ProductionMode", func(t *testing.T) { - mcpCmd := startOAuthServer(t, map[string]string{ - "MCP_FRONT_ENV": "production", - }) - defer stopServer(mcpCmd) - - if !waitForHealthCheck(30) { - t.Fatal("Server failed to start") - } - - // In production mode, state should be required - clientID := registerTestClient(t) - - params := url.Values{ - "response_type": {"code"}, - "client_id": {clientID}, - "redirect_uri": {"http://127.0.0.1:6274/oauth/callback"}, - "code_challenge": {"test-challenge"}, - "code_challenge_method": {"S256"}, - "scope": {"read"}, - // Intentionally omitting state parameter - } - - // Use a client that doesn't follow redirects - client := &http.Client{ - CheckRedirect: func(req *http.Request, via []*http.Request) error { - return http.ErrUseLastResponse - }, - } - resp, err := client.Get("http://localhost:8080/authorize?" + params.Encode()) - if err != nil { - t.Fatalf("Failed to make auth request: %v", err) - } - defer resp.Body.Close() - - // Should error - OAuth errors are returned as redirects - if resp.StatusCode == 302 || resp.StatusCode == 303 { - location := resp.Header.Get("Location") - if strings.Contains(location, "error=") { - } else { - t.Errorf("Expected error redirect in production mode, got redirect without error") - } - } else if resp.StatusCode >= 400 { - } else { - t.Errorf("Production mode should require state parameter, got status %d", resp.StatusCode) - } - }) -} - -// TestOAuthEndpoints tests all OAuth endpoints comprehensively -func TestOAuthEndpoints(t *testing.T) { - mcpCmd := startOAuthServer(t, map[string]string{ - "MCP_FRONT_ENV": "development", - }) - defer stopServer(mcpCmd) - - if !waitForHealthCheck(10) { - t.Fatal("Server failed to start") - } - - t.Run("Discovery", func(t *testing.T) { - resp, err := http.Get("http://localhost:8080/.well-known/oauth-authorization-server") - if err != nil { - t.Fatalf("Discovery request failed: %v", err) - } - defer resp.Body.Close() - - if resp.StatusCode != 200 { - t.Fatalf("Discovery failed with status %d", resp.StatusCode) - } - - var discovery map[string]any - if err := json.NewDecoder(resp.Body).Decode(&discovery); err != nil { - t.Fatalf("Failed to decode discovery response: %v", err) - } - - // Verify all required fields - required := []string{ - "issuer", - "authorization_endpoint", - "token_endpoint", - "registration_endpoint", - "response_types_supported", - "grant_types_supported", - "code_challenge_methods_supported", - } - - for _, field := range required { - if _, ok := discovery[field]; !ok { - t.Errorf("Missing required discovery field: %s", field) - } - } - - }) - - t.Run("HealthCheck", func(t *testing.T) { - resp, err := http.Get("http://localhost:8080/health") - if err != nil { - t.Fatalf("Health check failed: %v", err) - } - defer resp.Body.Close() - - if resp.StatusCode != 200 { - t.Errorf("Health check should return 200, got %d", resp.StatusCode) - } - - var health map[string]string - if err := json.NewDecoder(resp.Body).Decode(&health); err != nil { - t.Fatalf("Failed to decode health response: %v", err) - } - if health["status"] != "ok" { - t.Errorf("Expected status 'ok', got '%s'", health["status"]) - } - - }) -} - -// TestCORSHeaders tests CORS headers for Claude.ai compatibility -func TestCORSHeaders(t *testing.T) { - mcpCmd := startOAuthServer(t, map[string]string{ - "MCP_FRONT_ENV": "development", - }) - defer stopServer(mcpCmd) - - if !waitForHealthCheck(10) { - t.Fatal("Server failed to start") - } - - // Test preflight request - req, _ := http.NewRequest("OPTIONS", "http://localhost:8080/register", nil) - req.Header.Set("Origin", "https://claude.ai") - req.Header.Set("Access-Control-Request-Method", "POST") - req.Header.Set("Access-Control-Request-Headers", "content-type") - - resp, err := http.DefaultClient.Do(req) - if err != nil { - t.Fatalf("Preflight request failed: %v", err) - } - defer resp.Body.Close() - - if resp.StatusCode != 200 { - t.Errorf("Preflight should return 200, got %d", resp.StatusCode) - } - - // Check CORS headers - expectedHeaders := map[string]string{ - "Access-Control-Allow-Origin": "https://claude.ai", - "Access-Control-Allow-Methods": "GET, POST, OPTIONS", - "Access-Control-Allow-Headers": "Content-Type, Authorization, Cache-Control, mcp-protocol-version", - } - - for header, expected := range expectedHeaders { - actual := resp.Header.Get(header) - if actual != expected { - t.Errorf("Expected %s: '%s', got '%s'", header, expected, actual) - } - } - -} - -// TestToolAdvertisementWithUserTokens tests that tools are advertised even without user tokens -// but fail gracefully when invoked without the required token, and succeed with the token -func TestToolAdvertisementWithUserTokens(t *testing.T) { - cfg := buildTestConfig("http://localhost:8080", "mcp-front-oauth-usertoken-test", - testOAuthConfigFromEnv(), - map[string]any{"postgres": testPostgresServer(withUserToken())}, - ) - startMCPFront(t, writeTestConfig(t, cfg), - "JWT_SECRET=demo-jwt-secret-32-bytes-exactly!", - "ENCRYPTION_KEY=test-encryption-key-32-bytes-ok!", - "GOOGLE_CLIENT_ID=test-client-id-oauth", - "GOOGLE_CLIENT_SECRET=test-client-secret-oauth", - "MCP_FRONT_ENV=development", - "LOG_LEVEL=debug", - ) - - if !waitForHealthCheck(30) { - t.Fatal("Server failed to start") - } - - // Complete OAuth flow to get a valid access token - accessToken := getOAuthAccessToken(t, "http://localhost:8080/postgres") - - t.Run("ToolsAdvertisedWithoutToken", func(t *testing.T) { - // Create MCP client with OAuth token - mcpClient := NewMCPSSEClient("http://localhost:8080") - mcpClient.SetAuthToken(accessToken) - - // Connect to postgres SSE endpoint - err := mcpClient.Connect() - require.NoError(t, err, "Should connect to postgres SSE endpoint without user token") - defer mcpClient.Close() - - // Request tools list - toolsResp, err := mcpClient.SendMCPRequest("tools/list", map[string]any{}) - require.NoError(t, err, "Should list tools without user token") - - // Verify we got tools - resultMap, ok := toolsResp["result"].(map[string]any) - require.True(t, ok, "Expected result in tools response") - - tools, ok := resultMap["tools"].([]any) - require.True(t, ok, "Expected tools array in result") - assert.NotEmpty(t, tools, "Should have tools advertised") - - // Check for common postgres tools - var toolNames []string - for _, tool := range tools { - if toolMap, ok := tool.(map[string]any); ok { - if name, ok := toolMap["name"].(string); ok { - toolNames = append(toolNames, name) - } - } - } - - assert.Contains(t, toolNames, "execute_sql", "Should have execute_sql tool") - t.Logf("Successfully advertised tools without user token: %v", toolNames) - }) - - t.Run("ToolInvocationFailsWithoutToken", func(t *testing.T) { - // Create MCP client with OAuth token - mcpClient := NewMCPSSEClient("http://localhost:8080") - mcpClient.SetAuthToken(accessToken) - - // Connect to postgres SSE endpoint - err := mcpClient.Connect() - require.NoError(t, err) - defer mcpClient.Close() - - // Try to invoke a tool without user token - queryParams := map[string]any{ - "name": "execute_sql", - "arguments": map[string]any{ - "sql": "SELECT 1", - }, - } - - result, err := mcpClient.SendMCPRequest("tools/call", queryParams) - require.NoError(t, err, "Should get response even without token") - - // MCP protocol returns errors as successful responses with error content - require.NotNil(t, result["result"], "Should have result in response") - - resultMap := result["result"].(map[string]any) - content := resultMap["content"].([]any) - require.NotEmpty(t, content, "Should have content in result") - - contentItem := content[0].(map[string]any) - errorJSON := contentItem["text"].(string) - - // Parse the error JSON - var errorData map[string]any - err = json.Unmarshal([]byte(errorJSON), &errorData) - require.NoError(t, err, "Error should be valid JSON") - - // Verify error structure - errorInfo := errorData["error"].(map[string]any) - assert.Equal(t, "token_required", errorInfo["code"], "Error code should be token_required") - - errorMessage := errorInfo["message"].(string) - assert.Contains(t, errorMessage, "token required", "Error should mention token required") - assert.Contains(t, errorMessage, "/my/tokens", "Error should mention token setup URL") - assert.Contains(t, errorMessage, "Test Service", "Error should mention service name") - - // Verify error data - errData := errorInfo["data"].(map[string]any) - assert.Equal(t, "postgres", errData["service"], "Should identify the service") - assert.Contains(t, errData["tokenSetupUrl"].(string), "/my/tokens", "Should include token setup URL") - - // Verify instructions - instructions := errData["instructions"].(map[string]any) - assert.Contains(t, instructions["ai"].(string), "CRITICAL", "Should have AI instructions") - assert.Contains(t, instructions["human"].(string), "token required", "Should have human instructions") - }) - - t.Run("ToolInvocationSucceedsWithUserToken", func(t *testing.T) { - // Step 1: GET /my/tokens to extract CSRF token - jar, err := cookiejar.New(nil) - require.NoError(t, err) - client := &http.Client{ - Jar: jar, // Need cookie jar for CSRF - CheckRedirect: func(req *http.Request, via []*http.Request) error { - return http.ErrUseLastResponse // Don't follow redirects - }, - } - - req, err := http.NewRequest("GET", "http://localhost:8080/my/tokens", nil) - require.NoError(t, err) - req.Header.Set("Authorization", "Bearer "+accessToken) - - resp, err := client.Do(req) - require.NoError(t, err) - defer resp.Body.Close() - - // Check if we got the page or a redirect - if resp.StatusCode == 302 || resp.StatusCode == 303 { - // Follow the redirect - location := resp.Header.Get("Location") - t.Logf("Got redirect to: %s", location) - - // Allow redirects for this request - client = &http.Client{ - Jar: jar, - } - - req, err = http.NewRequest("GET", "http://localhost:8080/my/tokens", nil) - require.NoError(t, err) - req.Header.Set("Authorization", "Bearer "+accessToken) - - resp, err = client.Do(req) - require.NoError(t, err) - defer resp.Body.Close() - } - - require.Equal(t, 200, resp.StatusCode, "Should be able to access token page") - - // Extract CSRF token from HTML - body, err := io.ReadAll(resp.Body) - require.NoError(t, err) - - // Look for the CSRF token in the form - csrfRegex := regexp.MustCompile(`name="csrf_token" value="([^"]+)"`) - matches := csrfRegex.FindSubmatch(body) - require.Len(t, matches, 2, "Should find CSRF token in form") - csrfToken := string(matches[1]) - - // Step 2: POST to /my/tokens/set with test token - formData := url.Values{ - "service": {"postgres"}, - "token": {"test-user-token-12345"}, - "csrf_token": {csrfToken}, - } - - req, err = http.NewRequest("POST", "http://localhost:8080/my/tokens/set", strings.NewReader(formData.Encode())) - require.NoError(t, err) - req.Header.Set("Authorization", "Bearer "+accessToken) - req.Header.Set("Content-Type", "application/x-www-form-urlencoded") - - resp, err = client.Do(req) - require.NoError(t, err) - defer resp.Body.Close() - - // Check the response - it might be 200 if following redirects - switch resp.StatusCode { - case 200: - // That's fine, it means the token was set and we got the page back - t.Log("Token set successfully, got page response") - case 302, 303: - // Also fine, redirect means success - t.Log("Token set successfully, got redirect") - default: - body, _ := io.ReadAll(resp.Body) - t.Fatalf("Unexpected response setting token: status=%d, body=%s", resp.StatusCode, string(body)) - } - - // Step 3: Now test tool invocation with the token - mcpClient := NewMCPSSEClient("http://localhost:8080") - mcpClient.SetAuthToken(accessToken) - - err = mcpClient.Connect() - require.NoError(t, err, "Should connect to postgres SSE endpoint") - defer mcpClient.Close() - - // Call the query tool with a simple query - queryParams := map[string]any{ - "name": "execute_sql", - "arguments": map[string]any{ - "sql": "SELECT 1 as test", - }, - } - - result, err := mcpClient.SendMCPRequest("tools/call", queryParams) - require.NoError(t, err, "Should successfully call tool with token") - - // Verify we got a successful result, not an error - require.NotNil(t, result["result"], "Should have result in response") - - resultMap := result["result"].(map[string]any) - content := resultMap["content"].([]any) - require.NotEmpty(t, content, "Should have content in result") - - contentItem := content[0].(map[string]any) - resultText := contentItem["text"].(string) - - // The result should contain actual query results, not an error - assert.NotContains(t, resultText, "token_required", "Should not have token error") - assert.NotContains(t, resultText, "Token Required", "Should not have token error message") - - // Postgres query result should contain our test value - assert.Contains(t, resultText, "1", "Should contain query result") - - t.Log("Successfully invoked tool with user token") - }) -} - -// Helper functions - -func startOAuthServer(t *testing.T, env map[string]string) *exec.Cmd { - cfg := buildTestConfig("http://localhost:8080", "mcp-front-oauth-test", - testOAuthConfigFromEnv(), - map[string]any{"postgres": testPostgresServer()}, - ) - configPath := writeTestConfig(t, cfg) - - mcpCmd := exec.Command("../cmd/mcp-front/mcp-front", "-config", configPath) - - // Set default environment - mcpCmd.Env = []string{ - "PATH=" + os.Getenv("PATH"), - "JWT_SECRET=demo-jwt-secret-32-bytes-exactly!", - "ENCRYPTION_KEY=test-encryption-key-32-bytes-ok!", - "GOOGLE_CLIENT_ID=test-client-id-oauth", - "GOOGLE_CLIENT_SECRET=test-client-secret-oauth", - } - - // Override with provided env - for key, value := range env { - mcpCmd.Env = append(mcpCmd.Env, fmt.Sprintf("%s=%s", key, value)) - } - - // Capture stderr for debugging and also output to test log - var stderr bytes.Buffer - mcpCmd.Stderr = io.MultiWriter(&stderr, os.Stderr) - - if err := mcpCmd.Start(); err != nil { - t.Fatalf("Failed to start OAuth server: %v", err) - } - - // Give a moment for immediate failures - time.Sleep(100 * time.Millisecond) - - // Check if process died immediately - if mcpCmd.ProcessState != nil { - t.Fatalf("OAuth server died immediately: %s", stderr.String()) - } - - return mcpCmd -} - -// startOAuthServerWithTokenConfig starts the OAuth server with user token configuration -func startOAuthServerWithTokenConfig(t *testing.T) *exec.Cmd { - // Start with user token config - mcpCmd := exec.Command("../cmd/mcp-front/mcp-front", "-config", "config/config.oauth-token-test.json") - - // Set default environment - mcpCmd.Env = []string{ - "PATH=" + os.Getenv("PATH"), - "JWT_SECRET=demo-jwt-secret-32-bytes-exactly!", - "ENCRYPTION_KEY=test-encryption-key-32-bytes-ok!", - "GOOGLE_CLIENT_ID=test-client-id-oauth", - "GOOGLE_CLIENT_SECRET=test-client-secret-oauth", - "MCP_FRONT_ENV=development", - } - - // Capture stderr for debugging and also output to test log - var stderr bytes.Buffer - mcpCmd.Stderr = io.MultiWriter(&stderr, os.Stderr) - - if err := mcpCmd.Start(); err != nil { - t.Fatalf("Failed to start OAuth server: %v", err) - } - - // Give a moment for immediate failures - time.Sleep(100 * time.Millisecond) - - // Check if process died immediately - if mcpCmd.ProcessState != nil { - t.Fatalf("OAuth server died immediately: %s", stderr.String()) - } - - return mcpCmd -} - -func stopServer(cmd *exec.Cmd) { - if cmd != nil && cmd.Process != nil { - _ = cmd.Process.Kill() - _ = cmd.Wait() - // Give the OS time to release the port - time.Sleep(100 * time.Millisecond) - } -} - -func waitForHealthCheck(seconds int) bool { - for range seconds { - if checkHealth() { - return true - } - time.Sleep(1 * time.Second) - } - return false -} - -func checkHealth() bool { - resp, err := http.Get("http://localhost:8080/health") - if err == nil && resp.StatusCode == 200 { - resp.Body.Close() - return true - } - if resp != nil { - resp.Body.Close() - } - return false -} - -func registerTestClient(t *testing.T) string { - clientReq := map[string]any{ - "redirect_uris": []string{"http://127.0.0.1:6274/oauth/callback"}, - "scope": "openid email profile read write", - } - - body, _ := json.Marshal(clientReq) - resp, err := http.Post( - "http://localhost:8080/register", - "application/json", - bytes.NewBuffer(body), - ) - if err != nil { - t.Fatalf("Failed to register client: %v", err) - } - defer resp.Body.Close() - - if resp.StatusCode != 201 { - body, _ := io.ReadAll(resp.Body) - t.Fatalf("Client registration failed: %d - %s", resp.StatusCode, string(body)) - } - - var clientResp map[string]any - _ = json.NewDecoder(resp.Body).Decode(&clientResp) - return clientResp["client_id"].(string) -} - -// extractCSRFToken extracts the CSRF token from the HTML response -func extractCSRFToken(t *testing.T, html string) string { - // Look for - re := regexp.MustCompile(`]+name="csrf_token"[^>]+value="([^"]+)"`) - matches := re.FindStringSubmatch(html) - require.GreaterOrEqual(t, len(matches), 2, "CSRF token not found in response") - return matches[1] -} - -// contains is a simple helper to check if string contains substring -func contains(s, substr string) bool { - return strings.Contains(s, substr) -} - -// TestServiceOAuthIntegration validates the complete OAuth flow for external services -func TestServiceOAuthIntegration(t *testing.T) { - // Start fake service OAuth provider on port 9091 - fakeService := NewFakeServiceOAuthServer("9091") - err := fakeService.Start() - require.NoError(t, err) - defer func() { _ = fakeService.Stop() }() - - // Start mcp-front with OAuth service configuration - startMCPFront(t, "config/config.oauth-service-integration-test.json", - "JWT_SECRET=demo-jwt-secret-32-bytes-exactly!", - "ENCRYPTION_KEY=test-encryption-key-32-bytes-ok!", - "GOOGLE_CLIENT_ID=test-client-id-oauth", - "GOOGLE_CLIENT_SECRET=test-client-secret-oauth", - "TEST_SERVICE_CLIENT_ID=service-client-id", - "TEST_SERVICE_CLIENT_SECRET=service-client-secret", - "MCP_FRONT_ENV=development", - ) - - if !waitForHealthCheck(30) { - t.Fatal("Server failed to start") - } - - // For this test, we use browser SSO instead of OAuth client flow - // This simulates a user in the browser connecting services - jar, _ := cookiejar.New(nil) - client := &http.Client{Jar: jar} - - // Complete Google OAuth to get browser session - // Access /my/tokens which triggers SSO flow - resp, err := client.Get("http://localhost:8080/my/tokens") - require.NoError(t, err) - defer resp.Body.Close() - - // Should have completed SSO and landed on /my/tokens - require.Equal(t, http.StatusOK, resp.StatusCode) - - t.Run("ServiceOAuthConnectFlow", func(t *testing.T) { - // User clicks "Connect" for the service - req, _ := http.NewRequest("GET", "http://localhost:8080/oauth/connect?service=test-service", nil) - - resp, err := client.Do(req) - require.NoError(t, err) - defer resp.Body.Close() - - // Should complete OAuth flow and redirect back with success - // The http.Client automatically follows redirects: - // 1. /oauth/connect → redirects to localhost:9091/oauth/authorize - // 2. Fake service → redirects to /oauth/callback/test-service?code=... - // 3. Callback → stores token, redirects to /my/tokens with success message - - body, _ := io.ReadAll(resp.Body) - bodyStr := string(body) - - // Final page should show success - assert.Contains(t, bodyStr, "Successfully connected", "Should show success message after OAuth flow") - assert.Contains(t, bodyStr, "Test OAuth Service", "Should mention service name") - }) - - t.Run("ConnectedServiceShownOnTokenPage", func(t *testing.T) { - // After OAuth connection, service should appear as connected - req, _ := http.NewRequest("GET", "http://localhost:8080/my/tokens", nil) - - resp, err := client.Do(req) - require.NoError(t, err) - defer resp.Body.Close() - - body, _ := io.ReadAll(resp.Body) - bodyStr := string(body) - - // Should show the service with connected status - assert.Contains(t, bodyStr, "Test OAuth Service") - // OAuth-connected services show disconnect button, not connect - assert.Contains(t, bodyStr, "Disconnect", "OAuth-connected service should show Disconnect button") - }) -} - -// getOAuthAccessToken completes the OAuth flow and returns a valid access token -func getOAuthAccessToken(t *testing.T, resource string) string { - // Register a test client - clientID := registerTestClient(t) - - // Generate PKCE challenge (must be at least 43 characters) - codeVerifier := "test-code-verifier-that-is-at-least-43-characters-long" - h := sha256.New() - h.Write([]byte(codeVerifier)) - codeChallenge := base64.RawURLEncoding.EncodeToString(h.Sum(nil)) - - // Step 1: Authorization request - authParams := url.Values{ - "response_type": {"code"}, - "client_id": {clientID}, - "redirect_uri": {"http://127.0.0.1:6274/oauth/callback"}, - "code_challenge": {codeChallenge}, - "code_challenge_method": {"S256"}, - "scope": {"openid email profile"}, - "state": {"test-state"}, - "resource": {resource}, - } - - client := &http.Client{ - CheckRedirect: func(req *http.Request, via []*http.Request) error { - return http.ErrUseLastResponse - }, - } - - authResp, err := client.Get("http://localhost:8080/authorize?" + authParams.Encode()) - require.NoError(t, err) - defer authResp.Body.Close() - - // Should redirect to Google OAuth (our mock) - require.Contains(t, []int{302, 303}, authResp.StatusCode, "Should redirect to Google OAuth") - location := authResp.Header.Get("Location") - require.Contains(t, location, "http://localhost:9090/auth", "Should redirect to mock Google OAuth") - - // Parse the redirect to get the state parameter - redirectURL, err := url.Parse(location) - require.NoError(t, err) - _ = redirectURL.Query().Get("state") // state is included in the redirect but not needed for this test - - // Step 2: Follow redirect to mock Google OAuth (which immediately redirects back) - googleResp, err := client.Get(location) - require.NoError(t, err) - defer googleResp.Body.Close() - - // Mock Google OAuth redirects back to callback - require.Contains(t, []int{302, 303}, googleResp.StatusCode, "Mock Google should redirect back") - callbackLocation := googleResp.Header.Get("Location") - require.Contains(t, callbackLocation, "/oauth/callback", "Should redirect to callback") - - // Step 3: Follow callback redirect - callbackResp, err := client.Get(callbackLocation) - require.NoError(t, err) - defer callbackResp.Body.Close() - - // Should redirect to the original redirect_uri with authorization code - require.Contains(t, []int{302, 303}, callbackResp.StatusCode, "Callback should redirect with code") - finalLocation := callbackResp.Header.Get("Location") - - // Parse authorization code from final redirect - finalURL, err := url.Parse(finalLocation) - require.NoError(t, err) - authCode := finalURL.Query().Get("code") - require.NotEmpty(t, authCode, "Should have authorization code") - - // Step 4: Exchange code for token - tokenParams := url.Values{ - "grant_type": {"authorization_code"}, - "code": {authCode}, - "redirect_uri": {"http://127.0.0.1:6274/oauth/callback"}, - "client_id": {clientID}, - "code_verifier": {codeVerifier}, - } - - tokenResp, err := http.PostForm("http://localhost:8080/token", tokenParams) - require.NoError(t, err) - defer tokenResp.Body.Close() - - if tokenResp.StatusCode != 200 { - body, _ := io.ReadAll(tokenResp.Body) - t.Logf("Token exchange failed with status %d: %s", tokenResp.StatusCode, string(body)) - } - - require.Equal(t, 200, tokenResp.StatusCode, "Token exchange should succeed") - - var tokenData map[string]any - err = json.NewDecoder(tokenResp.Body).Decode(&tokenData) - require.NoError(t, err) - - accessToken := tokenData["access_token"].(string) - require.NotEmpty(t, accessToken, "Should have access token") - - return accessToken -} - -// MockUserTokenStore mocks the UserTokenStore interface for testing -type MockUserTokenStore struct { - mock.Mock -} - -func (m *MockUserTokenStore) GetUserToken(ctx context.Context, email, service string) (string, error) { - args := m.Called(ctx, email, service) - return args.String(0), args.Error(1) -} - -func (m *MockUserTokenStore) SetUserToken(ctx context.Context, email, service, token string) error { - args := m.Called(ctx, email, service, token) - return args.Error(0) -} - -func (m *MockUserTokenStore) DeleteUserToken(ctx context.Context, email, service string) error { - args := m.Called(ctx, email, service) - return args.Error(0) -} - -func (m *MockUserTokenStore) ListUserServices(ctx context.Context, email string) ([]string, error) { - args := m.Called(ctx, email) - if args.Get(0) == nil { - return nil, args.Error(1) - } - return args.Get(0).([]string), args.Error(1) -} - -// TestRFC8707ResourceIndicators validates RFC 8707 resource indicator functionality -func TestRFC8707ResourceIndicators(t *testing.T) { - startMCPFront(t, "config/config.oauth-rfc8707-test.json", - "JWT_SECRET=test-jwt-secret-32-bytes-exactly!", - "ENCRYPTION_KEY=test-encryption-key-32-bytes-ok!", - "GOOGLE_CLIENT_ID=test-client-id-for-oauth", - "GOOGLE_CLIENT_SECRET=test-client-secret-for-oauth", - "MCP_FRONT_ENV=development", - ) - - waitForMCPFront(t) - - t.Run("BaseProtectedResourceMetadataReturns404", func(t *testing.T) { - // Base metadata endpoint should return 404, directing clients to per-service endpoints - resp, err := http.Get("http://localhost:8080/.well-known/oauth-protected-resource") - require.NoError(t, err) - defer resp.Body.Close() - - assert.Equal(t, 404, resp.StatusCode, "Base protected resource metadata endpoint should return 404") - - var errResp map[string]any - err = json.NewDecoder(resp.Body).Decode(&errResp) - require.NoError(t, err) - - assert.Contains(t, errResp["message"], "per-service", "Error message should direct to per-service endpoints") - }) - - t.Run("PerServiceProtectedResourceMetadataEndpoint", func(t *testing.T) { - // Per-service metadata endpoint should return service-specific resource URI - resp, err := http.Get("http://localhost:8080/.well-known/oauth-protected-resource/test-sse") - require.NoError(t, err) - defer resp.Body.Close() - - assert.Equal(t, 200, resp.StatusCode, "Per-service protected resource metadata endpoint should exist") - - var metadata map[string]any - err = json.NewDecoder(resp.Body).Decode(&metadata) - require.NoError(t, err) - - // Resource should be service-specific, not base URL - assert.Equal(t, "http://localhost:8080/test-sse", metadata["resource"], - "Resource should be service-specific URL") - - authzServers, ok := metadata["authorization_servers"].([]any) - require.True(t, ok, "Should have authorization_servers array") - require.NotEmpty(t, authzServers) - assert.Equal(t, "http://localhost:8080", authzServers[0], - "Authorization server should be base issuer") - }) - - t.Run("UnknownServiceReturns404", func(t *testing.T) { - resp, err := http.Get("http://localhost:8080/.well-known/oauth-protected-resource/nonexistent-service") - require.NoError(t, err) - defer resp.Body.Close() - - assert.Equal(t, 404, resp.StatusCode, "Unknown service should return 404") - }) - - t.Run("TokenWithResourceParameter", func(t *testing.T) { - clientID := registerTestClient(t) - - codeVerifier := "test-code-verifier-that-is-at-least-43-characters-long" - h := sha256.New() - h.Write([]byte(codeVerifier)) - codeChallenge := base64.RawURLEncoding.EncodeToString(h.Sum(nil)) - - authParams := url.Values{ - "response_type": {"code"}, - "client_id": {clientID}, - "redirect_uri": {"http://127.0.0.1:6274/oauth/callback"}, - "code_challenge": {codeChallenge}, - "code_challenge_method": {"S256"}, - "scope": {"openid email profile"}, - "state": {"test-state"}, - "resource": {"http://localhost:8080/test-sse"}, - } - - client := &http.Client{ - CheckRedirect: func(req *http.Request, via []*http.Request) error { - return http.ErrUseLastResponse - }, - } - - authResp, err := client.Get("http://localhost:8080/authorize?" + authParams.Encode()) - require.NoError(t, err) - defer authResp.Body.Close() - - assert.Contains(t, []int{302, 303}, authResp.StatusCode, "Should redirect to Google OAuth") - - location := authResp.Header.Get("Location") - googleResp, err := client.Get(location) - require.NoError(t, err) - defer googleResp.Body.Close() - - callbackLocation := googleResp.Header.Get("Location") - callbackResp, err := client.Get(callbackLocation) - require.NoError(t, err) - defer callbackResp.Body.Close() - - finalURL, err := url.Parse(callbackResp.Header.Get("Location")) - require.NoError(t, err) - authCode := finalURL.Query().Get("code") - require.NotEmpty(t, authCode, "Should have authorization code") - - tokenParams := url.Values{ - "grant_type": {"authorization_code"}, - "code": {authCode}, - "redirect_uri": {"http://127.0.0.1:6274/oauth/callback"}, - "client_id": {clientID}, - "code_verifier": {codeVerifier}, - } - - tokenResp, err := http.PostForm("http://localhost:8080/token", tokenParams) - require.NoError(t, err) - defer tokenResp.Body.Close() - - require.Equal(t, 200, tokenResp.StatusCode, "Token exchange should succeed") - - var tokenData map[string]any - err = json.NewDecoder(tokenResp.Body).Decode(&tokenData) - require.NoError(t, err) - - testSSEToken := tokenData["access_token"].(string) - require.NotEmpty(t, testSSEToken, "Should have access token") - - t.Logf("Got token with test-sse audience: %s", testSSEToken[:20]+"...") - - // Verify token works for test-sse (matching audience) - req, _ := http.NewRequest("GET", "http://localhost:8080/test-sse/sse", nil) - req.Header.Set("Authorization", "Bearer "+testSSEToken) - req.Header.Set("Accept", "text/event-stream") - - sseResp, err := client.Do(req) - require.NoError(t, err) - defer sseResp.Body.Close() - - assert.Equal(t, 200, sseResp.StatusCode, - "Token with test-sse audience should access /test-sse/sse") - - // Verify token does NOT work for test-streamable (wrong audience) - req, _ = http.NewRequest("GET", "http://localhost:8080/test-streamable/sse", nil) - req.Header.Set("Authorization", "Bearer "+testSSEToken) - req.Header.Set("Accept", "text/event-stream") - - streamableResp, err := client.Do(req) - require.NoError(t, err) - defer streamableResp.Body.Close() - - assert.Equal(t, 401, streamableResp.StatusCode, - "Token with test-sse audience should NOT access /test-streamable/sse") - - wwwAuth := streamableResp.Header.Get("WWW-Authenticate") - assert.Contains(t, wwwAuth, "Bearer resource_metadata=", - "401 response should include RFC 9728 WWW-Authenticate header") - // Per RFC 9728 Section 5.2, the metadata URI should be service-specific - assert.Contains(t, wwwAuth, "/.well-known/oauth-protected-resource/test-streamable", - "401 response should point to per-service metadata endpoint") - }) - - t.Run("401ResponseIncludesServiceSpecificMetadataURI", func(t *testing.T) { - // Request to a protected endpoint without token should get 401 - // with service-specific metadata URI in WWW-Authenticate header - client := &http.Client{ - CheckRedirect: func(req *http.Request, via []*http.Request) error { - return http.ErrUseLastResponse - }, - } - - req, _ := http.NewRequest("GET", "http://localhost:8080/test-sse/sse", nil) - req.Header.Set("Accept", "text/event-stream") - - resp, err := client.Do(req) - require.NoError(t, err) - defer resp.Body.Close() - - assert.Equal(t, 401, resp.StatusCode, "Request without token should return 401") - - wwwAuth := resp.Header.Get("WWW-Authenticate") - assert.Contains(t, wwwAuth, "Bearer resource_metadata=", - "401 response should include RFC 9728 WWW-Authenticate header") - assert.Contains(t, wwwAuth, "/.well-known/oauth-protected-resource/test-sse", - "401 response should point to test-sse specific metadata endpoint") - }) -} diff --git a/integration/oauth_user_tokens_test.go b/integration/oauth_user_tokens_test.go new file mode 100644 index 0000000..6612d9d --- /dev/null +++ b/integration/oauth_user_tokens_test.go @@ -0,0 +1,536 @@ +package integration + +import ( + "context" + "crypto/sha256" + "encoding/base64" + "encoding/json" + "fmt" + "io" + "net/http" + "net/http/cookiejar" + "net/url" + "regexp" + "strings" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/mock" + "github.com/stretchr/testify/require" +) + +// TestUserTokenFlow tests the user token management functionality with browser-based SSO +// This test expects the /my/* routes to work with Google SSO (session-based auth), +// not Bearer token auth. +func TestUserTokenFlow(t *testing.T) { + // Start OAuth server with user token configuration + mcpCmd := startOAuthServerWithTokenConfig(t) + defer stopServer(mcpCmd) + + if !waitForHealthCheck(30) { + t.Fatal("Server failed to start") + } + + // Create a client with cookie jar to simulate browser behavior + jar, _ := cookiejar.New(nil) + client := &http.Client{ + Jar: jar, + CheckRedirect: func(req *http.Request, via []*http.Request) error { + // Allow up to 10 redirects + if len(via) >= 10 { + return fmt.Errorf("too many redirects") + } + return nil + }, + } + + t.Run("UnauthenticatedRedirectsToSSO", func(t *testing.T) { + // Create a client that doesn't follow redirects to test the initial redirect + noRedirectClient := &http.Client{ + CheckRedirect: func(req *http.Request, via []*http.Request) error { + return http.ErrUseLastResponse + }, + } + + // Try to access /my/tokens without authentication + resp, err := noRedirectClient.Get("http://localhost:8080/my/tokens") + require.NoError(t, err) + defer resp.Body.Close() + + // Should get a redirect response + assert.Equal(t, http.StatusFound, resp.StatusCode, "Should get redirect status") + + // Check the redirect location + location := resp.Header.Get("Location") + assert.Contains(t, location, "localhost:9090/auth", "Should redirect to Google OAuth") + assert.Contains(t, location, "client_id=", "Should include client_id") + assert.Contains(t, location, "redirect_uri=", "Should include redirect_uri") + // Extract and validate the state parameter + parsedURL, err := url.Parse(location) + require.NoError(t, err) + stateParam := parsedURL.Query().Get("state") + require.NotEmpty(t, stateParam, "State parameter should be present") + + // State format: "browser:" prefix followed by signed token + // We verify structure but not internal format (that's implementation detail) + assert.True(t, strings.HasPrefix(stateParam, "browser:"), "State should start with browser:") + assert.Greater(t, len(stateParam), len("browser:"), "State should have content after prefix") + + // Verify state contains signature (has dot separator indicating signed data) + stateContent := strings.TrimPrefix(stateParam, "browser:") + assert.Contains(t, stateContent, ".", "Signed state should contain signature separator") + }) + + t.Run("AuthenticatedUserCanAccessTokens", func(t *testing.T) { + // The client with cookie jar will automatically follow the full SSO flow: + // 1. GET /my/tokens -> redirect to Google OAuth + // 2. Google OAuth redirects to /oauth/callback with code + // 3. Callback sets session cookie and redirects to /my/tokens + // 4. Client follows redirect with cookie and gets the page + + resp, err := client.Get("http://localhost:8080/my/tokens") + require.NoError(t, err) + defer resp.Body.Close() + + // After following all redirects, we should be at /my/tokens with 200 OK + assert.Equal(t, http.StatusOK, resp.StatusCode, "Should access /my/tokens after SSO") + finalURL := resp.Request.URL.String() + assert.Contains(t, finalURL, "/my/tokens", "Should end up at /my/tokens after SSO") + + // Read response body + body, err := io.ReadAll(resp.Body) + require.NoError(t, err) + bodyStr := string(body) + + // Should show both services without tokens + assert.Contains(t, bodyStr, "Notion", "Expected Notion service in response") + assert.Contains(t, bodyStr, "GitHub", "Expected GitHub service in response") + }) + + t.Run("SetTokenWithValidation", func(t *testing.T) { + // Assume we're already authenticated from previous test + // Get CSRF token first + resp, err := client.Get("http://localhost:8080/my/tokens") + require.NoError(t, err) + defer resp.Body.Close() + + body, err := io.ReadAll(resp.Body) + require.NoError(t, err) + + // Extract CSRF token from response + csrfToken := extractCSRFToken(t, string(body)) + + // Try to set invalid Notion token + form := url.Values{ + "service": []string{"notion"}, + "token": []string{"invalid-token"}, + "csrf_token": []string{csrfToken}, + } + + req, _ := http.NewRequest("POST", "http://localhost:8080/my/tokens/set", strings.NewReader(form.Encode())) + req.Header.Set("Content-Type", "application/x-www-form-urlencoded") + + // Use custom client that doesn't follow redirects for this test + noRedirectClient := &http.Client{ + Jar: jar, // Use same cookie jar + CheckRedirect: func(req *http.Request, via []*http.Request) error { + return http.ErrUseLastResponse + }, + } + + resp, err = noRedirectClient.Do(req) + require.NoError(t, err) + defer resp.Body.Close() + + // Should redirect with error + assert.Equal(t, http.StatusSeeOther, resp.StatusCode, "Expected redirect") + location := resp.Header.Get("Location") + assert.Contains(t, location, "error", "Expected error in redirect") + + // Get new CSRF token + resp, err = client.Get("http://localhost:8080/my/tokens") + require.NoError(t, err) + defer resp.Body.Close() + + body, err = io.ReadAll(resp.Body) + require.NoError(t, err) + csrfToken = extractCSRFToken(t, string(body)) + + // Set valid Notion token (regex expects exactly 43 chars after "secret_") + form = url.Values{ + "service": {"notion"}, + "token": {"secret_1234567890123456789012345678901234567890123"}, + "csrf_token": {csrfToken}, + } + + req, _ = http.NewRequest("POST", "http://localhost:8080/my/tokens/set", strings.NewReader(form.Encode())) + req.Header.Set("Content-Type", "application/x-www-form-urlencoded") + + resp, err = noRedirectClient.Do(req) + require.NoError(t, err) + defer resp.Body.Close() + + // Should redirect with success + assert.Equal(t, http.StatusSeeOther, resp.StatusCode, "Expected redirect") + location = resp.Header.Get("Location") + assert.Contains(t, location, "success", "Expected success in redirect") + }) +} + +// TestToolAdvertisementWithUserTokens tests that tools are advertised even without user tokens +// but fail gracefully when invoked without the required token, and succeed with the token +func TestToolAdvertisementWithUserTokens(t *testing.T) { + cfg := buildTestConfig("http://localhost:8080", "mcp-front-oauth-usertoken-test", + testOAuthConfigFromEnv(), + map[string]any{"postgres": testPostgresServer(withUserToken())}, + ) + startMCPFront(t, writeTestConfig(t, cfg), + "JWT_SECRET=demo-jwt-secret-32-bytes-exactly!", + "ENCRYPTION_KEY=test-encryption-key-32-bytes-ok!", + "GOOGLE_CLIENT_ID=test-client-id-oauth", + "GOOGLE_CLIENT_SECRET=test-client-secret-oauth", + "MCP_FRONT_ENV=development", + "LOG_LEVEL=debug", + ) + + if !waitForHealthCheck(30) { + t.Fatal("Server failed to start") + } + + // Complete OAuth flow to get a valid access token + accessToken := getOAuthAccessToken(t, "http://localhost:8080/postgres") + + t.Run("ToolsAdvertisedWithoutToken", func(t *testing.T) { + // Create MCP client with OAuth token + mcpClient := NewMCPSSEClient("http://localhost:8080") + mcpClient.SetAuthToken(accessToken) + + // Connect to postgres SSE endpoint + err := mcpClient.Connect() + require.NoError(t, err, "Should connect to postgres SSE endpoint without user token") + defer mcpClient.Close() + + // Request tools list + toolsResp, err := mcpClient.SendMCPRequest("tools/list", map[string]any{}) + require.NoError(t, err, "Should list tools without user token") + + // Verify we got tools + resultMap, ok := toolsResp["result"].(map[string]any) + require.True(t, ok, "Expected result in tools response") + + tools, ok := resultMap["tools"].([]any) + require.True(t, ok, "Expected tools array in result") + assert.NotEmpty(t, tools, "Should have tools advertised") + + // Check for common postgres tools + var toolNames []string + for _, tool := range tools { + if toolMap, ok := tool.(map[string]any); ok { + if name, ok := toolMap["name"].(string); ok { + toolNames = append(toolNames, name) + } + } + } + + assert.Contains(t, toolNames, "execute_sql", "Should have execute_sql tool") + t.Logf("Successfully advertised tools without user token: %v", toolNames) + }) + + t.Run("ToolInvocationFailsWithoutToken", func(t *testing.T) { + // Create MCP client with OAuth token + mcpClient := NewMCPSSEClient("http://localhost:8080") + mcpClient.SetAuthToken(accessToken) + + // Connect to postgres SSE endpoint + err := mcpClient.Connect() + require.NoError(t, err) + defer mcpClient.Close() + + // Try to invoke a tool without user token + queryParams := map[string]any{ + "name": "execute_sql", + "arguments": map[string]any{ + "sql": "SELECT 1", + }, + } + + result, err := mcpClient.SendMCPRequest("tools/call", queryParams) + require.NoError(t, err, "Should get response even without token") + + // MCP protocol returns errors as successful responses with error content + require.NotNil(t, result["result"], "Should have result in response") + + resultMap := result["result"].(map[string]any) + content := resultMap["content"].([]any) + require.NotEmpty(t, content, "Should have content in result") + + contentItem := content[0].(map[string]any) + errorJSON := contentItem["text"].(string) + + // Parse the error JSON + var errorData map[string]any + err = json.Unmarshal([]byte(errorJSON), &errorData) + require.NoError(t, err, "Error should be valid JSON") + + // Verify error structure + errorInfo := errorData["error"].(map[string]any) + assert.Equal(t, "token_required", errorInfo["code"], "Error code should be token_required") + + errorMessage := errorInfo["message"].(string) + assert.Contains(t, errorMessage, "token required", "Error should mention token required") + assert.Contains(t, errorMessage, "/my/tokens", "Error should mention token setup URL") + assert.Contains(t, errorMessage, "Test Service", "Error should mention service name") + + // Verify error data + errData := errorInfo["data"].(map[string]any) + assert.Equal(t, "postgres", errData["service"], "Should identify the service") + assert.Contains(t, errData["tokenSetupUrl"].(string), "/my/tokens", "Should include token setup URL") + + // Verify instructions + instructions := errData["instructions"].(map[string]any) + assert.Contains(t, instructions["ai"].(string), "CRITICAL", "Should have AI instructions") + assert.Contains(t, instructions["human"].(string), "token required", "Should have human instructions") + }) + + t.Run("ToolInvocationSucceedsWithUserToken", func(t *testing.T) { + // Step 1: GET /my/tokens to extract CSRF token + jar, err := cookiejar.New(nil) + require.NoError(t, err) + client := &http.Client{ + Jar: jar, // Need cookie jar for CSRF + CheckRedirect: func(req *http.Request, via []*http.Request) error { + return http.ErrUseLastResponse // Don't follow redirects + }, + } + + req, err := http.NewRequest("GET", "http://localhost:8080/my/tokens", nil) + require.NoError(t, err) + req.Header.Set("Authorization", "Bearer "+accessToken) + + resp, err := client.Do(req) + require.NoError(t, err) + defer resp.Body.Close() + + // Check if we got the page or a redirect + if resp.StatusCode == 302 || resp.StatusCode == 303 { + // Follow the redirect + location := resp.Header.Get("Location") + t.Logf("Got redirect to: %s", location) + + // Allow redirects for this request + client = &http.Client{ + Jar: jar, + } + + req, err = http.NewRequest("GET", "http://localhost:8080/my/tokens", nil) + require.NoError(t, err) + req.Header.Set("Authorization", "Bearer "+accessToken) + + resp, err = client.Do(req) + require.NoError(t, err) + defer resp.Body.Close() + } + + require.Equal(t, 200, resp.StatusCode, "Should be able to access token page") + + // Extract CSRF token from HTML + body, err := io.ReadAll(resp.Body) + require.NoError(t, err) + + // Look for the CSRF token in the form + csrfRegex := regexp.MustCompile(`name="csrf_token" value="([^"]+)"`) + matches := csrfRegex.FindSubmatch(body) + require.Len(t, matches, 2, "Should find CSRF token in form") + csrfToken := string(matches[1]) + + // Step 2: POST to /my/tokens/set with test token + formData := url.Values{ + "service": {"postgres"}, + "token": {"test-user-token-12345"}, + "csrf_token": {csrfToken}, + } + + req, err = http.NewRequest("POST", "http://localhost:8080/my/tokens/set", strings.NewReader(formData.Encode())) + require.NoError(t, err) + req.Header.Set("Authorization", "Bearer "+accessToken) + req.Header.Set("Content-Type", "application/x-www-form-urlencoded") + + resp, err = client.Do(req) + require.NoError(t, err) + defer resp.Body.Close() + + // Check the response - it might be 200 if following redirects + switch resp.StatusCode { + case 200: + // That's fine, it means the token was set and we got the page back + t.Log("Token set successfully, got page response") + case 302, 303: + // Also fine, redirect means success + t.Log("Token set successfully, got redirect") + default: + body, _ := io.ReadAll(resp.Body) + t.Fatalf("Unexpected response setting token: status=%d, body=%s", resp.StatusCode, string(body)) + } + + // Step 3: Now test tool invocation with the token + mcpClient := NewMCPSSEClient("http://localhost:8080") + mcpClient.SetAuthToken(accessToken) + + err = mcpClient.Connect() + require.NoError(t, err, "Should connect to postgres SSE endpoint") + defer mcpClient.Close() + + // Call the query tool with a simple query + queryParams := map[string]any{ + "name": "execute_sql", + "arguments": map[string]any{ + "sql": "SELECT 1 as test", + }, + } + + result, err := mcpClient.SendMCPRequest("tools/call", queryParams) + require.NoError(t, err, "Should successfully call tool with token") + + // Verify we got a successful result, not an error + require.NotNil(t, result["result"], "Should have result in response") + + resultMap := result["result"].(map[string]any) + content := resultMap["content"].([]any) + require.NotEmpty(t, content, "Should have content in result") + + contentItem := content[0].(map[string]any) + resultText := contentItem["text"].(string) + + // The result should contain actual query results, not an error + assert.NotContains(t, resultText, "token_required", "Should not have token error") + assert.NotContains(t, resultText, "Token Required", "Should not have token error message") + + // Postgres query result should contain our test value + assert.Contains(t, resultText, "1", "Should contain query result") + + t.Log("Successfully invoked tool with user token") + }) +} + +// getOAuthAccessTokenForIDP completes the OAuth flow and returns a valid access token. +// expectedIDPHost is the host:port of the expected IDP redirect target (e.g., "localhost:9090" for Google). +func getOAuthAccessTokenForIDP(t *testing.T, resource, expectedIDPHost string) string { + t.Helper() + + clientID := registerTestClient(t) + + codeVerifier := "test-code-verifier-that-is-at-least-43-characters-long" + h := sha256.New() + h.Write([]byte(codeVerifier)) + codeChallenge := base64.RawURLEncoding.EncodeToString(h.Sum(nil)) + + authParams := url.Values{ + "response_type": {"code"}, + "client_id": {clientID}, + "redirect_uri": {"http://127.0.0.1:6274/oauth/callback"}, + "code_challenge": {codeChallenge}, + "code_challenge_method": {"S256"}, + "scope": {"openid email profile"}, + "state": {"test-state"}, + "resource": {resource}, + } + + client := &http.Client{ + CheckRedirect: func(req *http.Request, via []*http.Request) error { + return http.ErrUseLastResponse + }, + } + + // Step 1: Authorization request + authResp, err := client.Get("http://localhost:8080/authorize?" + authParams.Encode()) + require.NoError(t, err) + defer authResp.Body.Close() + + require.Contains(t, []int{302, 303}, authResp.StatusCode, "Should redirect to IDP") + location := authResp.Header.Get("Location") + require.Contains(t, location, expectedIDPHost, "Should redirect to expected IDP") + + // Step 2: Follow redirect to IDP (which immediately redirects back) + idpResp, err := client.Get(location) + require.NoError(t, err) + defer idpResp.Body.Close() + + require.Contains(t, []int{302, 303}, idpResp.StatusCode, "IDP should redirect back") + callbackLocation := idpResp.Header.Get("Location") + require.Contains(t, callbackLocation, "/oauth/callback", "Should redirect to callback") + + // Step 3: Follow callback redirect + callbackResp, err := client.Get(callbackLocation) + require.NoError(t, err) + defer callbackResp.Body.Close() + + require.Contains(t, []int{302, 303}, callbackResp.StatusCode, "Callback should redirect with code") + finalLocation := callbackResp.Header.Get("Location") + + finalURL, err := url.Parse(finalLocation) + require.NoError(t, err) + authCode := finalURL.Query().Get("code") + require.NotEmpty(t, authCode, "Should have authorization code") + + // Step 4: Exchange code for token + tokenParams := url.Values{ + "grant_type": {"authorization_code"}, + "code": {authCode}, + "redirect_uri": {"http://127.0.0.1:6274/oauth/callback"}, + "client_id": {clientID}, + "code_verifier": {codeVerifier}, + } + + tokenResp, err := http.PostForm("http://localhost:8080/token", tokenParams) + require.NoError(t, err) + defer tokenResp.Body.Close() + + if tokenResp.StatusCode != 200 { + body, _ := io.ReadAll(tokenResp.Body) + t.Logf("Token exchange failed with status %d: %s", tokenResp.StatusCode, string(body)) + } + + require.Equal(t, 200, tokenResp.StatusCode, "Token exchange should succeed") + + var tokenData map[string]any + err = json.NewDecoder(tokenResp.Body).Decode(&tokenData) + require.NoError(t, err) + + accessToken := tokenData["access_token"].(string) + require.NotEmpty(t, accessToken, "Should have access token") + + return accessToken +} + +// getOAuthAccessToken completes the OAuth flow using the Google IDP and returns a valid access token. +func getOAuthAccessToken(t *testing.T, resource string) string { + return getOAuthAccessTokenForIDP(t, resource, "localhost:9090") +} + +// MockUserTokenStore mocks the UserTokenStore interface for testing +type MockUserTokenStore struct { + mock.Mock +} + +func (m *MockUserTokenStore) GetUserToken(ctx context.Context, email, service string) (string, error) { + args := m.Called(ctx, email, service) + return args.String(0), args.Error(1) +} + +func (m *MockUserTokenStore) SetUserToken(ctx context.Context, email, service, token string) error { + args := m.Called(ctx, email, service, token) + return args.Error(0) +} + +func (m *MockUserTokenStore) DeleteUserToken(ctx context.Context, email, service string) error { + args := m.Called(ctx, email, service) + return args.Error(0) +} + +func (m *MockUserTokenStore) ListUserServices(ctx context.Context, email string) ([]string, error) { + args := m.Called(ctx, email) + if args.Get(0) == nil { + return nil, args.Error(1) + } + return args.Get(0).([]string), args.Error(1) +} diff --git a/integration/streamable_client.go b/integration/streamable_client.go deleted file mode 100644 index 76ab1b7..0000000 --- a/integration/streamable_client.go +++ /dev/null @@ -1 +0,0 @@ -package integration diff --git a/integration/test_clients.go b/integration/test_clients.go new file mode 100644 index 0000000..a5cb357 --- /dev/null +++ b/integration/test_clients.go @@ -0,0 +1,442 @@ +package integration + +import ( + "bufio" + "bytes" + "encoding/json" + "fmt" + "io" + "net/http" + "net/url" + "strings" + "sync" + "time" +) + +// MCPSSEClient simulates an MCP client for testing +type MCPSSEClient struct { + baseURL string + token string + sseConn io.ReadCloser + messageEndpoint string + sseScanner *bufio.Scanner + sessionID string +} + +// NewMCPSSEClient creates a new MCP client for testing +func NewMCPSSEClient(baseURL string) *MCPSSEClient { + return &MCPSSEClient{ + baseURL: baseURL, + } +} + +// Authenticate sets up authentication for the client +func (c *MCPSSEClient) Authenticate() error { + c.token = "test-token" + return nil +} + +// SetAuthToken sets a specific auth token for the client +func (c *MCPSSEClient) SetAuthToken(token string) { + c.token = token +} + +// Connect establishes an SSE connection and retrieves the message endpoint +func (c *MCPSSEClient) Connect() error { + return c.ConnectToServer("postgres") +} + +// ConnectToServer establishes an SSE connection to a specific server +func (c *MCPSSEClient) ConnectToServer(serverName string) error { + // Close any existing connection + if c.sseConn != nil { + c.sseConn.Close() + c.sseConn = nil + c.messageEndpoint = "" + } + + sseURL := c.baseURL + "/" + serverName + "/sse" + tracef("ConnectToServer: requesting %s", sseURL) + + req, err := http.NewRequest("GET", sseURL, nil) + if err != nil { + return fmt.Errorf("failed to create SSE request: %v", err) + } + + req.Header.Set("Accept", "text/event-stream") + req.Header.Set("Authorization", "Bearer "+c.token) + req.Header.Set("Cache-Control", "no-cache") + tracef("ConnectToServer: headers set, making request") + + // Don't use a timeout on the client for SSE + client := &http.Client{} + resp, err := client.Do(req) + if err != nil { + return fmt.Errorf("SSE connection failed: %v", err) + } + + tracef("ConnectToServer: got response status %d", resp.StatusCode) + if resp.StatusCode != 200 { + body, _ := io.ReadAll(resp.Body) + resp.Body.Close() + return fmt.Errorf("SSE connection returned %d: %s", resp.StatusCode, string(body)) + } + + // Store the connection + c.sseConn = resp.Body + c.sseScanner = bufio.NewScanner(resp.Body) + + // Read initial SSE messages to get the endpoint + // For inline servers, we don't get a message endpoint - we use the server path directly + gotEndpointMessage := false + for c.sseScanner.Scan() { + line := c.sseScanner.Text() + tracef("ConnectToServer: SSE line: %s", line) + + // Look for data lines + if after, ok := strings.CutPrefix(line, "data: "); ok { + data := after + + // Check if it's an endpoint message (for inline servers) + if strings.Contains(data, `"type":"endpoint"`) { + gotEndpointMessage = true + // For inline servers, construct the message endpoint + c.messageEndpoint = c.baseURL + "/" + serverName + "/message" + tracef("ConnectToServer: inline server detected, using endpoint: %s", c.messageEndpoint) + break + } + + // Check if it's a message endpoint URL (for stdio servers) + if strings.Contains(data, "http://") || strings.Contains(data, "https://") { + c.messageEndpoint = data + + // Extract session ID from endpoint URL + if u, err := url.Parse(data); err == nil { + c.sessionID = u.Query().Get("sessionId") + } + + tracef("ConnectToServer: found endpoint: %s", c.messageEndpoint) + break + } + } + } + + if c.messageEndpoint == "" && !gotEndpointMessage { + c.sseConn.Close() + c.sseConn = nil + return fmt.Errorf("no message endpoint received") + } + + tracef("Connect: successfully connected to MCP server") + return nil +} + +// ValidateBackendConnectivity checks if we can connect to the MCP server +func (c *MCPSSEClient) ValidateBackendConnectivity() error { + return c.Connect() +} + +// Close closes the SSE connection +func (c *MCPSSEClient) Close() { + if c.sseConn != nil { + c.sseConn.Close() + c.sseConn = nil + c.messageEndpoint = "" + c.sseScanner = nil + } +} + +// SendMCPRequest sends an MCP JSON-RPC request and returns the response +func (c *MCPSSEClient) SendMCPRequest(method string, params any) (map[string]any, error) { + // Ensure we have a connection + if c.messageEndpoint == "" { + if err := c.Connect(); err != nil { + return nil, fmt.Errorf("failed to connect: %v", err) + } + } + + // Send MCP request to the message endpoint + request := map[string]any{ + "jsonrpc": "2.0", + "id": 1, + "method": method, + "params": params, + } + + reqBody, err := json.Marshal(request) + if err != nil { + return nil, err + } + + msgReq, err := http.NewRequest("POST", c.messageEndpoint, bytes.NewBuffer(reqBody)) + if err != nil { + return nil, err + } + + msgReq.Header.Set("Content-Type", "application/json") + msgReq.Header.Set("Authorization", "Bearer "+c.token) + + client := &http.Client{Timeout: 30 * time.Second} + msgResp, err := client.Do(msgReq) + if err != nil { + return nil, err + } + defer msgResp.Body.Close() + + respBody, err := io.ReadAll(msgResp.Body) + if err != nil { + return nil, err + } + + if msgResp.StatusCode != 200 && msgResp.StatusCode != 202 { + return nil, fmt.Errorf("MCP request failed: %d - %s", msgResp.StatusCode, string(respBody)) + } + + // Handle 202 and empty responses - read response from SSE stream + if msgResp.StatusCode == 202 || len(respBody) == 0 { + // Read response from SSE stream + for c.sseScanner.Scan() { + line := c.sseScanner.Text() + + if after, ok := strings.CutPrefix(line, "data: "); ok { + data := after + // Try to parse as JSON + var msg map[string]any + if err := json.Unmarshal([]byte(data), &msg); err == nil { + // Check if this is our response (matching ID) + if id, ok := msg["id"]; ok && id == float64(1) { + return msg, nil + } + } + } + } + + if err := c.sseScanner.Err(); err != nil { + return nil, fmt.Errorf("SSE scanner error: %v", err) + } + + return nil, fmt.Errorf("no response received from SSE stream") + } + + var result map[string]any + if err := json.Unmarshal(respBody, &result); err != nil { + return nil, fmt.Errorf("failed to parse response: %v - %s", err, string(respBody)) + } + + return result, nil +} + +// MCPStreamableClient is a test client for HTTP-Streamable MCP servers +type MCPStreamableClient struct { + baseURL string + serverName string + token string + httpClient *http.Client + + // For GET SSE streaming + sseConn io.ReadCloser + sseScanner *bufio.Scanner + sseCancel chan struct{} + + mu sync.Mutex +} + +// NewMCPStreamableClient creates a new streamable-http test client +func NewMCPStreamableClient(baseURL string) *MCPStreamableClient { + return &MCPStreamableClient{ + baseURL: baseURL, + httpClient: &http.Client{ + Timeout: 30 * time.Second, + }, + } +} + +// SetAuthToken sets the authentication token +func (c *MCPStreamableClient) SetAuthToken(token string) { + c.token = token +} + +// ConnectToServer establishes connection to a streamable-http server +func (c *MCPStreamableClient) ConnectToServer(serverName string) error { + c.mu.Lock() + defer c.mu.Unlock() + + // Close any existing connection + c.close() + + c.serverName = serverName + + // For streamable-http, we can optionally open a GET SSE stream for server-initiated messages + // But it's not required for basic request/response + return c.openSSEStream() +} + +// openSSEStream opens a GET SSE connection for receiving server-initiated messages +func (c *MCPStreamableClient) openSSEStream() error { + url := c.baseURL + "/" + c.serverName + "/sse" + + req, err := http.NewRequest("GET", url, nil) + if err != nil { + return fmt.Errorf("failed to create GET request: %v", err) + } + + req.Header.Set("Accept", "text/event-stream") + req.Header.Set("Authorization", "Bearer "+c.token) + req.Header.Set("Cache-Control", "no-cache") + + // Use a client without timeout for SSE + sseClient := &http.Client{} + resp, err := sseClient.Do(req) + if err != nil { + return fmt.Errorf("SSE connection failed: %v", err) + } + + if resp.StatusCode != 200 { + body, _ := io.ReadAll(resp.Body) + resp.Body.Close() + return fmt.Errorf("SSE connection returned %d: %s", resp.StatusCode, string(body)) + } + + c.sseConn = resp.Body + c.sseScanner = bufio.NewScanner(resp.Body) + c.sseCancel = make(chan struct{}) + + // Start reading SSE messages in background + go c.readSSEMessages() + + return nil +} + +// readSSEMessages reads server-initiated messages from the SSE stream +func (c *MCPStreamableClient) readSSEMessages() { + for { + select { + case <-c.sseCancel: + return + default: + if c.sseScanner.Scan() { + line := c.sseScanner.Text() + if after, ok := strings.CutPrefix(line, "data: "); ok { + data := after + // In a real implementation, we'd process server-initiated messages here + tracef("StreamableClient: received SSE message: %s", data) + } + } else { + // Scanner stopped - connection closed or error + return + } + } + } +} + +// SendMCPRequest sends a JSON-RPC request via POST +func (c *MCPStreamableClient) SendMCPRequest(method string, params any) (map[string]any, error) { + c.mu.Lock() + serverName := c.serverName + c.mu.Unlock() + + if serverName == "" { + return nil, fmt.Errorf("not connected to any server") + } + + // For streamable-http, we POST to the server endpoint + url := c.baseURL + "/" + serverName + "/sse" + + // Construct JSON-RPC request + request := map[string]any{ + "jsonrpc": "2.0", + "id": 1, + "method": method, + "params": params, + } + + body, err := json.Marshal(request) + if err != nil { + return nil, fmt.Errorf("failed to marshal request: %v", err) + } + + req, err := http.NewRequest("POST", url, bytes.NewReader(body)) + if err != nil { + return nil, fmt.Errorf("failed to create POST request: %v", err) + } + + req.Header.Set("Content-Type", "application/json") + req.Header.Set("Authorization", "Bearer "+c.token) + // Accept both JSON and SSE responses + req.Header.Set("Accept", "application/json, text/event-stream") + + resp, err := c.httpClient.Do(req) + if err != nil { + return nil, fmt.Errorf("request failed: %v", err) + } + defer resp.Body.Close() + + // Check content type to determine response format + contentType := resp.Header.Get("Content-Type") + + if strings.HasPrefix(contentType, "text/event-stream") { + // Handle SSE response + return c.handleSSEResponse(resp.Body) + } else { + // Handle JSON response + return c.handleJSONResponse(resp.Body) + } +} + +// handleJSONResponse processes a regular JSON response +func (c *MCPStreamableClient) handleJSONResponse(body io.Reader) (map[string]any, error) { + var response map[string]any + if err := json.NewDecoder(body).Decode(&response); err != nil { + return nil, fmt.Errorf("failed to decode JSON response: %v", err) + } + return response, nil +} + +// handleSSEResponse processes an SSE stream response from a POST +func (c *MCPStreamableClient) handleSSEResponse(body io.Reader) (map[string]any, error) { + scanner := bufio.NewScanner(body) + var lastResponse map[string]any + + for scanner.Scan() { + line := scanner.Text() + if after, ok := strings.CutPrefix(line, "data: "); ok { + data := after + var msg map[string]any + if err := json.Unmarshal([]byte(data), &msg); err == nil { + // Keep the last response with an ID (not a notification) + if _, hasID := msg["id"]; hasID { + lastResponse = msg + } + } + } + } + + if lastResponse == nil { + return nil, fmt.Errorf("no response received in SSE stream") + } + + return lastResponse, nil +} + +// Close closes all connections +func (c *MCPStreamableClient) Close() { + c.mu.Lock() + defer c.mu.Unlock() + c.close() +} + +// close is the internal close method (must be called with lock held) +func (c *MCPStreamableClient) close() { + if c.sseCancel != nil { + close(c.sseCancel) + c.sseCancel = nil + } + + if c.sseConn != nil { + c.sseConn.Close() + c.sseConn = nil + c.sseScanner = nil + } + + c.serverName = "" +} diff --git a/integration/test_fakes.go b/integration/test_fakes.go new file mode 100644 index 0000000..019b8c0 --- /dev/null +++ b/integration/test_fakes.go @@ -0,0 +1,355 @@ +package integration + +import ( + "context" + "encoding/json" + "fmt" + "net/http" + "time" +) + +// FakeGCPServer provides a fake GCP OAuth server for testing +type FakeGCPServer struct { + server *http.Server + port string +} + +// NewFakeGCPServer creates a new fake GCP server +func NewFakeGCPServer(port string) *FakeGCPServer { + mux := http.NewServeMux() + + mux.HandleFunc("/auth", func(w http.ResponseWriter, r *http.Request) { + redirectURI := r.URL.Query().Get("redirect_uri") + state := r.URL.Query().Get("state") + http.Redirect(w, r, fmt.Sprintf("%s?code=test-auth-code&state=%s", redirectURI, state), http.StatusFound) + }) + + mux.HandleFunc("/token", func(w http.ResponseWriter, r *http.Request) { + // Parse the form data + if err := r.ParseForm(); err != nil { + http.Error(w, "Invalid request", http.StatusBadRequest) + return + } + + // Check the authorization code + code := r.FormValue("code") + if code != "test-auth-code" { + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(http.StatusBadRequest) + _ = json.NewEncoder(w).Encode(map[string]any{ + "error": "invalid_grant", + "error_description": "Invalid authorization code", + }) + return + } + + w.Header().Set("Content-Type", "application/json") + _ = json.NewEncoder(w).Encode(map[string]any{ + "access_token": "test-access-token", + "token_type": "Bearer", + "expires_in": 3600, + }) + }) + + mux.HandleFunc("/userinfo", func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "application/json") + _ = json.NewEncoder(w).Encode(map[string]any{ + "email": "test@test.com", + "hd": "test.com", + }) + }) + + server := &http.Server{ + Addr: ":" + port, + Handler: mux, + } + + return &FakeGCPServer{ + server: server, + port: port, + } +} + +// Start starts the fake GCP server +func (m *FakeGCPServer) Start() error { + go func() { + if err := m.server.ListenAndServe(); err != nil && err != http.ErrServerClosed { + panic(err) + } + }() + + time.Sleep(100 * time.Millisecond) + return nil +} + +// Stop stops the fake GCP server +func (m *FakeGCPServer) Stop() error { + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel() + return m.server.Shutdown(ctx) +} + +// FakeServiceOAuthServer provides a fake OAuth server for external services (like Linear, GitHub) +type FakeServiceOAuthServer struct { + server *http.Server + port string +} + +// NewFakeServiceOAuthServer creates a new fake service OAuth server +func NewFakeServiceOAuthServer(port string) *FakeServiceOAuthServer { + mux := http.NewServeMux() + + mux.HandleFunc("/oauth/authorize", func(w http.ResponseWriter, r *http.Request) { + redirectURI := r.URL.Query().Get("redirect_uri") + state := r.URL.Query().Get("state") + http.Redirect(w, r, fmt.Sprintf("%s?code=service-auth-code&state=%s", redirectURI, state), http.StatusFound) + }) + + mux.HandleFunc("/oauth/token", func(w http.ResponseWriter, r *http.Request) { + if err := r.ParseForm(); err != nil { + http.Error(w, "Invalid request", http.StatusBadRequest) + return + } + + code := r.FormValue("code") + if code != "service-auth-code" { + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(http.StatusBadRequest) + _ = json.NewEncoder(w).Encode(map[string]any{ + "error": "invalid_grant", + "error_description": "Invalid authorization code", + }) + return + } + + w.Header().Set("Content-Type", "application/json") + _ = json.NewEncoder(w).Encode(map[string]any{ + "access_token": "service-oauth-access-token", + "refresh_token": "service-oauth-refresh-token", + "token_type": "Bearer", + "expires_in": 3600, + }) + }) + + server := &http.Server{ + Addr: ":" + port, + Handler: mux, + } + + return &FakeServiceOAuthServer{ + server: server, + port: port, + } +} + +// Start starts the fake service OAuth server +func (s *FakeServiceOAuthServer) Start() error { + go func() { + if err := s.server.ListenAndServe(); err != nil && err != http.ErrServerClosed { + panic(err) + } + }() + + time.Sleep(100 * time.Millisecond) + return nil +} + +// Stop stops the fake service OAuth server +func (s *FakeServiceOAuthServer) Stop() error { + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel() + return s.server.Shutdown(ctx) +} + +// FakeGitHubServer simulates GitHub's OAuth and API endpoints for integration testing. +type FakeGitHubServer struct { + server *http.Server + port string +} + +// NewFakeGitHubServer creates a new fake GitHub server. +// orgs controls what organizations the /user/orgs endpoint returns. +func NewFakeGitHubServer(port string, orgs []string) *FakeGitHubServer { + mux := http.NewServeMux() + + mux.HandleFunc("/login/oauth/authorize", func(w http.ResponseWriter, r *http.Request) { + redirectURI := r.URL.Query().Get("redirect_uri") + state := r.URL.Query().Get("state") + http.Redirect(w, r, fmt.Sprintf("%s?code=github-test-code&state=%s", redirectURI, state), http.StatusFound) + }) + + mux.HandleFunc("/login/oauth/access_token", func(w http.ResponseWriter, r *http.Request) { + if err := r.ParseForm(); err != nil { + http.Error(w, "Invalid request", http.StatusBadRequest) + return + } + + code := r.FormValue("code") + if code != "github-test-code" { + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(http.StatusBadRequest) + _ = json.NewEncoder(w).Encode(map[string]any{ + "error": "bad_verification_code", + "error_description": "Invalid authorization code", + }) + return + } + + w.Header().Set("Content-Type", "application/json") + _ = json.NewEncoder(w).Encode(map[string]any{ + "access_token": "github-test-token", + "token_type": "bearer", + "scope": "user:email,read:org", + }) + }) + + mux.HandleFunc("/user", func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "application/json") + _ = json.NewEncoder(w).Encode(map[string]any{ + "id": 12345, + "login": "testuser", + "email": "test@test.com", + "name": "Test User", + "avatar_url": "https://github.com/avatar.jpg", + }) + }) + + mux.HandleFunc("/user/emails", func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "application/json") + _ = json.NewEncoder(w).Encode([]map[string]any{ + {"email": "test@test.com", "primary": true, "verified": true}, + }) + }) + + mux.HandleFunc("/user/orgs", func(w http.ResponseWriter, r *http.Request) { + orgList := make([]map[string]any, len(orgs)) + for i, org := range orgs { + orgList[i] = map[string]any{"login": org} + } + w.Header().Set("Content-Type", "application/json") + _ = json.NewEncoder(w).Encode(orgList) + }) + + server := &http.Server{ + Addr: ":" + port, + Handler: mux, + } + + return &FakeGitHubServer{ + server: server, + port: port, + } +} + +// Start starts the fake GitHub server +func (s *FakeGitHubServer) Start() error { + go func() { + if err := s.server.ListenAndServe(); err != nil && err != http.ErrServerClosed { + panic(err) + } + }() + + time.Sleep(100 * time.Millisecond) + return nil +} + +// Stop stops the fake GitHub server +func (s *FakeGitHubServer) Stop() error { + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel() + return s.server.Shutdown(ctx) +} + +// FakeOIDCServer simulates a generic OIDC provider for integration testing. +// Used for both generic OIDC and Azure tests (Azure is OIDC-compliant). +type FakeOIDCServer struct { + server *http.Server + port string +} + +// NewFakeOIDCServer creates a new fake OIDC server. +func NewFakeOIDCServer(port string) *FakeOIDCServer { + mux := http.NewServeMux() + + baseURL := "http://localhost:" + port + + mux.HandleFunc("/.well-known/openid-configuration", func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "application/json") + _ = json.NewEncoder(w).Encode(map[string]any{ + "issuer": baseURL, + "authorization_endpoint": baseURL + "/authorize", + "token_endpoint": baseURL + "/token", + "userinfo_endpoint": baseURL + "/userinfo", + }) + }) + + mux.HandleFunc("/authorize", func(w http.ResponseWriter, r *http.Request) { + redirectURI := r.URL.Query().Get("redirect_uri") + state := r.URL.Query().Get("state") + http.Redirect(w, r, fmt.Sprintf("%s?code=oidc-test-code&state=%s", redirectURI, state), http.StatusFound) + }) + + mux.HandleFunc("/token", func(w http.ResponseWriter, r *http.Request) { + if err := r.ParseForm(); err != nil { + http.Error(w, "Invalid request", http.StatusBadRequest) + return + } + + code := r.FormValue("code") + if code != "oidc-test-code" { + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(http.StatusBadRequest) + _ = json.NewEncoder(w).Encode(map[string]any{ + "error": "invalid_grant", + "error_description": "Invalid authorization code", + }) + return + } + + w.Header().Set("Content-Type", "application/json") + _ = json.NewEncoder(w).Encode(map[string]any{ + "access_token": "oidc-test-token", + "token_type": "Bearer", + "expires_in": 3600, + }) + }) + + mux.HandleFunc("/userinfo", func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "application/json") + _ = json.NewEncoder(w).Encode(map[string]any{ + "sub": "oidc-12345", + "email": "test@oidc-test.com", + "email_verified": true, + "name": "OIDC User", + }) + }) + + server := &http.Server{ + Addr: ":" + port, + Handler: mux, + } + + return &FakeOIDCServer{ + server: server, + port: port, + } +} + +// Start starts the fake OIDC server +func (s *FakeOIDCServer) Start() error { + go func() { + if err := s.server.ListenAndServe(); err != nil && err != http.ErrServerClosed { + panic(err) + } + }() + + time.Sleep(100 * time.Millisecond) + return nil +} + +// Stop stops the fake OIDC server +func (s *FakeOIDCServer) Stop() error { + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel() + return s.server.Shutdown(ctx) +} diff --git a/integration/test_helpers.go b/integration/test_helpers.go new file mode 100644 index 0000000..1b8715f --- /dev/null +++ b/integration/test_helpers.go @@ -0,0 +1,463 @@ +package integration + +import ( + "encoding/json" + "fmt" + "net/http" + "os" + "os/exec" + "slices" + "strings" + "syscall" + "testing" + "time" +) + +// ToolboxImage is the Docker image for the MCP Toolbox for Databases. +// Used as the MCP server backing integration tests. All test configs +// that reference a postgres MCP server should use this image. +const ToolboxImage = "us-central1-docker.pkg.dev/database-toolbox/toolbox/toolbox:latest" + +// testPostgresDockerArgs returns the Docker args for running the toolbox +// as a stdio MCP server against the test postgres database. +func testPostgresDockerArgs() []string { + return []string{ + "run", "--rm", "-i", "--network", "host", + "-e", "POSTGRES_HOST=localhost", + "-e", "POSTGRES_PORT=15432", + "-e", "POSTGRES_DATABASE=testdb", + "-e", "POSTGRES_USER=testuser", + "-e", "POSTGRES_PASSWORD=testpass", + ToolboxImage, + "--stdio", "--prebuilt", "postgres", + } +} + +// testPostgresServer returns an MCP server config for the test postgres database. +// Options can customize auth, logging, etc. +func testPostgresServer(opts ...serverOption) map[string]any { + args := make([]any, len(testPostgresDockerArgs())) + for i, a := range testPostgresDockerArgs() { + args[i] = a + } + s := map[string]any{ + "transportType": "stdio", + "command": "docker", + "args": args, + } + for _, opt := range opts { + opt(s) + } + return s +} + +type serverOption func(map[string]any) + +func withBearerTokens(tokens ...string) serverOption { + return func(s map[string]any) { + s["serviceAuths"] = []map[string]any{ + {"type": "bearer", "tokens": tokens}, + } + } +} + +func withBasicAuth(username, passwordEnvVar string) serverOption { + return func(s map[string]any) { + auths, _ := s["serviceAuths"].([]map[string]any) + auths = append(auths, map[string]any{ + "type": "basic", + "username": username, + "password": map[string]string{"$env": passwordEnvVar}, + }) + s["serviceAuths"] = auths + } +} + +func withLogEnabled() serverOption { + return func(s map[string]any) { + s["options"] = map[string]any{"logEnabled": true} + } +} + +func withUserToken() serverOption { + return func(s map[string]any) { + s["env"] = map[string]any{ + "USER_TOKEN": map[string]string{"$userToken": "{{token}}"}, + } + s["requiresUserToken"] = true + s["userAuthentication"] = map[string]any{ + "type": "manual", + "displayName": "Test Service", + "instructions": "Enter your test token", + "helpUrl": "https://example.com/help", + } + } +} + +// testOAuthConfigFromEnv returns an OAuth auth config that reads secrets from env vars. +func testOAuthConfigFromEnv() map[string]any { + return map[string]any{ + "kind": "oauth", + "issuer": "http://localhost:8080", + "gcpProject": "test-project", + "idp": map[string]any{ + "provider": "google", + "clientId": map[string]string{"$env": "GOOGLE_CLIENT_ID"}, + "clientSecret": map[string]string{"$env": "GOOGLE_CLIENT_SECRET"}, + "redirectUri": "http://localhost:8080/oauth/callback", + "authorizationUrl": "http://localhost:9090/auth", + "tokenUrl": "http://localhost:9090/token", + "userInfoUrl": "http://localhost:9090/userinfo", + }, + "allowedDomains": []string{"test.com", "stainless.com", "claude.ai"}, + "allowedOrigins": []string{"https://claude.ai"}, + "tokenTtl": "1h", + "storage": "memory", + "jwtSecret": map[string]string{"$env": "JWT_SECRET"}, + "encryptionKey": map[string]string{"$env": "ENCRYPTION_KEY"}, + } +} + +// testGitHubOAuthConfig returns an OAuth auth config for GitHub IDP testing. +// Secrets use $env references; callers must pass JWT_SECRET, ENCRYPTION_KEY, +// and GITHUB_CLIENT_SECRET env vars. +func testGitHubOAuthConfig(allowedOrgs ...string) map[string]any { + idpCfg := map[string]any{ + "provider": "github", + "clientId": "test-github-client-id", + "clientSecret": map[string]string{"$env": "GITHUB_CLIENT_SECRET"}, + "redirectUri": "http://localhost:8080/oauth/callback", + "authorizationUrl": "http://localhost:9092/login/oauth/authorize", + "tokenUrl": "http://localhost:9092/login/oauth/access_token", + "userInfoUrl": "http://localhost:9092", + } + if len(allowedOrgs) > 0 { + idpCfg["allowedOrgs"] = allowedOrgs + } + return map[string]any{ + "kind": "oauth", + "issuer": "http://localhost:8080", + "gcpProject": "test-project", + "idp": idpCfg, + "allowedDomains": []string{"test.com"}, + "allowedOrigins": []string{"https://claude.ai"}, + "tokenTtl": "1h", + "storage": "memory", + "jwtSecret": map[string]string{"$env": "JWT_SECRET"}, + "encryptionKey": map[string]string{"$env": "ENCRYPTION_KEY"}, + } +} + +// testOIDCOAuthConfig returns an OAuth auth config for generic OIDC IDP testing. +// Secrets use $env references; callers must pass JWT_SECRET, ENCRYPTION_KEY, +// and OIDC_CLIENT_SECRET env vars. +func testOIDCOAuthConfig() map[string]any { + return map[string]any{ + "kind": "oauth", + "issuer": "http://localhost:8080", + "gcpProject": "test-project", + "idp": map[string]any{ + "provider": "oidc", + "clientId": "test-oidc-client-id", + "clientSecret": map[string]string{"$env": "OIDC_CLIENT_SECRET"}, + "redirectUri": "http://localhost:8080/oauth/callback", + "authorizationUrl": "http://localhost:9093/authorize", + "tokenUrl": "http://localhost:9093/token", + "userInfoUrl": "http://localhost:9093/userinfo", + }, + "allowedDomains": []string{"oidc-test.com"}, + "allowedOrigins": []string{"https://claude.ai"}, + "tokenTtl": "1h", + "storage": "memory", + "jwtSecret": map[string]string{"$env": "JWT_SECRET"}, + "encryptionKey": map[string]string{"$env": "ENCRYPTION_KEY"}, + } +} + +// testAzureOAuthConfig returns an OAuth auth config for Azure IDP testing. +// Secrets use $env references; callers must pass JWT_SECRET, ENCRYPTION_KEY, +// and AZURE_CLIENT_SECRET env vars. +func testAzureOAuthConfig() map[string]any { + return map[string]any{ + "kind": "oauth", + "issuer": "http://localhost:8080", + "gcpProject": "test-project", + "idp": map[string]any{ + "provider": "azure", + "tenantId": "test-tenant", + "clientId": "test-azure-client-id", + "clientSecret": map[string]string{"$env": "AZURE_CLIENT_SECRET"}, + "redirectUri": "http://localhost:8080/oauth/callback", + "authorizationUrl": "http://localhost:9093/authorize", + "tokenUrl": "http://localhost:9093/token", + "userInfoUrl": "http://localhost:9093/userinfo", + }, + "allowedDomains": []string{"oidc-test.com"}, + "allowedOrigins": []string{"https://claude.ai"}, + "tokenTtl": "1h", + "storage": "memory", + "jwtSecret": map[string]string{"$env": "JWT_SECRET"}, + "encryptionKey": map[string]string{"$env": "ENCRYPTION_KEY"}, + } +} + +// writeTestConfig writes a config map to a temporary JSON file and returns its path. +// The file is automatically cleaned up when the test finishes. +func writeTestConfig(t *testing.T, cfg map[string]any) string { + t.Helper() + data, err := json.MarshalIndent(cfg, "", " ") + if err != nil { + t.Fatalf("Failed to marshal test config: %v", err) + } + f, err := os.CreateTemp(t.TempDir(), "config-*.json") + if err != nil { + t.Fatalf("Failed to create temp config file: %v", err) + } + if _, err := f.Write(data); err != nil { + t.Fatalf("Failed to write temp config: %v", err) + } + if err := f.Close(); err != nil { + t.Fatalf("Failed to close temp config: %v", err) + } + return f.Name() +} + +// buildTestConfig builds a complete mcp-front config map. +func buildTestConfig(baseURL, name string, auth map[string]any, mcpServers map[string]any) map[string]any { + proxy := map[string]any{ + "baseURL": baseURL, + "addr": ":8080", + "name": name, + } + if auth != nil { + proxy["auth"] = auth + } + return map[string]any{ + "version": "v0.0.1-DEV_EDITION_EXPECT_CHANGES", + "proxy": proxy, + "mcpServers": mcpServers, + } +} + +// TestConfig holds all timeout configurations for integration tests +type TestConfig struct { + SessionTimeout string + CleanupInterval string + CleanupWaitTime string + TimerResetWaitTime string + MultiUserWaitTime string +} + +// GetTestConfig returns test configuration from environment variables or defaults +func GetTestConfig() TestConfig { + c := TestConfig{ + SessionTimeout: "10s", + CleanupInterval: "2s", + CleanupWaitTime: "15s", + TimerResetWaitTime: "12s", + MultiUserWaitTime: "15s", + } + + // Override from environment if set + if v := os.Getenv("SESSION_TIMEOUT"); v != "" { + c.SessionTimeout = v + } + if v := os.Getenv("SESSION_CLEANUP_INTERVAL"); v != "" { + c.CleanupInterval = v + } + if v := os.Getenv("TEST_CLEANUP_WAIT_TIME"); v != "" { + c.CleanupWaitTime = v + } + if v := os.Getenv("TEST_TIMER_RESET_WAIT_TIME"); v != "" { + c.TimerResetWaitTime = v + } + if v := os.Getenv("TEST_MULTI_USER_WAIT_TIME"); v != "" { + c.MultiUserWaitTime = v + } + + return c +} + +func waitForDB(t *testing.T) { + waitForSec := 5 + for range waitForSec { + // Check if container is running + psCmd := exec.Command("docker", "compose", "ps", "-q", "test-postgres") + if output, err := psCmd.Output(); err != nil || len(output) == 0 { + time.Sleep(1 * time.Second) + continue + } + + // Check if database is ready + checkCmd := exec.Command("docker", "compose", "exec", "-T", "test-postgres", "pg_isready", "-U", "testuser", "-d", "testdb") + if err := checkCmd.Run(); err == nil { + return + } + time.Sleep(1 * time.Second) + } + + t.Fatalf("Database failed to become ready after %d seconds", waitForSec) +} + +// trace logs a message if TRACE environment variable is set +func trace(t *testing.T, format string, args ...any) { + if os.Getenv("TRACE") == "1" { + t.Logf("TRACE: "+format, args...) + } +} + +// tracef logs a formatted message to stdout if TRACE is set (for use outside tests) +func tracef(format string, args ...any) { + if os.Getenv("TRACE") == "1" { + fmt.Printf("TRACE: "+format+"\n", args...) + } +} + +// startMCPFront starts the mcp-front server with the given config +func startMCPFront(t *testing.T, configPath string, extraEnv ...string) { + mcpCmd := exec.Command("../cmd/mcp-front/mcp-front", "-config", configPath) + + // Get test config for session timeouts + testConfig := GetTestConfig() + + // Build default environment with test timeouts + defaultEnv := []string{ + "SESSION_TIMEOUT=" + testConfig.SessionTimeout, + "SESSION_CLEANUP_INTERVAL=" + testConfig.CleanupInterval, + } + + // Start with system environment + mcpCmd.Env = os.Environ() + + // Apply defaults first + mcpCmd.Env = append(mcpCmd.Env, defaultEnv...) + + // Apply extra env (can override defaults) + mcpCmd.Env = append(mcpCmd.Env, extraEnv...) + + // Pass through LOG_LEVEL and LOG_FORMAT if set + if logLevel := os.Getenv("LOG_LEVEL"); logLevel != "" { + mcpCmd.Env = append(mcpCmd.Env, "LOG_LEVEL="+logLevel) + } + if logFormat := os.Getenv("LOG_FORMAT"); logFormat != "" { + mcpCmd.Env = append(mcpCmd.Env, "LOG_FORMAT="+logFormat) + } + + // Capture output to log file if MCP_LOG_FILE is set + if logFile := os.Getenv("MCP_LOG_FILE"); logFile != "" { + f, err := os.OpenFile(logFile, os.O_CREATE|os.O_WRONLY|os.O_APPEND, 0644) + if err == nil { + mcpCmd.Stderr = f + mcpCmd.Stdout = f + t.Cleanup(func() { f.Close() }) + } + } + + if err := mcpCmd.Start(); err != nil { + t.Fatalf("Failed to start mcp-front: %v", err) + } + + // Register cleanup that runs even if test is killed + t.Cleanup(func() { + stopMCPFront(mcpCmd) + }) +} + +// stopMCPFront stops the mcp-front server gracefully +func stopMCPFront(cmd *exec.Cmd) { + if cmd == nil || cmd.Process == nil { + return + } + + // Try graceful shutdown first (SIGINT) + if err := cmd.Process.Signal(syscall.SIGINT); err != nil { + // If SIGINT fails, force kill immediately + _ = cmd.Process.Kill() + _ = cmd.Wait() + return + } + + // Wait up to 5 seconds for graceful shutdown + done := make(chan error, 1) + go func() { + done <- cmd.Wait() + }() + + select { + case <-done: + // Graceful shutdown completed + return + case <-time.After(5 * time.Second): + // Timeout, force kill + _ = cmd.Process.Kill() + _ = cmd.Wait() + } +} + +// waitForMCPFront waits for the mcp-front server to be ready +func waitForMCPFront(t *testing.T) { + t.Helper() + for range 10 { + resp, err := http.Get("http://localhost:8080/health") + if err == nil && resp.StatusCode == 200 { + resp.Body.Close() + return + } + if resp != nil { + resp.Body.Close() + } + time.Sleep(1 * time.Second) + } + t.Fatal("mcp-front failed to become ready after 10 seconds") +} + +// getMCPContainers returns a list of running toolbox container IDs +func getMCPContainers() []string { + cmd := exec.Command("docker", "ps", "--format", "{{.ID}}", "--filter", "ancestor="+ToolboxImage) + output, err := cmd.Output() + if err != nil { + return nil + } + + var containers []string + for line := range strings.SplitSeq(strings.TrimSpace(string(output)), "\n") { + if line != "" { + containers = append(containers, line) + } + } + return containers +} + +// cleanupContainers forces cleanup of containers that weren't in the initial set +func cleanupContainers(t *testing.T, initialContainers []string) { + time.Sleep(2 * time.Second) + containers := getMCPContainers() + for _, container := range containers { + isInitial := slices.Contains(initialContainers, container) + if !isInitial { + t.Logf("Force stopping container: %s...", container) + if err := exec.Command("docker", "stop", container).Run(); err != nil { + t.Logf("Failed to stop container %s: %v", container, err) + } else { + t.Logf("Stopped container: %s", container) + } + } + } +} + +// TestQuickSmoke provides a fast validation test +func TestQuickSmoke(t *testing.T) { + t.Log("Running quick smoke test...") + + // Just verify the test infrastructure works + client := NewMCPSSEClient("http://localhost:8080") + if client == nil { + t.Fatal("Failed to create client") + } + + if err := client.Authenticate(); err != nil { + t.Fatal("Failed to set up authentication") + } + + t.Log("Quick smoke test passed - test infrastructure is working") +} diff --git a/integration/test_utils.go b/integration/test_utils.go deleted file mode 100644 index 6557f13..0000000 --- a/integration/test_utils.go +++ /dev/null @@ -1,992 +0,0 @@ -package integration - -import ( - "bufio" - "bytes" - "context" - "encoding/json" - "fmt" - "io" - "net/http" - "net/url" - "os" - "os/exec" - "slices" - "strings" - "sync" - "syscall" - "testing" - "time" -) - -// ToolboxImage is the Docker image for the MCP Toolbox for Databases. -// Used as the MCP server backing integration tests. All test configs -// that reference a postgres MCP server should use this image. -const ToolboxImage = "us-central1-docker.pkg.dev/database-toolbox/toolbox/toolbox:latest" - -// testPostgresDockerArgs returns the Docker args for running the toolbox -// as a stdio MCP server against the test postgres database. -func testPostgresDockerArgs() []string { - return []string{ - "run", "--rm", "-i", "--network", "host", - "-e", "POSTGRES_HOST=localhost", - "-e", "POSTGRES_PORT=15432", - "-e", "POSTGRES_DATABASE=testdb", - "-e", "POSTGRES_USER=testuser", - "-e", "POSTGRES_PASSWORD=testpass", - ToolboxImage, - "--stdio", "--prebuilt", "postgres", - } -} - -// testPostgresServer returns an MCP server config for the test postgres database. -// Options can customize auth, logging, etc. -func testPostgresServer(opts ...serverOption) map[string]any { - args := make([]any, len(testPostgresDockerArgs())) - for i, a := range testPostgresDockerArgs() { - args[i] = a - } - s := map[string]any{ - "transportType": "stdio", - "command": "docker", - "args": args, - } - for _, opt := range opts { - opt(s) - } - return s -} - -type serverOption func(map[string]any) - -func withBearerTokens(tokens ...string) serverOption { - return func(s map[string]any) { - s["serviceAuths"] = []map[string]any{ - {"type": "bearer", "tokens": tokens}, - } - } -} - -func withBasicAuth(username, passwordEnvVar string) serverOption { - return func(s map[string]any) { - auths, _ := s["serviceAuths"].([]map[string]any) - auths = append(auths, map[string]any{ - "type": "basic", - "username": username, - "password": map[string]string{"$env": passwordEnvVar}, - }) - s["serviceAuths"] = auths - } -} - -func withLogEnabled() serverOption { - return func(s map[string]any) { - s["options"] = map[string]any{"logEnabled": true} - } -} - -func withUserToken() serverOption { - return func(s map[string]any) { - s["env"] = map[string]any{ - "USER_TOKEN": map[string]string{"$userToken": "{{token}}"}, - } - s["requiresUserToken"] = true - s["userAuthentication"] = map[string]any{ - "type": "manual", - "displayName": "Test Service", - "instructions": "Enter your test token", - "helpUrl": "https://example.com/help", - } - } -} - -// testOAuthConfig returns a standard OAuth auth config for testing. -// Uses hardcoded values suitable for integration tests with the fake GCP server. -func testOAuthConfig() map[string]any { - return map[string]any{ - "kind": "oauth", - "issuer": "http://localhost:8080", - "gcpProject": "test-project", - "idp": map[string]any{ - "provider": "google", - "clientId": "test-client-id", - "clientSecret": "test-client-secret-for-integration-testing", - "redirectUri": "http://localhost:8080/oauth/callback", - "authorizationUrl": "http://localhost:9090/auth", - "tokenUrl": "http://localhost:9090/token", - "userInfoUrl": "http://localhost:9090/userinfo", - }, - "allowedDomains": []string{"test.com"}, - "allowedOrigins": []string{"https://claude.ai"}, - "tokenTtl": "1h", - "storage": "memory", - "jwtSecret": "test-jwt-secret-for-integration-testing-32-chars-long", - "encryptionKey": "test-encryption-key-32-bytes-aes", - } -} - -// testOAuthConfigFromEnv returns an OAuth auth config that reads secrets from env vars. -func testOAuthConfigFromEnv() map[string]any { - return map[string]any{ - "kind": "oauth", - "issuer": "http://localhost:8080", - "gcpProject": "test-project", - "idp": map[string]any{ - "provider": "google", - "clientId": map[string]string{"$env": "GOOGLE_CLIENT_ID"}, - "clientSecret": map[string]string{"$env": "GOOGLE_CLIENT_SECRET"}, - "redirectUri": "http://localhost:8080/oauth/callback", - "authorizationUrl": "http://localhost:9090/auth", - "tokenUrl": "http://localhost:9090/token", - "userInfoUrl": "http://localhost:9090/userinfo", - }, - "allowedDomains": []string{"test.com", "stainless.com", "claude.ai"}, - "allowedOrigins": []string{"https://claude.ai"}, - "tokenTtl": "1h", - "storage": "memory", - "jwtSecret": map[string]string{"$env": "JWT_SECRET"}, - "encryptionKey": map[string]string{"$env": "ENCRYPTION_KEY"}, - } -} - -// writeTestConfig writes a config map to a temporary JSON file and returns its path. -// The file is automatically cleaned up when the test finishes. -func writeTestConfig(t *testing.T, cfg map[string]any) string { - t.Helper() - data, err := json.MarshalIndent(cfg, "", " ") - if err != nil { - t.Fatalf("Failed to marshal test config: %v", err) - } - f, err := os.CreateTemp(t.TempDir(), "config-*.json") - if err != nil { - t.Fatalf("Failed to create temp config file: %v", err) - } - if _, err := f.Write(data); err != nil { - t.Fatalf("Failed to write temp config: %v", err) - } - if err := f.Close(); err != nil { - t.Fatalf("Failed to close temp config: %v", err) - } - return f.Name() -} - -// buildTestConfig builds a complete mcp-front config map. -func buildTestConfig(baseURL, name string, auth map[string]any, mcpServers map[string]any) map[string]any { - proxy := map[string]any{ - "baseURL": baseURL, - "addr": ":8080", - "name": name, - } - if auth != nil { - proxy["auth"] = auth - } - return map[string]any{ - "version": "v0.0.1-DEV_EDITION_EXPECT_CHANGES", - "proxy": proxy, - "mcpServers": mcpServers, - } -} - -// MCPSSEClient simulates an MCP client for testing -type MCPSSEClient struct { - baseURL string - token string - sseConn io.ReadCloser - messageEndpoint string - sseScanner *bufio.Scanner - sessionID string -} - -// NewMCPSSEClient creates a new MCP client for testing -func NewMCPSSEClient(baseURL string) *MCPSSEClient { - return &MCPSSEClient{ - baseURL: baseURL, - } -} - -// Authenticate sets up authentication for the client -func (c *MCPSSEClient) Authenticate() error { - c.token = "test-token" - return nil -} - -// SetAuthToken sets a specific auth token for the client -func (c *MCPSSEClient) SetAuthToken(token string) { - c.token = token -} - -// Connect establishes an SSE connection and retrieves the message endpoint -func (c *MCPSSEClient) Connect() error { - return c.ConnectToServer("postgres") -} - -// ConnectToServer establishes an SSE connection to a specific server -func (c *MCPSSEClient) ConnectToServer(serverName string) error { - // Close any existing connection - if c.sseConn != nil { - c.sseConn.Close() - c.sseConn = nil - c.messageEndpoint = "" - } - - sseURL := c.baseURL + "/" + serverName + "/sse" - tracef("ConnectToServer: requesting %s", sseURL) - - req, err := http.NewRequest("GET", sseURL, nil) - if err != nil { - return fmt.Errorf("failed to create SSE request: %v", err) - } - - req.Header.Set("Accept", "text/event-stream") - req.Header.Set("Authorization", "Bearer "+c.token) - req.Header.Set("Cache-Control", "no-cache") - tracef("ConnectToServer: headers set, making request") - - // Don't use a timeout on the client for SSE - client := &http.Client{} - resp, err := client.Do(req) - if err != nil { - return fmt.Errorf("SSE connection failed: %v", err) - } - - tracef("ConnectToServer: got response status %d", resp.StatusCode) - if resp.StatusCode != 200 { - body, _ := io.ReadAll(resp.Body) - resp.Body.Close() - return fmt.Errorf("SSE connection returned %d: %s", resp.StatusCode, string(body)) - } - - // Store the connection - c.sseConn = resp.Body - c.sseScanner = bufio.NewScanner(resp.Body) - - // Read initial SSE messages to get the endpoint - // For inline servers, we don't get a message endpoint - we use the server path directly - gotEndpointMessage := false - for c.sseScanner.Scan() { - line := c.sseScanner.Text() - tracef("ConnectToServer: SSE line: %s", line) - - // Look for data lines - if after, ok := strings.CutPrefix(line, "data: "); ok { - data := after - - // Check if it's an endpoint message (for inline servers) - if strings.Contains(data, `"type":"endpoint"`) { - gotEndpointMessage = true - // For inline servers, construct the message endpoint - c.messageEndpoint = c.baseURL + "/" + serverName + "/message" - tracef("ConnectToServer: inline server detected, using endpoint: %s", c.messageEndpoint) - break - } - - // Check if it's a message endpoint URL (for stdio servers) - if strings.Contains(data, "http://") || strings.Contains(data, "https://") { - c.messageEndpoint = data - - // Extract session ID from endpoint URL - if u, err := url.Parse(data); err == nil { - c.sessionID = u.Query().Get("sessionId") - } - - tracef("ConnectToServer: found endpoint: %s", c.messageEndpoint) - break - } - } - } - - if c.messageEndpoint == "" && !gotEndpointMessage { - c.sseConn.Close() - c.sseConn = nil - return fmt.Errorf("no message endpoint received") - } - - tracef("Connect: successfully connected to MCP server") - return nil -} - -// ValidateBackendConnectivity checks if we can connect to the MCP server -func (c *MCPSSEClient) ValidateBackendConnectivity() error { - return c.Connect() -} - -// Close closes the SSE connection -func (c *MCPSSEClient) Close() { - if c.sseConn != nil { - c.sseConn.Close() - c.sseConn = nil - c.messageEndpoint = "" - c.sseScanner = nil - } -} - -// SendMCPRequest sends an MCP JSON-RPC request and returns the response -func (c *MCPSSEClient) SendMCPRequest(method string, params any) (map[string]any, error) { - // Ensure we have a connection - if c.messageEndpoint == "" { - if err := c.Connect(); err != nil { - return nil, fmt.Errorf("failed to connect: %v", err) - } - } - - // Send MCP request to the message endpoint - request := map[string]any{ - "jsonrpc": "2.0", - "id": 1, - "method": method, - "params": params, - } - - reqBody, err := json.Marshal(request) - if err != nil { - return nil, err - } - - msgReq, err := http.NewRequest("POST", c.messageEndpoint, bytes.NewBuffer(reqBody)) - if err != nil { - return nil, err - } - - msgReq.Header.Set("Content-Type", "application/json") - msgReq.Header.Set("Authorization", "Bearer "+c.token) - - client := &http.Client{Timeout: 30 * time.Second} - msgResp, err := client.Do(msgReq) - if err != nil { - return nil, err - } - defer msgResp.Body.Close() - - respBody, err := io.ReadAll(msgResp.Body) - if err != nil { - return nil, err - } - - if msgResp.StatusCode != 200 && msgResp.StatusCode != 202 { - return nil, fmt.Errorf("MCP request failed: %d - %s", msgResp.StatusCode, string(respBody)) - } - - // Handle 202 and empty responses - read response from SSE stream - if msgResp.StatusCode == 202 || len(respBody) == 0 { - // Read response from SSE stream - for c.sseScanner.Scan() { - line := c.sseScanner.Text() - - if after, ok := strings.CutPrefix(line, "data: "); ok { - data := after - // Try to parse as JSON - var msg map[string]any - if err := json.Unmarshal([]byte(data), &msg); err == nil { - // Check if this is our response (matching ID) - if id, ok := msg["id"]; ok && id == float64(1) { - return msg, nil - } - } - } - } - - if err := c.sseScanner.Err(); err != nil { - return nil, fmt.Errorf("SSE scanner error: %v", err) - } - - return nil, fmt.Errorf("no response received from SSE stream") - } - - var result map[string]any - if err := json.Unmarshal(respBody, &result); err != nil { - return nil, fmt.Errorf("failed to parse response: %v - %s", err, string(respBody)) - } - - return result, nil -} - -// MCPStreamableClient is a test client for HTTP-Streamable MCP servers -type MCPStreamableClient struct { - baseURL string - serverName string - token string - httpClient *http.Client - - // For GET SSE streaming - sseConn io.ReadCloser - sseScanner *bufio.Scanner - sseCancel chan struct{} - - mu sync.Mutex -} - -// NewMCPStreamableClient creates a new streamable-http test client -func NewMCPStreamableClient(baseURL string) *MCPStreamableClient { - return &MCPStreamableClient{ - baseURL: baseURL, - httpClient: &http.Client{ - Timeout: 30 * time.Second, - }, - } -} - -// SetAuthToken sets the authentication token -func (c *MCPStreamableClient) SetAuthToken(token string) { - c.token = token -} - -// ConnectToServer establishes connection to a streamable-http server -func (c *MCPStreamableClient) ConnectToServer(serverName string) error { - c.mu.Lock() - defer c.mu.Unlock() - - // Close any existing connection - c.close() - - c.serverName = serverName - - // For streamable-http, we can optionally open a GET SSE stream for server-initiated messages - // But it's not required for basic request/response - return c.openSSEStream() -} - -// openSSEStream opens a GET SSE connection for receiving server-initiated messages -func (c *MCPStreamableClient) openSSEStream() error { - url := c.baseURL + "/" + c.serverName + "/sse" - - req, err := http.NewRequest("GET", url, nil) - if err != nil { - return fmt.Errorf("failed to create GET request: %v", err) - } - - req.Header.Set("Accept", "text/event-stream") - req.Header.Set("Authorization", "Bearer "+c.token) - req.Header.Set("Cache-Control", "no-cache") - - // Use a client without timeout for SSE - sseClient := &http.Client{} - resp, err := sseClient.Do(req) - if err != nil { - return fmt.Errorf("SSE connection failed: %v", err) - } - - if resp.StatusCode != 200 { - body, _ := io.ReadAll(resp.Body) - resp.Body.Close() - return fmt.Errorf("SSE connection returned %d: %s", resp.StatusCode, string(body)) - } - - c.sseConn = resp.Body - c.sseScanner = bufio.NewScanner(resp.Body) - c.sseCancel = make(chan struct{}) - - // Start reading SSE messages in background - go c.readSSEMessages() - - return nil -} - -// readSSEMessages reads server-initiated messages from the SSE stream -func (c *MCPStreamableClient) readSSEMessages() { - for { - select { - case <-c.sseCancel: - return - default: - if c.sseScanner.Scan() { - line := c.sseScanner.Text() - if after, ok := strings.CutPrefix(line, "data: "); ok { - data := after - // In a real implementation, we'd process server-initiated messages here - tracef("StreamableClient: received SSE message: %s", data) - } - } else { - // Scanner stopped - connection closed or error - return - } - } - } -} - -// SendMCPRequest sends a JSON-RPC request via POST -func (c *MCPStreamableClient) SendMCPRequest(method string, params any) (map[string]any, error) { - c.mu.Lock() - serverName := c.serverName - c.mu.Unlock() - - if serverName == "" { - return nil, fmt.Errorf("not connected to any server") - } - - // For streamable-http, we POST to the server endpoint - url := c.baseURL + "/" + serverName + "/sse" - - // Construct JSON-RPC request - request := map[string]any{ - "jsonrpc": "2.0", - "id": 1, - "method": method, - "params": params, - } - - body, err := json.Marshal(request) - if err != nil { - return nil, fmt.Errorf("failed to marshal request: %v", err) - } - - req, err := http.NewRequest("POST", url, bytes.NewReader(body)) - if err != nil { - return nil, fmt.Errorf("failed to create POST request: %v", err) - } - - req.Header.Set("Content-Type", "application/json") - req.Header.Set("Authorization", "Bearer "+c.token) - // Accept both JSON and SSE responses - req.Header.Set("Accept", "application/json, text/event-stream") - - resp, err := c.httpClient.Do(req) - if err != nil { - return nil, fmt.Errorf("request failed: %v", err) - } - defer resp.Body.Close() - - // Check content type to determine response format - contentType := resp.Header.Get("Content-Type") - - if strings.HasPrefix(contentType, "text/event-stream") { - // Handle SSE response - return c.handleSSEResponse(resp.Body) - } else { - // Handle JSON response - return c.handleJSONResponse(resp.Body) - } -} - -// handleJSONResponse processes a regular JSON response -func (c *MCPStreamableClient) handleJSONResponse(body io.Reader) (map[string]any, error) { - var response map[string]any - if err := json.NewDecoder(body).Decode(&response); err != nil { - return nil, fmt.Errorf("failed to decode JSON response: %v", err) - } - return response, nil -} - -// handleSSEResponse processes an SSE stream response from a POST -func (c *MCPStreamableClient) handleSSEResponse(body io.Reader) (map[string]any, error) { - scanner := bufio.NewScanner(body) - var lastResponse map[string]any - - for scanner.Scan() { - line := scanner.Text() - if after, ok := strings.CutPrefix(line, "data: "); ok { - data := after - var msg map[string]any - if err := json.Unmarshal([]byte(data), &msg); err == nil { - // Keep the last response with an ID (not a notification) - if _, hasID := msg["id"]; hasID { - lastResponse = msg - } - } - } - } - - if lastResponse == nil { - return nil, fmt.Errorf("no response received in SSE stream") - } - - return lastResponse, nil -} - -// Close closes all connections -func (c *MCPStreamableClient) Close() { - c.mu.Lock() - defer c.mu.Unlock() - c.close() -} - -// close is the internal close method (must be called with lock held) -func (c *MCPStreamableClient) close() { - if c.sseCancel != nil { - close(c.sseCancel) - c.sseCancel = nil - } - - if c.sseConn != nil { - c.sseConn.Close() - c.sseConn = nil - c.sseScanner = nil - } - - c.serverName = "" -} - -// FakeGCPServer provides a fake GCP OAuth server for testing -type FakeGCPServer struct { - server *http.Server - port string -} - -// NewFakeGCPServer creates a new fake GCP server -func NewFakeGCPServer(port string) *FakeGCPServer { - mux := http.NewServeMux() - - mux.HandleFunc("/auth", func(w http.ResponseWriter, r *http.Request) { - redirectURI := r.URL.Query().Get("redirect_uri") - state := r.URL.Query().Get("state") - http.Redirect(w, r, fmt.Sprintf("%s?code=test-auth-code&state=%s", redirectURI, state), http.StatusFound) - }) - - mux.HandleFunc("/token", func(w http.ResponseWriter, r *http.Request) { - // Parse the form data - if err := r.ParseForm(); err != nil { - http.Error(w, "Invalid request", http.StatusBadRequest) - return - } - - // Check the authorization code - code := r.FormValue("code") - if code != "test-auth-code" { - w.Header().Set("Content-Type", "application/json") - w.WriteHeader(http.StatusBadRequest) - _ = json.NewEncoder(w).Encode(map[string]any{ - "error": "invalid_grant", - "error_description": "Invalid authorization code", - }) - return - } - - w.Header().Set("Content-Type", "application/json") - _ = json.NewEncoder(w).Encode(map[string]any{ - "access_token": "test-access-token", - "token_type": "Bearer", - "expires_in": 3600, - }) - }) - - mux.HandleFunc("/userinfo", func(w http.ResponseWriter, r *http.Request) { - w.Header().Set("Content-Type", "application/json") - _ = json.NewEncoder(w).Encode(map[string]any{ - "email": "test@test.com", - "hd": "test.com", - }) - }) - - server := &http.Server{ - Addr: ":" + port, - Handler: mux, - } - - return &FakeGCPServer{ - server: server, - port: port, - } -} - -// Start starts the fake GCP server -func (m *FakeGCPServer) Start() error { - go func() { - if err := m.server.ListenAndServe(); err != nil && err != http.ErrServerClosed { - panic(err) - } - }() - - time.Sleep(100 * time.Millisecond) - return nil -} - -// Stop stops the fake GCP server -func (m *FakeGCPServer) Stop() error { - ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) - defer cancel() - return m.server.Shutdown(ctx) -} - -// FakeServiceOAuthServer provides a fake OAuth server for external services (like Linear, GitHub) -type FakeServiceOAuthServer struct { - server *http.Server - port string -} - -// NewFakeServiceOAuthServer creates a new fake service OAuth server -func NewFakeServiceOAuthServer(port string) *FakeServiceOAuthServer { - mux := http.NewServeMux() - - mux.HandleFunc("/oauth/authorize", func(w http.ResponseWriter, r *http.Request) { - redirectURI := r.URL.Query().Get("redirect_uri") - state := r.URL.Query().Get("state") - http.Redirect(w, r, fmt.Sprintf("%s?code=service-auth-code&state=%s", redirectURI, state), http.StatusFound) - }) - - mux.HandleFunc("/oauth/token", func(w http.ResponseWriter, r *http.Request) { - if err := r.ParseForm(); err != nil { - http.Error(w, "Invalid request", http.StatusBadRequest) - return - } - - code := r.FormValue("code") - if code != "service-auth-code" { - w.Header().Set("Content-Type", "application/json") - w.WriteHeader(http.StatusBadRequest) - _ = json.NewEncoder(w).Encode(map[string]any{ - "error": "invalid_grant", - "error_description": "Invalid authorization code", - }) - return - } - - w.Header().Set("Content-Type", "application/json") - _ = json.NewEncoder(w).Encode(map[string]any{ - "access_token": "service-oauth-access-token", - "refresh_token": "service-oauth-refresh-token", - "token_type": "Bearer", - "expires_in": 3600, - }) - }) - - server := &http.Server{ - Addr: ":" + port, - Handler: mux, - } - - return &FakeServiceOAuthServer{ - server: server, - port: port, - } -} - -// Start starts the fake service OAuth server -func (s *FakeServiceOAuthServer) Start() error { - go func() { - if err := s.server.ListenAndServe(); err != nil && err != http.ErrServerClosed { - panic(err) - } - }() - - time.Sleep(100 * time.Millisecond) - return nil -} - -// Stop stops the fake service OAuth server -func (s *FakeServiceOAuthServer) Stop() error { - ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) - defer cancel() - return s.server.Shutdown(ctx) -} - -// TestConfig holds all timeout configurations for integration tests -type TestConfig struct { - SessionTimeout string - CleanupInterval string - CleanupWaitTime string - TimerResetWaitTime string - MultiUserWaitTime string -} - -// GetTestConfig returns test configuration from environment variables or defaults -func GetTestConfig() TestConfig { - c := TestConfig{ - SessionTimeout: "10s", - CleanupInterval: "2s", - CleanupWaitTime: "15s", - TimerResetWaitTime: "12s", - MultiUserWaitTime: "15s", - } - - // Override from environment if set - if v := os.Getenv("SESSION_TIMEOUT"); v != "" { - c.SessionTimeout = v - } - if v := os.Getenv("SESSION_CLEANUP_INTERVAL"); v != "" { - c.CleanupInterval = v - } - if v := os.Getenv("TEST_CLEANUP_WAIT_TIME"); v != "" { - c.CleanupWaitTime = v - } - if v := os.Getenv("TEST_TIMER_RESET_WAIT_TIME"); v != "" { - c.TimerResetWaitTime = v - } - if v := os.Getenv("TEST_MULTI_USER_WAIT_TIME"); v != "" { - c.MultiUserWaitTime = v - } - - return c -} - -func waitForDB(t *testing.T) { - waitForSec := 5 - for range waitForSec { - // Check if container is running - psCmd := exec.Command("docker", "compose", "ps", "-q", "test-postgres") - if output, err := psCmd.Output(); err != nil || len(output) == 0 { - time.Sleep(1 * time.Second) - continue - } - - // Check if database is ready - checkCmd := exec.Command("docker", "compose", "exec", "-T", "test-postgres", "pg_isready", "-U", "testuser", "-d", "testdb") - if err := checkCmd.Run(); err == nil { - return - } - time.Sleep(1 * time.Second) - } - - t.Fatalf("Database failed to become ready after %d seconds", waitForSec) -} - -// trace logs a message if TRACE environment variable is set -func trace(t *testing.T, format string, args ...any) { - if os.Getenv("TRACE") == "1" { - t.Logf("TRACE: "+format, args...) - } -} - -// tracef logs a formatted message to stdout if TRACE is set (for use outside tests) -func tracef(format string, args ...any) { - if os.Getenv("TRACE") == "1" { - fmt.Printf("TRACE: "+format+"\n", args...) - } -} - -// startMCPFront starts the mcp-front server with the given config -func startMCPFront(t *testing.T, configPath string, extraEnv ...string) { - mcpCmd := exec.Command("../cmd/mcp-front/mcp-front", "-config", configPath) - - // Get test config for session timeouts - testConfig := GetTestConfig() - - // Build default environment with test timeouts - defaultEnv := []string{ - "SESSION_TIMEOUT=" + testConfig.SessionTimeout, - "SESSION_CLEANUP_INTERVAL=" + testConfig.CleanupInterval, - } - - // Start with system environment - mcpCmd.Env = os.Environ() - - // Apply defaults first - mcpCmd.Env = append(mcpCmd.Env, defaultEnv...) - - // Apply extra env (can override defaults) - mcpCmd.Env = append(mcpCmd.Env, extraEnv...) - - // Pass through LOG_LEVEL and LOG_FORMAT if set - if logLevel := os.Getenv("LOG_LEVEL"); logLevel != "" { - mcpCmd.Env = append(mcpCmd.Env, "LOG_LEVEL="+logLevel) - } - if logFormat := os.Getenv("LOG_FORMAT"); logFormat != "" { - mcpCmd.Env = append(mcpCmd.Env, "LOG_FORMAT="+logFormat) - } - - // Capture output to log file if MCP_LOG_FILE is set - if logFile := os.Getenv("MCP_LOG_FILE"); logFile != "" { - f, err := os.OpenFile(logFile, os.O_CREATE|os.O_WRONLY|os.O_APPEND, 0644) - if err == nil { - mcpCmd.Stderr = f - mcpCmd.Stdout = f - t.Cleanup(func() { f.Close() }) - } - } - - if err := mcpCmd.Start(); err != nil { - t.Fatalf("Failed to start mcp-front: %v", err) - } - - // Register cleanup that runs even if test is killed - t.Cleanup(func() { - stopMCPFront(mcpCmd) - }) -} - -// stopMCPFront stops the mcp-front server gracefully -func stopMCPFront(cmd *exec.Cmd) { - if cmd == nil || cmd.Process == nil { - return - } - - // Try graceful shutdown first (SIGINT) - if err := cmd.Process.Signal(syscall.SIGINT); err != nil { - // If SIGINT fails, force kill immediately - _ = cmd.Process.Kill() - _ = cmd.Wait() - return - } - - // Wait up to 5 seconds for graceful shutdown - done := make(chan error, 1) - go func() { - done <- cmd.Wait() - }() - - select { - case <-done: - // Graceful shutdown completed - return - case <-time.After(5 * time.Second): - // Timeout, force kill - _ = cmd.Process.Kill() - _ = cmd.Wait() - } -} - -// waitForMCPFront waits for the mcp-front server to be ready -func waitForMCPFront(t *testing.T) { - t.Helper() - for range 10 { - resp, err := http.Get("http://localhost:8080/health") - if err == nil && resp.StatusCode == 200 { - resp.Body.Close() - return - } - if resp != nil { - resp.Body.Close() - } - time.Sleep(1 * time.Second) - } - t.Fatal("mcp-front failed to become ready after 10 seconds") -} - -// getMCPContainers returns a list of running toolbox container IDs -func getMCPContainers() []string { - cmd := exec.Command("docker", "ps", "--format", "{{.ID}}", "--filter", "ancestor="+ToolboxImage) - output, err := cmd.Output() - if err != nil { - return nil - } - - var containers []string - for line := range strings.SplitSeq(strings.TrimSpace(string(output)), "\n") { - if line != "" { - containers = append(containers, line) - } - } - return containers -} - -// cleanupContainers forces cleanup of containers that weren't in the initial set -func cleanupContainers(t *testing.T, initialContainers []string) { - time.Sleep(2 * time.Second) - containers := getMCPContainers() - for _, container := range containers { - isInitial := slices.Contains(initialContainers, container) - if !isInitial { - t.Logf("Force stopping container: %s...", container) - if err := exec.Command("docker", "stop", container).Run(); err != nil { - t.Logf("Failed to stop container %s: %v", container, err) - } else { - t.Logf("Stopped container: %s", container) - } - } - } -} - -// TestQuickSmoke provides a fast validation test -func TestQuickSmoke(t *testing.T) { - t.Log("Running quick smoke test...") - - // Just verify the test infrastructure works - client := NewMCPSSEClient("http://localhost:8080") - if client == nil { - t.Fatal("Failed to create client") - } - - if err := client.Authenticate(); err != nil { - t.Fatal("Failed to set up authentication") - } - - t.Log("Quick smoke test passed - test infrastructure is working") -} diff --git a/internal/idp/azure.go b/internal/idp/azure.go index ccd0682..6694740 100644 --- a/internal/idp/azure.go +++ b/internal/idp/azure.go @@ -4,22 +4,31 @@ import "fmt" // NewAzureProvider creates an Azure AD provider using OIDC discovery. // Azure AD is OIDC-compliant, so we use the generic OIDC provider with Azure's tenant-specific discovery URL. -func NewAzureProvider(tenantID, clientID, clientSecret, redirectURI string) (*OIDCProvider, error) { +// Optional direct endpoint overrides (authorizationURL, tokenURL, userInfoURL) skip discovery +// when all three are provided — useful for testing. +func NewAzureProvider(tenantID, clientID, clientSecret, redirectURI, authorizationURL, tokenURL, userInfoURL string) (*OIDCProvider, error) { if tenantID == "" { return nil, fmt.Errorf("tenantId is required for Azure AD") } - discoveryURL := fmt.Sprintf( - "https://login.microsoftonline.com/%s/v2.0/.well-known/openid-configuration", - tenantID, - ) - - return NewOIDCProvider(OIDCConfig{ + cfg := OIDCConfig{ ProviderType: "azure", - DiscoveryURL: discoveryURL, ClientID: clientID, ClientSecret: clientSecret, RedirectURI: redirectURI, Scopes: []string{"openid", "email", "profile"}, - }) + } + + if authorizationURL != "" && tokenURL != "" && userInfoURL != "" { + cfg.AuthorizationURL = authorizationURL + cfg.TokenURL = tokenURL + cfg.UserInfoURL = userInfoURL + } else { + cfg.DiscoveryURL = fmt.Sprintf( + "https://login.microsoftonline.com/%s/v2.0/.well-known/openid-configuration", + tenantID, + ) + } + + return NewOIDCProvider(cfg) } diff --git a/internal/idp/azure_test.go b/internal/idp/azure_test.go index d163ce0..453a23a 100644 --- a/internal/idp/azure_test.go +++ b/internal/idp/azure_test.go @@ -8,7 +8,7 @@ import ( ) func TestNewAzureProvider_MissingTenantID(t *testing.T) { - _, err := NewAzureProvider("", "client-id", "client-secret", "https://example.com/callback") + _, err := NewAzureProvider("", "client-id", "client-secret", "https://example.com/callback", "", "", "") require.Error(t, err) assert.Contains(t, err.Error(), "tenantId is required") diff --git a/internal/idp/factory.go b/internal/idp/factory.go index ff55d8d..20e7e4c 100644 --- a/internal/idp/factory.go +++ b/internal/idp/factory.go @@ -25,6 +25,9 @@ func NewProvider(cfg config.IDPConfig) (Provider, error) { cfg.ClientID, string(cfg.ClientSecret), cfg.RedirectURI, + cfg.AuthorizationURL, + cfg.TokenURL, + cfg.UserInfoURL, ) case "github": @@ -32,6 +35,9 @@ func NewProvider(cfg config.IDPConfig) (Provider, error) { cfg.ClientID, string(cfg.ClientSecret), cfg.RedirectURI, + cfg.AuthorizationURL, + cfg.TokenURL, + cfg.UserInfoURL, ), nil case "oidc": diff --git a/internal/idp/github.go b/internal/idp/github.go index e737cd1..b13ec04 100644 --- a/internal/idp/github.go +++ b/internal/idp/github.go @@ -40,16 +40,31 @@ type githubOrgResponse struct { } // NewGitHubProvider creates a new GitHub OAuth provider. -func NewGitHubProvider(clientID, clientSecret, redirectURI string) *GitHubProvider { +// Optional endpoint overrides (authorizationURL, tokenURL, apiBaseURL) allow +// pointing at a non-GitHub server for testing. +func NewGitHubProvider(clientID, clientSecret, redirectURI, authorizationURL, tokenURL, apiBaseURL string) *GitHubProvider { + endpoint := github.Endpoint + if authorizationURL != "" { + endpoint.AuthURL = authorizationURL + } + if tokenURL != "" { + endpoint.TokenURL = tokenURL + } + + apiBase := "https://api.github.com" + if apiBaseURL != "" { + apiBase = apiBaseURL + } + return &GitHubProvider{ config: oauth2.Config{ ClientID: clientID, ClientSecret: clientSecret, RedirectURL: redirectURI, Scopes: []string{"user:email", "read:org"}, - Endpoint: github.Endpoint, + Endpoint: endpoint, }, - apiBaseURL: "https://api.github.com", + apiBaseURL: apiBase, } } diff --git a/internal/idp/github_test.go b/internal/idp/github_test.go index e532ef3..ccf0c87 100644 --- a/internal/idp/github_test.go +++ b/internal/idp/github_test.go @@ -13,12 +13,12 @@ import ( ) func TestGitHubProvider_Type(t *testing.T) { - provider := NewGitHubProvider("client-id", "client-secret", "https://example.com/callback") + provider := NewGitHubProvider("client-id", "client-secret", "https://example.com/callback", "", "", "") assert.Equal(t, "github", provider.Type()) } func TestGitHubProvider_AuthURL(t *testing.T) { - provider := NewGitHubProvider("client-id", "client-secret", "https://example.com/callback") + provider := NewGitHubProvider("client-id", "client-secret", "https://example.com/callback", "", "", "") authURL := provider.AuthURL("test-state")