diff --git a/CLAUDE.md b/CLAUDE.md index ec5900f..e0c17dd 100644 --- a/CLAUDE.md +++ b/CLAUDE.md @@ -245,15 +245,16 @@ When refactoring for better design: ```bash # Build everything -make build +./scripts/build # Format everything -make format +./scripts/format # Lint everything -make lint +./scripts/lint # Test mcp-front +./scripts/test go test ./internal/... -v go test ./integration -v @@ -261,7 +262,7 @@ go test ./integration -v ./mcp-front -config config.json # Start docs dev server -make doc +./scripts/docs ``` ## Documentation Site Guidelines diff --git a/cmd/mcp-front/main.go b/cmd/mcp-front/main.go index 7fd26a3..2cdb86d 100644 --- a/cmd/mcp-front/main.go +++ b/cmd/mcp-front/main.go @@ -22,17 +22,20 @@ func generateDefaultConfig(path string) error { "addr": ":8080", "name": "mcp-front", "auth": map[string]any{ - "kind": "oauth", - "issuer": "https://mcp.yourcompany.com", - "allowedDomains": []string{"yourcompany.com"}, - "allowedOrigins": []string{"https://claude.ai"}, - "tokenTtl": "24h", - "storage": "memory", - "googleClientId": map[string]string{"$env": "GOOGLE_CLIENT_ID"}, - "googleClientSecret": map[string]string{"$env": "GOOGLE_CLIENT_SECRET"}, - "googleRedirectUri": "https://mcp.yourcompany.com/oauth/callback", - "jwtSecret": map[string]string{"$env": "JWT_SECRET"}, - "encryptionKey": map[string]string{"$env": "ENCRYPTION_KEY"}, + "kind": "oauth", + "issuer": "https://mcp.yourcompany.com", + "allowedDomains": []string{"yourcompany.com"}, + "allowedOrigins": []string{"https://claude.ai"}, + "tokenTtl": "24h", + "storage": "memory", + "idp": map[string]any{ + "provider": "google", + "clientId": map[string]string{"$env": "GOOGLE_CLIENT_ID"}, + "clientSecret": map[string]string{"$env": "GOOGLE_CLIENT_SECRET"}, + "redirectUri": "https://mcp.yourcompany.com/oauth/callback", + }, + "jwtSecret": map[string]string{"$env": "JWT_SECRET"}, + "encryptionKey": map[string]string{"$env": "ENCRYPTION_KEY"}, }, }, "mcpServers": map[string]any{ @@ -40,9 +43,21 @@ func generateDefaultConfig(path string) error { "transportType": "stdio", "command": "docker", "args": []any{ - "run", "--rm", "-i", - "mcp/postgres:latest", - map[string]string{"$env": "POSTGRES_URL"}, + "run", "--rm", "-i", "--network", "host", + "-e", "POSTGRES_HOST", + "-e", "POSTGRES_PORT", + "-e", "POSTGRES_DATABASE", + "-e", "POSTGRES_USER", + "-e", "POSTGRES_PASSWORD", + "us-central1-docker.pkg.dev/database-toolbox/toolbox/toolbox:latest", + "--stdio", "--prebuilt", "postgres", + }, + "env": map[string]any{ + "POSTGRES_HOST": map[string]string{"$env": "POSTGRES_HOST"}, + "POSTGRES_PORT": map[string]string{"$env": "POSTGRES_PORT"}, + "POSTGRES_DATABASE": map[string]string{"$env": "POSTGRES_DATABASE"}, + "POSTGRES_USER": map[string]string{"$env": "POSTGRES_USER"}, + "POSTGRES_PASSWORD": map[string]string{"$env": "POSTGRES_PASSWORD"}, }, }, }, diff --git a/config-admin-example.json b/config-admin-example.json index e0fd169..b35934c 100644 --- a/config-admin-example.json +++ b/config-admin-example.json @@ -7,13 +7,16 @@ "auth": { "kind": "oauth", "issuer": "https://mcp.example.com", + "idp": { + "provider": "google", + "clientId": {"$env": "GOOGLE_CLIENT_ID"}, + "clientSecret": {"$env": "GOOGLE_CLIENT_SECRET"}, + "redirectUri": "https://mcp.example.com/oauth/callback" + }, "allowedDomains": ["example.com"], "allowedOrigins": ["https://claude.ai"], "tokenTtl": "1h", "storage": "memory", - "googleClientId": {"$env": "GOOGLE_CLIENT_ID"}, - "googleClientSecret": {"$env": "GOOGLE_CLIENT_SECRET"}, - "googleRedirectUri": "https://mcp.example.com/oauth/callback", "jwtSecret": {"$env": "JWT_SECRET"}, "encryptionKey": {"$env": "ENCRYPTION_KEY"} }, diff --git a/config-inline-example.json b/config-inline-example.json index 44038cd..cfa2466 100644 --- a/config-inline-example.json +++ b/config-inline-example.json @@ -7,13 +7,16 @@ "auth": { "kind": "oauth", "issuer": "https://mcp.example.com", + "idp": { + "provider": "google", + "clientId": {"$env": "GOOGLE_CLIENT_ID"}, + "clientSecret": {"$env": "GOOGLE_CLIENT_SECRET"}, + "redirectUri": "https://mcp.example.com/oauth/callback" + }, "allowedDomains": ["example.com"], "allowedOrigins": ["https://claude.ai"], "tokenTtl": "1h", "storage": "memory", - "googleClientId": {"$env": "GOOGLE_CLIENT_ID"}, - "googleClientSecret": {"$env": "GOOGLE_CLIENT_SECRET"}, - "googleRedirectUri": "https://mcp.example.com/oauth/callback", "jwtSecret": {"$env": "JWT_SECRET"}, "encryptionKey": {"$env": "ENCRYPTION_KEY"} } diff --git a/config-inline-test.json b/config-inline-test.json index 1d005ba..9f7ac5a 100644 --- a/config-inline-test.json +++ b/config-inline-test.json @@ -7,13 +7,16 @@ "auth": { "kind": "oauth", "issuer": "http://localhost:8080", + "idp": { + "provider": "google", + "clientId": {"$env": "GOOGLE_CLIENT_ID"}, + "clientSecret": {"$env": "GOOGLE_CLIENT_SECRET"}, + "redirectUri": "http://localhost:8080/oauth/callback" + }, "allowedDomains": ["gmail.com"], "allowedOrigins": ["https://claude.ai"], "tokenTtl": "1h", "storage": "memory", - "googleClientId": {"$env": "GOOGLE_CLIENT_ID"}, - "googleClientSecret": {"$env": "GOOGLE_CLIENT_SECRET"}, - "googleRedirectUri": "http://localhost:8080/oauth/callback", "jwtSecret": {"$env": "JWT_SECRET"}, "encryptionKey": {"$env": "ENCRYPTION_KEY"} } diff --git a/config-oauth-firestore.example.json b/config-oauth-firestore.example.json index 72697a6..aaa7d4d 100644 --- a/config-oauth-firestore.example.json +++ b/config-oauth-firestore.example.json @@ -8,13 +8,16 @@ "kind": "oauth", "issuer": {"$env": "OAUTH_ISSUER"}, "gcpProject": {"$env": "GCP_PROJECT"}, + "idp": { + "provider": "google", + "clientId": {"$env": "GOOGLE_CLIENT_ID"}, + "clientSecret": {"$env": "GOOGLE_CLIENT_SECRET"}, + "redirectUri": {"$env": "GOOGLE_REDIRECT_URI"} + }, "allowedDomains": ["yourcompany.com", "contractors.yourcompany.com"], "allowedOrigins": ["https://claude.ai", "https://yourcompany.com"], "tokenTtl": "24h", "storage": "firestore", - "googleClientId": {"$env": "GOOGLE_CLIENT_ID"}, - "googleClientSecret": {"$env": "GOOGLE_CLIENT_SECRET"}, - "googleRedirectUri": {"$env": "GOOGLE_REDIRECT_URI"}, "jwtSecret": {"$env": "JWT_SECRET"}, "encryptionKey": {"$env": "ENCRYPTION_KEY"} } @@ -25,11 +28,20 @@ "command": "docker", "args": [ "run", "--rm", "-i", "--network", "host", - "mcp/postgres", - {"$env": "DATABASE_URL"} + "-e", "POSTGRES_HOST", + "-e", "POSTGRES_PORT", + "-e", "POSTGRES_DATABASE", + "-e", "POSTGRES_USER", + "-e", "POSTGRES_PASSWORD", + "us-central1-docker.pkg.dev/database-toolbox/toolbox/toolbox:latest", + "--stdio", "--prebuilt", "postgres" ], "env": { - "PGPASSWORD": {"$env": "POSTGRES_PASSWORD"} + "POSTGRES_HOST": {"$env": "POSTGRES_HOST"}, + "POSTGRES_PORT": {"$env": "POSTGRES_PORT"}, + "POSTGRES_DATABASE": {"$env": "POSTGRES_DATABASE"}, + "POSTGRES_USER": {"$env": "POSTGRES_USER"}, + "POSTGRES_PASSWORD": {"$env": "POSTGRES_PASSWORD"} } }, "notion": { diff --git a/config-oauth.example.json b/config-oauth.example.json index 8b2b1ff..80b522d 100644 --- a/config-oauth.example.json +++ b/config-oauth.example.json @@ -8,13 +8,16 @@ "kind": "oauth", "issuer": {"$env": "OAUTH_ISSUER"}, "gcpProject": {"$env": "GCP_PROJECT"}, + "idp": { + "provider": "google", + "clientId": {"$env": "GOOGLE_CLIENT_ID"}, + "clientSecret": {"$env": "GOOGLE_CLIENT_SECRET"}, + "redirectUri": {"$env": "GOOGLE_REDIRECT_URI"} + }, "allowedDomains": ["yourcompany.com", "contractors.yourcompany.com"], "allowedOrigins": ["https://claude.ai", "https://yourcompany.com"], "tokenTtl": "24h", "storage": "memory", - "googleClientId": {"$env": "GOOGLE_CLIENT_ID"}, - "googleClientSecret": {"$env": "GOOGLE_CLIENT_SECRET"}, - "googleRedirectUri": {"$env": "GOOGLE_REDIRECT_URI"}, "jwtSecret": {"$env": "JWT_SECRET"}, "encryptionKey": {"$env": "ENCRYPTION_KEY"} } @@ -25,11 +28,20 @@ "command": "docker", "args": [ "run", "--rm", "-i", "--network", "host", - "mcp/postgres", - {"$env": "DATABASE_URL"} + "-e", "POSTGRES_HOST", + "-e", "POSTGRES_PORT", + "-e", "POSTGRES_DATABASE", + "-e", "POSTGRES_USER", + "-e", "POSTGRES_PASSWORD", + "us-central1-docker.pkg.dev/database-toolbox/toolbox/toolbox:latest", + "--stdio", "--prebuilt", "postgres" ], "env": { - "PGPASSWORD": {"$env": "POSTGRES_PASSWORD"} + "POSTGRES_HOST": {"$env": "POSTGRES_HOST"}, + "POSTGRES_PORT": {"$env": "POSTGRES_PORT"}, + "POSTGRES_DATABASE": {"$env": "POSTGRES_DATABASE"}, + "POSTGRES_USER": {"$env": "POSTGRES_USER"}, + "POSTGRES_PASSWORD": {"$env": "POSTGRES_PASSWORD"} } }, "notion": { diff --git a/config-oauth.json b/config-oauth.json index dec7a3d..a49ba37 100644 --- a/config-oauth.json +++ b/config-oauth.json @@ -8,13 +8,16 @@ "kind": "oauth", "issuer": "https://mcp-internal.yourcompany.org", "gcpProject": {"$env": "GCP_PROJECT"}, + "idp": { + "provider": "google", + "clientId": {"$env": "GOOGLE_CLIENT_ID"}, + "clientSecret": {"$env": "GOOGLE_CLIENT_SECRET"}, + "redirectUri": "https://mcp-internal.yourcompany.org/oauth/callback" + }, "allowedDomains": ["yourcompany.com"], "allowedOrigins": ["https://claude.ai"], "tokenTtl": "24h", "storage": "memory", - "googleClientId": {"$env": "GOOGLE_CLIENT_ID"}, - "googleClientSecret": {"$env": "GOOGLE_CLIENT_SECRET"}, - "googleRedirectUri": "https://mcp-internal.yourcompany.org/oauth/callback", "jwtSecret": {"$env": "JWT_SECRET"}, "encryptionKey": {"$env": "ENCRYPTION_KEY"} } @@ -22,12 +25,24 @@ "mcpServers": { "postgres": { "transportType": "stdio", - "command": "docker", + "command": "docker", "args": [ - "run", "--rm", "-i", - "mcp/postgres:latest", - "postgresql://user:password@localhost:5432/database" - ] + "run", "--rm", "-i", "--network", "host", + "-e", "POSTGRES_HOST", + "-e", "POSTGRES_PORT", + "-e", "POSTGRES_DATABASE", + "-e", "POSTGRES_USER", + "-e", "POSTGRES_PASSWORD", + "us-central1-docker.pkg.dev/database-toolbox/toolbox/toolbox:latest", + "--stdio", "--prebuilt", "postgres" + ], + "env": { + "POSTGRES_HOST": {"$env": "POSTGRES_HOST"}, + "POSTGRES_PORT": {"$env": "POSTGRES_PORT"}, + "POSTGRES_DATABASE": {"$env": "POSTGRES_DATABASE"}, + "POSTGRES_USER": {"$env": "POSTGRES_USER"}, + "POSTGRES_PASSWORD": {"$env": "POSTGRES_PASSWORD"} + } }, "notion": { "transportType": "stdio", diff --git a/config-token.example.json b/config-token.example.json index b13702e..fc07d09 100644 --- a/config-token.example.json +++ b/config-token.example.json @@ -11,8 +11,13 @@ "command": "docker", "args": [ "run", "--rm", "-i", "--network", "host", - "mcp/postgres", - "postgresql://testuser:testpass@localhost:5432/testdb" + "-e", "POSTGRES_HOST=localhost", + "-e", "POSTGRES_PORT=5432", + "-e", "POSTGRES_DATABASE=testdb", + "-e", "POSTGRES_USER=testuser", + "-e", "POSTGRES_PASSWORD=testpass", + "us-central1-docker.pkg.dev/database-toolbox/toolbox/toolbox:latest", + "--stdio", "--prebuilt", "postgres" ], "serviceAuths": [ { diff --git a/config-user-tokens-example.json b/config-user-tokens-example.json index bd93f36..8abab60 100644 --- a/config-user-tokens-example.json +++ b/config-user-tokens-example.json @@ -8,13 +8,16 @@ "kind": "oauth", "issuer": {"$env": "OAUTH_ISSUER"}, "gcpProject": {"$env": "GCP_PROJECT"}, + "idp": { + "provider": "google", + "clientId": {"$env": "GOOGLE_CLIENT_ID"}, + "clientSecret": {"$env": "GOOGLE_CLIENT_SECRET"}, + "redirectUri": {"$env": "GOOGLE_REDIRECT_URI"} + }, "allowedDomains": ["yourcompany.com"], "allowedOrigins": ["https://claude.ai"], "tokenTtl": "24h", "storage": "memory", - "googleClientId": {"$env": "GOOGLE_CLIENT_ID"}, - "googleClientSecret": {"$env": "GOOGLE_CLIENT_SECRET"}, - "googleRedirectUri": {"$env": "GOOGLE_REDIRECT_URI"}, "jwtSecret": {"$env": "JWT_SECRET"}, "encryptionKey": {"$env": "ENCRYPTION_KEY"} } @@ -57,11 +60,20 @@ "command": "docker", "args": [ "run", "--rm", "-i", "--network", "host", - "mcp/postgres", - {"$env": "DATABASE_URL"} + "-e", "POSTGRES_HOST", + "-e", "POSTGRES_PORT", + "-e", "POSTGRES_DATABASE", + "-e", "POSTGRES_USER", + "-e", "POSTGRES_PASSWORD", + "us-central1-docker.pkg.dev/database-toolbox/toolbox/toolbox:latest", + "--stdio", "--prebuilt", "postgres" ], "env": { - "PGPASSWORD": {"$env": "POSTGRES_PASSWORD"} + "POSTGRES_HOST": {"$env": "POSTGRES_HOST"}, + "POSTGRES_PORT": {"$env": "POSTGRES_PORT"}, + "POSTGRES_DATABASE": {"$env": "POSTGRES_DATABASE"}, + "POSTGRES_USER": {"$env": "POSTGRES_USER"}, + "POSTGRES_PASSWORD": {"$env": "POSTGRES_PASSWORD"} } } } diff --git a/docs-site/src/content/docs/configuration.md b/docs-site/src/content/docs/configuration.md index 1351707..7a2f908 100644 --- a/docs-site/src/content/docs/configuration.md +++ b/docs-site/src/content/docs/configuration.md @@ -49,13 +49,16 @@ For production, use OAuth with Google. Claude redirects users to Google for auth "auth": { "kind": "oauth", "issuer": "https://mcp.company.com", + "idp": { + "provider": "google", + "clientId": { "$env": "GOOGLE_CLIENT_ID" }, + "clientSecret": { "$env": "GOOGLE_CLIENT_SECRET" }, + "redirectUri": "https://mcp.company.com/oauth/callback" + }, "allowedDomains": ["company.com"], "allowedOrigins": ["https://claude.ai"], "tokenTtl": "24h", "storage": "memory", - "googleClientId": { "$env": "GOOGLE_CLIENT_ID" }, - "googleClientSecret": { "$env": "GOOGLE_CLIENT_SECRET" }, - "googleRedirectUri": "https://mcp.company.com/oauth/callback", "jwtSecret": { "$env": "JWT_SECRET" }, "encryptionKey": { "$env": "ENCRYPTION_KEY" } } @@ -66,7 +69,7 @@ The `issuer` should match your `baseURL`. `allowedDomains` restricts access to s `tokenTtl` controls how long JWT tokens are valid. Shorter times are more secure but require more frequent logins. -Security requirements: `googleClientSecret`, `jwtSecret`, and `encryptionKey` must be environment variables. The JWT secret must be at least 32 bytes. The encryption key must be exactly 32 bytes. +Security requirements: `idp.clientSecret`, `jwtSecret`, and `encryptionKey` must be environment variables. The JWT secret must be at least 32 bytes. The encryption key must be exactly 32 bytes. For production, set `storage` to "firestore" and add `gcpProject`, `firestoreDatabase`, and `firestoreCollection` fields. @@ -345,13 +348,16 @@ Set `MCP_FRONT_ENV=development` when testing OAuth locally. It allows http:// UR "auth": { "kind": "oauth", "issuer": "https://mcp.company.com", + "idp": { + "provider": "google", + "clientId": { "$env": "GOOGLE_CLIENT_ID" }, + "clientSecret": { "$env": "GOOGLE_CLIENT_SECRET" }, + "redirectUri": "https://mcp.company.com/oauth/callback" + }, "allowedDomains": ["company.com"], "allowedOrigins": ["https://claude.ai"], "tokenTtl": "4h", "storage": "firestore", - "googleClientId": { "$env": "GOOGLE_CLIENT_ID" }, - "googleClientSecret": { "$env": "GOOGLE_CLIENT_SECRET" }, - "googleRedirectUri": "https://mcp.company.com/oauth/callback", "jwtSecret": { "$env": "JWT_SECRET" }, "encryptionKey": { "$env": "ENCRYPTION_KEY" }, "gcpProject": { "$env": "GOOGLE_CLOUD_PROJECT" }, diff --git a/docs-site/src/content/docs/examples/oauth-google.md b/docs-site/src/content/docs/examples/oauth-google.md index d4598eb..d35aea1 100644 --- a/docs-site/src/content/docs/examples/oauth-google.md +++ b/docs-site/src/content/docs/examples/oauth-google.md @@ -39,15 +39,26 @@ Create `config.json`: ```json { - "version": "1.0", + "version": "v0.0.1-DEV_EDITION_EXPECT_CHANGES", "proxy": { "name": "Company MCP Proxy", - "baseUrl": "https://mcp.company.com", + "baseURL": "https://mcp.company.com", "addr": ":8080", "auth": { "kind": "oauth", "issuer": "https://mcp.company.com", - "allowedDomains": ["company.com"] + "idp": { + "provider": "google", + "clientId": { "$env": "GOOGLE_CLIENT_ID" }, + "clientSecret": { "$env": "GOOGLE_CLIENT_SECRET" }, + "redirectUri": "https://mcp.company.com/oauth/callback" + }, + "allowedDomains": ["company.com"], + "allowedOrigins": ["https://claude.ai"], + "tokenTtl": "24h", + "storage": "memory", + "jwtSecret": { "$env": "JWT_SECRET" }, + "encryptionKey": { "$env": "ENCRYPTION_KEY" } } }, "mcpServers": { @@ -80,6 +91,7 @@ Create `config.json`: export GOOGLE_CLIENT_ID="your-client-id.apps.googleusercontent.com" export GOOGLE_CLIENT_SECRET="your-client-secret" export JWT_SECRET=$(openssl rand -base64 32) +export ENCRYPTION_KEY=$(openssl rand -base64 32) ``` ## 4. Run with Docker @@ -89,6 +101,7 @@ docker run -p 8080:8080 \ -e GOOGLE_CLIENT_ID \ -e GOOGLE_CLIENT_SECRET \ -e JWT_SECRET \ + -e ENCRYPTION_KEY \ -v $(pwd)/config.json:/config.json \ ghcr.io/dgellow/mcp-front:latest ``` diff --git a/docs-site/src/content/docs/index.mdx b/docs-site/src/content/docs/index.mdx index c7275ee..f1d79fb 100644 --- a/docs-site/src/content/docs/index.mdx +++ b/docs-site/src/content/docs/index.mdx @@ -41,9 +41,13 @@ Claude redirects users to Google for authentication, and MCP Front validates the "auth": { "kind": "oauth", "issuer": "https://mcp.company.com", + "idp": { + "provider": "google", + "clientId": { "$env": "GOOGLE_CLIENT_ID" }, + "clientSecret": { "$env": "GOOGLE_CLIENT_SECRET" }, + "redirectUri": "https://mcp.company.com/oauth/callback" + }, "allowedDomains": ["company.com"], - "googleClientId": { "$env": "GOOGLE_CLIENT_ID" }, - "googleClientSecret": { "$env": "GOOGLE_CLIENT_SECRET" }, "jwtSecret": { "$env": "JWT_SECRET" }, "encryptionKey": { "$env": "ENCRYPTION_KEY" } } diff --git a/integration/base_path_test.go b/integration/base_path_test.go index 88b8e13..ca0ec3c 100644 --- a/integration/base_path_test.go +++ b/integration/base_path_test.go @@ -11,7 +11,12 @@ import ( func TestBasePathRouting(t *testing.T) { waitForDB(t) - startMCPFront(t, "config/config.base-path-test.json") + cfg := buildTestConfig( + "http://localhost:8080/mcp-api", "mcp-front-base-path-test", + nil, + map[string]any{"postgres": testPostgresServer(withBearerTokens("test-token"))}, + ) + startMCPFront(t, writeTestConfig(t, cfg)) waitForMCPFront(t) initialContainers := getMCPContainers() diff --git a/integration/basic_auth_test.go b/integration/basic_auth_test.go index 1f31a91..4368f1a 100644 --- a/integration/basic_auth_test.go +++ b/integration/basic_auth_test.go @@ -10,8 +10,11 @@ import ( ) func TestBasicAuth(t *testing.T) { - // Start mcp-front with basic auth config - startMCPFront(t, "config/config.basic-auth-test.json", + cfg := buildTestConfig("http://localhost:8080", "mcp-front-basic-auth-test", + nil, + map[string]any{"postgres": testPostgresServer(withBasicAuth("admin", "ADMIN_PASSWORD"), withBasicAuth("user", "USER_PASSWORD"))}, + ) + startMCPFront(t, writeTestConfig(t, cfg), "ADMIN_PASSWORD=adminpass123", "USER_PASSWORD=userpass456", ) diff --git a/integration/config/config.base-path-test.json b/integration/config/config.base-path-test.json deleted file mode 100644 index 5b60230..0000000 --- a/integration/config/config.base-path-test.json +++ /dev/null @@ -1,28 +0,0 @@ -{ - "version": "v0.0.1-DEV_EDITION_EXPECT_CHANGES", - "proxy": { - "baseURL": "http://localhost:8080/mcp-api", - "addr": ":8080", - "name": "mcp-front-base-path-test" - }, - "mcpServers": { - "postgres": { - "transportType": "stdio", - "command": "docker", - "args": [ - "run", - "-i", - "--network", - "host", - "mcp/postgres", - "postgresql://testuser:testpass@localhost:15432/testdb" - ], - "serviceAuths": [ - { - "type": "bearer", - "tokens": ["test-token"] - } - ] - } - } -} diff --git a/integration/config/config.basic-auth-test.json b/integration/config/config.basic-auth-test.json deleted file mode 100644 index 8ef2e53..0000000 --- a/integration/config/config.basic-auth-test.json +++ /dev/null @@ -1,34 +0,0 @@ -{ - "version": "v0.0.1-DEV_EDITION_EXPECT_CHANGES", - "proxy": { - "baseURL": "http://localhost:8080", - "addr": ":8080", - "name": "mcp-front-basic-auth-test" - }, - "mcpServers": { - "postgres": { - "transportType": "stdio", - "command": "docker", - "args": [ - "run", - "-i", - "--network", - "host", - "mcp/postgres", - "postgresql://testuser:testpass@localhost:15432/testdb" - ], - "serviceAuths": [ - { - "type": "basic", - "username": "admin", - "password": {"$env": "ADMIN_PASSWORD"} - }, - { - "type": "basic", - "username": "user", - "password": {"$env": "USER_PASSWORD"} - } - ] - } - } -} \ No newline at end of file diff --git a/integration/config/config.demo-token.json b/integration/config/config.demo-token.json index a9b07b5..7610c11 100644 --- a/integration/config/config.demo-token.json +++ b/integration/config/config.demo-token.json @@ -11,8 +11,13 @@ "command": "docker", "args": [ "run", "--rm", "-i", "--network", "host", - "mcp/postgres", - "postgresql://testuser:testpass@localhost:15432/testdb" + "-e", "POSTGRES_HOST=localhost", + "-e", "POSTGRES_PORT=15432", + "-e", "POSTGRES_DATABASE=testdb", + "-e", "POSTGRES_USER=testuser", + "-e", "POSTGRES_PASSWORD=testpass", + "us-central1-docker.pkg.dev/database-toolbox/toolbox/toolbox:latest", + "--stdio", "--prebuilt", "postgres" ], "options": { "logEnabled": true diff --git a/integration/config/config.oauth-integration-test.json b/integration/config/config.oauth-integration-test.json deleted file mode 100644 index 0a40a0f..0000000 --- a/integration/config/config.oauth-integration-test.json +++ /dev/null @@ -1,41 +0,0 @@ -{ - "version": "v0.0.1-DEV_EDITION_EXPECT_CHANGES", - "proxy": { - "baseURL": "http://localhost:8080", - "addr": ":8080", - "name": "mcp-front-oauth-test", - "auth": { - "kind": "oauth", - "issuer": "http://localhost:8080", - "gcpProject": "test-project", - "allowedDomains": [ - "test.com" - ], - "allowedOrigins": [ - "https://claude.ai" - ], - "tokenTtl": "1h", - "storage": "memory", - "googleClientId": "test-client-id", - "googleClientSecret": "test-client-secret-for-integration-testing", - "googleRedirectUri": "http://localhost:8080/oauth/callback", - "jwtSecret": "test-jwt-secret-for-integration-testing-32-chars-long", - "encryptionKey": "test-encryption-key-32-bytes-aes" - } - }, - "mcpServers": { - "postgres": { - "transportType": "stdio", - "command": "docker", - "args": [ - "run", - "--rm", - "-i", - "--network", - "host", - "mcp/postgres", - "postgresql://testuser:testpass@localhost:15432/testdb" - ] - } - } -} \ No newline at end of file diff --git a/integration/config/config.oauth-rfc8707-test.json b/integration/config/config.oauth-rfc8707-test.json index f169a64..75d1162 100644 --- a/integration/config/config.oauth-rfc8707-test.json +++ b/integration/config/config.oauth-rfc8707-test.json @@ -8,13 +8,19 @@ "kind": "oauth", "issuer": "http://localhost:8080", "gcpProject": "test-project", + "idp": { + "provider": "google", + "clientId": {"$env": "GOOGLE_CLIENT_ID"}, + "clientSecret": {"$env": "GOOGLE_CLIENT_SECRET"}, + "redirectUri": "http://localhost:8080/oauth/callback", + "authorizationUrl": "http://localhost:9090/auth", + "tokenUrl": "http://localhost:9090/token", + "userInfoUrl": "http://localhost:9090/userinfo" + }, "allowedDomains": ["test.com"], "allowedOrigins": ["https://claude.ai"], "tokenTtl": "1h", "storage": "memory", - "googleClientId": {"$env": "GOOGLE_CLIENT_ID"}, - "googleClientSecret": {"$env": "GOOGLE_CLIENT_SECRET"}, - "googleRedirectUri": "http://localhost:8080/oauth/callback", "jwtSecret": {"$env": "JWT_SECRET"}, "encryptionKey": {"$env": "ENCRYPTION_KEY"} } diff --git a/integration/config/config.oauth-service-integration-test.json b/integration/config/config.oauth-service-integration-test.json index e46b23c..e1c4e61 100644 --- a/integration/config/config.oauth-service-integration-test.json +++ b/integration/config/config.oauth-service-integration-test.json @@ -8,13 +8,19 @@ "kind": "oauth", "issuer": "http://localhost:8080", "gcpProject": "test-project", + "idp": { + "provider": "google", + "clientId": {"$env": "GOOGLE_CLIENT_ID"}, + "clientSecret": {"$env": "GOOGLE_CLIENT_SECRET"}, + "redirectUri": "http://localhost:8080/oauth/callback", + "authorizationUrl": "http://localhost:9090/auth", + "tokenUrl": "http://localhost:9090/token", + "userInfoUrl": "http://localhost:9090/userinfo" + }, "allowedDomains": ["test.com"], "allowedOrigins": ["https://claude.ai"], "tokenTtl": "1h", "storage": "memory", - "googleClientId": {"$env": "GOOGLE_CLIENT_ID"}, - "googleClientSecret": {"$env": "GOOGLE_CLIENT_SECRET"}, - "googleRedirectUri": "http://localhost:8080/oauth/callback", "jwtSecret": {"$env": "JWT_SECRET"}, "encryptionKey": {"$env": "ENCRYPTION_KEY"} } diff --git a/integration/config/config.oauth-service-test.json b/integration/config/config.oauth-service-test.json deleted file mode 100644 index 9d2f6d4..0000000 --- a/integration/config/config.oauth-service-test.json +++ /dev/null @@ -1,47 +0,0 @@ -{ - "version": "v0.0.1-DEV_EDITION_EXPECT_CHANGES", - "proxy": { - "baseURL": "http://localhost:8080", - "addr": ":8080", - "name": "mcp-front-oauth-usertoken-test", - "auth": { - "kind": "oauth", - "issuer": "http://localhost:8080", - "gcpProject": "test-project", - "allowedDomains": ["test.com"], - "allowedOrigins": ["https://claude.ai"], - "tokenTtl": "1h", - "storage": "memory", - "googleClientId": {"$env": "GOOGLE_CLIENT_ID"}, - "googleClientSecret": {"$env": "GOOGLE_CLIENT_SECRET"}, - "googleRedirectUri": "http://localhost:8080/oauth/callback", - "jwtSecret": {"$env": "JWT_SECRET"}, - "encryptionKey": {"$env": "ENCRYPTION_KEY"} - } - }, - "mcpServers": { - "postgres": { - "transportType": "stdio", - "command": "docker", - "args": [ - "run", - "--rm", - "-i", - "--network", - "host", - "mcp/postgres", - "postgresql://testuser:testpass@localhost:15432/testdb" - ], - "env": { - "USER_TOKEN": {"$userToken": "{{token}}"} - }, - "requiresUserToken": true, - "userAuthentication": { - "type": "manual", - "displayName": "Test Service", - "instructions": "Enter your test token", - "helpUrl": "https://example.com/help" - } - } - } -} \ No newline at end of file diff --git a/integration/config/config.oauth-test.json b/integration/config/config.oauth-test.json deleted file mode 100644 index d08faca..0000000 --- a/integration/config/config.oauth-test.json +++ /dev/null @@ -1,33 +0,0 @@ -{ - "version": "v0.0.1-DEV_EDITION_EXPECT_CHANGES", - "proxy": { - "baseURL": "http://localhost:8080", - "addr": ":8080", - "name": "mcp-front-oauth-test", - "auth": { - "kind": "oauth", - "issuer": "http://localhost:8080", - "gcpProject": "test-project", - "allowedDomains": ["test.com", "stainless.com", "claude.ai"], - "allowedOrigins": ["https://claude.ai"], - "tokenTtl": "1h", - "storage": "memory", - "googleClientId": {"$env": "GOOGLE_CLIENT_ID"}, - "googleClientSecret": {"$env": "GOOGLE_CLIENT_SECRET"}, - "googleRedirectUri": "http://localhost:8080/oauth/callback", - "jwtSecret": {"$env": "JWT_SECRET"}, - "encryptionKey": {"$env": "ENCRYPTION_KEY"} - } - }, - "mcpServers": { - "postgres": { - "transportType": "stdio", - "command": "docker", - "args": [ - "run", "--rm", "-i", "--network", "host", - "mcp/postgres", - "postgresql://testuser:testpass@localhost:15432/testdb" - ] - } - } -} \ No newline at end of file diff --git a/integration/config/config.oauth-token-test.json b/integration/config/config.oauth-token-test.json index 5d39ce4..f3874b7 100644 --- a/integration/config/config.oauth-token-test.json +++ b/integration/config/config.oauth-token-test.json @@ -8,13 +8,19 @@ "kind": "oauth", "issuer": "http://localhost:8080", "gcpProject": "test-project", + "idp": { + "provider": "google", + "clientId": {"$env": "GOOGLE_CLIENT_ID"}, + "clientSecret": {"$env": "GOOGLE_CLIENT_SECRET"}, + "redirectUri": "http://localhost:8080/oauth/callback", + "authorizationUrl": "http://localhost:9090/auth", + "tokenUrl": "http://localhost:9090/token", + "userInfoUrl": "http://localhost:9090/userinfo" + }, "allowedDomains": ["test.com"], "allowedOrigins": ["https://claude.ai"], "tokenTtl": "1h", "storage": "memory", - "googleClientId": {"$env": "GOOGLE_CLIENT_ID"}, - "googleClientSecret": {"$env": "GOOGLE_CLIENT_SECRET"}, - "googleRedirectUri": "http://localhost:8080/oauth/callback", "jwtSecret": {"$env": "JWT_SECRET"}, "encryptionKey": {"$env": "ENCRYPTION_KEY"} } diff --git a/integration/config/config.oauth-usertoken-tools-test.json b/integration/config/config.oauth-usertoken-tools-test.json deleted file mode 100644 index 9d2f6d4..0000000 --- a/integration/config/config.oauth-usertoken-tools-test.json +++ /dev/null @@ -1,47 +0,0 @@ -{ - "version": "v0.0.1-DEV_EDITION_EXPECT_CHANGES", - "proxy": { - "baseURL": "http://localhost:8080", - "addr": ":8080", - "name": "mcp-front-oauth-usertoken-test", - "auth": { - "kind": "oauth", - "issuer": "http://localhost:8080", - "gcpProject": "test-project", - "allowedDomains": ["test.com"], - "allowedOrigins": ["https://claude.ai"], - "tokenTtl": "1h", - "storage": "memory", - "googleClientId": {"$env": "GOOGLE_CLIENT_ID"}, - "googleClientSecret": {"$env": "GOOGLE_CLIENT_SECRET"}, - "googleRedirectUri": "http://localhost:8080/oauth/callback", - "jwtSecret": {"$env": "JWT_SECRET"}, - "encryptionKey": {"$env": "ENCRYPTION_KEY"} - } - }, - "mcpServers": { - "postgres": { - "transportType": "stdio", - "command": "docker", - "args": [ - "run", - "--rm", - "-i", - "--network", - "host", - "mcp/postgres", - "postgresql://testuser:testpass@localhost:15432/testdb" - ], - "env": { - "USER_TOKEN": {"$userToken": "{{token}}"} - }, - "requiresUserToken": true, - "userAuthentication": { - "type": "manual", - "displayName": "Test Service", - "instructions": "Enter your test token", - "helpUrl": "https://example.com/help" - } - } - } -} \ No newline at end of file diff --git a/integration/config/config.test.json b/integration/config/config.test.json deleted file mode 100644 index 773d78c..0000000 --- a/integration/config/config.test.json +++ /dev/null @@ -1,31 +0,0 @@ -{ - "version": "v0.0.1-DEV_EDITION_EXPECT_CHANGES", - "proxy": { - "baseURL": "http://localhost:8080", - "addr": ":8080", - "name": "mcp-front-test" - }, - "mcpServers": { - "postgres": { - "transportType": "stdio", - "command": "docker", - "args": [ - "run", - "-i", - "--network", - "host", - "mcp/postgres", - "postgresql://testuser:testpass@localhost:15432/testdb" - ], - "options": { - "logEnabled": true - }, - "serviceAuths": [ - { - "type": "bearer", - "tokens": ["test-token", "alt-test-token"] - } - ] - } - } -} \ No newline at end of file diff --git a/integration/integration_test.go b/integration/integration_test.go index 72ff8b4..2bd6473 100644 --- a/integration/integration_test.go +++ b/integration/integration_test.go @@ -15,11 +15,11 @@ func TestIntegration(t *testing.T) { waitForDB(t) trace(t, "Starting mcp-front") - startMCPFront(t, "config/config.test.json", - "GOOGLE_OAUTH_AUTH_URL=http://localhost:9090/auth", - "GOOGLE_OAUTH_TOKEN_URL=http://localhost:9090/token", - "GOOGLE_USERINFO_URL=http://localhost:9090/userinfo", + cfg := buildTestConfig("http://localhost:8080", "mcp-front-test", + nil, + map[string]any{"postgres": testPostgresServer(withBearerTokens("test-token", "alt-test-token"), withLogEnabled())}, ) + startMCPFront(t, writeTestConfig(t, cfg)) waitForMCPFront(t) trace(t, "mcp-front is ready") @@ -49,7 +49,7 @@ func TestIntegration(t *testing.T) { t.Log("Connected to MCP server with session") queryParams := map[string]any{ - "name": "query", + "name": "execute_sql", "arguments": map[string]any{ "sql": "SELECT COUNT(*) as user_count FROM users", }, @@ -75,23 +75,31 @@ func TestIntegration(t *testing.T) { assert.NotEmpty(t, content, "Query result missing content") t.Log("Query executed successfully") - // Test resources list - resourcesResult, err := client.SendMCPRequest("resources/list", map[string]any{}) - require.NoError(t, err, "Failed to list resources") + // Test tools list + toolsResult, err := client.SendMCPRequest("tools/list", map[string]any{}) + require.NoError(t, err, "Failed to list tools") - t.Logf("Resources response: %+v", resourcesResult) + t.Logf("Tools response: %+v", toolsResult) - // Check for error in resources response - errorMap, hasError = resourcesResult["error"].(map[string]any) - assert.False(t, hasError, "Resources list returned error: %v", errorMap) + errorMap, hasError = toolsResult["error"].(map[string]any) + assert.False(t, hasError, "Tools list returned error: %v", errorMap) - // Verify we got resources - resultMap, ok = resourcesResult["result"].(map[string]any) - require.True(t, ok, "Expected result in resources response") + resultMap, ok = toolsResult["result"].(map[string]any) + require.True(t, ok, "Expected result in tools response") - resources, ok := resultMap["resources"].([]any) - require.True(t, ok, "Expected resources array in result") - assert.NotEmpty(t, resources, "Expected at least one resource") - t.Logf("Found %d resources", len(resources)) + tools, ok := resultMap["tools"].([]any) + require.True(t, ok, "Expected tools array in result") + assert.NotEmpty(t, tools, "Expected at least one tool") + t.Logf("Found %d tools", len(tools)) + // Verify execute_sql tool is present + var toolNames []string + for _, tool := range tools { + if toolMap, ok := tool.(map[string]any); ok { + if name, ok := toolMap["name"].(string); ok { + toolNames = append(toolNames, name) + } + } + } + assert.Contains(t, toolNames, "execute_sql", "Should have execute_sql tool") } diff --git a/integration/isolation_test.go b/integration/isolation_test.go index 41d5d0f..1036e91 100644 --- a/integration/isolation_test.go +++ b/integration/isolation_test.go @@ -18,12 +18,16 @@ func TestMultiUserSessionIsolation(t *testing.T) { // Start mcp-front with bearer token auth trace(t, "Starting mcp-front") - startMCPFront(t, "config/config.test.json") + cfg := buildTestConfig("http://localhost:8080", "mcp-front-test", + nil, + map[string]any{"postgres": testPostgresServer(withBearerTokens("test-token", "alt-test-token"), withLogEnabled())}, + ) + startMCPFront(t, writeTestConfig(t, cfg)) waitForMCPFront(t) // Get initial container count initialContainers := getMCPContainers() - t.Logf("Initial mcp/postgres containers: %d", len(initialContainers)) + t.Logf("Initial toolbox containers: %d", len(initialContainers)) // Create two clients with different auth tokens client1 := NewMCPSSEClient("http://localhost:8080") @@ -69,7 +73,7 @@ func TestMultiUserSessionIsolation(t *testing.T) { } query1Result, err := client1.SendMCPRequest("tools/call", map[string]any{ - "name": "query", + "name": "execute_sql", "arguments": map[string]any{ "sql": "SELECT 'user1-query1' as test_id, COUNT(*) as count FROM users", }, @@ -116,13 +120,13 @@ func TestMultiUserSessionIsolation(t *testing.T) { // Verify that client1 and client2 have different containers if client1Container != "" && client2Container != "" && client1Container == client2Container { t.Errorf("CRITICAL: Both users are using the same Docker container! Container ID: %s", client1Container) - t.Error("This indicates session isolation is NOT working - users are sharing the same mcp/postgres instance") + t.Error("This indicates session isolation is NOT working - users are sharing the same toolbox instance") } else if client1Container != "" && client2Container != "" { t.Logf("Confirmed different stdio processes: User1 container=%s, User2 container=%s", client1Container, client2Container) } query2Result, err := client2.SendMCPRequest("tools/call", map[string]any{ - "name": "query", + "name": "execute_sql", "arguments": map[string]any{ "sql": "SELECT 'user2-query1' as test_id, COUNT(*) as count FROM orders", }, @@ -135,7 +139,7 @@ func TestMultiUserSessionIsolation(t *testing.T) { // Step 3: First user sends another query t.Log("\nStep 3: First user sends another query") query3Result, err := client1.SendMCPRequest("tools/call", map[string]any{ - "name": "query", + "name": "execute_sql", "arguments": map[string]any{ "sql": "SELECT 'user1-query2' as test_id, current_timestamp as ts", }, @@ -148,7 +152,7 @@ func TestMultiUserSessionIsolation(t *testing.T) { // Step 4: First user sends another query t.Log("\nStep 4: First user sends another query") query4Result, err := client1.SendMCPRequest("tools/call", map[string]any{ - "name": "query", + "name": "execute_sql", "arguments": map[string]any{ "sql": "SELECT 'user1-query3' as test_id, version() as db_version", }, @@ -161,7 +165,7 @@ func TestMultiUserSessionIsolation(t *testing.T) { // Step 5: Second user sends a query t.Log("\nStep 5: Second user sends a query") query5Result, err := client2.SendMCPRequest("tools/call", map[string]any{ - "name": "query", + "name": "execute_sql", "arguments": map[string]any{ "sql": "SELECT 'user2-query2' as test_id, current_database() as db_name", }, @@ -208,12 +212,16 @@ func TestSessionCleanupAfterTimeout(t *testing.T) { // Start mcp-front with test timeout configuration trace(t, "Starting mcp-front with test session timeout") - startMCPFront(t, "config/config.test.json") + cfg := buildTestConfig("http://localhost:8080", "mcp-front-test", + nil, + map[string]any{"postgres": testPostgresServer(withBearerTokens("test-token", "alt-test-token"), withLogEnabled())}, + ) + startMCPFront(t, writeTestConfig(t, cfg)) waitForMCPFront(t) // Get initial container count initialContainers := getMCPContainers() - t.Logf("Initial mcp/postgres containers: %d", len(initialContainers)) + t.Logf("Initial toolbox containers: %d", len(initialContainers)) // Create a client and connect client := NewMCPSSEClient("http://localhost:8080") @@ -237,7 +245,7 @@ func TestSessionCleanupAfterTimeout(t *testing.T) { // Send a query to ensure session is active _, err = client.SendMCPRequest("tools/call", map[string]any{ - "name": "query", + "name": "execute_sql", "arguments": map[string]any{ "sql": "SELECT 'test' as test_id", }, @@ -277,12 +285,16 @@ func TestSessionTimerReset(t *testing.T) { // Start mcp-front with test timeout configuration trace(t, "Starting mcp-front with test session timeout") - startMCPFront(t, "config/config.test.json") + cfg := buildTestConfig("http://localhost:8080", "mcp-front-test", + nil, + map[string]any{"postgres": testPostgresServer(withBearerTokens("test-token", "alt-test-token"), withLogEnabled())}, + ) + startMCPFront(t, writeTestConfig(t, cfg)) waitForMCPFront(t) // Get initial container count initialContainers := getMCPContainers() - t.Logf("Initial mcp/postgres containers: %d", len(initialContainers)) + t.Logf("Initial toolbox containers: %d", len(initialContainers)) // Create a client and connect client := NewMCPSSEClient("http://localhost:8080") @@ -308,7 +320,7 @@ func TestSessionTimerReset(t *testing.T) { for i := range 3 { t.Logf("Sending keepalive query %d/3...", i+1) _, err := client.SendMCPRequest("tools/call", map[string]any{ - "name": "query", + "name": "execute_sql", "arguments": map[string]any{ "sql": "SELECT 'keepalive' as status, NOW() as timestamp", }, @@ -357,12 +369,16 @@ func TestMultiUserTimerIndependence(t *testing.T) { // Start mcp-front with test timeout configuration trace(t, "Starting mcp-front with test session timeout") - startMCPFront(t, "config/config.test.json") + cfg := buildTestConfig("http://localhost:8080", "mcp-front-test", + nil, + map[string]any{"postgres": testPostgresServer(withBearerTokens("test-token", "alt-test-token"), withLogEnabled())}, + ) + startMCPFront(t, writeTestConfig(t, cfg)) waitForMCPFront(t) // Get initial container count initialContainers := getMCPContainers() - t.Logf("Initial mcp/postgres containers: %d", len(initialContainers)) + t.Logf("Initial toolbox containers: %d", len(initialContainers)) // Create two clients client1 := NewMCPSSEClient("http://localhost:8080") @@ -411,7 +427,7 @@ func TestMultiUserTimerIndependence(t *testing.T) { for i := range 4 { time.Sleep(4 * time.Second) _, err := client2.SendMCPRequest("tools/call", map[string]any{ - "name": "query", + "name": "execute_sql", "arguments": map[string]any{ "sql": "SELECT 'client2-keepalive' as status", }, diff --git a/integration/main_test.go b/integration/main_test.go index 49eb876..39527d7 100644 --- a/integration/main_test.go +++ b/integration/main_test.go @@ -22,6 +22,15 @@ func TestMain(m *testing.M) { os.Exit(1) } + // Pull the toolbox image so the first test doesn't timeout on image pull + fmt.Println("Pulling toolbox image...") + pullCmd := exec.Command("docker", "pull", ToolboxImage) + pullCmd.Stdout = os.Stdout + pullCmd.Stderr = os.Stderr + if err := pullCmd.Run(); err != nil { + fmt.Printf("Warning: failed to pull toolbox image: %v\n", err) + } + // Set up local log file for mcp-front output logFile := "mcp-front-test.log" os.Setenv("MCP_LOG_FILE", logFile) @@ -53,7 +62,7 @@ func TestMain(m *testing.M) { os.Exit(exitCode) }() - // Start fake GCP server for OAuth + // Start fake GCP server for OAuth (port 9090) fakeGCP := NewFakeGCPServer("9090") err := fakeGCP.Start() if err != nil { @@ -65,6 +74,28 @@ func TestMain(m *testing.M) { _ = fakeGCP.Stop() }() + // Start fake GitHub server (port 9092) + fakeGitHub := NewFakeGitHubServer("9092", []string{"test-org", "another-org"}) + if err := fakeGitHub.Start(); err != nil { + fmt.Printf("Failed to start fake GitHub server: %v\n", err) + exitCode = 1 + return + } + defer func() { + _ = fakeGitHub.Stop() + }() + + // Start fake OIDC server (port 9093) — used for both OIDC and Azure tests + fakeOIDC := NewFakeOIDCServer("9093") + if err := fakeOIDC.Start(); err != nil { + fmt.Printf("Failed to start fake OIDC server: %v\n", err) + exitCode = 1 + return + } + defer func() { + _ = fakeOIDC.Stop() + }() + // Wait for database to be ready fmt.Println("Waiting for database to be ready...") for i := range 30 { // Wait up to 30 seconds diff --git a/integration/oauth_flow_test.go b/integration/oauth_flow_test.go new file mode 100644 index 0000000..c6e8217 --- /dev/null +++ b/integration/oauth_flow_test.go @@ -0,0 +1,789 @@ +package integration + +import ( + "bufio" + "bytes" + "encoding/json" + "fmt" + "io" + "net/http" + "net/url" + "os" + "os/exec" + "regexp" + "strings" + "testing" + "time" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +// TestBasicOAuthFlow tests the basic OAuth server functionality +func TestBasicOAuthFlow(t *testing.T) { + cfg := buildTestConfig("http://localhost:8080", "mcp-front-oauth-test", + testOAuthConfigFromEnv(), + map[string]any{"postgres": testPostgresServer()}, + ) + startMCPFront(t, writeTestConfig(t, cfg), + "JWT_SECRET=test-jwt-secret-32-bytes-exactly!", + "ENCRYPTION_KEY=test-encryption-key-32-bytes-ok!", + "GOOGLE_CLIENT_ID=test-client-id-for-oauth", + "GOOGLE_CLIENT_SECRET=test-client-secret-for-oauth", + "MCP_FRONT_ENV=development", + ) + + // Wait for startup + waitForMCPFront(t) + + // Test OAuth discovery + resp, err := http.Get("http://localhost:8080/.well-known/oauth-authorization-server") + require.NoError(t, err, "Failed to get OAuth discovery") + defer resp.Body.Close() + + assert.Equal(t, 200, resp.StatusCode, "OAuth discovery failed") + + var discovery map[string]any + err = json.NewDecoder(resp.Body).Decode(&discovery) + require.NoError(t, err, "Failed to decode discovery") + + // Verify required endpoints + requiredEndpoints := []string{ + "issuer", + "authorization_endpoint", + "token_endpoint", + "registration_endpoint", + } + + for _, endpoint := range requiredEndpoints { + _, ok := discovery[endpoint] + assert.True(t, ok, "Missing required endpoint: %s", endpoint) + } + + // Verify client_secret_post is advertised + authMethods, ok := discovery["token_endpoint_auth_methods_supported"].([]any) + assert.True(t, ok, "token_endpoint_auth_methods_supported should be present") + + var hasNone, hasClientSecretPost bool + for _, method := range authMethods { + if method == "none" { + hasNone = true + } + if method == "client_secret_post" { + hasClientSecretPost = true + } + } + assert.True(t, hasNone, "Should support 'none' auth method for public clients") + assert.True(t, hasClientSecretPost, "Should support 'client_secret_post' auth method for confidential clients") +} + +// TestJWTSecretValidation tests JWT secret length requirements +func TestJWTSecretValidation(t *testing.T) { + oauthCfg := buildTestConfig("http://localhost:8080", "mcp-front-oauth-test", + testOAuthConfigFromEnv(), + map[string]any{"postgres": testPostgresServer()}, + ) + configPath := writeTestConfig(t, oauthCfg) + + tests := []struct { + name string + secret string + shouldFail bool + }{ + {"Short 3-byte secret", "123", true}, + {"Short 16-byte secret", "sixteen-byte-key", true}, + {"Valid 32-byte secret", "demo-jwt-secret-32-bytes-exactly!", false}, + {"Long 64-byte secret", "demo-jwt-secret-32-bytes-exactly!demo-jwt-secret-32-bytes-exactly!", false}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + + mcpCmd := exec.Command("../cmd/mcp-front/mcp-front", "-config", configPath) + mcpCmd.Env = []string{ + "PATH=" + os.Getenv("PATH"), + "JWT_SECRET=" + tt.secret, + "ENCRYPTION_KEY=test-encryption-key-32-bytes-ok!", + "GOOGLE_CLIENT_ID=test-client-id", + "GOOGLE_CLIENT_SECRET=test-client-secret", + "MCP_FRONT_ENV=development", + } + + // Capture stderr + stderrPipe, _ := mcpCmd.StderrPipe() + scanner := bufio.NewScanner(stderrPipe) + + if err := mcpCmd.Start(); err != nil { + t.Fatalf("Failed to start mcp-front: %v", err) + } + + // Read stderr to check for errors + errorFound := false + go func() { + for scanner.Scan() { + line := scanner.Text() + if contains(line, "JWT secret must be at least") { + errorFound = true + } + } + }() + + // Give it time to start or fail + time.Sleep(2 * time.Second) + + // Check if it's running + healthy := checkHealth() + + // Clean up + if mcpCmd.Process != nil { + _ = mcpCmd.Process.Kill() + _ = mcpCmd.Wait() + } + + if tt.shouldFail { + assert.False(t, healthy && !errorFound, "Expected failure with short JWT secret but server started successfully") + } else { + assert.True(t, healthy, "Expected success with valid JWT secret but server failed to start") + } + }) + } +} + +// TestClientRegistration tests dynamic client registration (RFC 7591) +func TestClientRegistration(t *testing.T) { + // Start OAuth server + mcpCmd := startOAuthServer(t, map[string]string{ + "MCP_FRONT_ENV": "development", + }) + defer stopServer(mcpCmd) + + if !waitForHealthCheck(30) { + t.Fatal("OAuth server failed to start") + } + + t.Run("PublicClientRegistration", func(t *testing.T) { + // Register a public client (no secret) + clientReq := map[string]any{ + "redirect_uris": []string{"http://127.0.0.1:6274/oauth/callback/debug"}, + "scope": "read write", + } + + body, _ := json.Marshal(clientReq) + resp, err := http.Post( + "http://localhost:8080/register", + "application/json", + bytes.NewBuffer(body), + ) + if err != nil { + t.Fatalf("Failed to register client: %v", err) + } + defer resp.Body.Close() + + if resp.StatusCode != 201 { + body, _ := io.ReadAll(resp.Body) + t.Fatalf("Client registration failed with status %d: %s", resp.StatusCode, string(body)) + } + + var clientResp map[string]any + if err := json.NewDecoder(resp.Body).Decode(&clientResp); err != nil { + t.Fatalf("Failed to decode response: %v", err) + } + + // Verify response + if clientResp["client_id"] == "" { + t.Error("Client ID should not be empty") + } + if clientResp["client_secret"] != nil { + t.Error("Public client should not have a secret") + } + if scope, ok := clientResp["scope"].(string); !ok || scope != "read write" { + t.Errorf("Expected scope 'read write' as string, got: %v", clientResp["scope"]) + } + }) + + t.Run("MultipleRegistrations", func(t *testing.T) { + // Register multiple clients and verify they get different IDs + var clientIDs []string + + for i := range 3 { + clientReq := map[string]any{ + "redirect_uris": []string{fmt.Sprintf("http://example.com/callback%d", i)}, + "scope": "read", + } + + body, _ := json.Marshal(clientReq) + resp, err := http.Post( + "http://localhost:8080/register", + "application/json", + bytes.NewBuffer(body), + ) + if err != nil { + t.Fatalf("Failed to register client %d: %v", i, err) + } + defer resp.Body.Close() + + var clientResp map[string]any + _ = json.NewDecoder(resp.Body).Decode(&clientResp) + clientIDs = append(clientIDs, clientResp["client_id"].(string)) + } + + // Verify all IDs are unique + for i := 0; i < len(clientIDs); i++ { + for j := i + 1; j < len(clientIDs); j++ { + if clientIDs[i] == clientIDs[j] { + t.Errorf("Client IDs should be unique, but got duplicate: %s", clientIDs[i]) + } + } + } + + }) + + t.Run("ConfidentialClientRegistration", func(t *testing.T) { + // Register a confidential client with client_secret_post + clientReq := map[string]any{ + "redirect_uris": []string{"https://example.com/callback"}, + "scope": "read write", + "token_endpoint_auth_method": "client_secret_post", + } + + body, _ := json.Marshal(clientReq) + resp, err := http.Post( + "http://localhost:8080/register", + "application/json", + bytes.NewBuffer(body), + ) + if err != nil { + t.Fatalf("Failed to register confidential client: %v", err) + } + defer resp.Body.Close() + + if resp.StatusCode != 201 { + body, _ := io.ReadAll(resp.Body) + t.Fatalf("Confidential client registration failed with status %d: %s", resp.StatusCode, string(body)) + } + + var clientResp map[string]any + if err := json.NewDecoder(resp.Body).Decode(&clientResp); err != nil { + t.Fatalf("Failed to decode response: %v", err) + } + + // Verify response includes client_secret + if clientResp["client_id"] == "" { + t.Error("Client ID should not be empty") + } + clientSecret, ok := clientResp["client_secret"].(string) + if !ok || clientSecret == "" { + t.Error("Confidential client should receive a client_secret") + } + // Verify secret has reasonable length (base64 of 32 bytes) + if len(clientSecret) < 40 { + t.Errorf("Client secret seems too short: %d chars", len(clientSecret)) + } + + tokenAuthMethod, ok := clientResp["token_endpoint_auth_method"].(string) + if !ok || tokenAuthMethod != "client_secret_post" { + t.Errorf("Expected token_endpoint_auth_method 'client_secret_post', got: %v", clientResp["token_endpoint_auth_method"]) + } + + // Verify scope is returned as string + if scope, ok := clientResp["scope"].(string); !ok || scope != "read write" { + t.Errorf("Expected scope 'read write' as string, got: %v", clientResp["scope"]) + } + }) + + t.Run("PublicVsConfidentialClients", func(t *testing.T) { + // Test that public clients don't get secrets and confidential ones do + + // First, create a public client + publicReq := map[string]any{ + "redirect_uris": []string{"https://public.example.com/callback"}, + "scope": "read", + // No token_endpoint_auth_method specified - defaults to "none" + } + + body, _ := json.Marshal(publicReq) + resp, err := http.Post( + "http://localhost:8080/register", + "application/json", + bytes.NewBuffer(body), + ) + require.NoError(t, err) + defer resp.Body.Close() + + var publicResp map[string]any + _ = json.NewDecoder(resp.Body).Decode(&publicResp) + + // Verify public client has no secret + if _, hasSecret := publicResp["client_secret"]; hasSecret { + t.Error("Public client should not have a secret") + } + if authMethod := publicResp["token_endpoint_auth_method"]; authMethod != "none" { + t.Errorf("Public client should have auth method 'none', got: %v", authMethod) + } + + // Now create a confidential client + confidentialReq := map[string]any{ + "redirect_uris": []string{"https://confidential.example.com/callback"}, + "scope": "read write", + "token_endpoint_auth_method": "client_secret_post", + } + + body, _ = json.Marshal(confidentialReq) + resp, err = http.Post( + "http://localhost:8080/register", + "application/json", + bytes.NewBuffer(body), + ) + require.NoError(t, err) + defer resp.Body.Close() + + var confResp map[string]any + _ = json.NewDecoder(resp.Body).Decode(&confResp) + + // Verify confidential client has a secret + if secret, ok := confResp["client_secret"].(string); !ok || secret == "" { + t.Error("Confidential client should have a secret") + } + if authMethod := confResp["token_endpoint_auth_method"]; authMethod != "client_secret_post" { + t.Errorf("Confidential client should have auth method 'client_secret_post', got: %v", authMethod) + } + }) +} + +// TestStateParameterHandling tests OAuth state parameter requirements +func TestStateParameterHandling(t *testing.T) { + tests := []struct { + name string + environment string + state string + expectError bool + }{ + {"Production without state", "production", "", true}, + {"Production with state", "production", "secure-random-state", false}, + {"Development without state", "development", "", false}, // Should auto-generate + {"Development with state", "development", "test-state", false}, + } + + for _, tt := range tests { + // capture range variable + t.Run(tt.name, func(t *testing.T) { + // Start server with specific environment + mcpCmd := startOAuthServer(t, map[string]string{ + "MCP_FRONT_ENV": tt.environment, + }) + defer stopServer(mcpCmd) + + if !waitForHealthCheck(10) { + t.Fatal("Server failed to start") + } + + // Register a client first + clientID := registerTestClient(t) + + // Create authorization request + params := url.Values{ + "response_type": {"code"}, + "client_id": {clientID}, + "redirect_uri": {"http://127.0.0.1:6274/oauth/callback"}, + "code_challenge": {"test-challenge"}, + "code_challenge_method": {"S256"}, + "scope": {"read write"}, + } + if tt.state != "" { + params.Set("state", tt.state) + } + + authURL := fmt.Sprintf("http://localhost:8080/authorize?%s", params.Encode()) + + // Use a client that doesn't follow redirects + client := &http.Client{ + CheckRedirect: func(req *http.Request, via []*http.Request) error { + return http.ErrUseLastResponse + }, + } + resp, err := client.Get(authURL) + if err != nil { + t.Fatalf("Authorization request failed: %v", err) + } + defer resp.Body.Close() + + if tt.expectError { + // OAuth errors are returned as redirects with error parameters + if resp.StatusCode == 302 || resp.StatusCode == 303 { + location := resp.Header.Get("Location") + if strings.Contains(location, "error=") { + } else { + t.Errorf("Expected error redirect for %s, got redirect without error", tt.name) + } + } else if resp.StatusCode >= 400 { + } else { + t.Errorf("Expected error for %s, got status %d", tt.name, resp.StatusCode) + } + } else { + if resp.StatusCode == 302 || resp.StatusCode == 303 { + location := resp.Header.Get("Location") + if strings.Contains(location, "error=") { + t.Errorf("Unexpected error redirect for %s: %s", tt.name, location) + } + } else if resp.StatusCode < 400 { + } else { + body, _ := io.ReadAll(resp.Body) + t.Errorf("Expected success for %s, got status %d: %s", tt.name, resp.StatusCode, string(body)) + } + } + }) + } +} + +// TestEnvironmentModes tests development vs production mode differences +func TestEnvironmentModes(t *testing.T) { + t.Run("DevelopmentMode", func(t *testing.T) { + mcpCmd := startOAuthServer(t, map[string]string{ + "MCP_FRONT_ENV": "development", + }) + defer stopServer(mcpCmd) + + if !waitForHealthCheck(30) { + t.Fatal("Server failed to start") + } + + // In development mode, missing state should be auto-generated + clientID := registerTestClient(t) + + params := url.Values{ + "response_type": {"code"}, + "client_id": {clientID}, + "redirect_uri": {"http://127.0.0.1:6274/oauth/callback"}, + "code_challenge": {"test-challenge"}, + "code_challenge_method": {"S256"}, + "scope": {"read"}, + // Intentionally omitting state parameter + } + + // Use a client that doesn't follow redirects + client := &http.Client{ + CheckRedirect: func(req *http.Request, via []*http.Request) error { + return http.ErrUseLastResponse + }, + } + resp, err := client.Get("http://localhost:8080/authorize?" + params.Encode()) + if err != nil { + t.Fatalf("Failed to make auth request: %v", err) + } + defer resp.Body.Close() + + // Should redirect (302) not error + if resp.StatusCode >= 400 && resp.StatusCode != 302 { + t.Errorf("Development mode should handle missing state, got status %d", resp.StatusCode) + } + }) + + t.Run("ProductionMode", func(t *testing.T) { + mcpCmd := startOAuthServer(t, map[string]string{ + "MCP_FRONT_ENV": "production", + }) + defer stopServer(mcpCmd) + + if !waitForHealthCheck(30) { + t.Fatal("Server failed to start") + } + + // In production mode, state should be required + clientID := registerTestClient(t) + + params := url.Values{ + "response_type": {"code"}, + "client_id": {clientID}, + "redirect_uri": {"http://127.0.0.1:6274/oauth/callback"}, + "code_challenge": {"test-challenge"}, + "code_challenge_method": {"S256"}, + "scope": {"read"}, + // Intentionally omitting state parameter + } + + // Use a client that doesn't follow redirects + client := &http.Client{ + CheckRedirect: func(req *http.Request, via []*http.Request) error { + return http.ErrUseLastResponse + }, + } + resp, err := client.Get("http://localhost:8080/authorize?" + params.Encode()) + if err != nil { + t.Fatalf("Failed to make auth request: %v", err) + } + defer resp.Body.Close() + + // Should error - OAuth errors are returned as redirects + if resp.StatusCode == 302 || resp.StatusCode == 303 { + location := resp.Header.Get("Location") + if strings.Contains(location, "error=") { + } else { + t.Errorf("Expected error redirect in production mode, got redirect without error") + } + } else if resp.StatusCode >= 400 { + } else { + t.Errorf("Production mode should require state parameter, got status %d", resp.StatusCode) + } + }) +} + +// TestOAuthEndpoints tests all OAuth endpoints comprehensively +func TestOAuthEndpoints(t *testing.T) { + mcpCmd := startOAuthServer(t, map[string]string{ + "MCP_FRONT_ENV": "development", + }) + defer stopServer(mcpCmd) + + if !waitForHealthCheck(10) { + t.Fatal("Server failed to start") + } + + t.Run("Discovery", func(t *testing.T) { + resp, err := http.Get("http://localhost:8080/.well-known/oauth-authorization-server") + if err != nil { + t.Fatalf("Discovery request failed: %v", err) + } + defer resp.Body.Close() + + if resp.StatusCode != 200 { + t.Fatalf("Discovery failed with status %d", resp.StatusCode) + } + + var discovery map[string]any + if err := json.NewDecoder(resp.Body).Decode(&discovery); err != nil { + t.Fatalf("Failed to decode discovery response: %v", err) + } + + // Verify all required fields + required := []string{ + "issuer", + "authorization_endpoint", + "token_endpoint", + "registration_endpoint", + "response_types_supported", + "grant_types_supported", + "code_challenge_methods_supported", + } + + for _, field := range required { + if _, ok := discovery[field]; !ok { + t.Errorf("Missing required discovery field: %s", field) + } + } + + }) + + t.Run("HealthCheck", func(t *testing.T) { + resp, err := http.Get("http://localhost:8080/health") + if err != nil { + t.Fatalf("Health check failed: %v", err) + } + defer resp.Body.Close() + + if resp.StatusCode != 200 { + t.Errorf("Health check should return 200, got %d", resp.StatusCode) + } + + var health map[string]string + if err := json.NewDecoder(resp.Body).Decode(&health); err != nil { + t.Fatalf("Failed to decode health response: %v", err) + } + if health["status"] != "ok" { + t.Errorf("Expected status 'ok', got '%s'", health["status"]) + } + + }) +} + +// TestCORSHeaders tests CORS headers for Claude.ai compatibility +func TestCORSHeaders(t *testing.T) { + mcpCmd := startOAuthServer(t, map[string]string{ + "MCP_FRONT_ENV": "development", + }) + defer stopServer(mcpCmd) + + if !waitForHealthCheck(10) { + t.Fatal("Server failed to start") + } + + // Test preflight request + req, _ := http.NewRequest("OPTIONS", "http://localhost:8080/register", nil) + req.Header.Set("Origin", "https://claude.ai") + req.Header.Set("Access-Control-Request-Method", "POST") + req.Header.Set("Access-Control-Request-Headers", "content-type") + + resp, err := http.DefaultClient.Do(req) + if err != nil { + t.Fatalf("Preflight request failed: %v", err) + } + defer resp.Body.Close() + + if resp.StatusCode != 200 { + t.Errorf("Preflight should return 200, got %d", resp.StatusCode) + } + + // Check CORS headers + expectedHeaders := map[string]string{ + "Access-Control-Allow-Origin": "https://claude.ai", + "Access-Control-Allow-Methods": "GET, POST, OPTIONS", + "Access-Control-Allow-Headers": "Content-Type, Authorization, Cache-Control, mcp-protocol-version", + } + + for header, expected := range expectedHeaders { + actual := resp.Header.Get(header) + if actual != expected { + t.Errorf("Expected %s: '%s', got '%s'", header, expected, actual) + } + } + +} + +// Shared helpers for OAuth tests + +func startOAuthServer(t *testing.T, env map[string]string) *exec.Cmd { + cfg := buildTestConfig("http://localhost:8080", "mcp-front-oauth-test", + testOAuthConfigFromEnv(), + map[string]any{"postgres": testPostgresServer()}, + ) + configPath := writeTestConfig(t, cfg) + + mcpCmd := exec.Command("../cmd/mcp-front/mcp-front", "-config", configPath) + + // Set default environment + mcpCmd.Env = []string{ + "PATH=" + os.Getenv("PATH"), + "JWT_SECRET=demo-jwt-secret-32-bytes-exactly!", + "ENCRYPTION_KEY=test-encryption-key-32-bytes-ok!", + "GOOGLE_CLIENT_ID=test-client-id-oauth", + "GOOGLE_CLIENT_SECRET=test-client-secret-oauth", + } + + // Override with provided env + for key, value := range env { + mcpCmd.Env = append(mcpCmd.Env, fmt.Sprintf("%s=%s", key, value)) + } + + // Capture stderr for debugging and also output to test log + var stderr bytes.Buffer + mcpCmd.Stderr = io.MultiWriter(&stderr, os.Stderr) + + if err := mcpCmd.Start(); err != nil { + t.Fatalf("Failed to start OAuth server: %v", err) + } + + // Give a moment for immediate failures + time.Sleep(100 * time.Millisecond) + + // Check if process died immediately + if mcpCmd.ProcessState != nil { + t.Fatalf("OAuth server died immediately: %s", stderr.String()) + } + + return mcpCmd +} + +// startOAuthServerWithTokenConfig starts the OAuth server with user token configuration +func startOAuthServerWithTokenConfig(t *testing.T) *exec.Cmd { + // Start with user token config + mcpCmd := exec.Command("../cmd/mcp-front/mcp-front", "-config", "config/config.oauth-token-test.json") + + // Set default environment + mcpCmd.Env = []string{ + "PATH=" + os.Getenv("PATH"), + "JWT_SECRET=demo-jwt-secret-32-bytes-exactly!", + "ENCRYPTION_KEY=test-encryption-key-32-bytes-ok!", + "GOOGLE_CLIENT_ID=test-client-id-oauth", + "GOOGLE_CLIENT_SECRET=test-client-secret-oauth", + "MCP_FRONT_ENV=development", + } + + // Capture stderr for debugging and also output to test log + var stderr bytes.Buffer + mcpCmd.Stderr = io.MultiWriter(&stderr, os.Stderr) + + if err := mcpCmd.Start(); err != nil { + t.Fatalf("Failed to start OAuth server: %v", err) + } + + // Give a moment for immediate failures + time.Sleep(100 * time.Millisecond) + + // Check if process died immediately + if mcpCmd.ProcessState != nil { + t.Fatalf("OAuth server died immediately: %s", stderr.String()) + } + + return mcpCmd +} + +func stopServer(cmd *exec.Cmd) { + if cmd != nil && cmd.Process != nil { + _ = cmd.Process.Kill() + _ = cmd.Wait() + // Give the OS time to release the port + time.Sleep(100 * time.Millisecond) + } +} + +func waitForHealthCheck(seconds int) bool { + for range seconds { + if checkHealth() { + return true + } + time.Sleep(1 * time.Second) + } + return false +} + +func checkHealth() bool { + resp, err := http.Get("http://localhost:8080/health") + if err == nil && resp.StatusCode == 200 { + resp.Body.Close() + return true + } + if resp != nil { + resp.Body.Close() + } + return false +} + +func registerTestClient(t *testing.T) string { + clientReq := map[string]any{ + "redirect_uris": []string{"http://127.0.0.1:6274/oauth/callback"}, + "scope": "openid email profile read write", + } + + body, _ := json.Marshal(clientReq) + resp, err := http.Post( + "http://localhost:8080/register", + "application/json", + bytes.NewBuffer(body), + ) + if err != nil { + t.Fatalf("Failed to register client: %v", err) + } + defer resp.Body.Close() + + if resp.StatusCode != 201 { + body, _ := io.ReadAll(resp.Body) + t.Fatalf("Client registration failed: %d - %s", resp.StatusCode, string(body)) + } + + var clientResp map[string]any + _ = json.NewDecoder(resp.Body).Decode(&clientResp) + return clientResp["client_id"].(string) +} + +// extractCSRFToken extracts the CSRF token from the HTML response +func extractCSRFToken(t *testing.T, html string) string { + // Look for + re := regexp.MustCompile(`]+name="csrf_token"[^>]+value="([^"]+)"`) + matches := re.FindStringSubmatch(html) + require.GreaterOrEqual(t, len(matches), 2, "CSRF token not found in response") + return matches[1] +} + +// contains is a simple helper to check if string contains substring +func contains(s, substr string) bool { + return strings.Contains(s, substr) +} diff --git a/integration/oauth_idp_test.go b/integration/oauth_idp_test.go new file mode 100644 index 0000000..86b2171 --- /dev/null +++ b/integration/oauth_idp_test.go @@ -0,0 +1,256 @@ +package integration + +import ( + "io" + "net/http" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +// idpEnvGitHub is the set of env vars needed for GitHub IDP integration tests. +var idpEnvGitHub = []string{ + "JWT_SECRET=test-jwt-secret-32-bytes-exactly!", + "ENCRYPTION_KEY=test-encryption-key-32-bytes-ok!", + "GITHUB_CLIENT_SECRET=test-github-client-secret", + "MCP_FRONT_ENV=development", +} + +// idpEnvOIDC is the set of env vars needed for OIDC IDP integration tests. +var idpEnvOIDC = []string{ + "JWT_SECRET=test-jwt-secret-32-bytes-exactly!", + "ENCRYPTION_KEY=test-encryption-key-32-bytes-ok!", + "OIDC_CLIENT_SECRET=test-oidc-client-secret", + "MCP_FRONT_ENV=development", +} + +// idpEnvAzure is the set of env vars needed for Azure IDP integration tests. +var idpEnvAzure = []string{ + "JWT_SECRET=test-jwt-secret-32-bytes-exactly!", + "ENCRYPTION_KEY=test-encryption-key-32-bytes-ok!", + "AZURE_CLIENT_SECRET=test-azure-client-secret", + "MCP_FRONT_ENV=development", +} + +// TestGitHubOAuthFlow tests the full OAuth flow using the GitHub IDP. +func TestGitHubOAuthFlow(t *testing.T) { + cfg := buildTestConfig("http://localhost:8080", "mcp-front-github-test", + testGitHubOAuthConfig("test-org"), + map[string]any{"postgres": testPostgresServer()}, + ) + startMCPFront(t, writeTestConfig(t, cfg), idpEnvGitHub...) + waitForMCPFront(t) + + accessToken := getOAuthAccessTokenForIDP(t, "http://localhost:8080/postgres", "localhost:9092") + + mcpClient := NewMCPSSEClient("http://localhost:8080") + mcpClient.SetAuthToken(accessToken) + + err := mcpClient.Connect() + require.NoError(t, err, "Should connect to postgres with GitHub-issued token") + defer mcpClient.Close() + + toolsResp, err := mcpClient.SendMCPRequest("tools/list", map[string]any{}) + require.NoError(t, err, "Should list tools with GitHub-issued token") + + resultMap, ok := toolsResp["result"].(map[string]any) + require.True(t, ok, "Expected result in tools response") + tools, ok := resultMap["tools"].([]any) + require.True(t, ok, "Expected tools array") + assert.NotEmpty(t, tools, "Should have tools") +} + +// TestGitHubOrgDenial verifies that users not in allowed orgs are denied. +func TestGitHubOrgDenial(t *testing.T) { + cfg := buildTestConfig("http://localhost:8080", "mcp-front-github-org-deny", + testGitHubOAuthConfig("org-that-doesnt-match"), + map[string]any{"postgres": testPostgresServer()}, + ) + startMCPFront(t, writeTestConfig(t, cfg), idpEnvGitHub...) + waitForMCPFront(t) + + // Register client and start auth flow + clientID := registerTestClient(t) + + client := &http.Client{ + CheckRedirect: func(req *http.Request, via []*http.Request) error { + return http.ErrUseLastResponse + }, + } + + // Step 1: Authorize + authResp, err := client.Get("http://localhost:8080/authorize?response_type=code&client_id=" + clientID + + "&redirect_uri=http://127.0.0.1:6274/oauth/callback&code_challenge=test-challenge&code_challenge_method=S256" + + "&scope=openid+email+profile&state=test-state&resource=http://localhost:8080/postgres") + require.NoError(t, err) + defer authResp.Body.Close() + require.Contains(t, []int{302, 303}, authResp.StatusCode) + + // Step 2: Follow to GitHub fake + location := authResp.Header.Get("Location") + idpResp, err := client.Get(location) + require.NoError(t, err) + defer idpResp.Body.Close() + + // Step 3: Follow callback — should get an error (access denied), not a code + callbackLocation := idpResp.Header.Get("Location") + callbackResp, err := client.Get(callbackLocation) + require.NoError(t, err) + defer callbackResp.Body.Close() + + // The callback should either: + // - Return a 403 directly, or + // - Redirect to the redirect_uri with an error parameter + if callbackResp.StatusCode == 302 || callbackResp.StatusCode == 303 { + finalLocation := callbackResp.Header.Get("Location") + assert.Contains(t, finalLocation, "error=", "Should redirect with error for org denial") + } else { + // Direct error response + body, _ := io.ReadAll(callbackResp.Body) + assert.Contains(t, string(body), "denied", "Should contain denial message") + } +} + +// TestOIDCOAuthFlow tests the full OAuth flow using a generic OIDC provider. +func TestOIDCOAuthFlow(t *testing.T) { + cfg := buildTestConfig("http://localhost:8080", "mcp-front-oidc-test", + testOIDCOAuthConfig(), + map[string]any{"postgres": testPostgresServer()}, + ) + startMCPFront(t, writeTestConfig(t, cfg), idpEnvOIDC...) + waitForMCPFront(t) + + accessToken := getOAuthAccessTokenForIDP(t, "http://localhost:8080/postgres", "localhost:9093") + + mcpClient := NewMCPSSEClient("http://localhost:8080") + mcpClient.SetAuthToken(accessToken) + + err := mcpClient.Connect() + require.NoError(t, err, "Should connect to postgres with OIDC-issued token") + defer mcpClient.Close() + + toolsResp, err := mcpClient.SendMCPRequest("tools/list", map[string]any{}) + require.NoError(t, err, "Should list tools with OIDC-issued token") + + resultMap, ok := toolsResp["result"].(map[string]any) + require.True(t, ok, "Expected result in tools response") + tools, ok := resultMap["tools"].([]any) + require.True(t, ok, "Expected tools array") + assert.NotEmpty(t, tools, "Should have tools") +} + +// TestAzureOAuthFlow tests the full OAuth flow using the Azure IDP (backed by the OIDC fake server). +func TestAzureOAuthFlow(t *testing.T) { + cfg := buildTestConfig("http://localhost:8080", "mcp-front-azure-test", + testAzureOAuthConfig(), + map[string]any{"postgres": testPostgresServer()}, + ) + startMCPFront(t, writeTestConfig(t, cfg), idpEnvAzure...) + waitForMCPFront(t) + + accessToken := getOAuthAccessTokenForIDP(t, "http://localhost:8080/postgres", "localhost:9093") + + mcpClient := NewMCPSSEClient("http://localhost:8080") + mcpClient.SetAuthToken(accessToken) + + err := mcpClient.Connect() + require.NoError(t, err, "Should connect to postgres with Azure-issued token") + defer mcpClient.Close() + + toolsResp, err := mcpClient.SendMCPRequest("tools/list", map[string]any{}) + require.NoError(t, err, "Should list tools with Azure-issued token") + + resultMap, ok := toolsResp["result"].(map[string]any) + require.True(t, ok, "Expected result in tools response") + tools, ok := resultMap["tools"].([]any) + require.True(t, ok, "Expected tools array") + assert.NotEmpty(t, tools, "Should have tools") +} + +// TestIDPDomainDenial verifies that domain restrictions are enforced across providers. +func TestIDPDomainDenial(t *testing.T) { + tests := []struct { + name string + authConfig map[string]any + env []string + idpHost string + }{ + { + name: "GitHub", + authConfig: func() map[string]any { + cfg := testGitHubOAuthConfig("test-org") + cfg["allowedDomains"] = []string{"wrong-domain.com"} + return cfg + }(), + env: idpEnvGitHub, + idpHost: "localhost:9092", + }, + { + name: "OIDC", + authConfig: func() map[string]any { + cfg := testOIDCOAuthConfig() + cfg["allowedDomains"] = []string{"wrong-domain.com"} + return cfg + }(), + env: idpEnvOIDC, + idpHost: "localhost:9093", + }, + { + name: "Azure", + authConfig: func() map[string]any { + cfg := testAzureOAuthConfig() + cfg["allowedDomains"] = []string{"wrong-domain.com"} + return cfg + }(), + env: idpEnvAzure, + idpHost: "localhost:9093", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + cfg := buildTestConfig("http://localhost:8080", "mcp-front-domain-deny-"+tt.name, + tt.authConfig, + map[string]any{"postgres": testPostgresServer()}, + ) + startMCPFront(t, writeTestConfig(t, cfg), tt.env...) + waitForMCPFront(t) + + clientID := registerTestClient(t) + + client := &http.Client{ + CheckRedirect: func(req *http.Request, via []*http.Request) error { + return http.ErrUseLastResponse + }, + } + + authResp, err := client.Get("http://localhost:8080/authorize?response_type=code&client_id=" + clientID + + "&redirect_uri=http://127.0.0.1:6274/oauth/callback&code_challenge=test-challenge&code_challenge_method=S256" + + "&scope=openid+email+profile&state=test-state&resource=http://localhost:8080/postgres") + require.NoError(t, err) + defer authResp.Body.Close() + require.Contains(t, []int{302, 303}, authResp.StatusCode) + + location := authResp.Header.Get("Location") + require.Contains(t, location, tt.idpHost, "Should redirect to expected IDP") + + idpResp, err := client.Get(location) + require.NoError(t, err) + defer idpResp.Body.Close() + + callbackLocation := idpResp.Header.Get("Location") + callbackResp, err := client.Get(callbackLocation) + require.NoError(t, err) + defer callbackResp.Body.Close() + + if callbackResp.StatusCode == 302 || callbackResp.StatusCode == 303 { + finalLocation := callbackResp.Header.Get("Location") + assert.Contains(t, finalLocation, "error=", "Should redirect with error for domain denial") + } else { + body, _ := io.ReadAll(callbackResp.Body) + assert.Contains(t, string(body), "denied", "Should contain denial message") + } + }) + } +} diff --git a/integration/oauth_rfc8707_test.go b/integration/oauth_rfc8707_test.go new file mode 100644 index 0000000..94c01f2 --- /dev/null +++ b/integration/oauth_rfc8707_test.go @@ -0,0 +1,198 @@ +package integration + +import ( + "crypto/sha256" + "encoding/base64" + "encoding/json" + "net/http" + "net/url" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +// TestRFC8707ResourceIndicators validates RFC 8707 resource indicator functionality +func TestRFC8707ResourceIndicators(t *testing.T) { + startMCPFront(t, "config/config.oauth-rfc8707-test.json", + "JWT_SECRET=test-jwt-secret-32-bytes-exactly!", + "ENCRYPTION_KEY=test-encryption-key-32-bytes-ok!", + "GOOGLE_CLIENT_ID=test-client-id-for-oauth", + "GOOGLE_CLIENT_SECRET=test-client-secret-for-oauth", + "MCP_FRONT_ENV=development", + ) + + waitForMCPFront(t) + + t.Run("BaseProtectedResourceMetadataReturns404", func(t *testing.T) { + // Base metadata endpoint should return 404, directing clients to per-service endpoints + resp, err := http.Get("http://localhost:8080/.well-known/oauth-protected-resource") + require.NoError(t, err) + defer resp.Body.Close() + + assert.Equal(t, 404, resp.StatusCode, "Base protected resource metadata endpoint should return 404") + + var errResp map[string]any + err = json.NewDecoder(resp.Body).Decode(&errResp) + require.NoError(t, err) + + assert.Contains(t, errResp["message"], "per-service", "Error message should direct to per-service endpoints") + }) + + t.Run("PerServiceProtectedResourceMetadataEndpoint", func(t *testing.T) { + // Per-service metadata endpoint should return service-specific resource URI + resp, err := http.Get("http://localhost:8080/.well-known/oauth-protected-resource/test-sse") + require.NoError(t, err) + defer resp.Body.Close() + + assert.Equal(t, 200, resp.StatusCode, "Per-service protected resource metadata endpoint should exist") + + var metadata map[string]any + err = json.NewDecoder(resp.Body).Decode(&metadata) + require.NoError(t, err) + + // Resource should be service-specific, not base URL + assert.Equal(t, "http://localhost:8080/test-sse", metadata["resource"], + "Resource should be service-specific URL") + + authzServers, ok := metadata["authorization_servers"].([]any) + require.True(t, ok, "Should have authorization_servers array") + require.NotEmpty(t, authzServers) + assert.Equal(t, "http://localhost:8080", authzServers[0], + "Authorization server should be base issuer") + }) + + t.Run("UnknownServiceReturns404", func(t *testing.T) { + resp, err := http.Get("http://localhost:8080/.well-known/oauth-protected-resource/nonexistent-service") + require.NoError(t, err) + defer resp.Body.Close() + + assert.Equal(t, 404, resp.StatusCode, "Unknown service should return 404") + }) + + t.Run("TokenWithResourceParameter", func(t *testing.T) { + clientID := registerTestClient(t) + + codeVerifier := "test-code-verifier-that-is-at-least-43-characters-long" + h := sha256.New() + h.Write([]byte(codeVerifier)) + codeChallenge := base64.RawURLEncoding.EncodeToString(h.Sum(nil)) + + authParams := url.Values{ + "response_type": {"code"}, + "client_id": {clientID}, + "redirect_uri": {"http://127.0.0.1:6274/oauth/callback"}, + "code_challenge": {codeChallenge}, + "code_challenge_method": {"S256"}, + "scope": {"openid email profile"}, + "state": {"test-state"}, + "resource": {"http://localhost:8080/test-sse"}, + } + + client := &http.Client{ + CheckRedirect: func(req *http.Request, via []*http.Request) error { + return http.ErrUseLastResponse + }, + } + + authResp, err := client.Get("http://localhost:8080/authorize?" + authParams.Encode()) + require.NoError(t, err) + defer authResp.Body.Close() + + assert.Contains(t, []int{302, 303}, authResp.StatusCode, "Should redirect to Google OAuth") + + location := authResp.Header.Get("Location") + googleResp, err := client.Get(location) + require.NoError(t, err) + defer googleResp.Body.Close() + + callbackLocation := googleResp.Header.Get("Location") + callbackResp, err := client.Get(callbackLocation) + require.NoError(t, err) + defer callbackResp.Body.Close() + + finalURL, err := url.Parse(callbackResp.Header.Get("Location")) + require.NoError(t, err) + authCode := finalURL.Query().Get("code") + require.NotEmpty(t, authCode, "Should have authorization code") + + tokenParams := url.Values{ + "grant_type": {"authorization_code"}, + "code": {authCode}, + "redirect_uri": {"http://127.0.0.1:6274/oauth/callback"}, + "client_id": {clientID}, + "code_verifier": {codeVerifier}, + } + + tokenResp, err := http.PostForm("http://localhost:8080/token", tokenParams) + require.NoError(t, err) + defer tokenResp.Body.Close() + + require.Equal(t, 200, tokenResp.StatusCode, "Token exchange should succeed") + + var tokenData map[string]any + err = json.NewDecoder(tokenResp.Body).Decode(&tokenData) + require.NoError(t, err) + + testSSEToken := tokenData["access_token"].(string) + require.NotEmpty(t, testSSEToken, "Should have access token") + + t.Logf("Got token with test-sse audience: %s", testSSEToken[:20]+"...") + + // Verify token works for test-sse (matching audience) + req, _ := http.NewRequest("GET", "http://localhost:8080/test-sse/sse", nil) + req.Header.Set("Authorization", "Bearer "+testSSEToken) + req.Header.Set("Accept", "text/event-stream") + + sseResp, err := client.Do(req) + require.NoError(t, err) + defer sseResp.Body.Close() + + assert.Equal(t, 200, sseResp.StatusCode, + "Token with test-sse audience should access /test-sse/sse") + + // Verify token does NOT work for test-streamable (wrong audience) + req, _ = http.NewRequest("GET", "http://localhost:8080/test-streamable/sse", nil) + req.Header.Set("Authorization", "Bearer "+testSSEToken) + req.Header.Set("Accept", "text/event-stream") + + streamableResp, err := client.Do(req) + require.NoError(t, err) + defer streamableResp.Body.Close() + + assert.Equal(t, 401, streamableResp.StatusCode, + "Token with test-sse audience should NOT access /test-streamable/sse") + + wwwAuth := streamableResp.Header.Get("WWW-Authenticate") + assert.Contains(t, wwwAuth, "Bearer resource_metadata=", + "401 response should include RFC 9728 WWW-Authenticate header") + // Per RFC 9728 Section 5.2, the metadata URI should be service-specific + assert.Contains(t, wwwAuth, "/.well-known/oauth-protected-resource/test-streamable", + "401 response should point to per-service metadata endpoint") + }) + + t.Run("401ResponseIncludesServiceSpecificMetadataURI", func(t *testing.T) { + // Request to a protected endpoint without token should get 401 + // with service-specific metadata URI in WWW-Authenticate header + client := &http.Client{ + CheckRedirect: func(req *http.Request, via []*http.Request) error { + return http.ErrUseLastResponse + }, + } + + req, _ := http.NewRequest("GET", "http://localhost:8080/test-sse/sse", nil) + req.Header.Set("Accept", "text/event-stream") + + resp, err := client.Do(req) + require.NoError(t, err) + defer resp.Body.Close() + + assert.Equal(t, 401, resp.StatusCode, "Request without token should return 401") + + wwwAuth := resp.Header.Get("WWW-Authenticate") + assert.Contains(t, wwwAuth, "Bearer resource_metadata=", + "401 response should include RFC 9728 WWW-Authenticate header") + assert.Contains(t, wwwAuth, "/.well-known/oauth-protected-resource/test-sse", + "401 response should point to test-sse specific metadata endpoint") + }) +} diff --git a/integration/oauth_service_test.go b/integration/oauth_service_test.go new file mode 100644 index 0000000..20a8ca1 --- /dev/null +++ b/integration/oauth_service_test.go @@ -0,0 +1,88 @@ +package integration + +import ( + "io" + "net/http" + "net/http/cookiejar" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +// TestServiceOAuthIntegration validates the complete OAuth flow for external services +func TestServiceOAuthIntegration(t *testing.T) { + // Start fake service OAuth provider on port 9091 + fakeService := NewFakeServiceOAuthServer("9091") + err := fakeService.Start() + require.NoError(t, err) + defer func() { _ = fakeService.Stop() }() + + // Start mcp-front with OAuth service configuration + startMCPFront(t, "config/config.oauth-service-integration-test.json", + "JWT_SECRET=demo-jwt-secret-32-bytes-exactly!", + "ENCRYPTION_KEY=test-encryption-key-32-bytes-ok!", + "GOOGLE_CLIENT_ID=test-client-id-oauth", + "GOOGLE_CLIENT_SECRET=test-client-secret-oauth", + "TEST_SERVICE_CLIENT_ID=service-client-id", + "TEST_SERVICE_CLIENT_SECRET=service-client-secret", + "MCP_FRONT_ENV=development", + ) + + if !waitForHealthCheck(30) { + t.Fatal("Server failed to start") + } + + // For this test, we use browser SSO instead of OAuth client flow + // This simulates a user in the browser connecting services + jar, _ := cookiejar.New(nil) + client := &http.Client{Jar: jar} + + // Complete Google OAuth to get browser session + // Access /my/tokens which triggers SSO flow + resp, err := client.Get("http://localhost:8080/my/tokens") + require.NoError(t, err) + defer resp.Body.Close() + + // Should have completed SSO and landed on /my/tokens + require.Equal(t, http.StatusOK, resp.StatusCode) + + t.Run("ServiceOAuthConnectFlow", func(t *testing.T) { + // User clicks "Connect" for the service + req, _ := http.NewRequest("GET", "http://localhost:8080/oauth/connect?service=test-service", nil) + + resp, err := client.Do(req) + require.NoError(t, err) + defer resp.Body.Close() + + // Should complete OAuth flow and redirect back with success + // The http.Client automatically follows redirects: + // 1. /oauth/connect → redirects to localhost:9091/oauth/authorize + // 2. Fake service → redirects to /oauth/callback/test-service?code=... + // 3. Callback → stores token, redirects to /my/tokens with success message + + body, _ := io.ReadAll(resp.Body) + bodyStr := string(body) + + // Final page should show success + assert.Contains(t, bodyStr, "Successfully connected", "Should show success message after OAuth flow") + assert.Contains(t, bodyStr, "Test OAuth Service", "Should mention service name") + }) + + t.Run("ConnectedServiceShownOnTokenPage", func(t *testing.T) { + // After OAuth connection, service should appear as connected + req, _ := http.NewRequest("GET", "http://localhost:8080/my/tokens", nil) + + resp, err := client.Do(req) + require.NoError(t, err) + defer resp.Body.Close() + + body, _ := io.ReadAll(resp.Body) + bodyStr := string(body) + + // Should show the service with connected status + assert.Contains(t, bodyStr, "Test OAuth Service") + // OAuth-connected services show disconnect button, not connect + assert.Contains(t, bodyStr, "Disconnect", "OAuth-connected service should show Disconnect button") + }) +} diff --git a/integration/oauth_test.go b/integration/oauth_test.go deleted file mode 100644 index 0b2949f..0000000 --- a/integration/oauth_test.go +++ /dev/null @@ -1,1577 +0,0 @@ -package integration - -import ( - "bufio" - "bytes" - "context" - "crypto/sha256" - "encoding/base64" - "encoding/json" - "fmt" - "io" - "net/http" - "net/http/cookiejar" - "net/url" - "os" - "os/exec" - "regexp" - "strings" - "testing" - "time" - - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/mock" - "github.com/stretchr/testify/require" -) - -// TestBasicOAuthFlow tests the basic OAuth server functionality -func TestBasicOAuthFlow(t *testing.T) { - // Start mcp-front with OAuth config - startMCPFront(t, "config/config.oauth-test.json", - "JWT_SECRET=test-jwt-secret-32-bytes-exactly!", - "ENCRYPTION_KEY=test-encryption-key-32-bytes-ok!", - "GOOGLE_CLIENT_ID=test-client-id-for-oauth", - "GOOGLE_CLIENT_SECRET=test-client-secret-for-oauth", - "MCP_FRONT_ENV=development", - "GOOGLE_OAUTH_AUTH_URL=http://localhost:9090/auth", - "GOOGLE_OAUTH_TOKEN_URL=http://localhost:9090/token", - "GOOGLE_USERINFO_URL=http://localhost:9090/userinfo", - ) - - // Wait for startup - waitForMCPFront(t) - - // Test OAuth discovery - resp, err := http.Get("http://localhost:8080/.well-known/oauth-authorization-server") - require.NoError(t, err, "Failed to get OAuth discovery") - defer resp.Body.Close() - - assert.Equal(t, 200, resp.StatusCode, "OAuth discovery failed") - - var discovery map[string]any - err = json.NewDecoder(resp.Body).Decode(&discovery) - require.NoError(t, err, "Failed to decode discovery") - - // Verify required endpoints - requiredEndpoints := []string{ - "issuer", - "authorization_endpoint", - "token_endpoint", - "registration_endpoint", - } - - for _, endpoint := range requiredEndpoints { - _, ok := discovery[endpoint] - assert.True(t, ok, "Missing required endpoint: %s", endpoint) - } - - // Verify client_secret_post is advertised - authMethods, ok := discovery["token_endpoint_auth_methods_supported"].([]any) - assert.True(t, ok, "token_endpoint_auth_methods_supported should be present") - - var hasNone, hasClientSecretPost bool - for _, method := range authMethods { - if method == "none" { - hasNone = true - } - if method == "client_secret_post" { - hasClientSecretPost = true - } - } - assert.True(t, hasNone, "Should support 'none' auth method for public clients") - assert.True(t, hasClientSecretPost, "Should support 'client_secret_post' auth method for confidential clients") -} - -// TestJWTSecretValidation tests JWT secret length requirements -func TestJWTSecretValidation(t *testing.T) { - tests := []struct { - name string - secret string - shouldFail bool - }{ - {"Short 3-byte secret", "123", true}, - {"Short 16-byte secret", "sixteen-byte-key", true}, - {"Valid 32-byte secret", "demo-jwt-secret-32-bytes-exactly!", false}, - {"Long 64-byte secret", "demo-jwt-secret-32-bytes-exactly!demo-jwt-secret-32-bytes-exactly!", false}, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - - // Start mcp-front with specific JWT secret - mcpCmd := exec.Command("../cmd/mcp-front/mcp-front", "-config", "config/config.oauth-test.json") - mcpCmd.Env = []string{ - "PATH=" + os.Getenv("PATH"), - "JWT_SECRET=" + tt.secret, - "ENCRYPTION_KEY=test-encryption-key-32-bytes-ok!", - "GOOGLE_CLIENT_ID=test-client-id", - "GOOGLE_CLIENT_SECRET=test-client-secret", - "MCP_FRONT_ENV=development", - } - - // Capture stderr - stderrPipe, _ := mcpCmd.StderrPipe() - scanner := bufio.NewScanner(stderrPipe) - - if err := mcpCmd.Start(); err != nil { - t.Fatalf("Failed to start mcp-front: %v", err) - } - - // Read stderr to check for errors - errorFound := false - go func() { - for scanner.Scan() { - line := scanner.Text() - if contains(line, "JWT secret must be at least") { - errorFound = true - } - } - }() - - // Give it time to start or fail - time.Sleep(2 * time.Second) - - // Check if it's running - healthy := checkHealth() - - // Clean up - if mcpCmd.Process != nil { - _ = mcpCmd.Process.Kill() - _ = mcpCmd.Wait() - } - - if tt.shouldFail { - assert.False(t, healthy && !errorFound, "Expected failure with short JWT secret but server started successfully") - } else { - assert.True(t, healthy, "Expected success with valid JWT secret but server failed to start") - } - }) - } -} - -// TestClientRegistration tests dynamic client registration (RFC 7591) -func TestClientRegistration(t *testing.T) { - // Start OAuth server - mcpCmd := startOAuthServer(t, map[string]string{ - "MCP_FRONT_ENV": "development", - }) - defer stopServer(mcpCmd) - - if !waitForHealthCheck(30) { - t.Fatal("OAuth server failed to start") - } - - t.Run("PublicClientRegistration", func(t *testing.T) { - // Register a public client (no secret) - clientReq := map[string]any{ - "redirect_uris": []string{"http://127.0.0.1:6274/oauth/callback/debug"}, - "scope": "read write", - } - - body, _ := json.Marshal(clientReq) - resp, err := http.Post( - "http://localhost:8080/register", - "application/json", - bytes.NewBuffer(body), - ) - if err != nil { - t.Fatalf("Failed to register client: %v", err) - } - defer resp.Body.Close() - - if resp.StatusCode != 201 { - body, _ := io.ReadAll(resp.Body) - t.Fatalf("Client registration failed with status %d: %s", resp.StatusCode, string(body)) - } - - var clientResp map[string]any - if err := json.NewDecoder(resp.Body).Decode(&clientResp); err != nil { - t.Fatalf("Failed to decode response: %v", err) - } - - // Verify response - if clientResp["client_id"] == "" { - t.Error("Client ID should not be empty") - } - if clientResp["client_secret"] != nil { - t.Error("Public client should not have a secret") - } - if scope, ok := clientResp["scope"].(string); !ok || scope != "read write" { - t.Errorf("Expected scope 'read write' as string, got: %v", clientResp["scope"]) - } - }) - - t.Run("MultipleRegistrations", func(t *testing.T) { - // Register multiple clients and verify they get different IDs - var clientIDs []string - - for i := range 3 { - clientReq := map[string]any{ - "redirect_uris": []string{fmt.Sprintf("http://example.com/callback%d", i)}, - "scope": "read", - } - - body, _ := json.Marshal(clientReq) - resp, err := http.Post( - "http://localhost:8080/register", - "application/json", - bytes.NewBuffer(body), - ) - if err != nil { - t.Fatalf("Failed to register client %d: %v", i, err) - } - defer resp.Body.Close() - - var clientResp map[string]any - _ = json.NewDecoder(resp.Body).Decode(&clientResp) - clientIDs = append(clientIDs, clientResp["client_id"].(string)) - } - - // Verify all IDs are unique - for i := 0; i < len(clientIDs); i++ { - for j := i + 1; j < len(clientIDs); j++ { - if clientIDs[i] == clientIDs[j] { - t.Errorf("Client IDs should be unique, but got duplicate: %s", clientIDs[i]) - } - } - } - - }) - - t.Run("ConfidentialClientRegistration", func(t *testing.T) { - // Register a confidential client with client_secret_post - clientReq := map[string]any{ - "redirect_uris": []string{"https://example.com/callback"}, - "scope": "read write", - "token_endpoint_auth_method": "client_secret_post", - } - - body, _ := json.Marshal(clientReq) - resp, err := http.Post( - "http://localhost:8080/register", - "application/json", - bytes.NewBuffer(body), - ) - if err != nil { - t.Fatalf("Failed to register confidential client: %v", err) - } - defer resp.Body.Close() - - if resp.StatusCode != 201 { - body, _ := io.ReadAll(resp.Body) - t.Fatalf("Confidential client registration failed with status %d: %s", resp.StatusCode, string(body)) - } - - var clientResp map[string]any - if err := json.NewDecoder(resp.Body).Decode(&clientResp); err != nil { - t.Fatalf("Failed to decode response: %v", err) - } - - // Verify response includes client_secret - if clientResp["client_id"] == "" { - t.Error("Client ID should not be empty") - } - clientSecret, ok := clientResp["client_secret"].(string) - if !ok || clientSecret == "" { - t.Error("Confidential client should receive a client_secret") - } - // Verify secret has reasonable length (base64 of 32 bytes) - if len(clientSecret) < 40 { - t.Errorf("Client secret seems too short: %d chars", len(clientSecret)) - } - - tokenAuthMethod, ok := clientResp["token_endpoint_auth_method"].(string) - if !ok || tokenAuthMethod != "client_secret_post" { - t.Errorf("Expected token_endpoint_auth_method 'client_secret_post', got: %v", clientResp["token_endpoint_auth_method"]) - } - - // Verify scope is returned as string - if scope, ok := clientResp["scope"].(string); !ok || scope != "read write" { - t.Errorf("Expected scope 'read write' as string, got: %v", clientResp["scope"]) - } - }) - - t.Run("PublicVsConfidentialClients", func(t *testing.T) { - // Test that public clients don't get secrets and confidential ones do - - // First, create a public client - publicReq := map[string]any{ - "redirect_uris": []string{"https://public.example.com/callback"}, - "scope": "read", - // No token_endpoint_auth_method specified - defaults to "none" - } - - body, _ := json.Marshal(publicReq) - resp, err := http.Post( - "http://localhost:8080/register", - "application/json", - bytes.NewBuffer(body), - ) - require.NoError(t, err) - defer resp.Body.Close() - - var publicResp map[string]any - _ = json.NewDecoder(resp.Body).Decode(&publicResp) - - // Verify public client has no secret - if _, hasSecret := publicResp["client_secret"]; hasSecret { - t.Error("Public client should not have a secret") - } - if authMethod := publicResp["token_endpoint_auth_method"]; authMethod != "none" { - t.Errorf("Public client should have auth method 'none', got: %v", authMethod) - } - - // Now create a confidential client - confidentialReq := map[string]any{ - "redirect_uris": []string{"https://confidential.example.com/callback"}, - "scope": "read write", - "token_endpoint_auth_method": "client_secret_post", - } - - body, _ = json.Marshal(confidentialReq) - resp, err = http.Post( - "http://localhost:8080/register", - "application/json", - bytes.NewBuffer(body), - ) - require.NoError(t, err) - defer resp.Body.Close() - - var confResp map[string]any - _ = json.NewDecoder(resp.Body).Decode(&confResp) - - // Verify confidential client has a secret - if secret, ok := confResp["client_secret"].(string); !ok || secret == "" { - t.Error("Confidential client should have a secret") - } - if authMethod := confResp["token_endpoint_auth_method"]; authMethod != "client_secret_post" { - t.Errorf("Confidential client should have auth method 'client_secret_post', got: %v", authMethod) - } - }) -} - -// TestUserTokenFlow tests the user token management functionality with browser-based SSO -// This test expects the /my/* routes to work with Google SSO (session-based auth), -// not Bearer token auth. -func TestUserTokenFlow(t *testing.T) { - // Start OAuth server with user token configuration - mcpCmd := startOAuthServerWithTokenConfig(t) - defer stopServer(mcpCmd) - - if !waitForHealthCheck(30) { - t.Fatal("Server failed to start") - } - - // Create a client with cookie jar to simulate browser behavior - jar, _ := cookiejar.New(nil) - client := &http.Client{ - Jar: jar, - CheckRedirect: func(req *http.Request, via []*http.Request) error { - // Allow up to 10 redirects - if len(via) >= 10 { - return fmt.Errorf("too many redirects") - } - return nil - }, - } - - t.Run("UnauthenticatedRedirectsToSSO", func(t *testing.T) { - // Create a client that doesn't follow redirects to test the initial redirect - noRedirectClient := &http.Client{ - CheckRedirect: func(req *http.Request, via []*http.Request) error { - return http.ErrUseLastResponse - }, - } - - // Try to access /my/tokens without authentication - resp, err := noRedirectClient.Get("http://localhost:8080/my/tokens") - require.NoError(t, err) - defer resp.Body.Close() - - // Should get a redirect response - assert.Equal(t, http.StatusFound, resp.StatusCode, "Should get redirect status") - - // Check the redirect location - location := resp.Header.Get("Location") - assert.Contains(t, location, "localhost:9090/auth", "Should redirect to Google OAuth") - assert.Contains(t, location, "client_id=", "Should include client_id") - assert.Contains(t, location, "redirect_uri=", "Should include redirect_uri") - // Extract and validate the state parameter - parsedURL, err := url.Parse(location) - require.NoError(t, err) - stateParam := parsedURL.Query().Get("state") - require.NotEmpty(t, stateParam, "State parameter should be present") - - // State format: "browser:" prefix followed by signed token - // We verify structure but not internal format (that's implementation detail) - assert.True(t, strings.HasPrefix(stateParam, "browser:"), "State should start with browser:") - assert.Greater(t, len(stateParam), len("browser:"), "State should have content after prefix") - - // Verify state contains signature (has dot separator indicating signed data) - stateContent := strings.TrimPrefix(stateParam, "browser:") - assert.Contains(t, stateContent, ".", "Signed state should contain signature separator") - }) - - t.Run("AuthenticatedUserCanAccessTokens", func(t *testing.T) { - // The client with cookie jar will automatically follow the full SSO flow: - // 1. GET /my/tokens -> redirect to Google OAuth - // 2. Google OAuth redirects to /oauth/callback with code - // 3. Callback sets session cookie and redirects to /my/tokens - // 4. Client follows redirect with cookie and gets the page - - resp, err := client.Get("http://localhost:8080/my/tokens") - require.NoError(t, err) - defer resp.Body.Close() - - // After following all redirects, we should be at /my/tokens with 200 OK - assert.Equal(t, http.StatusOK, resp.StatusCode, "Should access /my/tokens after SSO") - finalURL := resp.Request.URL.String() - assert.Contains(t, finalURL, "/my/tokens", "Should end up at /my/tokens after SSO") - - // Read response body - body, err := io.ReadAll(resp.Body) - require.NoError(t, err) - bodyStr := string(body) - - // Should show both services without tokens - assert.Contains(t, bodyStr, "Notion", "Expected Notion service in response") - assert.Contains(t, bodyStr, "GitHub", "Expected GitHub service in response") - }) - - t.Run("SetTokenWithValidation", func(t *testing.T) { - // Assume we're already authenticated from previous test - // Get CSRF token first - resp, err := client.Get("http://localhost:8080/my/tokens") - require.NoError(t, err) - defer resp.Body.Close() - - body, err := io.ReadAll(resp.Body) - require.NoError(t, err) - - // Extract CSRF token from response - csrfToken := extractCSRFToken(t, string(body)) - - // Try to set invalid Notion token - form := url.Values{ - "service": []string{"notion"}, - "token": []string{"invalid-token"}, - "csrf_token": []string{csrfToken}, - } - - req, _ := http.NewRequest("POST", "http://localhost:8080/my/tokens/set", strings.NewReader(form.Encode())) - req.Header.Set("Content-Type", "application/x-www-form-urlencoded") - - // Use custom client that doesn't follow redirects for this test - noRedirectClient := &http.Client{ - Jar: jar, // Use same cookie jar - CheckRedirect: func(req *http.Request, via []*http.Request) error { - return http.ErrUseLastResponse - }, - } - - resp, err = noRedirectClient.Do(req) - require.NoError(t, err) - defer resp.Body.Close() - - // Should redirect with error - assert.Equal(t, http.StatusSeeOther, resp.StatusCode, "Expected redirect") - location := resp.Header.Get("Location") - assert.Contains(t, location, "error", "Expected error in redirect") - - // Get new CSRF token - resp, err = client.Get("http://localhost:8080/my/tokens") - require.NoError(t, err) - defer resp.Body.Close() - - body, err = io.ReadAll(resp.Body) - require.NoError(t, err) - csrfToken = extractCSRFToken(t, string(body)) - - // Set valid Notion token (regex expects exactly 43 chars after "secret_") - form = url.Values{ - "service": {"notion"}, - "token": {"secret_1234567890123456789012345678901234567890123"}, - "csrf_token": {csrfToken}, - } - - req, _ = http.NewRequest("POST", "http://localhost:8080/my/tokens/set", strings.NewReader(form.Encode())) - req.Header.Set("Content-Type", "application/x-www-form-urlencoded") - - resp, err = noRedirectClient.Do(req) - require.NoError(t, err) - defer resp.Body.Close() - - // Should redirect with success - assert.Equal(t, http.StatusSeeOther, resp.StatusCode, "Expected redirect") - location = resp.Header.Get("Location") - assert.Contains(t, location, "success", "Expected success in redirect") - }) -} - -// TestStateParameterHandling tests OAuth state parameter requirements -func TestStateParameterHandling(t *testing.T) { - tests := []struct { - name string - environment string - state string - expectError bool - }{ - {"Production without state", "production", "", true}, - {"Production with state", "production", "secure-random-state", false}, - {"Development without state", "development", "", false}, // Should auto-generate - {"Development with state", "development", "test-state", false}, - } - - for _, tt := range tests { - // capture range variable - t.Run(tt.name, func(t *testing.T) { - // Start server with specific environment - mcpCmd := startOAuthServer(t, map[string]string{ - "MCP_FRONT_ENV": tt.environment, - }) - defer stopServer(mcpCmd) - - if !waitForHealthCheck(10) { - t.Fatal("Server failed to start") - } - - // Register a client first - clientID := registerTestClient(t) - - // Create authorization request - params := url.Values{ - "response_type": {"code"}, - "client_id": {clientID}, - "redirect_uri": {"http://127.0.0.1:6274/oauth/callback"}, - "code_challenge": {"test-challenge"}, - "code_challenge_method": {"S256"}, - "scope": {"read write"}, - } - if tt.state != "" { - params.Set("state", tt.state) - } - - authURL := fmt.Sprintf("http://localhost:8080/authorize?%s", params.Encode()) - - // Use a client that doesn't follow redirects - client := &http.Client{ - CheckRedirect: func(req *http.Request, via []*http.Request) error { - return http.ErrUseLastResponse - }, - } - resp, err := client.Get(authURL) - if err != nil { - t.Fatalf("Authorization request failed: %v", err) - } - defer resp.Body.Close() - - if tt.expectError { - // OAuth errors are returned as redirects with error parameters - if resp.StatusCode == 302 || resp.StatusCode == 303 { - location := resp.Header.Get("Location") - if strings.Contains(location, "error=") { - } else { - t.Errorf("Expected error redirect for %s, got redirect without error", tt.name) - } - } else if resp.StatusCode >= 400 { - } else { - t.Errorf("Expected error for %s, got status %d", tt.name, resp.StatusCode) - } - } else { - if resp.StatusCode == 302 || resp.StatusCode == 303 { - location := resp.Header.Get("Location") - if strings.Contains(location, "error=") { - t.Errorf("Unexpected error redirect for %s: %s", tt.name, location) - } - } else if resp.StatusCode < 400 { - } else { - body, _ := io.ReadAll(resp.Body) - t.Errorf("Expected success for %s, got status %d: %s", tt.name, resp.StatusCode, string(body)) - } - } - }) - } -} - -// TestEnvironmentModes tests development vs production mode differences -func TestEnvironmentModes(t *testing.T) { - t.Run("DevelopmentMode", func(t *testing.T) { - mcpCmd := startOAuthServer(t, map[string]string{ - "MCP_FRONT_ENV": "development", - }) - defer stopServer(mcpCmd) - - if !waitForHealthCheck(30) { - t.Fatal("Server failed to start") - } - - // In development mode, missing state should be auto-generated - clientID := registerTestClient(t) - - params := url.Values{ - "response_type": {"code"}, - "client_id": {clientID}, - "redirect_uri": {"http://127.0.0.1:6274/oauth/callback"}, - "code_challenge": {"test-challenge"}, - "code_challenge_method": {"S256"}, - "scope": {"read"}, - // Intentionally omitting state parameter - } - - // Use a client that doesn't follow redirects - client := &http.Client{ - CheckRedirect: func(req *http.Request, via []*http.Request) error { - return http.ErrUseLastResponse - }, - } - resp, err := client.Get("http://localhost:8080/authorize?" + params.Encode()) - if err != nil { - t.Fatalf("Failed to make auth request: %v", err) - } - defer resp.Body.Close() - - // Should redirect (302) not error - if resp.StatusCode >= 400 && resp.StatusCode != 302 { - t.Errorf("Development mode should handle missing state, got status %d", resp.StatusCode) - } - }) - - t.Run("ProductionMode", func(t *testing.T) { - mcpCmd := startOAuthServer(t, map[string]string{ - "MCP_FRONT_ENV": "production", - }) - defer stopServer(mcpCmd) - - if !waitForHealthCheck(30) { - t.Fatal("Server failed to start") - } - - // In production mode, state should be required - clientID := registerTestClient(t) - - params := url.Values{ - "response_type": {"code"}, - "client_id": {clientID}, - "redirect_uri": {"http://127.0.0.1:6274/oauth/callback"}, - "code_challenge": {"test-challenge"}, - "code_challenge_method": {"S256"}, - "scope": {"read"}, - // Intentionally omitting state parameter - } - - // Use a client that doesn't follow redirects - client := &http.Client{ - CheckRedirect: func(req *http.Request, via []*http.Request) error { - return http.ErrUseLastResponse - }, - } - resp, err := client.Get("http://localhost:8080/authorize?" + params.Encode()) - if err != nil { - t.Fatalf("Failed to make auth request: %v", err) - } - defer resp.Body.Close() - - // Should error - OAuth errors are returned as redirects - if resp.StatusCode == 302 || resp.StatusCode == 303 { - location := resp.Header.Get("Location") - if strings.Contains(location, "error=") { - } else { - t.Errorf("Expected error redirect in production mode, got redirect without error") - } - } else if resp.StatusCode >= 400 { - } else { - t.Errorf("Production mode should require state parameter, got status %d", resp.StatusCode) - } - }) -} - -// TestOAuthEndpoints tests all OAuth endpoints comprehensively -func TestOAuthEndpoints(t *testing.T) { - mcpCmd := startOAuthServer(t, map[string]string{ - "MCP_FRONT_ENV": "development", - }) - defer stopServer(mcpCmd) - - if !waitForHealthCheck(10) { - t.Fatal("Server failed to start") - } - - t.Run("Discovery", func(t *testing.T) { - resp, err := http.Get("http://localhost:8080/.well-known/oauth-authorization-server") - if err != nil { - t.Fatalf("Discovery request failed: %v", err) - } - defer resp.Body.Close() - - if resp.StatusCode != 200 { - t.Fatalf("Discovery failed with status %d", resp.StatusCode) - } - - var discovery map[string]any - if err := json.NewDecoder(resp.Body).Decode(&discovery); err != nil { - t.Fatalf("Failed to decode discovery response: %v", err) - } - - // Verify all required fields - required := []string{ - "issuer", - "authorization_endpoint", - "token_endpoint", - "registration_endpoint", - "response_types_supported", - "grant_types_supported", - "code_challenge_methods_supported", - } - - for _, field := range required { - if _, ok := discovery[field]; !ok { - t.Errorf("Missing required discovery field: %s", field) - } - } - - }) - - t.Run("HealthCheck", func(t *testing.T) { - resp, err := http.Get("http://localhost:8080/health") - if err != nil { - t.Fatalf("Health check failed: %v", err) - } - defer resp.Body.Close() - - if resp.StatusCode != 200 { - t.Errorf("Health check should return 200, got %d", resp.StatusCode) - } - - var health map[string]string - if err := json.NewDecoder(resp.Body).Decode(&health); err != nil { - t.Fatalf("Failed to decode health response: %v", err) - } - if health["status"] != "ok" { - t.Errorf("Expected status 'ok', got '%s'", health["status"]) - } - - }) -} - -// TestCORSHeaders tests CORS headers for Claude.ai compatibility -func TestCORSHeaders(t *testing.T) { - mcpCmd := startOAuthServer(t, map[string]string{ - "MCP_FRONT_ENV": "development", - }) - defer stopServer(mcpCmd) - - if !waitForHealthCheck(10) { - t.Fatal("Server failed to start") - } - - // Test preflight request - req, _ := http.NewRequest("OPTIONS", "http://localhost:8080/register", nil) - req.Header.Set("Origin", "https://claude.ai") - req.Header.Set("Access-Control-Request-Method", "POST") - req.Header.Set("Access-Control-Request-Headers", "content-type") - - resp, err := http.DefaultClient.Do(req) - if err != nil { - t.Fatalf("Preflight request failed: %v", err) - } - defer resp.Body.Close() - - if resp.StatusCode != 200 { - t.Errorf("Preflight should return 200, got %d", resp.StatusCode) - } - - // Check CORS headers - expectedHeaders := map[string]string{ - "Access-Control-Allow-Origin": "https://claude.ai", - "Access-Control-Allow-Methods": "GET, POST, OPTIONS", - "Access-Control-Allow-Headers": "Content-Type, Authorization, Cache-Control, mcp-protocol-version", - } - - for header, expected := range expectedHeaders { - actual := resp.Header.Get(header) - if actual != expected { - t.Errorf("Expected %s: '%s', got '%s'", header, expected, actual) - } - } - -} - -// TestToolAdvertisementWithUserTokens tests that tools are advertised even without user tokens -// but fail gracefully when invoked without the required token, and succeed with the token -func TestToolAdvertisementWithUserTokens(t *testing.T) { - // Start OAuth server with user token configuration - startMCPFront(t, "config/config.oauth-usertoken-tools-test.json", - "JWT_SECRET=demo-jwt-secret-32-bytes-exactly!", - "ENCRYPTION_KEY=test-encryption-key-32-bytes-ok!", - "GOOGLE_CLIENT_ID=test-client-id-oauth", - "GOOGLE_CLIENT_SECRET=test-client-secret-oauth", - "GOOGLE_OAUTH_AUTH_URL=http://localhost:9090/auth", - "GOOGLE_OAUTH_TOKEN_URL=http://localhost:9090/token", - "GOOGLE_USERINFO_URL=http://localhost:9090/userinfo", - "MCP_FRONT_ENV=development", - "LOG_LEVEL=debug", - ) - - if !waitForHealthCheck(30) { - t.Fatal("Server failed to start") - } - - // Complete OAuth flow to get a valid access token - accessToken := getOAuthAccessToken(t, "http://localhost:8080/postgres") - - t.Run("ToolsAdvertisedWithoutToken", func(t *testing.T) { - // Create MCP client with OAuth token - mcpClient := NewMCPSSEClient("http://localhost:8080") - mcpClient.SetAuthToken(accessToken) - - // Connect to postgres SSE endpoint - err := mcpClient.Connect() - require.NoError(t, err, "Should connect to postgres SSE endpoint without user token") - defer mcpClient.Close() - - // Request tools list - toolsResp, err := mcpClient.SendMCPRequest("tools/list", map[string]any{}) - require.NoError(t, err, "Should list tools without user token") - - // Verify we got tools - resultMap, ok := toolsResp["result"].(map[string]any) - require.True(t, ok, "Expected result in tools response") - - tools, ok := resultMap["tools"].([]any) - require.True(t, ok, "Expected tools array in result") - assert.NotEmpty(t, tools, "Should have tools advertised") - - // Check for common postgres tools - var toolNames []string - for _, tool := range tools { - if toolMap, ok := tool.(map[string]any); ok { - if name, ok := toolMap["name"].(string); ok { - toolNames = append(toolNames, name) - } - } - } - - assert.Contains(t, toolNames, "query", "Should have query tool") - t.Logf("Successfully advertised tools without user token: %v", toolNames) - }) - - t.Run("ToolInvocationFailsWithoutToken", func(t *testing.T) { - // Create MCP client with OAuth token - mcpClient := NewMCPSSEClient("http://localhost:8080") - mcpClient.SetAuthToken(accessToken) - - // Connect to postgres SSE endpoint - err := mcpClient.Connect() - require.NoError(t, err) - defer mcpClient.Close() - - // Try to invoke a tool without user token - queryParams := map[string]any{ - "name": "query", - "arguments": map[string]any{ - "sql": "SELECT 1", - }, - } - - result, err := mcpClient.SendMCPRequest("tools/call", queryParams) - require.NoError(t, err, "Should get response even without token") - - // MCP protocol returns errors as successful responses with error content - require.NotNil(t, result["result"], "Should have result in response") - - resultMap := result["result"].(map[string]any) - content := resultMap["content"].([]any) - require.NotEmpty(t, content, "Should have content in result") - - contentItem := content[0].(map[string]any) - errorJSON := contentItem["text"].(string) - - // Parse the error JSON - var errorData map[string]any - err = json.Unmarshal([]byte(errorJSON), &errorData) - require.NoError(t, err, "Error should be valid JSON") - - // Verify error structure - errorInfo := errorData["error"].(map[string]any) - assert.Equal(t, "token_required", errorInfo["code"], "Error code should be token_required") - - errorMessage := errorInfo["message"].(string) - assert.Contains(t, errorMessage, "token required", "Error should mention token required") - assert.Contains(t, errorMessage, "/my/tokens", "Error should mention token setup URL") - assert.Contains(t, errorMessage, "Test Service", "Error should mention service name") - - // Verify error data - errData := errorInfo["data"].(map[string]any) - assert.Equal(t, "postgres", errData["service"], "Should identify the service") - assert.Contains(t, errData["tokenSetupUrl"].(string), "/my/tokens", "Should include token setup URL") - - // Verify instructions - instructions := errData["instructions"].(map[string]any) - assert.Contains(t, instructions["ai"].(string), "CRITICAL", "Should have AI instructions") - assert.Contains(t, instructions["human"].(string), "token required", "Should have human instructions") - }) - - t.Run("ToolInvocationSucceedsWithUserToken", func(t *testing.T) { - // Step 1: GET /my/tokens to extract CSRF token - jar, err := cookiejar.New(nil) - require.NoError(t, err) - client := &http.Client{ - Jar: jar, // Need cookie jar for CSRF - CheckRedirect: func(req *http.Request, via []*http.Request) error { - return http.ErrUseLastResponse // Don't follow redirects - }, - } - - req, err := http.NewRequest("GET", "http://localhost:8080/my/tokens", nil) - require.NoError(t, err) - req.Header.Set("Authorization", "Bearer "+accessToken) - - resp, err := client.Do(req) - require.NoError(t, err) - defer resp.Body.Close() - - // Check if we got the page or a redirect - if resp.StatusCode == 302 || resp.StatusCode == 303 { - // Follow the redirect - location := resp.Header.Get("Location") - t.Logf("Got redirect to: %s", location) - - // Allow redirects for this request - client = &http.Client{ - Jar: jar, - } - - req, err = http.NewRequest("GET", "http://localhost:8080/my/tokens", nil) - require.NoError(t, err) - req.Header.Set("Authorization", "Bearer "+accessToken) - - resp, err = client.Do(req) - require.NoError(t, err) - defer resp.Body.Close() - } - - require.Equal(t, 200, resp.StatusCode, "Should be able to access token page") - - // Extract CSRF token from HTML - body, err := io.ReadAll(resp.Body) - require.NoError(t, err) - - // Look for the CSRF token in the form - csrfRegex := regexp.MustCompile(`name="csrf_token" value="([^"]+)"`) - matches := csrfRegex.FindSubmatch(body) - require.Len(t, matches, 2, "Should find CSRF token in form") - csrfToken := string(matches[1]) - - // Step 2: POST to /my/tokens/set with test token - formData := url.Values{ - "service": {"postgres"}, - "token": {"test-user-token-12345"}, - "csrf_token": {csrfToken}, - } - - req, err = http.NewRequest("POST", "http://localhost:8080/my/tokens/set", strings.NewReader(formData.Encode())) - require.NoError(t, err) - req.Header.Set("Authorization", "Bearer "+accessToken) - req.Header.Set("Content-Type", "application/x-www-form-urlencoded") - - resp, err = client.Do(req) - require.NoError(t, err) - defer resp.Body.Close() - - // Check the response - it might be 200 if following redirects - switch resp.StatusCode { - case 200: - // That's fine, it means the token was set and we got the page back - t.Log("Token set successfully, got page response") - case 302, 303: - // Also fine, redirect means success - t.Log("Token set successfully, got redirect") - default: - body, _ := io.ReadAll(resp.Body) - t.Fatalf("Unexpected response setting token: status=%d, body=%s", resp.StatusCode, string(body)) - } - - // Step 3: Now test tool invocation with the token - mcpClient := NewMCPSSEClient("http://localhost:8080") - mcpClient.SetAuthToken(accessToken) - - err = mcpClient.Connect() - require.NoError(t, err, "Should connect to postgres SSE endpoint") - defer mcpClient.Close() - - // Call the query tool with a simple query - queryParams := map[string]any{ - "name": "query", - "arguments": map[string]any{ - "sql": "SELECT 1 as test", - }, - } - - result, err := mcpClient.SendMCPRequest("tools/call", queryParams) - require.NoError(t, err, "Should successfully call tool with token") - - // Verify we got a successful result, not an error - require.NotNil(t, result["result"], "Should have result in response") - - resultMap := result["result"].(map[string]any) - content := resultMap["content"].([]any) - require.NotEmpty(t, content, "Should have content in result") - - contentItem := content[0].(map[string]any) - resultText := contentItem["text"].(string) - - // The result should contain actual query results, not an error - assert.NotContains(t, resultText, "token_required", "Should not have token error") - assert.NotContains(t, resultText, "Token Required", "Should not have token error message") - - // Postgres query result should contain our test value - assert.Contains(t, resultText, "1", "Should contain query result") - - t.Log("Successfully invoked tool with user token") - }) -} - -// Helper functions - -func startOAuthServer(t *testing.T, env map[string]string) *exec.Cmd { - // Start with OAuth config - mcpCmd := exec.Command("../cmd/mcp-front/mcp-front", "-config", "config/config.oauth-test.json") - - // Set default environment - mcpCmd.Env = []string{ - "PATH=" + os.Getenv("PATH"), - "JWT_SECRET=demo-jwt-secret-32-bytes-exactly!", - "ENCRYPTION_KEY=test-encryption-key-32-bytes-ok!", - "GOOGLE_CLIENT_ID=test-client-id-oauth", - "GOOGLE_CLIENT_SECRET=test-client-secret-oauth", - "GOOGLE_OAUTH_AUTH_URL=http://localhost:9090/auth", - "GOOGLE_OAUTH_TOKEN_URL=http://localhost:9090/token", - "GOOGLE_USERINFO_URL=http://localhost:9090/userinfo", - } - - // Override with provided env - for key, value := range env { - mcpCmd.Env = append(mcpCmd.Env, fmt.Sprintf("%s=%s", key, value)) - } - - // Capture stderr for debugging and also output to test log - var stderr bytes.Buffer - mcpCmd.Stderr = io.MultiWriter(&stderr, os.Stderr) - - if err := mcpCmd.Start(); err != nil { - t.Fatalf("Failed to start OAuth server: %v", err) - } - - // Give a moment for immediate failures - time.Sleep(100 * time.Millisecond) - - // Check if process died immediately - if mcpCmd.ProcessState != nil { - t.Fatalf("OAuth server died immediately: %s", stderr.String()) - } - - return mcpCmd -} - -// startOAuthServerWithTokenConfig starts the OAuth server with user token configuration -func startOAuthServerWithTokenConfig(t *testing.T) *exec.Cmd { - // Start with user token config - mcpCmd := exec.Command("../cmd/mcp-front/mcp-front", "-config", "config/config.oauth-token-test.json") - - // Set default environment - mcpCmd.Env = []string{ - "PATH=" + os.Getenv("PATH"), - "JWT_SECRET=demo-jwt-secret-32-bytes-exactly!", - "ENCRYPTION_KEY=test-encryption-key-32-bytes-ok!", - "GOOGLE_CLIENT_ID=test-client-id-oauth", - "GOOGLE_CLIENT_SECRET=test-client-secret-oauth", - "GOOGLE_OAUTH_AUTH_URL=http://localhost:9090/auth", - "GOOGLE_OAUTH_TOKEN_URL=http://localhost:9090/token", - "GOOGLE_USERINFO_URL=http://localhost:9090/userinfo", - "MCP_FRONT_ENV=development", - } - - // Capture stderr for debugging and also output to test log - var stderr bytes.Buffer - mcpCmd.Stderr = io.MultiWriter(&stderr, os.Stderr) - - if err := mcpCmd.Start(); err != nil { - t.Fatalf("Failed to start OAuth server: %v", err) - } - - // Give a moment for immediate failures - time.Sleep(100 * time.Millisecond) - - // Check if process died immediately - if mcpCmd.ProcessState != nil { - t.Fatalf("OAuth server died immediately: %s", stderr.String()) - } - - return mcpCmd -} - -func stopServer(cmd *exec.Cmd) { - if cmd != nil && cmd.Process != nil { - _ = cmd.Process.Kill() - _ = cmd.Wait() - // Give the OS time to release the port - time.Sleep(100 * time.Millisecond) - } -} - -func waitForHealthCheck(seconds int) bool { - for range seconds { - if checkHealth() { - return true - } - time.Sleep(1 * time.Second) - } - return false -} - -func checkHealth() bool { - resp, err := http.Get("http://localhost:8080/health") - if err == nil && resp.StatusCode == 200 { - resp.Body.Close() - return true - } - if resp != nil { - resp.Body.Close() - } - return false -} - -func registerTestClient(t *testing.T) string { - clientReq := map[string]any{ - "redirect_uris": []string{"http://127.0.0.1:6274/oauth/callback"}, - "scope": "openid email profile read write", - } - - body, _ := json.Marshal(clientReq) - resp, err := http.Post( - "http://localhost:8080/register", - "application/json", - bytes.NewBuffer(body), - ) - if err != nil { - t.Fatalf("Failed to register client: %v", err) - } - defer resp.Body.Close() - - if resp.StatusCode != 201 { - body, _ := io.ReadAll(resp.Body) - t.Fatalf("Client registration failed: %d - %s", resp.StatusCode, string(body)) - } - - var clientResp map[string]any - _ = json.NewDecoder(resp.Body).Decode(&clientResp) - return clientResp["client_id"].(string) -} - -// extractCSRFToken extracts the CSRF token from the HTML response -func extractCSRFToken(t *testing.T, html string) string { - // Look for - re := regexp.MustCompile(`]+name="csrf_token"[^>]+value="([^"]+)"`) - matches := re.FindStringSubmatch(html) - require.GreaterOrEqual(t, len(matches), 2, "CSRF token not found in response") - return matches[1] -} - -// contains is a simple helper to check if string contains substring -func contains(s, substr string) bool { - return strings.Contains(s, substr) -} - -// TestServiceOAuthIntegration validates the complete OAuth flow for external services -func TestServiceOAuthIntegration(t *testing.T) { - // Start fake service OAuth provider on port 9091 - fakeService := NewFakeServiceOAuthServer("9091") - err := fakeService.Start() - require.NoError(t, err) - defer func() { _ = fakeService.Stop() }() - - // Start mcp-front with OAuth service configuration - startMCPFront(t, "config/config.oauth-service-integration-test.json", - "JWT_SECRET=demo-jwt-secret-32-bytes-exactly!", - "ENCRYPTION_KEY=test-encryption-key-32-bytes-ok!", - "GOOGLE_CLIENT_ID=test-client-id-oauth", - "GOOGLE_CLIENT_SECRET=test-client-secret-oauth", - "TEST_SERVICE_CLIENT_ID=service-client-id", - "TEST_SERVICE_CLIENT_SECRET=service-client-secret", - "GOOGLE_OAUTH_AUTH_URL=http://localhost:9090/auth", - "GOOGLE_OAUTH_TOKEN_URL=http://localhost:9090/token", - "GOOGLE_USERINFO_URL=http://localhost:9090/userinfo", - "MCP_FRONT_ENV=development", - ) - - if !waitForHealthCheck(30) { - t.Fatal("Server failed to start") - } - - // For this test, we use browser SSO instead of OAuth client flow - // This simulates a user in the browser connecting services - jar, _ := cookiejar.New(nil) - client := &http.Client{Jar: jar} - - // Complete Google OAuth to get browser session - // Access /my/tokens which triggers SSO flow - resp, err := client.Get("http://localhost:8080/my/tokens") - require.NoError(t, err) - defer resp.Body.Close() - - // Should have completed SSO and landed on /my/tokens - require.Equal(t, http.StatusOK, resp.StatusCode) - - t.Run("ServiceOAuthConnectFlow", func(t *testing.T) { - // User clicks "Connect" for the service - req, _ := http.NewRequest("GET", "http://localhost:8080/oauth/connect?service=test-service", nil) - - resp, err := client.Do(req) - require.NoError(t, err) - defer resp.Body.Close() - - // Should complete OAuth flow and redirect back with success - // The http.Client automatically follows redirects: - // 1. /oauth/connect → redirects to localhost:9091/oauth/authorize - // 2. Fake service → redirects to /oauth/callback/test-service?code=... - // 3. Callback → stores token, redirects to /my/tokens with success message - - body, _ := io.ReadAll(resp.Body) - bodyStr := string(body) - - // Final page should show success - assert.Contains(t, bodyStr, "Successfully connected", "Should show success message after OAuth flow") - assert.Contains(t, bodyStr, "Test OAuth Service", "Should mention service name") - }) - - t.Run("ConnectedServiceShownOnTokenPage", func(t *testing.T) { - // After OAuth connection, service should appear as connected - req, _ := http.NewRequest("GET", "http://localhost:8080/my/tokens", nil) - - resp, err := client.Do(req) - require.NoError(t, err) - defer resp.Body.Close() - - body, _ := io.ReadAll(resp.Body) - bodyStr := string(body) - - // Should show the service with connected status - assert.Contains(t, bodyStr, "Test OAuth Service") - // OAuth-connected services show disconnect button, not connect - assert.Contains(t, bodyStr, "Disconnect", "OAuth-connected service should show Disconnect button") - }) -} - -// getOAuthAccessToken completes the OAuth flow and returns a valid access token -func getOAuthAccessToken(t *testing.T, resource string) string { - // Register a test client - clientID := registerTestClient(t) - - // Generate PKCE challenge (must be at least 43 characters) - codeVerifier := "test-code-verifier-that-is-at-least-43-characters-long" - h := sha256.New() - h.Write([]byte(codeVerifier)) - codeChallenge := base64.RawURLEncoding.EncodeToString(h.Sum(nil)) - - // Step 1: Authorization request - authParams := url.Values{ - "response_type": {"code"}, - "client_id": {clientID}, - "redirect_uri": {"http://127.0.0.1:6274/oauth/callback"}, - "code_challenge": {codeChallenge}, - "code_challenge_method": {"S256"}, - "scope": {"openid email profile"}, - "state": {"test-state"}, - "resource": {resource}, - } - - client := &http.Client{ - CheckRedirect: func(req *http.Request, via []*http.Request) error { - return http.ErrUseLastResponse - }, - } - - authResp, err := client.Get("http://localhost:8080/authorize?" + authParams.Encode()) - require.NoError(t, err) - defer authResp.Body.Close() - - // Should redirect to Google OAuth (our mock) - require.Contains(t, []int{302, 303}, authResp.StatusCode, "Should redirect to Google OAuth") - location := authResp.Header.Get("Location") - require.Contains(t, location, "http://localhost:9090/auth", "Should redirect to mock Google OAuth") - - // Parse the redirect to get the state parameter - redirectURL, err := url.Parse(location) - require.NoError(t, err) - _ = redirectURL.Query().Get("state") // state is included in the redirect but not needed for this test - - // Step 2: Follow redirect to mock Google OAuth (which immediately redirects back) - googleResp, err := client.Get(location) - require.NoError(t, err) - defer googleResp.Body.Close() - - // Mock Google OAuth redirects back to callback - require.Contains(t, []int{302, 303}, googleResp.StatusCode, "Mock Google should redirect back") - callbackLocation := googleResp.Header.Get("Location") - require.Contains(t, callbackLocation, "/oauth/callback", "Should redirect to callback") - - // Step 3: Follow callback redirect - callbackResp, err := client.Get(callbackLocation) - require.NoError(t, err) - defer callbackResp.Body.Close() - - // Should redirect to the original redirect_uri with authorization code - require.Contains(t, []int{302, 303}, callbackResp.StatusCode, "Callback should redirect with code") - finalLocation := callbackResp.Header.Get("Location") - - // Parse authorization code from final redirect - finalURL, err := url.Parse(finalLocation) - require.NoError(t, err) - authCode := finalURL.Query().Get("code") - require.NotEmpty(t, authCode, "Should have authorization code") - - // Step 4: Exchange code for token - tokenParams := url.Values{ - "grant_type": {"authorization_code"}, - "code": {authCode}, - "redirect_uri": {"http://127.0.0.1:6274/oauth/callback"}, - "client_id": {clientID}, - "code_verifier": {codeVerifier}, - } - - tokenResp, err := http.PostForm("http://localhost:8080/token", tokenParams) - require.NoError(t, err) - defer tokenResp.Body.Close() - - if tokenResp.StatusCode != 200 { - body, _ := io.ReadAll(tokenResp.Body) - t.Logf("Token exchange failed with status %d: %s", tokenResp.StatusCode, string(body)) - } - - require.Equal(t, 200, tokenResp.StatusCode, "Token exchange should succeed") - - var tokenData map[string]any - err = json.NewDecoder(tokenResp.Body).Decode(&tokenData) - require.NoError(t, err) - - accessToken := tokenData["access_token"].(string) - require.NotEmpty(t, accessToken, "Should have access token") - - return accessToken -} - -// MockUserTokenStore mocks the UserTokenStore interface for testing -type MockUserTokenStore struct { - mock.Mock -} - -func (m *MockUserTokenStore) GetUserToken(ctx context.Context, email, service string) (string, error) { - args := m.Called(ctx, email, service) - return args.String(0), args.Error(1) -} - -func (m *MockUserTokenStore) SetUserToken(ctx context.Context, email, service, token string) error { - args := m.Called(ctx, email, service, token) - return args.Error(0) -} - -func (m *MockUserTokenStore) DeleteUserToken(ctx context.Context, email, service string) error { - args := m.Called(ctx, email, service) - return args.Error(0) -} - -func (m *MockUserTokenStore) ListUserServices(ctx context.Context, email string) ([]string, error) { - args := m.Called(ctx, email) - if args.Get(0) == nil { - return nil, args.Error(1) - } - return args.Get(0).([]string), args.Error(1) -} - -// TestRFC8707ResourceIndicators validates RFC 8707 resource indicator functionality -func TestRFC8707ResourceIndicators(t *testing.T) { - startMCPFront(t, "config/config.oauth-rfc8707-test.json", - "JWT_SECRET=test-jwt-secret-32-bytes-exactly!", - "ENCRYPTION_KEY=test-encryption-key-32-bytes-ok!", - "GOOGLE_CLIENT_ID=test-client-id-for-oauth", - "GOOGLE_CLIENT_SECRET=test-client-secret-for-oauth", - "MCP_FRONT_ENV=development", - "GOOGLE_OAUTH_AUTH_URL=http://localhost:9090/auth", - "GOOGLE_OAUTH_TOKEN_URL=http://localhost:9090/token", - "GOOGLE_USERINFO_URL=http://localhost:9090/userinfo", - ) - - waitForMCPFront(t) - - t.Run("BaseProtectedResourceMetadataReturns404", func(t *testing.T) { - // Base metadata endpoint should return 404, directing clients to per-service endpoints - resp, err := http.Get("http://localhost:8080/.well-known/oauth-protected-resource") - require.NoError(t, err) - defer resp.Body.Close() - - assert.Equal(t, 404, resp.StatusCode, "Base protected resource metadata endpoint should return 404") - - var errResp map[string]any - err = json.NewDecoder(resp.Body).Decode(&errResp) - require.NoError(t, err) - - assert.Contains(t, errResp["message"], "per-service", "Error message should direct to per-service endpoints") - }) - - t.Run("PerServiceProtectedResourceMetadataEndpoint", func(t *testing.T) { - // Per-service metadata endpoint should return service-specific resource URI - resp, err := http.Get("http://localhost:8080/.well-known/oauth-protected-resource/test-sse") - require.NoError(t, err) - defer resp.Body.Close() - - assert.Equal(t, 200, resp.StatusCode, "Per-service protected resource metadata endpoint should exist") - - var metadata map[string]any - err = json.NewDecoder(resp.Body).Decode(&metadata) - require.NoError(t, err) - - // Resource should be service-specific, not base URL - assert.Equal(t, "http://localhost:8080/test-sse", metadata["resource"], - "Resource should be service-specific URL") - - authzServers, ok := metadata["authorization_servers"].([]any) - require.True(t, ok, "Should have authorization_servers array") - require.NotEmpty(t, authzServers) - assert.Equal(t, "http://localhost:8080", authzServers[0], - "Authorization server should be base issuer") - }) - - t.Run("UnknownServiceReturns404", func(t *testing.T) { - resp, err := http.Get("http://localhost:8080/.well-known/oauth-protected-resource/nonexistent-service") - require.NoError(t, err) - defer resp.Body.Close() - - assert.Equal(t, 404, resp.StatusCode, "Unknown service should return 404") - }) - - t.Run("TokenWithResourceParameter", func(t *testing.T) { - clientID := registerTestClient(t) - - codeVerifier := "test-code-verifier-that-is-at-least-43-characters-long" - h := sha256.New() - h.Write([]byte(codeVerifier)) - codeChallenge := base64.RawURLEncoding.EncodeToString(h.Sum(nil)) - - authParams := url.Values{ - "response_type": {"code"}, - "client_id": {clientID}, - "redirect_uri": {"http://127.0.0.1:6274/oauth/callback"}, - "code_challenge": {codeChallenge}, - "code_challenge_method": {"S256"}, - "scope": {"openid email profile"}, - "state": {"test-state"}, - "resource": {"http://localhost:8080/test-sse"}, - } - - client := &http.Client{ - CheckRedirect: func(req *http.Request, via []*http.Request) error { - return http.ErrUseLastResponse - }, - } - - authResp, err := client.Get("http://localhost:8080/authorize?" + authParams.Encode()) - require.NoError(t, err) - defer authResp.Body.Close() - - assert.Contains(t, []int{302, 303}, authResp.StatusCode, "Should redirect to Google OAuth") - - location := authResp.Header.Get("Location") - googleResp, err := client.Get(location) - require.NoError(t, err) - defer googleResp.Body.Close() - - callbackLocation := googleResp.Header.Get("Location") - callbackResp, err := client.Get(callbackLocation) - require.NoError(t, err) - defer callbackResp.Body.Close() - - finalURL, err := url.Parse(callbackResp.Header.Get("Location")) - require.NoError(t, err) - authCode := finalURL.Query().Get("code") - require.NotEmpty(t, authCode, "Should have authorization code") - - tokenParams := url.Values{ - "grant_type": {"authorization_code"}, - "code": {authCode}, - "redirect_uri": {"http://127.0.0.1:6274/oauth/callback"}, - "client_id": {clientID}, - "code_verifier": {codeVerifier}, - } - - tokenResp, err := http.PostForm("http://localhost:8080/token", tokenParams) - require.NoError(t, err) - defer tokenResp.Body.Close() - - require.Equal(t, 200, tokenResp.StatusCode, "Token exchange should succeed") - - var tokenData map[string]any - err = json.NewDecoder(tokenResp.Body).Decode(&tokenData) - require.NoError(t, err) - - testSSEToken := tokenData["access_token"].(string) - require.NotEmpty(t, testSSEToken, "Should have access token") - - t.Logf("Got token with test-sse audience: %s", testSSEToken[:20]+"...") - - // Verify token works for test-sse (matching audience) - req, _ := http.NewRequest("GET", "http://localhost:8080/test-sse/sse", nil) - req.Header.Set("Authorization", "Bearer "+testSSEToken) - req.Header.Set("Accept", "text/event-stream") - - sseResp, err := client.Do(req) - require.NoError(t, err) - defer sseResp.Body.Close() - - assert.Equal(t, 200, sseResp.StatusCode, - "Token with test-sse audience should access /test-sse/sse") - - // Verify token does NOT work for test-streamable (wrong audience) - req, _ = http.NewRequest("GET", "http://localhost:8080/test-streamable/sse", nil) - req.Header.Set("Authorization", "Bearer "+testSSEToken) - req.Header.Set("Accept", "text/event-stream") - - streamableResp, err := client.Do(req) - require.NoError(t, err) - defer streamableResp.Body.Close() - - assert.Equal(t, 401, streamableResp.StatusCode, - "Token with test-sse audience should NOT access /test-streamable/sse") - - wwwAuth := streamableResp.Header.Get("WWW-Authenticate") - assert.Contains(t, wwwAuth, "Bearer resource_metadata=", - "401 response should include RFC 9728 WWW-Authenticate header") - // Per RFC 9728 Section 5.2, the metadata URI should be service-specific - assert.Contains(t, wwwAuth, "/.well-known/oauth-protected-resource/test-streamable", - "401 response should point to per-service metadata endpoint") - }) - - t.Run("401ResponseIncludesServiceSpecificMetadataURI", func(t *testing.T) { - // Request to a protected endpoint without token should get 401 - // with service-specific metadata URI in WWW-Authenticate header - client := &http.Client{ - CheckRedirect: func(req *http.Request, via []*http.Request) error { - return http.ErrUseLastResponse - }, - } - - req, _ := http.NewRequest("GET", "http://localhost:8080/test-sse/sse", nil) - req.Header.Set("Accept", "text/event-stream") - - resp, err := client.Do(req) - require.NoError(t, err) - defer resp.Body.Close() - - assert.Equal(t, 401, resp.StatusCode, "Request without token should return 401") - - wwwAuth := resp.Header.Get("WWW-Authenticate") - assert.Contains(t, wwwAuth, "Bearer resource_metadata=", - "401 response should include RFC 9728 WWW-Authenticate header") - assert.Contains(t, wwwAuth, "/.well-known/oauth-protected-resource/test-sse", - "401 response should point to test-sse specific metadata endpoint") - }) -} diff --git a/integration/oauth_user_tokens_test.go b/integration/oauth_user_tokens_test.go new file mode 100644 index 0000000..6612d9d --- /dev/null +++ b/integration/oauth_user_tokens_test.go @@ -0,0 +1,536 @@ +package integration + +import ( + "context" + "crypto/sha256" + "encoding/base64" + "encoding/json" + "fmt" + "io" + "net/http" + "net/http/cookiejar" + "net/url" + "regexp" + "strings" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/mock" + "github.com/stretchr/testify/require" +) + +// TestUserTokenFlow tests the user token management functionality with browser-based SSO +// This test expects the /my/* routes to work with Google SSO (session-based auth), +// not Bearer token auth. +func TestUserTokenFlow(t *testing.T) { + // Start OAuth server with user token configuration + mcpCmd := startOAuthServerWithTokenConfig(t) + defer stopServer(mcpCmd) + + if !waitForHealthCheck(30) { + t.Fatal("Server failed to start") + } + + // Create a client with cookie jar to simulate browser behavior + jar, _ := cookiejar.New(nil) + client := &http.Client{ + Jar: jar, + CheckRedirect: func(req *http.Request, via []*http.Request) error { + // Allow up to 10 redirects + if len(via) >= 10 { + return fmt.Errorf("too many redirects") + } + return nil + }, + } + + t.Run("UnauthenticatedRedirectsToSSO", func(t *testing.T) { + // Create a client that doesn't follow redirects to test the initial redirect + noRedirectClient := &http.Client{ + CheckRedirect: func(req *http.Request, via []*http.Request) error { + return http.ErrUseLastResponse + }, + } + + // Try to access /my/tokens without authentication + resp, err := noRedirectClient.Get("http://localhost:8080/my/tokens") + require.NoError(t, err) + defer resp.Body.Close() + + // Should get a redirect response + assert.Equal(t, http.StatusFound, resp.StatusCode, "Should get redirect status") + + // Check the redirect location + location := resp.Header.Get("Location") + assert.Contains(t, location, "localhost:9090/auth", "Should redirect to Google OAuth") + assert.Contains(t, location, "client_id=", "Should include client_id") + assert.Contains(t, location, "redirect_uri=", "Should include redirect_uri") + // Extract and validate the state parameter + parsedURL, err := url.Parse(location) + require.NoError(t, err) + stateParam := parsedURL.Query().Get("state") + require.NotEmpty(t, stateParam, "State parameter should be present") + + // State format: "browser:" prefix followed by signed token + // We verify structure but not internal format (that's implementation detail) + assert.True(t, strings.HasPrefix(stateParam, "browser:"), "State should start with browser:") + assert.Greater(t, len(stateParam), len("browser:"), "State should have content after prefix") + + // Verify state contains signature (has dot separator indicating signed data) + stateContent := strings.TrimPrefix(stateParam, "browser:") + assert.Contains(t, stateContent, ".", "Signed state should contain signature separator") + }) + + t.Run("AuthenticatedUserCanAccessTokens", func(t *testing.T) { + // The client with cookie jar will automatically follow the full SSO flow: + // 1. GET /my/tokens -> redirect to Google OAuth + // 2. Google OAuth redirects to /oauth/callback with code + // 3. Callback sets session cookie and redirects to /my/tokens + // 4. Client follows redirect with cookie and gets the page + + resp, err := client.Get("http://localhost:8080/my/tokens") + require.NoError(t, err) + defer resp.Body.Close() + + // After following all redirects, we should be at /my/tokens with 200 OK + assert.Equal(t, http.StatusOK, resp.StatusCode, "Should access /my/tokens after SSO") + finalURL := resp.Request.URL.String() + assert.Contains(t, finalURL, "/my/tokens", "Should end up at /my/tokens after SSO") + + // Read response body + body, err := io.ReadAll(resp.Body) + require.NoError(t, err) + bodyStr := string(body) + + // Should show both services without tokens + assert.Contains(t, bodyStr, "Notion", "Expected Notion service in response") + assert.Contains(t, bodyStr, "GitHub", "Expected GitHub service in response") + }) + + t.Run("SetTokenWithValidation", func(t *testing.T) { + // Assume we're already authenticated from previous test + // Get CSRF token first + resp, err := client.Get("http://localhost:8080/my/tokens") + require.NoError(t, err) + defer resp.Body.Close() + + body, err := io.ReadAll(resp.Body) + require.NoError(t, err) + + // Extract CSRF token from response + csrfToken := extractCSRFToken(t, string(body)) + + // Try to set invalid Notion token + form := url.Values{ + "service": []string{"notion"}, + "token": []string{"invalid-token"}, + "csrf_token": []string{csrfToken}, + } + + req, _ := http.NewRequest("POST", "http://localhost:8080/my/tokens/set", strings.NewReader(form.Encode())) + req.Header.Set("Content-Type", "application/x-www-form-urlencoded") + + // Use custom client that doesn't follow redirects for this test + noRedirectClient := &http.Client{ + Jar: jar, // Use same cookie jar + CheckRedirect: func(req *http.Request, via []*http.Request) error { + return http.ErrUseLastResponse + }, + } + + resp, err = noRedirectClient.Do(req) + require.NoError(t, err) + defer resp.Body.Close() + + // Should redirect with error + assert.Equal(t, http.StatusSeeOther, resp.StatusCode, "Expected redirect") + location := resp.Header.Get("Location") + assert.Contains(t, location, "error", "Expected error in redirect") + + // Get new CSRF token + resp, err = client.Get("http://localhost:8080/my/tokens") + require.NoError(t, err) + defer resp.Body.Close() + + body, err = io.ReadAll(resp.Body) + require.NoError(t, err) + csrfToken = extractCSRFToken(t, string(body)) + + // Set valid Notion token (regex expects exactly 43 chars after "secret_") + form = url.Values{ + "service": {"notion"}, + "token": {"secret_1234567890123456789012345678901234567890123"}, + "csrf_token": {csrfToken}, + } + + req, _ = http.NewRequest("POST", "http://localhost:8080/my/tokens/set", strings.NewReader(form.Encode())) + req.Header.Set("Content-Type", "application/x-www-form-urlencoded") + + resp, err = noRedirectClient.Do(req) + require.NoError(t, err) + defer resp.Body.Close() + + // Should redirect with success + assert.Equal(t, http.StatusSeeOther, resp.StatusCode, "Expected redirect") + location = resp.Header.Get("Location") + assert.Contains(t, location, "success", "Expected success in redirect") + }) +} + +// TestToolAdvertisementWithUserTokens tests that tools are advertised even without user tokens +// but fail gracefully when invoked without the required token, and succeed with the token +func TestToolAdvertisementWithUserTokens(t *testing.T) { + cfg := buildTestConfig("http://localhost:8080", "mcp-front-oauth-usertoken-test", + testOAuthConfigFromEnv(), + map[string]any{"postgres": testPostgresServer(withUserToken())}, + ) + startMCPFront(t, writeTestConfig(t, cfg), + "JWT_SECRET=demo-jwt-secret-32-bytes-exactly!", + "ENCRYPTION_KEY=test-encryption-key-32-bytes-ok!", + "GOOGLE_CLIENT_ID=test-client-id-oauth", + "GOOGLE_CLIENT_SECRET=test-client-secret-oauth", + "MCP_FRONT_ENV=development", + "LOG_LEVEL=debug", + ) + + if !waitForHealthCheck(30) { + t.Fatal("Server failed to start") + } + + // Complete OAuth flow to get a valid access token + accessToken := getOAuthAccessToken(t, "http://localhost:8080/postgres") + + t.Run("ToolsAdvertisedWithoutToken", func(t *testing.T) { + // Create MCP client with OAuth token + mcpClient := NewMCPSSEClient("http://localhost:8080") + mcpClient.SetAuthToken(accessToken) + + // Connect to postgres SSE endpoint + err := mcpClient.Connect() + require.NoError(t, err, "Should connect to postgres SSE endpoint without user token") + defer mcpClient.Close() + + // Request tools list + toolsResp, err := mcpClient.SendMCPRequest("tools/list", map[string]any{}) + require.NoError(t, err, "Should list tools without user token") + + // Verify we got tools + resultMap, ok := toolsResp["result"].(map[string]any) + require.True(t, ok, "Expected result in tools response") + + tools, ok := resultMap["tools"].([]any) + require.True(t, ok, "Expected tools array in result") + assert.NotEmpty(t, tools, "Should have tools advertised") + + // Check for common postgres tools + var toolNames []string + for _, tool := range tools { + if toolMap, ok := tool.(map[string]any); ok { + if name, ok := toolMap["name"].(string); ok { + toolNames = append(toolNames, name) + } + } + } + + assert.Contains(t, toolNames, "execute_sql", "Should have execute_sql tool") + t.Logf("Successfully advertised tools without user token: %v", toolNames) + }) + + t.Run("ToolInvocationFailsWithoutToken", func(t *testing.T) { + // Create MCP client with OAuth token + mcpClient := NewMCPSSEClient("http://localhost:8080") + mcpClient.SetAuthToken(accessToken) + + // Connect to postgres SSE endpoint + err := mcpClient.Connect() + require.NoError(t, err) + defer mcpClient.Close() + + // Try to invoke a tool without user token + queryParams := map[string]any{ + "name": "execute_sql", + "arguments": map[string]any{ + "sql": "SELECT 1", + }, + } + + result, err := mcpClient.SendMCPRequest("tools/call", queryParams) + require.NoError(t, err, "Should get response even without token") + + // MCP protocol returns errors as successful responses with error content + require.NotNil(t, result["result"], "Should have result in response") + + resultMap := result["result"].(map[string]any) + content := resultMap["content"].([]any) + require.NotEmpty(t, content, "Should have content in result") + + contentItem := content[0].(map[string]any) + errorJSON := contentItem["text"].(string) + + // Parse the error JSON + var errorData map[string]any + err = json.Unmarshal([]byte(errorJSON), &errorData) + require.NoError(t, err, "Error should be valid JSON") + + // Verify error structure + errorInfo := errorData["error"].(map[string]any) + assert.Equal(t, "token_required", errorInfo["code"], "Error code should be token_required") + + errorMessage := errorInfo["message"].(string) + assert.Contains(t, errorMessage, "token required", "Error should mention token required") + assert.Contains(t, errorMessage, "/my/tokens", "Error should mention token setup URL") + assert.Contains(t, errorMessage, "Test Service", "Error should mention service name") + + // Verify error data + errData := errorInfo["data"].(map[string]any) + assert.Equal(t, "postgres", errData["service"], "Should identify the service") + assert.Contains(t, errData["tokenSetupUrl"].(string), "/my/tokens", "Should include token setup URL") + + // Verify instructions + instructions := errData["instructions"].(map[string]any) + assert.Contains(t, instructions["ai"].(string), "CRITICAL", "Should have AI instructions") + assert.Contains(t, instructions["human"].(string), "token required", "Should have human instructions") + }) + + t.Run("ToolInvocationSucceedsWithUserToken", func(t *testing.T) { + // Step 1: GET /my/tokens to extract CSRF token + jar, err := cookiejar.New(nil) + require.NoError(t, err) + client := &http.Client{ + Jar: jar, // Need cookie jar for CSRF + CheckRedirect: func(req *http.Request, via []*http.Request) error { + return http.ErrUseLastResponse // Don't follow redirects + }, + } + + req, err := http.NewRequest("GET", "http://localhost:8080/my/tokens", nil) + require.NoError(t, err) + req.Header.Set("Authorization", "Bearer "+accessToken) + + resp, err := client.Do(req) + require.NoError(t, err) + defer resp.Body.Close() + + // Check if we got the page or a redirect + if resp.StatusCode == 302 || resp.StatusCode == 303 { + // Follow the redirect + location := resp.Header.Get("Location") + t.Logf("Got redirect to: %s", location) + + // Allow redirects for this request + client = &http.Client{ + Jar: jar, + } + + req, err = http.NewRequest("GET", "http://localhost:8080/my/tokens", nil) + require.NoError(t, err) + req.Header.Set("Authorization", "Bearer "+accessToken) + + resp, err = client.Do(req) + require.NoError(t, err) + defer resp.Body.Close() + } + + require.Equal(t, 200, resp.StatusCode, "Should be able to access token page") + + // Extract CSRF token from HTML + body, err := io.ReadAll(resp.Body) + require.NoError(t, err) + + // Look for the CSRF token in the form + csrfRegex := regexp.MustCompile(`name="csrf_token" value="([^"]+)"`) + matches := csrfRegex.FindSubmatch(body) + require.Len(t, matches, 2, "Should find CSRF token in form") + csrfToken := string(matches[1]) + + // Step 2: POST to /my/tokens/set with test token + formData := url.Values{ + "service": {"postgres"}, + "token": {"test-user-token-12345"}, + "csrf_token": {csrfToken}, + } + + req, err = http.NewRequest("POST", "http://localhost:8080/my/tokens/set", strings.NewReader(formData.Encode())) + require.NoError(t, err) + req.Header.Set("Authorization", "Bearer "+accessToken) + req.Header.Set("Content-Type", "application/x-www-form-urlencoded") + + resp, err = client.Do(req) + require.NoError(t, err) + defer resp.Body.Close() + + // Check the response - it might be 200 if following redirects + switch resp.StatusCode { + case 200: + // That's fine, it means the token was set and we got the page back + t.Log("Token set successfully, got page response") + case 302, 303: + // Also fine, redirect means success + t.Log("Token set successfully, got redirect") + default: + body, _ := io.ReadAll(resp.Body) + t.Fatalf("Unexpected response setting token: status=%d, body=%s", resp.StatusCode, string(body)) + } + + // Step 3: Now test tool invocation with the token + mcpClient := NewMCPSSEClient("http://localhost:8080") + mcpClient.SetAuthToken(accessToken) + + err = mcpClient.Connect() + require.NoError(t, err, "Should connect to postgres SSE endpoint") + defer mcpClient.Close() + + // Call the query tool with a simple query + queryParams := map[string]any{ + "name": "execute_sql", + "arguments": map[string]any{ + "sql": "SELECT 1 as test", + }, + } + + result, err := mcpClient.SendMCPRequest("tools/call", queryParams) + require.NoError(t, err, "Should successfully call tool with token") + + // Verify we got a successful result, not an error + require.NotNil(t, result["result"], "Should have result in response") + + resultMap := result["result"].(map[string]any) + content := resultMap["content"].([]any) + require.NotEmpty(t, content, "Should have content in result") + + contentItem := content[0].(map[string]any) + resultText := contentItem["text"].(string) + + // The result should contain actual query results, not an error + assert.NotContains(t, resultText, "token_required", "Should not have token error") + assert.NotContains(t, resultText, "Token Required", "Should not have token error message") + + // Postgres query result should contain our test value + assert.Contains(t, resultText, "1", "Should contain query result") + + t.Log("Successfully invoked tool with user token") + }) +} + +// getOAuthAccessTokenForIDP completes the OAuth flow and returns a valid access token. +// expectedIDPHost is the host:port of the expected IDP redirect target (e.g., "localhost:9090" for Google). +func getOAuthAccessTokenForIDP(t *testing.T, resource, expectedIDPHost string) string { + t.Helper() + + clientID := registerTestClient(t) + + codeVerifier := "test-code-verifier-that-is-at-least-43-characters-long" + h := sha256.New() + h.Write([]byte(codeVerifier)) + codeChallenge := base64.RawURLEncoding.EncodeToString(h.Sum(nil)) + + authParams := url.Values{ + "response_type": {"code"}, + "client_id": {clientID}, + "redirect_uri": {"http://127.0.0.1:6274/oauth/callback"}, + "code_challenge": {codeChallenge}, + "code_challenge_method": {"S256"}, + "scope": {"openid email profile"}, + "state": {"test-state"}, + "resource": {resource}, + } + + client := &http.Client{ + CheckRedirect: func(req *http.Request, via []*http.Request) error { + return http.ErrUseLastResponse + }, + } + + // Step 1: Authorization request + authResp, err := client.Get("http://localhost:8080/authorize?" + authParams.Encode()) + require.NoError(t, err) + defer authResp.Body.Close() + + require.Contains(t, []int{302, 303}, authResp.StatusCode, "Should redirect to IDP") + location := authResp.Header.Get("Location") + require.Contains(t, location, expectedIDPHost, "Should redirect to expected IDP") + + // Step 2: Follow redirect to IDP (which immediately redirects back) + idpResp, err := client.Get(location) + require.NoError(t, err) + defer idpResp.Body.Close() + + require.Contains(t, []int{302, 303}, idpResp.StatusCode, "IDP should redirect back") + callbackLocation := idpResp.Header.Get("Location") + require.Contains(t, callbackLocation, "/oauth/callback", "Should redirect to callback") + + // Step 3: Follow callback redirect + callbackResp, err := client.Get(callbackLocation) + require.NoError(t, err) + defer callbackResp.Body.Close() + + require.Contains(t, []int{302, 303}, callbackResp.StatusCode, "Callback should redirect with code") + finalLocation := callbackResp.Header.Get("Location") + + finalURL, err := url.Parse(finalLocation) + require.NoError(t, err) + authCode := finalURL.Query().Get("code") + require.NotEmpty(t, authCode, "Should have authorization code") + + // Step 4: Exchange code for token + tokenParams := url.Values{ + "grant_type": {"authorization_code"}, + "code": {authCode}, + "redirect_uri": {"http://127.0.0.1:6274/oauth/callback"}, + "client_id": {clientID}, + "code_verifier": {codeVerifier}, + } + + tokenResp, err := http.PostForm("http://localhost:8080/token", tokenParams) + require.NoError(t, err) + defer tokenResp.Body.Close() + + if tokenResp.StatusCode != 200 { + body, _ := io.ReadAll(tokenResp.Body) + t.Logf("Token exchange failed with status %d: %s", tokenResp.StatusCode, string(body)) + } + + require.Equal(t, 200, tokenResp.StatusCode, "Token exchange should succeed") + + var tokenData map[string]any + err = json.NewDecoder(tokenResp.Body).Decode(&tokenData) + require.NoError(t, err) + + accessToken := tokenData["access_token"].(string) + require.NotEmpty(t, accessToken, "Should have access token") + + return accessToken +} + +// getOAuthAccessToken completes the OAuth flow using the Google IDP and returns a valid access token. +func getOAuthAccessToken(t *testing.T, resource string) string { + return getOAuthAccessTokenForIDP(t, resource, "localhost:9090") +} + +// MockUserTokenStore mocks the UserTokenStore interface for testing +type MockUserTokenStore struct { + mock.Mock +} + +func (m *MockUserTokenStore) GetUserToken(ctx context.Context, email, service string) (string, error) { + args := m.Called(ctx, email, service) + return args.String(0), args.Error(1) +} + +func (m *MockUserTokenStore) SetUserToken(ctx context.Context, email, service, token string) error { + args := m.Called(ctx, email, service, token) + return args.Error(0) +} + +func (m *MockUserTokenStore) DeleteUserToken(ctx context.Context, email, service string) error { + args := m.Called(ctx, email, service) + return args.Error(0) +} + +func (m *MockUserTokenStore) ListUserServices(ctx context.Context, email string) ([]string, error) { + args := m.Called(ctx, email) + if args.Get(0) == nil { + return nil, args.Error(1) + } + return args.Get(0).([]string), args.Error(1) +} diff --git a/integration/security_test.go b/integration/security_test.go index 2305c92..653ca14 100644 --- a/integration/security_test.go +++ b/integration/security_test.go @@ -15,7 +15,11 @@ func TestSecurityScenarios(t *testing.T) { waitForDB(t) // Start mcp-front - startMCPFront(t, "config/config.test.json") + cfg := buildTestConfig("http://localhost:8080", "mcp-front-test", + nil, + map[string]any{"postgres": testPostgresServer(withBearerTokens("test-token", "alt-test-token"), withLogEnabled())}, + ) + startMCPFront(t, writeTestConfig(t, cfg)) // Wait for server to be ready waitForMCPFront(t) @@ -78,7 +82,7 @@ func TestSecurityScenarios(t *testing.T) { }) t.Run("SQLInjectionAttempts", func(t *testing.T) { - t.Skip("Skipping SQL injection tests, it's not a responsibility of mcp-front to guard mcp/postgres") + t.Skip("Skipping SQL injection tests, it's not a responsibility of mcp-front to guard toolbox") client := NewMCPSSEClient("http://localhost:8080") _ = client.Authenticate() @@ -101,7 +105,7 @@ func TestSecurityScenarios(t *testing.T) { // Try to inject via the query parameter _, err := client.SendMCPRequest("tools/call", map[string]any{ - "name": "query", + "name": "execute_sql", "arguments": map[string]any{ "query": payload, }, @@ -297,7 +301,11 @@ func TestFailureScenarios(t *testing.T) { // Database is already started by TestMain, just wait for readiness waitForDB(t) - startMCPFront(t, "config/config.test.json") + cfg := buildTestConfig("http://localhost:8080", "mcp-front-test", + nil, + map[string]any{"postgres": testPostgresServer(withBearerTokens("test-token", "alt-test-token"), withLogEnabled())}, + ) + startMCPFront(t, writeTestConfig(t, cfg)) // Wait for server to be ready waitForMCPFront(t) diff --git a/integration/streamable_client.go b/integration/streamable_client.go deleted file mode 100644 index 76ab1b7..0000000 --- a/integration/streamable_client.go +++ /dev/null @@ -1 +0,0 @@ -package integration diff --git a/integration/test_clients.go b/integration/test_clients.go new file mode 100644 index 0000000..a5cb357 --- /dev/null +++ b/integration/test_clients.go @@ -0,0 +1,442 @@ +package integration + +import ( + "bufio" + "bytes" + "encoding/json" + "fmt" + "io" + "net/http" + "net/url" + "strings" + "sync" + "time" +) + +// MCPSSEClient simulates an MCP client for testing +type MCPSSEClient struct { + baseURL string + token string + sseConn io.ReadCloser + messageEndpoint string + sseScanner *bufio.Scanner + sessionID string +} + +// NewMCPSSEClient creates a new MCP client for testing +func NewMCPSSEClient(baseURL string) *MCPSSEClient { + return &MCPSSEClient{ + baseURL: baseURL, + } +} + +// Authenticate sets up authentication for the client +func (c *MCPSSEClient) Authenticate() error { + c.token = "test-token" + return nil +} + +// SetAuthToken sets a specific auth token for the client +func (c *MCPSSEClient) SetAuthToken(token string) { + c.token = token +} + +// Connect establishes an SSE connection and retrieves the message endpoint +func (c *MCPSSEClient) Connect() error { + return c.ConnectToServer("postgres") +} + +// ConnectToServer establishes an SSE connection to a specific server +func (c *MCPSSEClient) ConnectToServer(serverName string) error { + // Close any existing connection + if c.sseConn != nil { + c.sseConn.Close() + c.sseConn = nil + c.messageEndpoint = "" + } + + sseURL := c.baseURL + "/" + serverName + "/sse" + tracef("ConnectToServer: requesting %s", sseURL) + + req, err := http.NewRequest("GET", sseURL, nil) + if err != nil { + return fmt.Errorf("failed to create SSE request: %v", err) + } + + req.Header.Set("Accept", "text/event-stream") + req.Header.Set("Authorization", "Bearer "+c.token) + req.Header.Set("Cache-Control", "no-cache") + tracef("ConnectToServer: headers set, making request") + + // Don't use a timeout on the client for SSE + client := &http.Client{} + resp, err := client.Do(req) + if err != nil { + return fmt.Errorf("SSE connection failed: %v", err) + } + + tracef("ConnectToServer: got response status %d", resp.StatusCode) + if resp.StatusCode != 200 { + body, _ := io.ReadAll(resp.Body) + resp.Body.Close() + return fmt.Errorf("SSE connection returned %d: %s", resp.StatusCode, string(body)) + } + + // Store the connection + c.sseConn = resp.Body + c.sseScanner = bufio.NewScanner(resp.Body) + + // Read initial SSE messages to get the endpoint + // For inline servers, we don't get a message endpoint - we use the server path directly + gotEndpointMessage := false + for c.sseScanner.Scan() { + line := c.sseScanner.Text() + tracef("ConnectToServer: SSE line: %s", line) + + // Look for data lines + if after, ok := strings.CutPrefix(line, "data: "); ok { + data := after + + // Check if it's an endpoint message (for inline servers) + if strings.Contains(data, `"type":"endpoint"`) { + gotEndpointMessage = true + // For inline servers, construct the message endpoint + c.messageEndpoint = c.baseURL + "/" + serverName + "/message" + tracef("ConnectToServer: inline server detected, using endpoint: %s", c.messageEndpoint) + break + } + + // Check if it's a message endpoint URL (for stdio servers) + if strings.Contains(data, "http://") || strings.Contains(data, "https://") { + c.messageEndpoint = data + + // Extract session ID from endpoint URL + if u, err := url.Parse(data); err == nil { + c.sessionID = u.Query().Get("sessionId") + } + + tracef("ConnectToServer: found endpoint: %s", c.messageEndpoint) + break + } + } + } + + if c.messageEndpoint == "" && !gotEndpointMessage { + c.sseConn.Close() + c.sseConn = nil + return fmt.Errorf("no message endpoint received") + } + + tracef("Connect: successfully connected to MCP server") + return nil +} + +// ValidateBackendConnectivity checks if we can connect to the MCP server +func (c *MCPSSEClient) ValidateBackendConnectivity() error { + return c.Connect() +} + +// Close closes the SSE connection +func (c *MCPSSEClient) Close() { + if c.sseConn != nil { + c.sseConn.Close() + c.sseConn = nil + c.messageEndpoint = "" + c.sseScanner = nil + } +} + +// SendMCPRequest sends an MCP JSON-RPC request and returns the response +func (c *MCPSSEClient) SendMCPRequest(method string, params any) (map[string]any, error) { + // Ensure we have a connection + if c.messageEndpoint == "" { + if err := c.Connect(); err != nil { + return nil, fmt.Errorf("failed to connect: %v", err) + } + } + + // Send MCP request to the message endpoint + request := map[string]any{ + "jsonrpc": "2.0", + "id": 1, + "method": method, + "params": params, + } + + reqBody, err := json.Marshal(request) + if err != nil { + return nil, err + } + + msgReq, err := http.NewRequest("POST", c.messageEndpoint, bytes.NewBuffer(reqBody)) + if err != nil { + return nil, err + } + + msgReq.Header.Set("Content-Type", "application/json") + msgReq.Header.Set("Authorization", "Bearer "+c.token) + + client := &http.Client{Timeout: 30 * time.Second} + msgResp, err := client.Do(msgReq) + if err != nil { + return nil, err + } + defer msgResp.Body.Close() + + respBody, err := io.ReadAll(msgResp.Body) + if err != nil { + return nil, err + } + + if msgResp.StatusCode != 200 && msgResp.StatusCode != 202 { + return nil, fmt.Errorf("MCP request failed: %d - %s", msgResp.StatusCode, string(respBody)) + } + + // Handle 202 and empty responses - read response from SSE stream + if msgResp.StatusCode == 202 || len(respBody) == 0 { + // Read response from SSE stream + for c.sseScanner.Scan() { + line := c.sseScanner.Text() + + if after, ok := strings.CutPrefix(line, "data: "); ok { + data := after + // Try to parse as JSON + var msg map[string]any + if err := json.Unmarshal([]byte(data), &msg); err == nil { + // Check if this is our response (matching ID) + if id, ok := msg["id"]; ok && id == float64(1) { + return msg, nil + } + } + } + } + + if err := c.sseScanner.Err(); err != nil { + return nil, fmt.Errorf("SSE scanner error: %v", err) + } + + return nil, fmt.Errorf("no response received from SSE stream") + } + + var result map[string]any + if err := json.Unmarshal(respBody, &result); err != nil { + return nil, fmt.Errorf("failed to parse response: %v - %s", err, string(respBody)) + } + + return result, nil +} + +// MCPStreamableClient is a test client for HTTP-Streamable MCP servers +type MCPStreamableClient struct { + baseURL string + serverName string + token string + httpClient *http.Client + + // For GET SSE streaming + sseConn io.ReadCloser + sseScanner *bufio.Scanner + sseCancel chan struct{} + + mu sync.Mutex +} + +// NewMCPStreamableClient creates a new streamable-http test client +func NewMCPStreamableClient(baseURL string) *MCPStreamableClient { + return &MCPStreamableClient{ + baseURL: baseURL, + httpClient: &http.Client{ + Timeout: 30 * time.Second, + }, + } +} + +// SetAuthToken sets the authentication token +func (c *MCPStreamableClient) SetAuthToken(token string) { + c.token = token +} + +// ConnectToServer establishes connection to a streamable-http server +func (c *MCPStreamableClient) ConnectToServer(serverName string) error { + c.mu.Lock() + defer c.mu.Unlock() + + // Close any existing connection + c.close() + + c.serverName = serverName + + // For streamable-http, we can optionally open a GET SSE stream for server-initiated messages + // But it's not required for basic request/response + return c.openSSEStream() +} + +// openSSEStream opens a GET SSE connection for receiving server-initiated messages +func (c *MCPStreamableClient) openSSEStream() error { + url := c.baseURL + "/" + c.serverName + "/sse" + + req, err := http.NewRequest("GET", url, nil) + if err != nil { + return fmt.Errorf("failed to create GET request: %v", err) + } + + req.Header.Set("Accept", "text/event-stream") + req.Header.Set("Authorization", "Bearer "+c.token) + req.Header.Set("Cache-Control", "no-cache") + + // Use a client without timeout for SSE + sseClient := &http.Client{} + resp, err := sseClient.Do(req) + if err != nil { + return fmt.Errorf("SSE connection failed: %v", err) + } + + if resp.StatusCode != 200 { + body, _ := io.ReadAll(resp.Body) + resp.Body.Close() + return fmt.Errorf("SSE connection returned %d: %s", resp.StatusCode, string(body)) + } + + c.sseConn = resp.Body + c.sseScanner = bufio.NewScanner(resp.Body) + c.sseCancel = make(chan struct{}) + + // Start reading SSE messages in background + go c.readSSEMessages() + + return nil +} + +// readSSEMessages reads server-initiated messages from the SSE stream +func (c *MCPStreamableClient) readSSEMessages() { + for { + select { + case <-c.sseCancel: + return + default: + if c.sseScanner.Scan() { + line := c.sseScanner.Text() + if after, ok := strings.CutPrefix(line, "data: "); ok { + data := after + // In a real implementation, we'd process server-initiated messages here + tracef("StreamableClient: received SSE message: %s", data) + } + } else { + // Scanner stopped - connection closed or error + return + } + } + } +} + +// SendMCPRequest sends a JSON-RPC request via POST +func (c *MCPStreamableClient) SendMCPRequest(method string, params any) (map[string]any, error) { + c.mu.Lock() + serverName := c.serverName + c.mu.Unlock() + + if serverName == "" { + return nil, fmt.Errorf("not connected to any server") + } + + // For streamable-http, we POST to the server endpoint + url := c.baseURL + "/" + serverName + "/sse" + + // Construct JSON-RPC request + request := map[string]any{ + "jsonrpc": "2.0", + "id": 1, + "method": method, + "params": params, + } + + body, err := json.Marshal(request) + if err != nil { + return nil, fmt.Errorf("failed to marshal request: %v", err) + } + + req, err := http.NewRequest("POST", url, bytes.NewReader(body)) + if err != nil { + return nil, fmt.Errorf("failed to create POST request: %v", err) + } + + req.Header.Set("Content-Type", "application/json") + req.Header.Set("Authorization", "Bearer "+c.token) + // Accept both JSON and SSE responses + req.Header.Set("Accept", "application/json, text/event-stream") + + resp, err := c.httpClient.Do(req) + if err != nil { + return nil, fmt.Errorf("request failed: %v", err) + } + defer resp.Body.Close() + + // Check content type to determine response format + contentType := resp.Header.Get("Content-Type") + + if strings.HasPrefix(contentType, "text/event-stream") { + // Handle SSE response + return c.handleSSEResponse(resp.Body) + } else { + // Handle JSON response + return c.handleJSONResponse(resp.Body) + } +} + +// handleJSONResponse processes a regular JSON response +func (c *MCPStreamableClient) handleJSONResponse(body io.Reader) (map[string]any, error) { + var response map[string]any + if err := json.NewDecoder(body).Decode(&response); err != nil { + return nil, fmt.Errorf("failed to decode JSON response: %v", err) + } + return response, nil +} + +// handleSSEResponse processes an SSE stream response from a POST +func (c *MCPStreamableClient) handleSSEResponse(body io.Reader) (map[string]any, error) { + scanner := bufio.NewScanner(body) + var lastResponse map[string]any + + for scanner.Scan() { + line := scanner.Text() + if after, ok := strings.CutPrefix(line, "data: "); ok { + data := after + var msg map[string]any + if err := json.Unmarshal([]byte(data), &msg); err == nil { + // Keep the last response with an ID (not a notification) + if _, hasID := msg["id"]; hasID { + lastResponse = msg + } + } + } + } + + if lastResponse == nil { + return nil, fmt.Errorf("no response received in SSE stream") + } + + return lastResponse, nil +} + +// Close closes all connections +func (c *MCPStreamableClient) Close() { + c.mu.Lock() + defer c.mu.Unlock() + c.close() +} + +// close is the internal close method (must be called with lock held) +func (c *MCPStreamableClient) close() { + if c.sseCancel != nil { + close(c.sseCancel) + c.sseCancel = nil + } + + if c.sseConn != nil { + c.sseConn.Close() + c.sseConn = nil + c.sseScanner = nil + } + + c.serverName = "" +} diff --git a/integration/test_fakes.go b/integration/test_fakes.go new file mode 100644 index 0000000..019b8c0 --- /dev/null +++ b/integration/test_fakes.go @@ -0,0 +1,355 @@ +package integration + +import ( + "context" + "encoding/json" + "fmt" + "net/http" + "time" +) + +// FakeGCPServer provides a fake GCP OAuth server for testing +type FakeGCPServer struct { + server *http.Server + port string +} + +// NewFakeGCPServer creates a new fake GCP server +func NewFakeGCPServer(port string) *FakeGCPServer { + mux := http.NewServeMux() + + mux.HandleFunc("/auth", func(w http.ResponseWriter, r *http.Request) { + redirectURI := r.URL.Query().Get("redirect_uri") + state := r.URL.Query().Get("state") + http.Redirect(w, r, fmt.Sprintf("%s?code=test-auth-code&state=%s", redirectURI, state), http.StatusFound) + }) + + mux.HandleFunc("/token", func(w http.ResponseWriter, r *http.Request) { + // Parse the form data + if err := r.ParseForm(); err != nil { + http.Error(w, "Invalid request", http.StatusBadRequest) + return + } + + // Check the authorization code + code := r.FormValue("code") + if code != "test-auth-code" { + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(http.StatusBadRequest) + _ = json.NewEncoder(w).Encode(map[string]any{ + "error": "invalid_grant", + "error_description": "Invalid authorization code", + }) + return + } + + w.Header().Set("Content-Type", "application/json") + _ = json.NewEncoder(w).Encode(map[string]any{ + "access_token": "test-access-token", + "token_type": "Bearer", + "expires_in": 3600, + }) + }) + + mux.HandleFunc("/userinfo", func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "application/json") + _ = json.NewEncoder(w).Encode(map[string]any{ + "email": "test@test.com", + "hd": "test.com", + }) + }) + + server := &http.Server{ + Addr: ":" + port, + Handler: mux, + } + + return &FakeGCPServer{ + server: server, + port: port, + } +} + +// Start starts the fake GCP server +func (m *FakeGCPServer) Start() error { + go func() { + if err := m.server.ListenAndServe(); err != nil && err != http.ErrServerClosed { + panic(err) + } + }() + + time.Sleep(100 * time.Millisecond) + return nil +} + +// Stop stops the fake GCP server +func (m *FakeGCPServer) Stop() error { + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel() + return m.server.Shutdown(ctx) +} + +// FakeServiceOAuthServer provides a fake OAuth server for external services (like Linear, GitHub) +type FakeServiceOAuthServer struct { + server *http.Server + port string +} + +// NewFakeServiceOAuthServer creates a new fake service OAuth server +func NewFakeServiceOAuthServer(port string) *FakeServiceOAuthServer { + mux := http.NewServeMux() + + mux.HandleFunc("/oauth/authorize", func(w http.ResponseWriter, r *http.Request) { + redirectURI := r.URL.Query().Get("redirect_uri") + state := r.URL.Query().Get("state") + http.Redirect(w, r, fmt.Sprintf("%s?code=service-auth-code&state=%s", redirectURI, state), http.StatusFound) + }) + + mux.HandleFunc("/oauth/token", func(w http.ResponseWriter, r *http.Request) { + if err := r.ParseForm(); err != nil { + http.Error(w, "Invalid request", http.StatusBadRequest) + return + } + + code := r.FormValue("code") + if code != "service-auth-code" { + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(http.StatusBadRequest) + _ = json.NewEncoder(w).Encode(map[string]any{ + "error": "invalid_grant", + "error_description": "Invalid authorization code", + }) + return + } + + w.Header().Set("Content-Type", "application/json") + _ = json.NewEncoder(w).Encode(map[string]any{ + "access_token": "service-oauth-access-token", + "refresh_token": "service-oauth-refresh-token", + "token_type": "Bearer", + "expires_in": 3600, + }) + }) + + server := &http.Server{ + Addr: ":" + port, + Handler: mux, + } + + return &FakeServiceOAuthServer{ + server: server, + port: port, + } +} + +// Start starts the fake service OAuth server +func (s *FakeServiceOAuthServer) Start() error { + go func() { + if err := s.server.ListenAndServe(); err != nil && err != http.ErrServerClosed { + panic(err) + } + }() + + time.Sleep(100 * time.Millisecond) + return nil +} + +// Stop stops the fake service OAuth server +func (s *FakeServiceOAuthServer) Stop() error { + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel() + return s.server.Shutdown(ctx) +} + +// FakeGitHubServer simulates GitHub's OAuth and API endpoints for integration testing. +type FakeGitHubServer struct { + server *http.Server + port string +} + +// NewFakeGitHubServer creates a new fake GitHub server. +// orgs controls what organizations the /user/orgs endpoint returns. +func NewFakeGitHubServer(port string, orgs []string) *FakeGitHubServer { + mux := http.NewServeMux() + + mux.HandleFunc("/login/oauth/authorize", func(w http.ResponseWriter, r *http.Request) { + redirectURI := r.URL.Query().Get("redirect_uri") + state := r.URL.Query().Get("state") + http.Redirect(w, r, fmt.Sprintf("%s?code=github-test-code&state=%s", redirectURI, state), http.StatusFound) + }) + + mux.HandleFunc("/login/oauth/access_token", func(w http.ResponseWriter, r *http.Request) { + if err := r.ParseForm(); err != nil { + http.Error(w, "Invalid request", http.StatusBadRequest) + return + } + + code := r.FormValue("code") + if code != "github-test-code" { + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(http.StatusBadRequest) + _ = json.NewEncoder(w).Encode(map[string]any{ + "error": "bad_verification_code", + "error_description": "Invalid authorization code", + }) + return + } + + w.Header().Set("Content-Type", "application/json") + _ = json.NewEncoder(w).Encode(map[string]any{ + "access_token": "github-test-token", + "token_type": "bearer", + "scope": "user:email,read:org", + }) + }) + + mux.HandleFunc("/user", func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "application/json") + _ = json.NewEncoder(w).Encode(map[string]any{ + "id": 12345, + "login": "testuser", + "email": "test@test.com", + "name": "Test User", + "avatar_url": "https://github.com/avatar.jpg", + }) + }) + + mux.HandleFunc("/user/emails", func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "application/json") + _ = json.NewEncoder(w).Encode([]map[string]any{ + {"email": "test@test.com", "primary": true, "verified": true}, + }) + }) + + mux.HandleFunc("/user/orgs", func(w http.ResponseWriter, r *http.Request) { + orgList := make([]map[string]any, len(orgs)) + for i, org := range orgs { + orgList[i] = map[string]any{"login": org} + } + w.Header().Set("Content-Type", "application/json") + _ = json.NewEncoder(w).Encode(orgList) + }) + + server := &http.Server{ + Addr: ":" + port, + Handler: mux, + } + + return &FakeGitHubServer{ + server: server, + port: port, + } +} + +// Start starts the fake GitHub server +func (s *FakeGitHubServer) Start() error { + go func() { + if err := s.server.ListenAndServe(); err != nil && err != http.ErrServerClosed { + panic(err) + } + }() + + time.Sleep(100 * time.Millisecond) + return nil +} + +// Stop stops the fake GitHub server +func (s *FakeGitHubServer) Stop() error { + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel() + return s.server.Shutdown(ctx) +} + +// FakeOIDCServer simulates a generic OIDC provider for integration testing. +// Used for both generic OIDC and Azure tests (Azure is OIDC-compliant). +type FakeOIDCServer struct { + server *http.Server + port string +} + +// NewFakeOIDCServer creates a new fake OIDC server. +func NewFakeOIDCServer(port string) *FakeOIDCServer { + mux := http.NewServeMux() + + baseURL := "http://localhost:" + port + + mux.HandleFunc("/.well-known/openid-configuration", func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "application/json") + _ = json.NewEncoder(w).Encode(map[string]any{ + "issuer": baseURL, + "authorization_endpoint": baseURL + "/authorize", + "token_endpoint": baseURL + "/token", + "userinfo_endpoint": baseURL + "/userinfo", + }) + }) + + mux.HandleFunc("/authorize", func(w http.ResponseWriter, r *http.Request) { + redirectURI := r.URL.Query().Get("redirect_uri") + state := r.URL.Query().Get("state") + http.Redirect(w, r, fmt.Sprintf("%s?code=oidc-test-code&state=%s", redirectURI, state), http.StatusFound) + }) + + mux.HandleFunc("/token", func(w http.ResponseWriter, r *http.Request) { + if err := r.ParseForm(); err != nil { + http.Error(w, "Invalid request", http.StatusBadRequest) + return + } + + code := r.FormValue("code") + if code != "oidc-test-code" { + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(http.StatusBadRequest) + _ = json.NewEncoder(w).Encode(map[string]any{ + "error": "invalid_grant", + "error_description": "Invalid authorization code", + }) + return + } + + w.Header().Set("Content-Type", "application/json") + _ = json.NewEncoder(w).Encode(map[string]any{ + "access_token": "oidc-test-token", + "token_type": "Bearer", + "expires_in": 3600, + }) + }) + + mux.HandleFunc("/userinfo", func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "application/json") + _ = json.NewEncoder(w).Encode(map[string]any{ + "sub": "oidc-12345", + "email": "test@oidc-test.com", + "email_verified": true, + "name": "OIDC User", + }) + }) + + server := &http.Server{ + Addr: ":" + port, + Handler: mux, + } + + return &FakeOIDCServer{ + server: server, + port: port, + } +} + +// Start starts the fake OIDC server +func (s *FakeOIDCServer) Start() error { + go func() { + if err := s.server.ListenAndServe(); err != nil && err != http.ErrServerClosed { + panic(err) + } + }() + + time.Sleep(100 * time.Millisecond) + return nil +} + +// Stop stops the fake OIDC server +func (s *FakeOIDCServer) Stop() error { + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel() + return s.server.Shutdown(ctx) +} diff --git a/integration/test_helpers.go b/integration/test_helpers.go new file mode 100644 index 0000000..1b8715f --- /dev/null +++ b/integration/test_helpers.go @@ -0,0 +1,463 @@ +package integration + +import ( + "encoding/json" + "fmt" + "net/http" + "os" + "os/exec" + "slices" + "strings" + "syscall" + "testing" + "time" +) + +// ToolboxImage is the Docker image for the MCP Toolbox for Databases. +// Used as the MCP server backing integration tests. All test configs +// that reference a postgres MCP server should use this image. +const ToolboxImage = "us-central1-docker.pkg.dev/database-toolbox/toolbox/toolbox:latest" + +// testPostgresDockerArgs returns the Docker args for running the toolbox +// as a stdio MCP server against the test postgres database. +func testPostgresDockerArgs() []string { + return []string{ + "run", "--rm", "-i", "--network", "host", + "-e", "POSTGRES_HOST=localhost", + "-e", "POSTGRES_PORT=15432", + "-e", "POSTGRES_DATABASE=testdb", + "-e", "POSTGRES_USER=testuser", + "-e", "POSTGRES_PASSWORD=testpass", + ToolboxImage, + "--stdio", "--prebuilt", "postgres", + } +} + +// testPostgresServer returns an MCP server config for the test postgres database. +// Options can customize auth, logging, etc. +func testPostgresServer(opts ...serverOption) map[string]any { + args := make([]any, len(testPostgresDockerArgs())) + for i, a := range testPostgresDockerArgs() { + args[i] = a + } + s := map[string]any{ + "transportType": "stdio", + "command": "docker", + "args": args, + } + for _, opt := range opts { + opt(s) + } + return s +} + +type serverOption func(map[string]any) + +func withBearerTokens(tokens ...string) serverOption { + return func(s map[string]any) { + s["serviceAuths"] = []map[string]any{ + {"type": "bearer", "tokens": tokens}, + } + } +} + +func withBasicAuth(username, passwordEnvVar string) serverOption { + return func(s map[string]any) { + auths, _ := s["serviceAuths"].([]map[string]any) + auths = append(auths, map[string]any{ + "type": "basic", + "username": username, + "password": map[string]string{"$env": passwordEnvVar}, + }) + s["serviceAuths"] = auths + } +} + +func withLogEnabled() serverOption { + return func(s map[string]any) { + s["options"] = map[string]any{"logEnabled": true} + } +} + +func withUserToken() serverOption { + return func(s map[string]any) { + s["env"] = map[string]any{ + "USER_TOKEN": map[string]string{"$userToken": "{{token}}"}, + } + s["requiresUserToken"] = true + s["userAuthentication"] = map[string]any{ + "type": "manual", + "displayName": "Test Service", + "instructions": "Enter your test token", + "helpUrl": "https://example.com/help", + } + } +} + +// testOAuthConfigFromEnv returns an OAuth auth config that reads secrets from env vars. +func testOAuthConfigFromEnv() map[string]any { + return map[string]any{ + "kind": "oauth", + "issuer": "http://localhost:8080", + "gcpProject": "test-project", + "idp": map[string]any{ + "provider": "google", + "clientId": map[string]string{"$env": "GOOGLE_CLIENT_ID"}, + "clientSecret": map[string]string{"$env": "GOOGLE_CLIENT_SECRET"}, + "redirectUri": "http://localhost:8080/oauth/callback", + "authorizationUrl": "http://localhost:9090/auth", + "tokenUrl": "http://localhost:9090/token", + "userInfoUrl": "http://localhost:9090/userinfo", + }, + "allowedDomains": []string{"test.com", "stainless.com", "claude.ai"}, + "allowedOrigins": []string{"https://claude.ai"}, + "tokenTtl": "1h", + "storage": "memory", + "jwtSecret": map[string]string{"$env": "JWT_SECRET"}, + "encryptionKey": map[string]string{"$env": "ENCRYPTION_KEY"}, + } +} + +// testGitHubOAuthConfig returns an OAuth auth config for GitHub IDP testing. +// Secrets use $env references; callers must pass JWT_SECRET, ENCRYPTION_KEY, +// and GITHUB_CLIENT_SECRET env vars. +func testGitHubOAuthConfig(allowedOrgs ...string) map[string]any { + idpCfg := map[string]any{ + "provider": "github", + "clientId": "test-github-client-id", + "clientSecret": map[string]string{"$env": "GITHUB_CLIENT_SECRET"}, + "redirectUri": "http://localhost:8080/oauth/callback", + "authorizationUrl": "http://localhost:9092/login/oauth/authorize", + "tokenUrl": "http://localhost:9092/login/oauth/access_token", + "userInfoUrl": "http://localhost:9092", + } + if len(allowedOrgs) > 0 { + idpCfg["allowedOrgs"] = allowedOrgs + } + return map[string]any{ + "kind": "oauth", + "issuer": "http://localhost:8080", + "gcpProject": "test-project", + "idp": idpCfg, + "allowedDomains": []string{"test.com"}, + "allowedOrigins": []string{"https://claude.ai"}, + "tokenTtl": "1h", + "storage": "memory", + "jwtSecret": map[string]string{"$env": "JWT_SECRET"}, + "encryptionKey": map[string]string{"$env": "ENCRYPTION_KEY"}, + } +} + +// testOIDCOAuthConfig returns an OAuth auth config for generic OIDC IDP testing. +// Secrets use $env references; callers must pass JWT_SECRET, ENCRYPTION_KEY, +// and OIDC_CLIENT_SECRET env vars. +func testOIDCOAuthConfig() map[string]any { + return map[string]any{ + "kind": "oauth", + "issuer": "http://localhost:8080", + "gcpProject": "test-project", + "idp": map[string]any{ + "provider": "oidc", + "clientId": "test-oidc-client-id", + "clientSecret": map[string]string{"$env": "OIDC_CLIENT_SECRET"}, + "redirectUri": "http://localhost:8080/oauth/callback", + "authorizationUrl": "http://localhost:9093/authorize", + "tokenUrl": "http://localhost:9093/token", + "userInfoUrl": "http://localhost:9093/userinfo", + }, + "allowedDomains": []string{"oidc-test.com"}, + "allowedOrigins": []string{"https://claude.ai"}, + "tokenTtl": "1h", + "storage": "memory", + "jwtSecret": map[string]string{"$env": "JWT_SECRET"}, + "encryptionKey": map[string]string{"$env": "ENCRYPTION_KEY"}, + } +} + +// testAzureOAuthConfig returns an OAuth auth config for Azure IDP testing. +// Secrets use $env references; callers must pass JWT_SECRET, ENCRYPTION_KEY, +// and AZURE_CLIENT_SECRET env vars. +func testAzureOAuthConfig() map[string]any { + return map[string]any{ + "kind": "oauth", + "issuer": "http://localhost:8080", + "gcpProject": "test-project", + "idp": map[string]any{ + "provider": "azure", + "tenantId": "test-tenant", + "clientId": "test-azure-client-id", + "clientSecret": map[string]string{"$env": "AZURE_CLIENT_SECRET"}, + "redirectUri": "http://localhost:8080/oauth/callback", + "authorizationUrl": "http://localhost:9093/authorize", + "tokenUrl": "http://localhost:9093/token", + "userInfoUrl": "http://localhost:9093/userinfo", + }, + "allowedDomains": []string{"oidc-test.com"}, + "allowedOrigins": []string{"https://claude.ai"}, + "tokenTtl": "1h", + "storage": "memory", + "jwtSecret": map[string]string{"$env": "JWT_SECRET"}, + "encryptionKey": map[string]string{"$env": "ENCRYPTION_KEY"}, + } +} + +// writeTestConfig writes a config map to a temporary JSON file and returns its path. +// The file is automatically cleaned up when the test finishes. +func writeTestConfig(t *testing.T, cfg map[string]any) string { + t.Helper() + data, err := json.MarshalIndent(cfg, "", " ") + if err != nil { + t.Fatalf("Failed to marshal test config: %v", err) + } + f, err := os.CreateTemp(t.TempDir(), "config-*.json") + if err != nil { + t.Fatalf("Failed to create temp config file: %v", err) + } + if _, err := f.Write(data); err != nil { + t.Fatalf("Failed to write temp config: %v", err) + } + if err := f.Close(); err != nil { + t.Fatalf("Failed to close temp config: %v", err) + } + return f.Name() +} + +// buildTestConfig builds a complete mcp-front config map. +func buildTestConfig(baseURL, name string, auth map[string]any, mcpServers map[string]any) map[string]any { + proxy := map[string]any{ + "baseURL": baseURL, + "addr": ":8080", + "name": name, + } + if auth != nil { + proxy["auth"] = auth + } + return map[string]any{ + "version": "v0.0.1-DEV_EDITION_EXPECT_CHANGES", + "proxy": proxy, + "mcpServers": mcpServers, + } +} + +// TestConfig holds all timeout configurations for integration tests +type TestConfig struct { + SessionTimeout string + CleanupInterval string + CleanupWaitTime string + TimerResetWaitTime string + MultiUserWaitTime string +} + +// GetTestConfig returns test configuration from environment variables or defaults +func GetTestConfig() TestConfig { + c := TestConfig{ + SessionTimeout: "10s", + CleanupInterval: "2s", + CleanupWaitTime: "15s", + TimerResetWaitTime: "12s", + MultiUserWaitTime: "15s", + } + + // Override from environment if set + if v := os.Getenv("SESSION_TIMEOUT"); v != "" { + c.SessionTimeout = v + } + if v := os.Getenv("SESSION_CLEANUP_INTERVAL"); v != "" { + c.CleanupInterval = v + } + if v := os.Getenv("TEST_CLEANUP_WAIT_TIME"); v != "" { + c.CleanupWaitTime = v + } + if v := os.Getenv("TEST_TIMER_RESET_WAIT_TIME"); v != "" { + c.TimerResetWaitTime = v + } + if v := os.Getenv("TEST_MULTI_USER_WAIT_TIME"); v != "" { + c.MultiUserWaitTime = v + } + + return c +} + +func waitForDB(t *testing.T) { + waitForSec := 5 + for range waitForSec { + // Check if container is running + psCmd := exec.Command("docker", "compose", "ps", "-q", "test-postgres") + if output, err := psCmd.Output(); err != nil || len(output) == 0 { + time.Sleep(1 * time.Second) + continue + } + + // Check if database is ready + checkCmd := exec.Command("docker", "compose", "exec", "-T", "test-postgres", "pg_isready", "-U", "testuser", "-d", "testdb") + if err := checkCmd.Run(); err == nil { + return + } + time.Sleep(1 * time.Second) + } + + t.Fatalf("Database failed to become ready after %d seconds", waitForSec) +} + +// trace logs a message if TRACE environment variable is set +func trace(t *testing.T, format string, args ...any) { + if os.Getenv("TRACE") == "1" { + t.Logf("TRACE: "+format, args...) + } +} + +// tracef logs a formatted message to stdout if TRACE is set (for use outside tests) +func tracef(format string, args ...any) { + if os.Getenv("TRACE") == "1" { + fmt.Printf("TRACE: "+format+"\n", args...) + } +} + +// startMCPFront starts the mcp-front server with the given config +func startMCPFront(t *testing.T, configPath string, extraEnv ...string) { + mcpCmd := exec.Command("../cmd/mcp-front/mcp-front", "-config", configPath) + + // Get test config for session timeouts + testConfig := GetTestConfig() + + // Build default environment with test timeouts + defaultEnv := []string{ + "SESSION_TIMEOUT=" + testConfig.SessionTimeout, + "SESSION_CLEANUP_INTERVAL=" + testConfig.CleanupInterval, + } + + // Start with system environment + mcpCmd.Env = os.Environ() + + // Apply defaults first + mcpCmd.Env = append(mcpCmd.Env, defaultEnv...) + + // Apply extra env (can override defaults) + mcpCmd.Env = append(mcpCmd.Env, extraEnv...) + + // Pass through LOG_LEVEL and LOG_FORMAT if set + if logLevel := os.Getenv("LOG_LEVEL"); logLevel != "" { + mcpCmd.Env = append(mcpCmd.Env, "LOG_LEVEL="+logLevel) + } + if logFormat := os.Getenv("LOG_FORMAT"); logFormat != "" { + mcpCmd.Env = append(mcpCmd.Env, "LOG_FORMAT="+logFormat) + } + + // Capture output to log file if MCP_LOG_FILE is set + if logFile := os.Getenv("MCP_LOG_FILE"); logFile != "" { + f, err := os.OpenFile(logFile, os.O_CREATE|os.O_WRONLY|os.O_APPEND, 0644) + if err == nil { + mcpCmd.Stderr = f + mcpCmd.Stdout = f + t.Cleanup(func() { f.Close() }) + } + } + + if err := mcpCmd.Start(); err != nil { + t.Fatalf("Failed to start mcp-front: %v", err) + } + + // Register cleanup that runs even if test is killed + t.Cleanup(func() { + stopMCPFront(mcpCmd) + }) +} + +// stopMCPFront stops the mcp-front server gracefully +func stopMCPFront(cmd *exec.Cmd) { + if cmd == nil || cmd.Process == nil { + return + } + + // Try graceful shutdown first (SIGINT) + if err := cmd.Process.Signal(syscall.SIGINT); err != nil { + // If SIGINT fails, force kill immediately + _ = cmd.Process.Kill() + _ = cmd.Wait() + return + } + + // Wait up to 5 seconds for graceful shutdown + done := make(chan error, 1) + go func() { + done <- cmd.Wait() + }() + + select { + case <-done: + // Graceful shutdown completed + return + case <-time.After(5 * time.Second): + // Timeout, force kill + _ = cmd.Process.Kill() + _ = cmd.Wait() + } +} + +// waitForMCPFront waits for the mcp-front server to be ready +func waitForMCPFront(t *testing.T) { + t.Helper() + for range 10 { + resp, err := http.Get("http://localhost:8080/health") + if err == nil && resp.StatusCode == 200 { + resp.Body.Close() + return + } + if resp != nil { + resp.Body.Close() + } + time.Sleep(1 * time.Second) + } + t.Fatal("mcp-front failed to become ready after 10 seconds") +} + +// getMCPContainers returns a list of running toolbox container IDs +func getMCPContainers() []string { + cmd := exec.Command("docker", "ps", "--format", "{{.ID}}", "--filter", "ancestor="+ToolboxImage) + output, err := cmd.Output() + if err != nil { + return nil + } + + var containers []string + for line := range strings.SplitSeq(strings.TrimSpace(string(output)), "\n") { + if line != "" { + containers = append(containers, line) + } + } + return containers +} + +// cleanupContainers forces cleanup of containers that weren't in the initial set +func cleanupContainers(t *testing.T, initialContainers []string) { + time.Sleep(2 * time.Second) + containers := getMCPContainers() + for _, container := range containers { + isInitial := slices.Contains(initialContainers, container) + if !isInitial { + t.Logf("Force stopping container: %s...", container) + if err := exec.Command("docker", "stop", container).Run(); err != nil { + t.Logf("Failed to stop container %s: %v", container, err) + } else { + t.Logf("Stopped container: %s", container) + } + } + } +} + +// TestQuickSmoke provides a fast validation test +func TestQuickSmoke(t *testing.T) { + t.Log("Running quick smoke test...") + + // Just verify the test infrastructure works + client := NewMCPSSEClient("http://localhost:8080") + if client == nil { + t.Fatal("Failed to create client") + } + + if err := client.Authenticate(); err != nil { + t.Fatal("Failed to set up authentication") + } + + t.Log("Quick smoke test passed - test infrastructure is working") +} diff --git a/integration/test_utils.go b/integration/test_utils.go deleted file mode 100644 index 7d21f61..0000000 --- a/integration/test_utils.go +++ /dev/null @@ -1,917 +0,0 @@ -package integration - -import ( - "bufio" - "bytes" - "context" - "encoding/json" - "fmt" - "io" - "net/http" - "net/url" - "os" - "os/exec" - "slices" - "strings" - "sync" - "syscall" - "testing" - "time" -) - -// getDockerComposeCommand returns the appropriate docker compose command -func getDockerComposeCommand() string { - // Check if docker compose v2 is available - cmd := exec.Command("docker", "compose", "version") - if err := cmd.Run(); err == nil { - return "docker compose" - } - return "docker-compose" -} - -// execDockerCompose executes docker compose with the given arguments -func execDockerCompose(args ...string) *exec.Cmd { - dcCmd := getDockerComposeCommand() - if dcCmd == "docker compose" { - allArgs := append([]string{"compose"}, args...) - return exec.Command("docker", allArgs...) - } - return exec.Command("docker-compose", args...) -} - -// MCPSSEClient simulates an MCP client for testing -type MCPSSEClient struct { - baseURL string - token string - sseConn io.ReadCloser - messageEndpoint string - sseScanner *bufio.Scanner - sessionID string -} - -// NewMCPSSEClient creates a new MCP client for testing -func NewMCPSSEClient(baseURL string) *MCPSSEClient { - return &MCPSSEClient{ - baseURL: baseURL, - } -} - -// Authenticate sets up authentication for the client -func (c *MCPSSEClient) Authenticate() error { - c.token = "test-token" - return nil -} - -// SetAuthToken sets a specific auth token for the client -func (c *MCPSSEClient) SetAuthToken(token string) { - c.token = token -} - -// Connect establishes an SSE connection and retrieves the message endpoint -func (c *MCPSSEClient) Connect() error { - return c.ConnectToServer("postgres") -} - -// ConnectToServer establishes an SSE connection to a specific server -func (c *MCPSSEClient) ConnectToServer(serverName string) error { - // Close any existing connection - if c.sseConn != nil { - c.sseConn.Close() - c.sseConn = nil - c.messageEndpoint = "" - } - - sseURL := c.baseURL + "/" + serverName + "/sse" - tracef("ConnectToServer: requesting %s", sseURL) - - req, err := http.NewRequest("GET", sseURL, nil) - if err != nil { - return fmt.Errorf("failed to create SSE request: %v", err) - } - - req.Header.Set("Accept", "text/event-stream") - req.Header.Set("Authorization", "Bearer "+c.token) - req.Header.Set("Cache-Control", "no-cache") - tracef("ConnectToServer: headers set, making request") - - // Don't use a timeout on the client for SSE - client := &http.Client{} - resp, err := client.Do(req) - if err != nil { - return fmt.Errorf("SSE connection failed: %v", err) - } - - tracef("ConnectToServer: got response status %d", resp.StatusCode) - if resp.StatusCode != 200 { - body, _ := io.ReadAll(resp.Body) - resp.Body.Close() - return fmt.Errorf("SSE connection returned %d: %s", resp.StatusCode, string(body)) - } - - // Store the connection - c.sseConn = resp.Body - c.sseScanner = bufio.NewScanner(resp.Body) - - // Read initial SSE messages to get the endpoint - // For inline servers, we don't get a message endpoint - we use the server path directly - gotEndpointMessage := false - for c.sseScanner.Scan() { - line := c.sseScanner.Text() - tracef("ConnectToServer: SSE line: %s", line) - - // Look for data lines - if after, ok := strings.CutPrefix(line, "data: "); ok { - data := after - - // Check if it's an endpoint message (for inline servers) - if strings.Contains(data, `"type":"endpoint"`) { - gotEndpointMessage = true - // For inline servers, construct the message endpoint - c.messageEndpoint = c.baseURL + "/" + serverName + "/message" - tracef("ConnectToServer: inline server detected, using endpoint: %s", c.messageEndpoint) - break - } - - // Check if it's a message endpoint URL (for stdio servers) - if strings.Contains(data, "http://") || strings.Contains(data, "https://") { - c.messageEndpoint = data - - // Extract session ID from endpoint URL - if u, err := url.Parse(data); err == nil { - c.sessionID = u.Query().Get("sessionId") - } - - tracef("ConnectToServer: found endpoint: %s", c.messageEndpoint) - break - } - } - } - - if c.messageEndpoint == "" && !gotEndpointMessage { - c.sseConn.Close() - c.sseConn = nil - return fmt.Errorf("no message endpoint received") - } - - tracef("Connect: successfully connected to MCP server") - return nil -} - -// ValidateBackendConnectivity checks if we can connect to the MCP server -func (c *MCPSSEClient) ValidateBackendConnectivity() error { - return c.Connect() -} - -// Close closes the SSE connection -func (c *MCPSSEClient) Close() { - if c.sseConn != nil { - c.sseConn.Close() - c.sseConn = nil - c.messageEndpoint = "" - c.sseScanner = nil - } -} - -// SendMCPRequest sends an MCP JSON-RPC request and returns the response -func (c *MCPSSEClient) SendMCPRequest(method string, params any) (map[string]any, error) { - // Ensure we have a connection - if c.messageEndpoint == "" { - if err := c.Connect(); err != nil { - return nil, fmt.Errorf("failed to connect: %v", err) - } - } - - // Send MCP request to the message endpoint - request := map[string]any{ - "jsonrpc": "2.0", - "id": 1, - "method": method, - "params": params, - } - - reqBody, err := json.Marshal(request) - if err != nil { - return nil, err - } - - msgReq, err := http.NewRequest("POST", c.messageEndpoint, bytes.NewBuffer(reqBody)) - if err != nil { - return nil, err - } - - msgReq.Header.Set("Content-Type", "application/json") - msgReq.Header.Set("Authorization", "Bearer "+c.token) - - client := &http.Client{Timeout: 30 * time.Second} - msgResp, err := client.Do(msgReq) - if err != nil { - return nil, err - } - defer msgResp.Body.Close() - - respBody, err := io.ReadAll(msgResp.Body) - if err != nil { - return nil, err - } - - if msgResp.StatusCode != 200 && msgResp.StatusCode != 202 { - return nil, fmt.Errorf("MCP request failed: %d - %s", msgResp.StatusCode, string(respBody)) - } - - // Handle 202 and empty responses - read response from SSE stream - if msgResp.StatusCode == 202 || len(respBody) == 0 { - // Read response from SSE stream - for c.sseScanner.Scan() { - line := c.sseScanner.Text() - - if after, ok := strings.CutPrefix(line, "data: "); ok { - data := after - // Try to parse as JSON - var msg map[string]any - if err := json.Unmarshal([]byte(data), &msg); err == nil { - // Check if this is our response (matching ID) - if id, ok := msg["id"]; ok && id == float64(1) { - return msg, nil - } - } - } - } - - if err := c.sseScanner.Err(); err != nil { - return nil, fmt.Errorf("SSE scanner error: %v", err) - } - - return nil, fmt.Errorf("no response received from SSE stream") - } - - var result map[string]any - if err := json.Unmarshal(respBody, &result); err != nil { - return nil, fmt.Errorf("failed to parse response: %v - %s", err, string(respBody)) - } - - return result, nil -} - -// MCPStreamableClient is a test client for HTTP-Streamable MCP servers -type MCPStreamableClient struct { - baseURL string - serverName string - token string - httpClient *http.Client - - // For GET SSE streaming - sseConn io.ReadCloser - sseScanner *bufio.Scanner - sseCancel chan struct{} - - mu sync.Mutex -} - -// NewMCPStreamableClient creates a new streamable-http test client -func NewMCPStreamableClient(baseURL string) *MCPStreamableClient { - return &MCPStreamableClient{ - baseURL: baseURL, - httpClient: &http.Client{ - Timeout: 30 * time.Second, - }, - } -} - -// SetAuthToken sets the authentication token -func (c *MCPStreamableClient) SetAuthToken(token string) { - c.token = token -} - -// ConnectToServer establishes connection to a streamable-http server -func (c *MCPStreamableClient) ConnectToServer(serverName string) error { - c.mu.Lock() - defer c.mu.Unlock() - - // Close any existing connection - c.close() - - c.serverName = serverName - - // For streamable-http, we can optionally open a GET SSE stream for server-initiated messages - // But it's not required for basic request/response - return c.openSSEStream() -} - -// openSSEStream opens a GET SSE connection for receiving server-initiated messages -func (c *MCPStreamableClient) openSSEStream() error { - url := c.baseURL + "/" + c.serverName + "/sse" - - req, err := http.NewRequest("GET", url, nil) - if err != nil { - return fmt.Errorf("failed to create GET request: %v", err) - } - - req.Header.Set("Accept", "text/event-stream") - req.Header.Set("Authorization", "Bearer "+c.token) - req.Header.Set("Cache-Control", "no-cache") - - // Use a client without timeout for SSE - sseClient := &http.Client{} - resp, err := sseClient.Do(req) - if err != nil { - return fmt.Errorf("SSE connection failed: %v", err) - } - - if resp.StatusCode != 200 { - body, _ := io.ReadAll(resp.Body) - resp.Body.Close() - return fmt.Errorf("SSE connection returned %d: %s", resp.StatusCode, string(body)) - } - - c.sseConn = resp.Body - c.sseScanner = bufio.NewScanner(resp.Body) - c.sseCancel = make(chan struct{}) - - // Start reading SSE messages in background - go c.readSSEMessages() - - return nil -} - -// readSSEMessages reads server-initiated messages from the SSE stream -func (c *MCPStreamableClient) readSSEMessages() { - for { - select { - case <-c.sseCancel: - return - default: - if c.sseScanner.Scan() { - line := c.sseScanner.Text() - if after, ok := strings.CutPrefix(line, "data: "); ok { - data := after - // In a real implementation, we'd process server-initiated messages here - tracef("StreamableClient: received SSE message: %s", data) - } - } else { - // Scanner stopped - connection closed or error - return - } - } - } -} - -// SendMCPRequest sends a JSON-RPC request via POST -func (c *MCPStreamableClient) SendMCPRequest(method string, params any) (map[string]any, error) { - c.mu.Lock() - serverName := c.serverName - c.mu.Unlock() - - if serverName == "" { - return nil, fmt.Errorf("not connected to any server") - } - - // For streamable-http, we POST to the server endpoint - url := c.baseURL + "/" + serverName + "/sse" - - // Construct JSON-RPC request - request := map[string]any{ - "jsonrpc": "2.0", - "id": 1, - "method": method, - "params": params, - } - - body, err := json.Marshal(request) - if err != nil { - return nil, fmt.Errorf("failed to marshal request: %v", err) - } - - req, err := http.NewRequest("POST", url, bytes.NewReader(body)) - if err != nil { - return nil, fmt.Errorf("failed to create POST request: %v", err) - } - - req.Header.Set("Content-Type", "application/json") - req.Header.Set("Authorization", "Bearer "+c.token) - // Accept both JSON and SSE responses - req.Header.Set("Accept", "application/json, text/event-stream") - - resp, err := c.httpClient.Do(req) - if err != nil { - return nil, fmt.Errorf("request failed: %v", err) - } - defer resp.Body.Close() - - // Check content type to determine response format - contentType := resp.Header.Get("Content-Type") - - if strings.HasPrefix(contentType, "text/event-stream") { - // Handle SSE response - return c.handleSSEResponse(resp.Body) - } else { - // Handle JSON response - return c.handleJSONResponse(resp.Body) - } -} - -// handleJSONResponse processes a regular JSON response -func (c *MCPStreamableClient) handleJSONResponse(body io.Reader) (map[string]any, error) { - var response map[string]any - if err := json.NewDecoder(body).Decode(&response); err != nil { - return nil, fmt.Errorf("failed to decode JSON response: %v", err) - } - return response, nil -} - -// handleSSEResponse processes an SSE stream response from a POST -func (c *MCPStreamableClient) handleSSEResponse(body io.Reader) (map[string]any, error) { - scanner := bufio.NewScanner(body) - var lastResponse map[string]any - - for scanner.Scan() { - line := scanner.Text() - if after, ok := strings.CutPrefix(line, "data: "); ok { - data := after - var msg map[string]any - if err := json.Unmarshal([]byte(data), &msg); err == nil { - // Keep the last response with an ID (not a notification) - if _, hasID := msg["id"]; hasID { - lastResponse = msg - } - } - } - } - - if lastResponse == nil { - return nil, fmt.Errorf("no response received in SSE stream") - } - - return lastResponse, nil -} - -// Close closes all connections -func (c *MCPStreamableClient) Close() { - c.mu.Lock() - defer c.mu.Unlock() - c.close() -} - -// close is the internal close method (must be called with lock held) -func (c *MCPStreamableClient) close() { - if c.sseCancel != nil { - close(c.sseCancel) - c.sseCancel = nil - } - - if c.sseConn != nil { - c.sseConn.Close() - c.sseConn = nil - c.sseScanner = nil - } - - c.serverName = "" -} - -// FakeGCPServer provides a fake GCP OAuth server for testing -type FakeGCPServer struct { - server *http.Server - port string -} - -// NewFakeGCPServer creates a new fake GCP server -func NewFakeGCPServer(port string) *FakeGCPServer { - mux := http.NewServeMux() - - mux.HandleFunc("/auth", func(w http.ResponseWriter, r *http.Request) { - redirectURI := r.URL.Query().Get("redirect_uri") - state := r.URL.Query().Get("state") - http.Redirect(w, r, fmt.Sprintf("%s?code=test-auth-code&state=%s", redirectURI, state), http.StatusFound) - }) - - mux.HandleFunc("/token", func(w http.ResponseWriter, r *http.Request) { - // Parse the form data - if err := r.ParseForm(); err != nil { - http.Error(w, "Invalid request", http.StatusBadRequest) - return - } - - // Check the authorization code - code := r.FormValue("code") - if code != "test-auth-code" { - w.Header().Set("Content-Type", "application/json") - w.WriteHeader(http.StatusBadRequest) - _ = json.NewEncoder(w).Encode(map[string]any{ - "error": "invalid_grant", - "error_description": "Invalid authorization code", - }) - return - } - - w.Header().Set("Content-Type", "application/json") - _ = json.NewEncoder(w).Encode(map[string]any{ - "access_token": "test-access-token", - "token_type": "Bearer", - "expires_in": 3600, - }) - }) - - mux.HandleFunc("/userinfo", func(w http.ResponseWriter, r *http.Request) { - w.Header().Set("Content-Type", "application/json") - _ = json.NewEncoder(w).Encode(map[string]any{ - "email": "test@test.com", - "hd": "test.com", - }) - }) - - server := &http.Server{ - Addr: ":" + port, - Handler: mux, - } - - return &FakeGCPServer{ - server: server, - port: port, - } -} - -// Start starts the fake GCP server -func (m *FakeGCPServer) Start() error { - go func() { - if err := m.server.ListenAndServe(); err != nil && err != http.ErrServerClosed { - panic(err) - } - }() - - time.Sleep(100 * time.Millisecond) - return nil -} - -// Stop stops the fake GCP server -func (m *FakeGCPServer) Stop() error { - ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) - defer cancel() - return m.server.Shutdown(ctx) -} - -// FakeServiceOAuthServer provides a fake OAuth server for external services (like Linear, GitHub) -type FakeServiceOAuthServer struct { - server *http.Server - port string -} - -// NewFakeServiceOAuthServer creates a new fake service OAuth server -func NewFakeServiceOAuthServer(port string) *FakeServiceOAuthServer { - mux := http.NewServeMux() - - mux.HandleFunc("/oauth/authorize", func(w http.ResponseWriter, r *http.Request) { - redirectURI := r.URL.Query().Get("redirect_uri") - state := r.URL.Query().Get("state") - http.Redirect(w, r, fmt.Sprintf("%s?code=service-auth-code&state=%s", redirectURI, state), http.StatusFound) - }) - - mux.HandleFunc("/oauth/token", func(w http.ResponseWriter, r *http.Request) { - if err := r.ParseForm(); err != nil { - http.Error(w, "Invalid request", http.StatusBadRequest) - return - } - - code := r.FormValue("code") - if code != "service-auth-code" { - w.Header().Set("Content-Type", "application/json") - w.WriteHeader(http.StatusBadRequest) - _ = json.NewEncoder(w).Encode(map[string]any{ - "error": "invalid_grant", - "error_description": "Invalid authorization code", - }) - return - } - - w.Header().Set("Content-Type", "application/json") - _ = json.NewEncoder(w).Encode(map[string]any{ - "access_token": "service-oauth-access-token", - "refresh_token": "service-oauth-refresh-token", - "token_type": "Bearer", - "expires_in": 3600, - }) - }) - - server := &http.Server{ - Addr: ":" + port, - Handler: mux, - } - - return &FakeServiceOAuthServer{ - server: server, - port: port, - } -} - -// Start starts the fake service OAuth server -func (s *FakeServiceOAuthServer) Start() error { - go func() { - if err := s.server.ListenAndServe(); err != nil && err != http.ErrServerClosed { - panic(err) - } - }() - - time.Sleep(100 * time.Millisecond) - return nil -} - -// Stop stops the fake service OAuth server -func (s *FakeServiceOAuthServer) Stop() error { - ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) - defer cancel() - return s.server.Shutdown(ctx) -} - -// TestEnvironment manages the complete test environment -type TestEnvironment struct { - dbCmd *exec.Cmd - mcpCmd *exec.Cmd - fakeGCP *FakeGCPServer - client *MCPSSEClient -} - -// SetupTestEnvironment creates and starts all components needed for testing -func SetupTestEnvironment(t *testing.T) *TestEnvironment { - env := &TestEnvironment{} - - // Start test database - t.Log("🚀 Starting test database...") - env.dbCmd = execDockerCompose("-f", "config/docker-compose.test.yml", "up", "-d") - if err := env.dbCmd.Run(); err != nil { - t.Fatalf("Failed to start test database: %v", err) - } - - time.Sleep(10 * time.Second) - - // Start mock GCP server - t.Log("🚀 Starting mock GCP server...") - env.fakeGCP = NewFakeGCPServer("9090") - if err := env.fakeGCP.Start(); err != nil { - t.Fatalf("Failed to start mock GCP server: %v", err) - } - - // Start mcp-front - t.Log("🚀 Starting mcp-front...") - env.mcpCmd = exec.Command("../cmd/mcp-front/mcp-front", "-config", "config/config.test.json") - - // Capture stderr to log file if MCP_LOG_FILE is set - if logFile := os.Getenv("MCP_LOG_FILE"); logFile != "" { - f, err := os.OpenFile(logFile, os.O_CREATE|os.O_WRONLY|os.O_APPEND, 0644) - if err == nil { - env.mcpCmd.Stderr = f - env.mcpCmd.Stdout = f - t.Cleanup(func() { f.Close() }) - } - } - - if err := env.mcpCmd.Start(); err != nil { - t.Fatalf("Failed to start mcp-front: %v", err) - } - - time.Sleep(15 * time.Second) - - // Create and authenticate client - env.client = NewMCPSSEClient("http://localhost:8080") - if err := env.client.Authenticate(); err != nil { - t.Fatalf("Authentication failed: %v", err) - } - - return env -} - -// Cleanup stops all test environment components -func (env *TestEnvironment) Cleanup() { - if env.mcpCmd != nil && env.mcpCmd.Process != nil { - _ = env.mcpCmd.Process.Kill() - } - - if env.fakeGCP != nil { - _ = env.fakeGCP.Stop() - } - - if env.dbCmd != nil { - downCmd := execDockerCompose("-f", "config/docker-compose.test.yml", "down", "-v") - _ = downCmd.Run() - } -} - -// TestConfig holds all timeout configurations for integration tests -type TestConfig struct { - SessionTimeout string - CleanupInterval string - CleanupWaitTime string - TimerResetWaitTime string - MultiUserWaitTime string -} - -// GetTestConfig returns test configuration from environment variables or defaults -func GetTestConfig() TestConfig { - c := TestConfig{ - SessionTimeout: "10s", - CleanupInterval: "2s", - CleanupWaitTime: "15s", - TimerResetWaitTime: "12s", - MultiUserWaitTime: "15s", - } - - // Override from environment if set - if v := os.Getenv("SESSION_TIMEOUT"); v != "" { - c.SessionTimeout = v - } - if v := os.Getenv("SESSION_CLEANUP_INTERVAL"); v != "" { - c.CleanupInterval = v - } - if v := os.Getenv("TEST_CLEANUP_WAIT_TIME"); v != "" { - c.CleanupWaitTime = v - } - if v := os.Getenv("TEST_TIMER_RESET_WAIT_TIME"); v != "" { - c.TimerResetWaitTime = v - } - if v := os.Getenv("TEST_MULTI_USER_WAIT_TIME"); v != "" { - c.MultiUserWaitTime = v - } - - return c -} - -func waitForDB(t *testing.T) { - waitForSec := 5 - for range waitForSec { - // Check if container is running - psCmd := exec.Command("docker", "compose", "ps", "-q", "test-postgres") - if output, err := psCmd.Output(); err != nil || len(output) == 0 { - time.Sleep(1 * time.Second) - continue - } - - // Check if database is ready - checkCmd := exec.Command("docker", "compose", "exec", "-T", "test-postgres", "pg_isready", "-U", "testuser", "-d", "testdb") - if err := checkCmd.Run(); err == nil { - return - } - time.Sleep(1 * time.Second) - } - - t.Fatalf("Database failed to become ready after %d seconds", waitForSec) -} - -// trace logs a message if TRACE environment variable is set -func trace(t *testing.T, format string, args ...any) { - if os.Getenv("TRACE") == "1" { - t.Logf("TRACE: "+format, args...) - } -} - -// tracef logs a formatted message to stdout if TRACE is set (for use outside tests) -func tracef(format string, args ...any) { - if os.Getenv("TRACE") == "1" { - fmt.Printf("TRACE: "+format+"\n", args...) - } -} - -// startMCPFront starts the mcp-front server with the given config -func startMCPFront(t *testing.T, configPath string, extraEnv ...string) { - mcpCmd := exec.Command("../cmd/mcp-front/mcp-front", "-config", configPath) - - // Get test config for session timeouts - testConfig := GetTestConfig() - - // Build default environment with test timeouts - defaultEnv := []string{ - "SESSION_TIMEOUT=" + testConfig.SessionTimeout, - "SESSION_CLEANUP_INTERVAL=" + testConfig.CleanupInterval, - } - - // Start with system environment - mcpCmd.Env = os.Environ() - - // Apply defaults first - mcpCmd.Env = append(mcpCmd.Env, defaultEnv...) - - // Apply extra env (can override defaults) - mcpCmd.Env = append(mcpCmd.Env, extraEnv...) - - // Pass through LOG_LEVEL and LOG_FORMAT if set - if logLevel := os.Getenv("LOG_LEVEL"); logLevel != "" { - mcpCmd.Env = append(mcpCmd.Env, "LOG_LEVEL="+logLevel) - } - if logFormat := os.Getenv("LOG_FORMAT"); logFormat != "" { - mcpCmd.Env = append(mcpCmd.Env, "LOG_FORMAT="+logFormat) - } - - // Capture output to log file if MCP_LOG_FILE is set - if logFile := os.Getenv("MCP_LOG_FILE"); logFile != "" { - f, err := os.OpenFile(logFile, os.O_CREATE|os.O_WRONLY|os.O_APPEND, 0644) - if err == nil { - mcpCmd.Stderr = f - mcpCmd.Stdout = f - t.Cleanup(func() { f.Close() }) - } - } - - if err := mcpCmd.Start(); err != nil { - t.Fatalf("Failed to start mcp-front: %v", err) - } - - // Register cleanup that runs even if test is killed - t.Cleanup(func() { - stopMCPFront(mcpCmd) - }) -} - -// stopMCPFront stops the mcp-front server gracefully -func stopMCPFront(cmd *exec.Cmd) { - if cmd == nil || cmd.Process == nil { - return - } - - // Try graceful shutdown first (SIGINT) - if err := cmd.Process.Signal(syscall.SIGINT); err != nil { - // If SIGINT fails, force kill immediately - _ = cmd.Process.Kill() - _ = cmd.Wait() - return - } - - // Wait up to 5 seconds for graceful shutdown - done := make(chan error, 1) - go func() { - done <- cmd.Wait() - }() - - select { - case <-done: - // Graceful shutdown completed - return - case <-time.After(5 * time.Second): - // Timeout, force kill - _ = cmd.Process.Kill() - _ = cmd.Wait() - } -} - -// waitForMCPFront waits for the mcp-front server to be ready -func waitForMCPFront(t *testing.T) { - t.Helper() - for range 10 { - resp, err := http.Get("http://localhost:8080/health") - if err == nil && resp.StatusCode == 200 { - resp.Body.Close() - return - } - if resp != nil { - resp.Body.Close() - } - time.Sleep(1 * time.Second) - } - t.Fatal("mcp-front failed to become ready after 10 seconds") -} - -// getMCPContainers returns a list of running mcp/postgres container IDs -func getMCPContainers() []string { - cmd := exec.Command("docker", "ps", "--format", "{{.ID}}", "--filter", "ancestor=mcp/postgres") - output, err := cmd.Output() - if err != nil { - return nil - } - - var containers []string - for line := range strings.SplitSeq(strings.TrimSpace(string(output)), "\n") { - if line != "" { - containers = append(containers, line) - } - } - return containers -} - -// cleanupContainers forces cleanup of containers that weren't in the initial set -func cleanupContainers(t *testing.T, initialContainers []string) { - time.Sleep(2 * time.Second) - containers := getMCPContainers() - for _, container := range containers { - isInitial := slices.Contains(initialContainers, container) - if !isInitial { - t.Logf("Force stopping container: %s...", container) - if err := exec.Command("docker", "stop", container).Run(); err != nil { - t.Logf("Failed to stop container %s: %v", container, err) - } else { - t.Logf("Stopped container: %s", container) - } - } - } -} - -// TestQuickSmoke provides a fast validation test -func TestQuickSmoke(t *testing.T) { - t.Log("Running quick smoke test...") - - // Just verify the test infrastructure works - client := NewMCPSSEClient("http://localhost:8080") - if client == nil { - t.Fatal("Failed to create client") - } - - if err := client.Authenticate(); err != nil { - t.Fatal("Failed to set up authentication") - } - - t.Log("Quick smoke test passed - test infrastructure is working") -} diff --git a/internal/adminauth/admin.go b/internal/adminauth/admin.go index 5f35c18..977c1a5 100644 --- a/internal/adminauth/admin.go +++ b/internal/adminauth/admin.go @@ -22,8 +22,16 @@ func IsAdmin(ctx context.Context, email string, adminConfig *config.AdminConfig, return true } - // Check if user is a promoted admin in storage + // Check if user is a promoted admin in storage. + // Try exact match first (O(1)), fall back to case-insensitive scan + // since emails may be stored with different casing. if store != nil { + user, err := store.GetUser(ctx, normalizedEmail) + if err == nil { + return user.IsAdmin + } + + // Fallback: case-insensitive scan for emails stored with different casing users, err := store.GetAllUsers(ctx) if err == nil { for _, user := range users { diff --git a/internal/browserauth/session.go b/internal/browserauth/session.go deleted file mode 100644 index a1876d9..0000000 --- a/internal/browserauth/session.go +++ /dev/null @@ -1,15 +0,0 @@ -package browserauth - -import "time" - -// SessionCookie represents the data stored in encrypted browser session cookies -type SessionCookie struct { - Email string `json:"email"` - Expires time.Time `json:"expires"` -} - -// AuthorizationState represents the OAuth authorization code flow state parameter -type AuthorizationState struct { - Nonce string `json:"nonce"` - ReturnURL string `json:"return_url"` -} diff --git a/internal/envutil/envutil.go b/internal/config/env.go similarity index 94% rename from internal/envutil/envutil.go rename to internal/config/env.go index ef5acfc..6297707 100644 --- a/internal/envutil/envutil.go +++ b/internal/config/env.go @@ -1,4 +1,4 @@ -package envutil +package config import ( "os" diff --git a/internal/config/load.go b/internal/config/load.go index 326cec1..768ebf0 100644 --- a/internal/config/load.go +++ b/internal/config/load.go @@ -58,11 +58,11 @@ func validateRawConfig(rawConfig map[string]any) error { if proxy, ok := rawConfig["proxy"].(map[string]any); ok { if auth, ok := proxy["auth"].(map[string]any); ok { if kind, ok := auth["kind"].(string); ok && kind == "oauth" { + // Validate top-level auth secrets secrets := []struct { name string required bool }{ - {"googleClientSecret", true}, {"jwtSecret", true}, {"encryptionKey", true}, // Always required for OAuth } @@ -88,6 +88,20 @@ func validateRawConfig(rawConfig map[string]any) error { } } } + + // Validate IDP secret (clientSecret must be env ref) + if idp, ok := auth["idp"].(map[string]any); ok { + if clientSecret, exists := idp["clientSecret"]; exists { + if _, isString := clientSecret.(string); isString { + return fmt.Errorf("idp.clientSecret must use environment variable reference for security") + } + if refMap, isMap := clientSecret.(map[string]any); isMap { + if _, hasEnv := refMap["$env"]; !hasEnv { + return fmt.Errorf("idp.clientSecret must use {\"$env\": \"VAR_NAME\"} format") + } + } + } + } } } } @@ -148,24 +162,54 @@ func validateOAuthConfig(oauth *OAuthAuthConfig) error { if oauth.Issuer == "" { return fmt.Errorf("issuer is required") } - if oauth.GoogleClientID == "" { - return fmt.Errorf("googleClientId is required") + + // Validate IDP configuration + if oauth.IDP.Provider == "" { + return fmt.Errorf("idp.provider is required (google, azure, github, or oidc)") + } + if oauth.IDP.ClientID == "" { + return fmt.Errorf("idp.clientId is required") } - if oauth.GoogleClientSecret == "" { - return fmt.Errorf("googleClientSecret is required") + if oauth.IDP.ClientSecret == "" { + return fmt.Errorf("idp.clientSecret is required") } - if oauth.GoogleRedirectURI == "" { - return fmt.Errorf("googleRedirectUri is required") + if oauth.IDP.RedirectURI == "" { + return fmt.Errorf("idp.redirectUri is required") } + + // Provider-specific validation + switch oauth.IDP.Provider { + case "google": + // No additional validation needed + case "azure": + if oauth.IDP.TenantID == "" { + return fmt.Errorf("idp.tenantId is required for Azure AD") + } + case "github": + // No additional validation needed + case "oidc": + // Either discoveryUrl or all manual endpoints required + if oauth.IDP.DiscoveryURL == "" { + if oauth.IDP.AuthorizationURL == "" || oauth.IDP.TokenURL == "" || oauth.IDP.UserInfoURL == "" { + return fmt.Errorf("idp.discoveryUrl or all of (authorizationUrl, tokenUrl, userInfoUrl) required for OIDC") + } + } + default: + return fmt.Errorf("unsupported idp.provider: %s (must be google, azure, github, or oidc)", oauth.IDP.Provider) + } + if len(oauth.JWTSecret) < 32 { return fmt.Errorf("jwtSecret must be at least 32 characters (got %d). Generate with: openssl rand -base64 32", len(oauth.JWTSecret)) } if len(oauth.EncryptionKey) != 32 { return fmt.Errorf("encryptionKey must be exactly 32 characters (got %d). Generate with: openssl rand -base64 32 | head -c 32", len(oauth.EncryptionKey)) } - if len(oauth.AllowedDomains) == 0 { - return fmt.Errorf("at least one allowed domain is required") + + // Domain or org validation - at least one access control mechanism required + if len(oauth.AllowedDomains) == 0 && len(oauth.IDP.AllowedOrgs) == 0 { + return fmt.Errorf("at least one of allowedDomains or idp.allowedOrgs is required") } + if oauth.Storage == "firestore" { if oauth.GCPProject == "" { return fmt.Errorf("gcpProject is required when using firestore storage") diff --git a/internal/config/load_test.go b/internal/config/load_test.go index a79b8e3..4fb5659 100644 --- a/internal/config/load_test.go +++ b/internal/config/load_test.go @@ -41,17 +41,20 @@ func TestValidateConfig_UserTokensRequireOAuth(t *testing.T) { BaseURL: "https://test.example.com", Addr: ":8080", Auth: &OAuthAuthConfig{ - Kind: "oauth", - Issuer: "https://auth.example.com", - GoogleClientID: "test-client", - GoogleClientSecret: "test-secret", - GoogleRedirectURI: "https://test.example.com/callback", - JWTSecret: "test-jwt-secret-must-be-32-bytes-long", - EncryptionKey: "test-encryption-key-32-bytes-ok!", - AllowedDomains: []string{"example.com"}, - AllowedOrigins: []string{"https://test.example.com"}, - TokenTTL: time.Hour, - RefreshTokenTTL: 30 * 24 * time.Hour, + Kind: "oauth", + Issuer: "https://auth.example.com", + IDP: IDPConfig{ + Provider: "google", + ClientID: "test-client", + ClientSecret: "test-secret", + RedirectURI: "https://test.example.com/callback", + }, + JWTSecret: "test-jwt-secret-must-be-32-bytes-long", + EncryptionKey: "test-encryption-key-32-bytes-ok!", + AllowedDomains: []string{"example.com"}, + AllowedOrigins: []string{"https://test.example.com"}, + TokenTTL: time.Hour, + RefreshTokenTTL: 30 * 24 * time.Hour, }, }, MCPServers: map[string]*MCPClientConfig{ @@ -98,17 +101,20 @@ func TestValidateConfig_SessionConfig(t *testing.T) { BaseURL: "https://test.example.com", Addr: ":8080", Auth: &OAuthAuthConfig{ - Kind: "oauth", - Issuer: "https://auth.example.com", - GoogleClientID: "test-client", - GoogleClientSecret: "test-secret", - GoogleRedirectURI: "https://test.example.com/callback", - JWTSecret: "test-jwt-secret-must-be-32-bytes-long", - EncryptionKey: "test-encryption-key-32-bytes-ok!", - AllowedDomains: []string{"example.com"}, - AllowedOrigins: []string{"https://test.example.com"}, - TokenTTL: time.Hour, - RefreshTokenTTL: 30 * 24 * time.Hour, + Kind: "oauth", + Issuer: "https://auth.example.com", + IDP: IDPConfig{ + Provider: "google", + ClientID: "test-client", + ClientSecret: "test-secret", + RedirectURI: "https://test.example.com/callback", + }, + JWTSecret: "test-jwt-secret-must-be-32-bytes-long", + EncryptionKey: "test-encryption-key-32-bytes-ok!", + AllowedDomains: []string{"example.com"}, + AllowedOrigins: []string{"https://test.example.com"}, + TokenTTL: time.Hour, + RefreshTokenTTL: 30 * 24 * time.Hour, }, Sessions: &SessionConfig{ Timeout: 10 * time.Minute, @@ -128,17 +134,20 @@ func TestValidateConfig_SessionConfig(t *testing.T) { BaseURL: "https://test.example.com", Addr: ":8080", Auth: &OAuthAuthConfig{ - Kind: "oauth", - Issuer: "https://auth.example.com", - GoogleClientID: "test-client", - GoogleClientSecret: "test-secret", - GoogleRedirectURI: "https://test.example.com/callback", - JWTSecret: "test-jwt-secret-must-be-32-bytes-long", - EncryptionKey: "test-encryption-key-32-bytes-ok!", - AllowedDomains: []string{"example.com"}, - AllowedOrigins: []string{"https://test.example.com"}, - TokenTTL: time.Hour, - RefreshTokenTTL: 30 * 24 * time.Hour, + Kind: "oauth", + Issuer: "https://auth.example.com", + IDP: IDPConfig{ + Provider: "google", + ClientID: "test-client", + ClientSecret: "test-secret", + RedirectURI: "https://test.example.com/callback", + }, + JWTSecret: "test-jwt-secret-must-be-32-bytes-long", + EncryptionKey: "test-encryption-key-32-bytes-ok!", + AllowedDomains: []string{"example.com"}, + AllowedOrigins: []string{"https://test.example.com"}, + TokenTTL: time.Hour, + RefreshTokenTTL: 30 * 24 * time.Hour, }, Sessions: &SessionConfig{ Timeout: -1 * time.Minute, @@ -156,17 +165,20 @@ func TestValidateConfig_SessionConfig(t *testing.T) { BaseURL: "https://test.example.com", Addr: ":8080", Auth: &OAuthAuthConfig{ - Kind: "oauth", - Issuer: "https://auth.example.com", - GoogleClientID: "test-client", - GoogleClientSecret: "test-secret", - GoogleRedirectURI: "https://test.example.com/callback", - JWTSecret: "test-jwt-secret-must-be-32-bytes-long", - EncryptionKey: "test-encryption-key-32-bytes-ok!", - AllowedDomains: []string{"example.com"}, - AllowedOrigins: []string{"https://test.example.com"}, - TokenTTL: time.Hour, - RefreshTokenTTL: 30 * 24 * time.Hour, + Kind: "oauth", + Issuer: "https://auth.example.com", + IDP: IDPConfig{ + Provider: "google", + ClientID: "test-client", + ClientSecret: "test-secret", + RedirectURI: "https://test.example.com/callback", + }, + JWTSecret: "test-jwt-secret-must-be-32-bytes-long", + EncryptionKey: "test-encryption-key-32-bytes-ok!", + AllowedDomains: []string{"example.com"}, + AllowedOrigins: []string{"https://test.example.com"}, + TokenTTL: time.Hour, + RefreshTokenTTL: 30 * 24 * time.Hour, }, Sessions: &SessionConfig{ Timeout: 10 * time.Minute, @@ -184,17 +196,20 @@ func TestValidateConfig_SessionConfig(t *testing.T) { BaseURL: "https://test.example.com", Addr: ":8080", Auth: &OAuthAuthConfig{ - Kind: "oauth", - Issuer: "https://auth.example.com", - GoogleClientID: "test-client", - GoogleClientSecret: "test-secret", - GoogleRedirectURI: "https://test.example.com/callback", - JWTSecret: "test-jwt-secret-must-be-32-bytes-long", - EncryptionKey: "test-encryption-key-32-bytes-ok!", - AllowedDomains: []string{"example.com"}, - AllowedOrigins: []string{"https://test.example.com"}, - TokenTTL: time.Hour, - RefreshTokenTTL: 30 * 24 * time.Hour, + Kind: "oauth", + Issuer: "https://auth.example.com", + IDP: IDPConfig{ + Provider: "google", + ClientID: "test-client", + ClientSecret: "test-secret", + RedirectURI: "https://test.example.com/callback", + }, + JWTSecret: "test-jwt-secret-must-be-32-bytes-long", + EncryptionKey: "test-encryption-key-32-bytes-ok!", + AllowedDomains: []string{"example.com"}, + AllowedOrigins: []string{"https://test.example.com"}, + TokenTTL: time.Hour, + RefreshTokenTTL: 30 * 24 * time.Hour, }, Sessions: &SessionConfig{ Timeout: 0, diff --git a/internal/config/types.go b/internal/config/types.go index 515f202..87d714e 100644 --- a/internal/config/types.go +++ b/internal/config/types.go @@ -187,6 +187,11 @@ type MCPClientConfig struct { InlineConfig json.RawMessage `json:"inline,omitempty"` } +// IsStdio returns true if this is a stdio-based MCP server +func (c *MCPClientConfig) IsStdio() bool { + return c.TransportType == MCPClientTypeStdio +} + // SessionConfig represents session management configuration type SessionConfig struct { Timeout time.Duration @@ -200,12 +205,39 @@ type AdminConfig struct { AdminEmails []string `json:"adminEmails"` } +// IDPConfig represents identity provider configuration. +type IDPConfig struct { + // Provider type: "google", "azure", "github", or "oidc" + Provider string `json:"provider"` + + // OAuth client configuration + ClientID string `json:"clientId"` + ClientSecret Secret `json:"clientSecret"` + RedirectURI string `json:"redirectUri"` + + // For OIDC: discovery URL or manual endpoint configuration + DiscoveryURL string `json:"discoveryUrl,omitempty"` + AuthorizationURL string `json:"authorizationUrl,omitempty"` + TokenURL string `json:"tokenUrl,omitempty"` + UserInfoURL string `json:"userInfoUrl,omitempty"` + + // Custom scopes (optional, defaults per provider) + Scopes []string `json:"scopes,omitempty"` + + // Azure-specific: tenant ID + TenantID string `json:"tenantId,omitempty"` + + // GitHub-specific: allowed organizations + AllowedOrgs []string `json:"allowedOrgs,omitempty"` +} + // OAuthAuthConfig represents OAuth 2.0 configuration with resolved values type OAuthAuthConfig struct { Kind AuthKind `json:"kind"` Issuer string `json:"issuer"` GCPProject string `json:"gcpProject"` - AllowedDomains []string `json:"allowedDomains"` // For Google OAuth email validation + IDP IDPConfig `json:"idp"` + AllowedDomains []string `json:"allowedDomains"` // For domain-based access control AllowedOrigins []string `json:"allowedOrigins"` // For CORS validation TokenTTL time.Duration `json:"tokenTtl"` RefreshTokenTTL time.Duration `json:"refreshTokenTtl"` @@ -213,9 +245,6 @@ type OAuthAuthConfig struct { Storage string `json:"storage"` // "memory" or "firestore" FirestoreDatabase string `json:"firestoreDatabase,omitempty"` // Optional: Firestore database name FirestoreCollection string `json:"firestoreCollection,omitempty"` // Optional: Firestore collection name - GoogleClientID string `json:"googleClientId"` - GoogleClientSecret Secret `json:"googleClientSecret"` - GoogleRedirectURI string `json:"googleRedirectUri"` JWTSecret Secret `json:"jwtSecret"` EncryptionKey Secret `json:"encryptionKey"` // DangerouslyAcceptIssuerAudience allows tokens with just the base issuer as audience diff --git a/internal/config/unmarshal.go b/internal/config/unmarshal.go index eb20321..0a03878 100644 --- a/internal/config/unmarshal.go +++ b/internal/config/unmarshal.go @@ -118,6 +118,7 @@ func (o *OAuthAuthConfig) UnmarshalJSON(data []byte) error { Kind AuthKind `json:"kind"` Issuer json.RawMessage `json:"issuer"` GCPProject json.RawMessage `json:"gcpProject"` + IDP json.RawMessage `json:"idp"` AllowedDomains []string `json:"allowedDomains"` AllowedOrigins []string `json:"allowedOrigins"` TokenTTL string `json:"tokenTtl"` @@ -126,9 +127,6 @@ func (o *OAuthAuthConfig) UnmarshalJSON(data []byte) error { Storage string `json:"storage"` FirestoreDatabase string `json:"firestoreDatabase,omitempty"` FirestoreCollection string `json:"firestoreCollection,omitempty"` - GoogleClientID json.RawMessage `json:"googleClientId"` - GoogleClientSecret json.RawMessage `json:"googleClientSecret"` - GoogleRedirectURI json.RawMessage `json:"googleRedirectUri"` JWTSecret json.RawMessage `json:"jwtSecret"` EncryptionKey json.RawMessage `json:"encryptionKey,omitempty"` DangerouslyAcceptIssuerAudience bool `json:"dangerouslyAcceptIssuerAudience,omitempty"` @@ -202,38 +200,13 @@ func (o *OAuthAuthConfig) UnmarshalJSON(data []byte) error { o.GCPProject = parsed.value } - if raw.GoogleClientID != nil { - parsed, err := ParseConfigValue(raw.GoogleClientID) + // Parse IDP config + if raw.IDP != nil { + idp, err := parseIDPConfig(raw.IDP) if err != nil { - return fmt.Errorf("parsing googleClientId: %w", err) + return fmt.Errorf("parsing idp: %w", err) } - if parsed.needsUserToken { - return fmt.Errorf("googleClientId cannot be a user token reference") - } - o.GoogleClientID = parsed.value - } - - if raw.GoogleRedirectURI != nil { - parsed, err := ParseConfigValue(raw.GoogleRedirectURI) - if err != nil { - return fmt.Errorf("parsing googleRedirectUri: %w", err) - } - if parsed.needsUserToken { - return fmt.Errorf("googleRedirectUri cannot be a user token reference") - } - o.GoogleRedirectURI = parsed.value - } - - // Parse secret fields - if raw.GoogleClientSecret != nil { - parsed, err := ParseConfigValue(raw.GoogleClientSecret) - if err != nil { - return fmt.Errorf("parsing googleClientSecret: %w", err) - } - if parsed.needsUserToken { - return fmt.Errorf("googleClientSecret cannot be a user token reference") - } - o.GoogleClientSecret = Secret(parsed.value) + o.IDP = *idp } if raw.JWTSecret != nil { @@ -274,6 +247,87 @@ func (o *OAuthAuthConfig) UnmarshalJSON(data []byte) error { return nil } +// parseIDPConfig parses the IDP configuration with env var references +func parseIDPConfig(data json.RawMessage) (*IDPConfig, error) { + type rawIDP struct { + Provider string `json:"provider"` + ClientID json.RawMessage `json:"clientId"` + ClientSecret json.RawMessage `json:"clientSecret"` + RedirectURI json.RawMessage `json:"redirectUri"` + DiscoveryURL json.RawMessage `json:"discoveryUrl,omitempty"` + AuthorizationURL json.RawMessage `json:"authorizationUrl,omitempty"` + TokenURL json.RawMessage `json:"tokenUrl,omitempty"` + UserInfoURL json.RawMessage `json:"userInfoUrl,omitempty"` + Scopes []string `json:"scopes,omitempty"` + TenantID json.RawMessage `json:"tenantId,omitempty"` + AllowedOrgs []string `json:"allowedOrgs,omitempty"` + } + + var raw rawIDP + if err := json.Unmarshal(data, &raw); err != nil { + return nil, err + } + + idp := &IDPConfig{ + Provider: raw.Provider, + Scopes: raw.Scopes, + AllowedOrgs: raw.AllowedOrgs, + } + + // Helper to parse a field that cannot be a user token reference + parseField := func(data json.RawMessage, fieldName string) (string, error) { + if data == nil { + return "", nil + } + parsed, err := ParseConfigValue(data) + if err != nil { + return "", fmt.Errorf("parsing %s: %w", fieldName, err) + } + if parsed.needsUserToken { + return "", fmt.Errorf("%s cannot be a user token reference", fieldName) + } + return parsed.value, nil + } + + var err error + + if idp.ClientID, err = parseField(raw.ClientID, "clientId"); err != nil { + return nil, err + } + + var clientSecret string + if clientSecret, err = parseField(raw.ClientSecret, "clientSecret"); err != nil { + return nil, err + } + idp.ClientSecret = Secret(clientSecret) + + if idp.RedirectURI, err = parseField(raw.RedirectURI, "redirectUri"); err != nil { + return nil, err + } + + if idp.DiscoveryURL, err = parseField(raw.DiscoveryURL, "discoveryUrl"); err != nil { + return nil, err + } + + if idp.AuthorizationURL, err = parseField(raw.AuthorizationURL, "authorizationUrl"); err != nil { + return nil, err + } + + if idp.TokenURL, err = parseField(raw.TokenURL, "tokenUrl"); err != nil { + return nil, err + } + + if idp.UserInfoURL, err = parseField(raw.UserInfoURL, "userInfoUrl"); err != nil { + return nil, err + } + + if idp.TenantID, err = parseField(raw.TenantID, "tenantId"); err != nil { + return nil, err + } + + return idp, nil +} + // UnmarshalJSON implements custom unmarshaling for ProxyConfig func (p *ProxyConfig) UnmarshalJSON(data []byte) error { // Use a raw type to parse references diff --git a/internal/config/unmarshal_test.go b/internal/config/unmarshal_test.go index dc0682a..bc1db55 100644 --- a/internal/config/unmarshal_test.go +++ b/internal/config/unmarshal_test.go @@ -255,9 +255,12 @@ func TestOAuthAuthConfig_UnmarshalJSON(t *testing.T) { "allowedOrigins": ["https://claude.ai", "https://example.com"], "tokenTtl": "1h", "storage": "firestore", - "googleClientId": "test-client-id", - "googleClientSecret": {"$env": "CLIENT_SECRET"}, - "googleRedirectUri": "https://example.com/callback", + "idp": { + "provider": "google", + "clientId": "test-client-id", + "clientSecret": {"$env": "CLIENT_SECRET"}, + "redirectUri": "https://example.com/callback" + }, "jwtSecret": {"$env": "JWT_SECRET"}, "encryptionKey": {"$env": "ENCRYPTION_KEY"} }` @@ -271,8 +274,10 @@ func TestOAuthAuthConfig_UnmarshalJSON(t *testing.T) { assert.Equal(t, "test-project", config.GCPProject) assert.Equal(t, []string{"example.com"}, config.AllowedDomains) assert.Equal(t, []string{"https://claude.ai", "https://example.com"}, config.AllowedOrigins) - assert.Equal(t, "test-client-id", config.GoogleClientID) - assert.Equal(t, Secret("test-secret-value"), config.GoogleClientSecret) + assert.Equal(t, "google", config.IDP.Provider) + assert.Equal(t, "test-client-id", config.IDP.ClientID) + assert.Equal(t, Secret("test-secret-value"), config.IDP.ClientSecret) + assert.Equal(t, "https://example.com/callback", config.IDP.RedirectURI) assert.Equal(t, Secret("this-is-a-very-long-jwt-secret-key"), config.JWTSecret) assert.Equal(t, Secret("exactly-32-bytes-long-encryptkey"), config.EncryptionKey) } @@ -406,9 +411,12 @@ func TestProxyConfig_SessionConfigIntegration(t *testing.T) { "auth": { "kind": "oauth", "issuer": "https://auth.example.com", - "googleClientId": "test-client", - "googleClientSecret": "test-secret", - "googleRedirectUri": "https://test.example.com/callback", + "idp": { + "provider": "google", + "clientId": "test-client", + "clientSecret": "test-secret", + "redirectUri": "https://test.example.com/callback" + }, "jwtSecret": "test-jwt-secret-must-be-32-bytes-long", "encryptionKey": "test-encryption-key-32-bytes-ok!", "allowedDomains": ["example.com"], diff --git a/internal/config/validation.go b/internal/config/validation.go index 40612d1..490a123 100644 --- a/internal/config/validation.go +++ b/internal/config/validation.go @@ -146,9 +146,6 @@ func validateAuthStructure(auth map[string]any, result *ValidationResult) { hint string }{ {"issuer", ""}, - {"googleClientId", ""}, - {"googleClientSecret", ""}, - {"googleRedirectUri", ""}, {"jwtSecret", "Hint: Must be at least 32 bytes long for HMAC-SHA256"}, {"encryptionKey", "Hint: Must be exactly 32 bytes for AES-256-GCM encryption"}, } @@ -164,12 +161,18 @@ func validateAuthStructure(auth map[string]any, result *ValidationResult) { }) } } - if domains, ok := auth["allowedDomains"].([]any); !ok || len(domains) == 0 { + + // Validate IDP configuration + idp, hasIDP := auth["idp"].(map[string]any) + if !hasIDP { result.Errors = append(result.Errors, ValidationError{ - Path: "proxy.auth.allowedDomains", - Message: "at least one allowed domain is required for OAuth", + Path: "proxy.auth.idp", + Message: "idp configuration is required for OAuth", }) + } else { + validateIDPStructure(idp, result) } + if origins, ok := auth["allowedOrigins"].([]any); !ok || len(origins) == 0 { result.Errors = append(result.Errors, ValidationError{ Path: "proxy.auth.allowedOrigins", @@ -184,6 +187,74 @@ func validateAuthStructure(auth map[string]any, result *ValidationResult) { } } +// validateIDPStructure checks identity provider configuration +func validateIDPStructure(idp map[string]any, result *ValidationResult) { + provider, ok := idp["provider"].(string) + if !ok { + result.Errors = append(result.Errors, ValidationError{ + Path: "proxy.auth.idp.provider", + Message: "provider is required. Options: google, azure, github, oidc", + }) + return + } + + // Check required fields for all providers + if _, ok := idp["clientId"]; !ok { + result.Errors = append(result.Errors, ValidationError{ + Path: "proxy.auth.idp.clientId", + Message: "clientId is required for IDP configuration", + }) + } + if _, ok := idp["clientSecret"]; !ok { + result.Errors = append(result.Errors, ValidationError{ + Path: "proxy.auth.idp.clientSecret", + Message: "clientSecret is required for IDP configuration", + }) + } + if _, ok := idp["redirectUri"]; !ok { + result.Errors = append(result.Errors, ValidationError{ + Path: "proxy.auth.idp.redirectUri", + Message: "redirectUri is required for IDP configuration", + }) + } + + // Provider-specific validation + switch provider { + case "google", "github": + // No additional required fields + case "azure": + if _, ok := idp["tenantId"]; !ok { + result.Errors = append(result.Errors, ValidationError{ + Path: "proxy.auth.idp.tenantId", + Message: "tenantId is required for Azure AD provider", + }) + } + case "oidc": + // Either discoveryUrl or manual endpoints required + hasDiscovery := false + if _, ok := idp["discoveryUrl"]; ok { + hasDiscovery = true + } + if !hasDiscovery { + // Check for manual endpoints + requiredEndpoints := []string{"authorizationUrl", "tokenUrl", "userInfoUrl"} + for _, endpoint := range requiredEndpoints { + if _, ok := idp[endpoint]; !ok { + result.Errors = append(result.Errors, ValidationError{ + Path: "proxy.auth.idp." + endpoint, + Message: fmt.Sprintf("%s is required for OIDC provider when discoveryUrl is not provided", endpoint), + }) + } + } + } + default: + result.Errors = append(result.Errors, ValidationError{ + Path: "proxy.auth.idp.provider", + Message: fmt.Sprintf("unknown provider '%s' - supported providers: google, azure, github, oidc", provider), + }) + } +} + // validateAdminStructure checks admin configuration structure func validateAdminStructure(admin map[string]any, result *ValidationResult) { enabled, ok := admin["enabled"].(bool) diff --git a/internal/config/validation_test.go b/internal/config/validation_test.go index 910bb00..a8b5694 100644 --- a/internal/config/validation_test.go +++ b/internal/config/validation_test.go @@ -51,9 +51,12 @@ func TestValidateFile(t *testing.T) { "auth": { "kind": "oauth", "issuer": "https://example.com", - "googleClientId": {"$env": "CLIENT_ID"}, - "googleClientSecret": {"$env": "CLIENT_SECRET"}, - "googleRedirectUri": "https://example.com/callback", + "idp": { + "provider": "google", + "clientId": {"$env": "CLIENT_ID"}, + "clientSecret": {"$env": "CLIENT_SECRET"}, + "redirectUri": "https://example.com/callback" + }, "jwtSecret": {"$env": "JWT_SECRET"}, "encryptionKey": {"$env": "ENCRYPTION_KEY"}, "allowedDomains": ["example.com"], @@ -225,9 +228,12 @@ func TestValidateFile(t *testing.T) { "auth": { "kind": "oauth", "issuer": "https://example.com", - "googleClientId": "id", - "googleClientSecret": "secret", - "googleRedirectUri": "https://example.com/callback", + "idp": { + "provider": "google", + "clientId": "id", + "clientSecret": "secret", + "redirectUri": "https://example.com/callback" + }, "jwtSecret": "secret", "encryptionKey": "key", "allowedDomains": ["example.com"], @@ -260,15 +266,12 @@ func TestValidateFile(t *testing.T) { }`, wantErrors: []string{ "issuer is required for OAuth", - "googleClientId is required for OAuth", - "googleClientSecret is required for OAuth", - "googleRedirectUri is required for OAuth", "jwtSecret is required for OAuth. Hint: Must be at least 32 bytes long for HMAC-SHA256", "encryptionKey is required for OAuth. Hint: Must be exactly 32 bytes for AES-256-GCM encryption", - "at least one allowed domain is required for OAuth", + "idp configuration is required for OAuth", "at least one allowed origin is required for OAuth (CORS configuration)", }, - wantErrCount: 8, + wantErrCount: 5, }, { name: "valid_manual_user_authentication", @@ -280,9 +283,12 @@ func TestValidateFile(t *testing.T) { "auth": { "kind": "oauth", "issuer": "https://example.com", - "googleClientId": "id", - "googleClientSecret": "secret", - "googleRedirectUri": "https://example.com/callback", + "idp": { + "provider": "google", + "clientId": "id", + "clientSecret": "secret", + "redirectUri": "https://example.com/callback" + }, "jwtSecret": "secret123456789012345678901234567890", "encryptionKey": "key12345678901234567890123456789", "allowedDomains": ["example.com"], @@ -315,9 +321,12 @@ func TestValidateFile(t *testing.T) { "auth": { "kind": "oauth", "issuer": "https://example.com", - "googleClientId": "id", - "googleClientSecret": "secret", - "googleRedirectUri": "https://example.com/callback", + "idp": { + "provider": "google", + "clientId": "id", + "clientSecret": "secret", + "redirectUri": "https://example.com/callback" + }, "jwtSecret": "secret123456789012345678901234567890", "encryptionKey": "key12345678901234567890123456789", "allowedDomains": ["example.com"], @@ -354,9 +363,12 @@ func TestValidateFile(t *testing.T) { "auth": { "kind": "oauth", "issuer": "https://example.com", - "googleClientId": "id", - "googleClientSecret": "secret", - "googleRedirectUri": "https://example.com/callback", + "idp": { + "provider": "google", + "clientId": "id", + "clientSecret": "secret", + "redirectUri": "https://example.com/callback" + }, "jwtSecret": "secret123456789012345678901234567890", "encryptionKey": "key12345678901234567890123456789", "allowedDomains": ["example.com"], @@ -515,9 +527,12 @@ func TestValidateFile_ImprovedErrorMessages(t *testing.T) { "auth": { "kind": "oauth", "issuer": "https://example.com", - "googleClientId": "id", - "googleClientSecret": "secret", - "googleRedirectUri": "https://example.com/callback", + "idp": { + "provider": "google", + "clientId": "id", + "clientSecret": "secret", + "redirectUri": "https://example.com/callback" + }, "jwtSecret": "secret123456789012345678901234567890", "encryptionKey": "key12345678901234567890123456789", "allowedDomains": ["example.com"], diff --git a/internal/cookie/cookie.go b/internal/cookie/cookie.go index 3be28d4..0c38ea8 100644 --- a/internal/cookie/cookie.go +++ b/internal/cookie/cookie.go @@ -4,7 +4,7 @@ import ( "net/http" "time" - "github.com/dgellow/mcp-front/internal/envutil" + "github.com/dgellow/mcp-front/internal/config" "github.com/dgellow/mcp-front/internal/log" ) @@ -16,7 +16,7 @@ const ( // SetSession sets a session cookie with appropriate security settings func SetSession(w http.ResponseWriter, value string, maxAge time.Duration) { - secure := !envutil.IsDev() + secure := !config.IsDev() http.SetCookie(w, &http.Cookie{ Name: SessionCookie, Value: value, @@ -41,7 +41,7 @@ func SetCSRF(w http.ResponseWriter, value string) { Value: value, Path: "/", HttpOnly: false, // CSRF tokens need to be readable by JavaScript - Secure: !envutil.IsDev(), + Secure: !config.IsDev(), SameSite: http.SameSiteStrictMode, MaxAge: int((24 * time.Hour).Seconds()), // 24 hours }) diff --git a/internal/crypto/csrf.go b/internal/crypto/csrf.go new file mode 100644 index 0000000..53a9dbf --- /dev/null +++ b/internal/crypto/csrf.go @@ -0,0 +1,61 @@ +package crypto + +import ( + "fmt" + "strconv" + "strings" + "time" +) + +// CSRFProtection provides stateless HMAC-based CSRF token generation and validation. +// Tokens are self-contained: nonce:timestamp:signature, with configurable expiry. +type CSRFProtection struct { + signingKey []byte + ttl time.Duration +} + +// NewCSRFProtection creates a new CSRF protection instance +func NewCSRFProtection(signingKey []byte, ttl time.Duration) CSRFProtection { + return CSRFProtection{ + signingKey: signingKey, + ttl: ttl, + } +} + +// Generate creates a new CSRF token +func (c *CSRFProtection) Generate() (string, error) { + nonce := GenerateSecureToken() + if nonce == "" { + return "", fmt.Errorf("failed to generate nonce") + } + + timestamp := strconv.FormatInt(time.Now().Unix(), 10) + data := nonce + ":" + timestamp + signature := SignData(data, c.signingKey) + + return fmt.Sprintf("%s:%s:%s", nonce, timestamp, signature), nil +} + +// Validate checks if a CSRF token is valid and not expired +func (c *CSRFProtection) Validate(token string) bool { + parts := strings.SplitN(token, ":", 3) + if len(parts) != 3 { + return false + } + + nonce := parts[0] + timestampStr := parts[1] + signature := parts[2] + + timestamp, err := strconv.ParseInt(timestampStr, 10, 64) + if err != nil { + return false + } + + if time.Since(time.Unix(timestamp, 0)) > c.ttl { + return false + } + + data := nonce + ":" + timestampStr + return ValidateSignedData(data, signature, c.signingKey) +} diff --git a/internal/googleauth/google.go b/internal/googleauth/google.go deleted file mode 100644 index 1667a49..0000000 --- a/internal/googleauth/google.go +++ /dev/null @@ -1,121 +0,0 @@ -package googleauth - -import ( - "context" - "encoding/json" - "fmt" - "net/http" - "os" - "slices" - "strings" - - "github.com/dgellow/mcp-front/internal/config" - emailutil "github.com/dgellow/mcp-front/internal/emailutil" - "golang.org/x/oauth2" - "golang.org/x/oauth2/google" -) - -// UserInfo represents Google user information -type UserInfo struct { - Email string `json:"email"` - HostedDomain string `json:"hd"` - Name string `json:"name"` - Picture string `json:"picture"` - VerifiedEmail bool `json:"verified_email"` -} - -// GoogleAuthURL generates a Google OAuth authorization URL -func GoogleAuthURL(oauthConfig config.OAuthAuthConfig, state string) string { - googleOAuth := newGoogleOAuth2Config(oauthConfig) - return googleOAuth.AuthCodeURL(state, - oauth2.AccessTypeOffline, - oauth2.ApprovalForce, - ) -} - -// ExchangeCodeForToken exchanges the authorization code for a token -func ExchangeCodeForToken(ctx context.Context, oauthConfig config.OAuthAuthConfig, code string) (*oauth2.Token, error) { - googleOAuth := newGoogleOAuth2Config(oauthConfig) - return googleOAuth.Exchange(ctx, code) -} - -// ValidateUser validates the Google OAuth token and checks domain membership -func ValidateUser(ctx context.Context, oauthConfig config.OAuthAuthConfig, token *oauth2.Token) (UserInfo, error) { - googleOAuth := newGoogleOAuth2Config(oauthConfig) - client := googleOAuth.Client(ctx, token) - userInfoURL := "https://www.googleapis.com/oauth2/v2/userinfo" - if customURL := os.Getenv("GOOGLE_USERINFO_URL"); customURL != "" { - userInfoURL = customURL - } - resp, err := client.Get(userInfoURL) - if err != nil { - return UserInfo{}, fmt.Errorf("failed to get user info: %w", err) - } - defer resp.Body.Close() - - if resp.StatusCode != http.StatusOK { - return UserInfo{}, fmt.Errorf("failed to get user info: status %d", resp.StatusCode) - } - - var userInfo UserInfo - if err := json.NewDecoder(resp.Body).Decode(&userInfo); err != nil { - return UserInfo{}, fmt.Errorf("failed to decode user info: %w", err) - } - - // Validate domain if configured - if len(oauthConfig.AllowedDomains) > 0 { - userDomain := emailutil.ExtractDomain(userInfo.Email) - if !slices.Contains(oauthConfig.AllowedDomains, userDomain) { - return UserInfo{}, fmt.Errorf("domain '%s' is not allowed. Contact your administrator", userDomain) - } - } - - return userInfo, nil -} - -// ParseClientRequest parses MCP client registration metadata -func ParseClientRequest(metadata map[string]any) ([]string, []string, error) { - // Extract redirect URIs - redirectURIs := []string{} - if uris, ok := metadata["redirect_uris"].([]any); ok { - for _, uri := range uris { - if uriStr, ok := uri.(string); ok { - redirectURIs = append(redirectURIs, uriStr) - } - } - } - - if len(redirectURIs) == 0 { - return nil, nil, fmt.Errorf("no valid redirect URIs provided") - } - - // Extract scopes, default to read/write if not provided - scopes := []string{"read", "write"} // Default MCP scopes - if clientScopes, ok := metadata["scope"].(string); ok { - if strings.TrimSpace(clientScopes) != "" { - scopes = strings.Fields(clientScopes) - } - } - - return redirectURIs, scopes, nil -} - -// newGoogleOAuth2Config creates the OAuth2 config from our Config -func newGoogleOAuth2Config(oauthConfig config.OAuthAuthConfig) oauth2.Config { - // Use custom OAuth endpoints if provided (for testing) - endpoint := google.Endpoint - if authURL := os.Getenv("GOOGLE_OAUTH_AUTH_URL"); authURL != "" { - endpoint.AuthURL = authURL - } - if tokenURL := os.Getenv("GOOGLE_OAUTH_TOKEN_URL"); tokenURL != "" { - endpoint.TokenURL = tokenURL - } - - return oauth2.Config{ - ClientID: oauthConfig.GoogleClientID, - ClientSecret: string(oauthConfig.GoogleClientSecret), - RedirectURL: oauthConfig.GoogleRedirectURI, - Scopes: []string{"openid", "profile", "email"}, - Endpoint: endpoint, - } -} diff --git a/internal/googleauth/google_test.go b/internal/googleauth/google_test.go deleted file mode 100644 index 3f72d4a..0000000 --- a/internal/googleauth/google_test.go +++ /dev/null @@ -1,240 +0,0 @@ -package googleauth - -import ( - "context" - "encoding/json" - "net/http" - "net/http/httptest" - "testing" - "time" - - "github.com/dgellow/mcp-front/internal/config" - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" - "golang.org/x/oauth2" -) - -func TestGoogleAuthURL(t *testing.T) { - oauthConfig := config.OAuthAuthConfig{ - GoogleClientID: "test-client-id", - GoogleClientSecret: config.Secret("test-client-secret"), - GoogleRedirectURI: "https://test.example.com/oauth/callback", - } - - state := "test-state-parameter" - authURL := GoogleAuthURL(oauthConfig, state) - - // Verify URL structure - assert.Contains(t, authURL, "https://accounts.google.com/o/oauth2/auth") - assert.Contains(t, authURL, "client_id=test-client-id") - assert.Contains(t, authURL, "redirect_uri=https%3A%2F%2Ftest.example.com%2Foauth%2Fcallback") - assert.Contains(t, authURL, "state=test-state-parameter") - assert.Contains(t, authURL, "access_type=offline") - assert.Contains(t, authURL, "prompt=consent") - assert.Contains(t, authURL, "scope=openid+profile+email") -} - -func TestExchangeCodeForToken(t *testing.T) { - // Create a mock OAuth token endpoint - tokenServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - assert.Equal(t, "POST", r.Method) - assert.Equal(t, "/token", r.URL.Path) - - // Parse form data - err := r.ParseForm() - require.NoError(t, err) - - assert.Equal(t, "test-code", r.FormValue("code")) - assert.Equal(t, "test-client-id", r.FormValue("client_id")) - assert.Equal(t, "test-client-secret", r.FormValue("client_secret")) - assert.Equal(t, "https://test.example.com/oauth/callback", r.FormValue("redirect_uri")) - assert.Equal(t, "authorization_code", r.FormValue("grant_type")) - - // Return mock token response - response := map[string]any{ - "access_token": "mock-access-token", - "refresh_token": "mock-refresh-token", - "token_type": "Bearer", - "expires_in": 3600, - } - w.Header().Set("Content-Type", "application/json") - if err := json.NewEncoder(w).Encode(response); err != nil { - t.Errorf("failed to encode response: %v", err) - } - })) - defer tokenServer.Close() - - // Set environment variable for custom token URL - t.Setenv("GOOGLE_OAUTH_TOKEN_URL", tokenServer.URL+"/token") - - oauthConfig := config.OAuthAuthConfig{ - GoogleClientID: "test-client-id", - GoogleClientSecret: config.Secret("test-client-secret"), - GoogleRedirectURI: "https://test.example.com/oauth/callback", - } - - token, err := ExchangeCodeForToken(context.Background(), oauthConfig, "test-code") - require.NoError(t, err) - require.NotNil(t, token) - - assert.Equal(t, "mock-access-token", token.AccessToken) - assert.Equal(t, "mock-refresh-token", token.RefreshToken) - assert.Equal(t, "Bearer", token.TokenType) - assert.WithinDuration(t, time.Now().Add(3600*time.Second), token.Expiry, 5*time.Second) -} - -func TestValidateUser(t *testing.T) { - t.Run("valid user in allowed domain", func(t *testing.T) { - // Create mock user info endpoint - userInfoServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - assert.Equal(t, "GET", r.Method) - assert.Contains(t, r.Header.Get("Authorization"), "Bearer mock-token") - - response := UserInfo{ - Email: "user@example.com", - HostedDomain: "example.com", - Name: "Test User", - Picture: "https://example.com/pic.jpg", - VerifiedEmail: true, - } - w.Header().Set("Content-Type", "application/json") - if err := json.NewEncoder(w).Encode(response); err != nil { - t.Errorf("failed to encode response: %v", err) - } - })) - defer userInfoServer.Close() - - t.Setenv("GOOGLE_USERINFO_URL", userInfoServer.URL) - - oauthConfig := config.OAuthAuthConfig{ - AllowedDomains: []string{"example.com", "test.com"}, - } - - token := &oauth2.Token{AccessToken: "mock-token"} - userInfo, err := ValidateUser(context.Background(), oauthConfig, token) - - require.NoError(t, err) - assert.Equal(t, "user@example.com", userInfo.Email) - assert.Equal(t, "example.com", userInfo.HostedDomain) - assert.Equal(t, "Test User", userInfo.Name) - assert.True(t, userInfo.VerifiedEmail) - }) - - t.Run("user from disallowed domain", func(t *testing.T) { - userInfoServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - response := UserInfo{ - Email: "user@unauthorized.com", - HostedDomain: "unauthorized.com", - VerifiedEmail: true, - } - w.Header().Set("Content-Type", "application/json") - if err := json.NewEncoder(w).Encode(response); err != nil { - t.Errorf("failed to encode response: %v", err) - } - })) - defer userInfoServer.Close() - - t.Setenv("GOOGLE_USERINFO_URL", userInfoServer.URL) - - oauthConfig := config.OAuthAuthConfig{ - AllowedDomains: []string{"example.com", "test.com"}, - } - - token := &oauth2.Token{AccessToken: "mock-token"} - _, err := ValidateUser(context.Background(), oauthConfig, token) - - require.Error(t, err) - assert.Contains(t, err.Error(), "domain 'unauthorized.com' is not allowed") - }) - - t.Run("no domain restrictions", func(t *testing.T) { - userInfoServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - response := UserInfo{ - Email: "user@anydomain.com", - VerifiedEmail: true, - } - w.Header().Set("Content-Type", "application/json") - if err := json.NewEncoder(w).Encode(response); err != nil { - t.Errorf("failed to encode response: %v", err) - } - })) - defer userInfoServer.Close() - - t.Setenv("GOOGLE_USERINFO_URL", userInfoServer.URL) - - oauthConfig := config.OAuthAuthConfig{ - AllowedDomains: []string{}, // Empty means allow all - } - - token := &oauth2.Token{AccessToken: "mock-token"} - userInfo, err := ValidateUser(context.Background(), oauthConfig, token) - - require.NoError(t, err) - assert.Equal(t, "user@anydomain.com", userInfo.Email) - }) -} - -func TestParseClientRequest(t *testing.T) { - t.Run("valid request with redirect URIs and scopes", func(t *testing.T) { - metadata := map[string]any{ - "redirect_uris": []any{ - "https://example.com/callback1", - "https://example.com/callback2", - }, - "scope": "read write admin", - } - - redirectURIs, scopes, err := ParseClientRequest(metadata) - require.NoError(t, err) - - assert.Equal(t, []string{ - "https://example.com/callback1", - "https://example.com/callback2", - }, redirectURIs) - assert.Equal(t, []string{"read", "write", "admin"}, scopes) - }) - - t.Run("default scopes when not provided", func(t *testing.T) { - metadata := map[string]any{ - "redirect_uris": []any{"https://example.com/callback"}, - } - - redirectURIs, scopes, err := ParseClientRequest(metadata) - require.NoError(t, err) - - assert.Equal(t, []string{"https://example.com/callback"}, redirectURIs) - assert.Equal(t, []string{"read", "write"}, scopes, "Should default to read/write") - }) - - t.Run("empty scope string uses default", func(t *testing.T) { - metadata := map[string]any{ - "redirect_uris": []any{"https://example.com/callback"}, - "scope": " ", // Whitespace only - } - - _, scopes, err := ParseClientRequest(metadata) - require.NoError(t, err) - - assert.Equal(t, []string{"read", "write"}, scopes) - }) - - t.Run("missing redirect URIs", func(t *testing.T) { - metadata := map[string]any{ - "scope": "read write", - } - - _, _, err := ParseClientRequest(metadata) - assert.Error(t, err) - assert.Contains(t, err.Error(), "no valid redirect URIs") - }) - - t.Run("empty redirect URIs array", func(t *testing.T) { - metadata := map[string]any{ - "redirect_uris": []any{}, - } - - _, _, err := ParseClientRequest(metadata) - assert.Error(t, err) - assert.Contains(t, err.Error(), "no valid redirect URIs") - }) -} diff --git a/internal/idp/azure.go b/internal/idp/azure.go new file mode 100644 index 0000000..6694740 --- /dev/null +++ b/internal/idp/azure.go @@ -0,0 +1,34 @@ +package idp + +import "fmt" + +// NewAzureProvider creates an Azure AD provider using OIDC discovery. +// Azure AD is OIDC-compliant, so we use the generic OIDC provider with Azure's tenant-specific discovery URL. +// Optional direct endpoint overrides (authorizationURL, tokenURL, userInfoURL) skip discovery +// when all three are provided — useful for testing. +func NewAzureProvider(tenantID, clientID, clientSecret, redirectURI, authorizationURL, tokenURL, userInfoURL string) (*OIDCProvider, error) { + if tenantID == "" { + return nil, fmt.Errorf("tenantId is required for Azure AD") + } + + cfg := OIDCConfig{ + ProviderType: "azure", + ClientID: clientID, + ClientSecret: clientSecret, + RedirectURI: redirectURI, + Scopes: []string{"openid", "email", "profile"}, + } + + if authorizationURL != "" && tokenURL != "" && userInfoURL != "" { + cfg.AuthorizationURL = authorizationURL + cfg.TokenURL = tokenURL + cfg.UserInfoURL = userInfoURL + } else { + cfg.DiscoveryURL = fmt.Sprintf( + "https://login.microsoftonline.com/%s/v2.0/.well-known/openid-configuration", + tenantID, + ) + } + + return NewOIDCProvider(cfg) +} diff --git a/internal/idp/azure_test.go b/internal/idp/azure_test.go new file mode 100644 index 0000000..453a23a --- /dev/null +++ b/internal/idp/azure_test.go @@ -0,0 +1,15 @@ +package idp + +import ( + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestNewAzureProvider_MissingTenantID(t *testing.T) { + _, err := NewAzureProvider("", "client-id", "client-secret", "https://example.com/callback", "", "", "") + + require.Error(t, err) + assert.Contains(t, err.Error(), "tenantId is required") +} diff --git a/internal/idp/factory.go b/internal/idp/factory.go new file mode 100644 index 0000000..20e7e4c --- /dev/null +++ b/internal/idp/factory.go @@ -0,0 +1,59 @@ +package idp + +import ( + "fmt" + + "github.com/dgellow/mcp-front/internal/config" +) + +// NewProvider creates a Provider based on the IDPConfig. +func NewProvider(cfg config.IDPConfig) (Provider, error) { + switch cfg.Provider { + case "google": + return NewGoogleProvider( + cfg.ClientID, + string(cfg.ClientSecret), + cfg.RedirectURI, + cfg.AuthorizationURL, + cfg.TokenURL, + cfg.UserInfoURL, + ), nil + + case "azure": + return NewAzureProvider( + cfg.TenantID, + cfg.ClientID, + string(cfg.ClientSecret), + cfg.RedirectURI, + cfg.AuthorizationURL, + cfg.TokenURL, + cfg.UserInfoURL, + ) + + case "github": + return NewGitHubProvider( + cfg.ClientID, + string(cfg.ClientSecret), + cfg.RedirectURI, + cfg.AuthorizationURL, + cfg.TokenURL, + cfg.UserInfoURL, + ), nil + + case "oidc": + return NewOIDCProvider(OIDCConfig{ + ProviderType: "oidc", + DiscoveryURL: cfg.DiscoveryURL, + AuthorizationURL: cfg.AuthorizationURL, + TokenURL: cfg.TokenURL, + UserInfoURL: cfg.UserInfoURL, + ClientID: cfg.ClientID, + ClientSecret: string(cfg.ClientSecret), + RedirectURI: cfg.RedirectURI, + Scopes: cfg.Scopes, + }) + + default: + return nil, fmt.Errorf("unknown provider type: %s", cfg.Provider) + } +} diff --git a/internal/idp/factory_test.go b/internal/idp/factory_test.go new file mode 100644 index 0000000..9b38681 --- /dev/null +++ b/internal/idp/factory_test.go @@ -0,0 +1,105 @@ +package idp + +import ( + "testing" + + "github.com/dgellow/mcp-front/internal/config" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestNewProvider(t *testing.T) { + tests := []struct { + name string + cfg config.IDPConfig + wantType string + wantErr bool + errContains string + skipCreation bool + }{ + { + name: "google_provider", + cfg: config.IDPConfig{ + Provider: "google", + ClientID: "test-client-id", + ClientSecret: config.Secret("test-client-secret"), + RedirectURI: "https://example.com/callback", + }, + wantType: "google", + wantErr: false, + }, + { + name: "github_provider", + cfg: config.IDPConfig{ + Provider: "github", + ClientID: "test-client-id", + ClientSecret: config.Secret("test-client-secret"), + RedirectURI: "https://example.com/callback", + }, + wantType: "github", + wantErr: false, + }, + { + name: "azure_provider_missing_tenant", + cfg: config.IDPConfig{ + Provider: "azure", + ClientID: "test-client-id", + ClientSecret: config.Secret("test-client-secret"), + RedirectURI: "https://example.com/callback", + }, + wantErr: true, + errContains: "tenantId is required", + }, + { + name: "oidc_provider_missing_endpoints", + cfg: config.IDPConfig{ + Provider: "oidc", + ClientID: "test-client-id", + ClientSecret: config.Secret("test-client-secret"), + RedirectURI: "https://example.com/callback", + }, + wantErr: true, + errContains: "discoveryUrl or all endpoints", + }, + { + name: "oidc_provider_with_direct_endpoints", + cfg: config.IDPConfig{ + Provider: "oidc", + ClientID: "test-client-id", + ClientSecret: config.Secret("test-client-secret"), + RedirectURI: "https://example.com/callback", + AuthorizationURL: "https://idp.example.com/authorize", + TokenURL: "https://idp.example.com/token", + UserInfoURL: "https://idp.example.com/userinfo", + }, + wantType: "oidc", + wantErr: false, + }, + { + name: "unknown_provider", + cfg: config.IDPConfig{ + Provider: "unknown", + }, + wantErr: true, + errContains: "unknown provider type", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + provider, err := NewProvider(tt.cfg) + + if tt.wantErr { + require.Error(t, err) + if tt.errContains != "" { + assert.Contains(t, err.Error(), tt.errContains) + } + return + } + + require.NoError(t, err) + require.NotNil(t, provider) + assert.Equal(t, tt.wantType, provider.Type()) + }) + } +} diff --git a/internal/idp/github.go b/internal/idp/github.go new file mode 100644 index 0000000..b13ec04 --- /dev/null +++ b/internal/idp/github.go @@ -0,0 +1,201 @@ +package idp + +import ( + "context" + "encoding/json" + "fmt" + "net/http" + + emailutil "github.com/dgellow/mcp-front/internal/emailutil" + "golang.org/x/oauth2" + "golang.org/x/oauth2/github" +) + +// GitHubProvider implements the Provider interface for GitHub OAuth. +// GitHub uses OAuth 2.0 (not OIDC) and has its own API for user info and org membership. +type GitHubProvider struct { + config oauth2.Config + apiBaseURL string // defaults to https://api.github.com, can be overridden for testing +} + +// githubUserResponse represents GitHub's user API response. +type githubUserResponse struct { + ID int64 `json:"id"` + Login string `json:"login"` + Email string `json:"email"` + Name string `json:"name"` + AvatarURL string `json:"avatar_url"` +} + +// githubEmailResponse represents an email from GitHub's emails API. +type githubEmailResponse struct { + Email string `json:"email"` + Primary bool `json:"primary"` + Verified bool `json:"verified"` +} + +// githubOrgResponse represents an org from GitHub's orgs API. +type githubOrgResponse struct { + Login string `json:"login"` +} + +// NewGitHubProvider creates a new GitHub OAuth provider. +// Optional endpoint overrides (authorizationURL, tokenURL, apiBaseURL) allow +// pointing at a non-GitHub server for testing. +func NewGitHubProvider(clientID, clientSecret, redirectURI, authorizationURL, tokenURL, apiBaseURL string) *GitHubProvider { + endpoint := github.Endpoint + if authorizationURL != "" { + endpoint.AuthURL = authorizationURL + } + if tokenURL != "" { + endpoint.TokenURL = tokenURL + } + + apiBase := "https://api.github.com" + if apiBaseURL != "" { + apiBase = apiBaseURL + } + + return &GitHubProvider{ + config: oauth2.Config{ + ClientID: clientID, + ClientSecret: clientSecret, + RedirectURL: redirectURI, + Scopes: []string{"user:email", "read:org"}, + Endpoint: endpoint, + }, + apiBaseURL: apiBase, + } +} + +// Type returns the provider type. +func (p *GitHubProvider) Type() string { + return "github" +} + +// AuthURL generates the authorization URL. +func (p *GitHubProvider) AuthURL(state string) string { + return p.config.AuthCodeURL(state) +} + +// ExchangeCode exchanges an authorization code for tokens. +func (p *GitHubProvider) ExchangeCode(ctx context.Context, code string) (*oauth2.Token, error) { + return p.config.Exchange(ctx, code) +} + +// UserInfo fetches user identity from GitHub's API. +// Always fetches organizations so the authorization layer can check membership. +func (p *GitHubProvider) UserInfo(ctx context.Context, token *oauth2.Token) (*Identity, error) { + client := p.config.Client(ctx, token) + + user, err := p.fetchUser(client) + if err != nil { + return nil, err + } + + // Fetch primary email if not in profile + // GitHub only shows verified emails in user profile, so if email is present it's verified + email := user.Email + emailVerified := email != "" + if email == "" { + primaryEmail, verified, err := p.fetchPrimaryEmail(client) + if err != nil { + return nil, fmt.Errorf("failed to get user email: %w", err) + } + email = primaryEmail + emailVerified = verified + } + + domain := emailutil.ExtractDomain(email) + + orgs, err := p.fetchOrganizations(client) + if err != nil { + return nil, fmt.Errorf("failed to get user organizations: %w", err) + } + + return &Identity{ + ProviderType: "github", + Subject: fmt.Sprintf("%d", user.ID), + Email: email, + EmailVerified: emailVerified, + Name: user.Name, + Picture: user.AvatarURL, + Domain: domain, + Organizations: orgs, + }, nil +} + +func (p *GitHubProvider) fetchUser(client *http.Client) (*githubUserResponse, error) { + resp, err := client.Get(p.apiBaseURL + "/user") + if err != nil { + return nil, fmt.Errorf("failed to get user: %w", err) + } + defer resp.Body.Close() + + if resp.StatusCode != http.StatusOK { + return nil, fmt.Errorf("failed to get user: status %d", resp.StatusCode) + } + + var user githubUserResponse + if err := json.NewDecoder(resp.Body).Decode(&user); err != nil { + return nil, fmt.Errorf("failed to decode user: %w", err) + } + + return &user, nil +} + +func (p *GitHubProvider) fetchPrimaryEmail(client *http.Client) (string, bool, error) { + resp, err := client.Get(p.apiBaseURL + "/user/emails") + if err != nil { + return "", false, fmt.Errorf("failed to get emails: %w", err) + } + defer resp.Body.Close() + + if resp.StatusCode != http.StatusOK { + return "", false, fmt.Errorf("failed to get emails: status %d", resp.StatusCode) + } + + var emails []githubEmailResponse + if err := json.NewDecoder(resp.Body).Decode(&emails); err != nil { + return "", false, fmt.Errorf("failed to decode emails: %w", err) + } + + for _, email := range emails { + if email.Primary && email.Verified { + return email.Email, true, nil + } + } + + // Fallback to first verified email + for _, email := range emails { + if email.Verified { + return email.Email, true, nil + } + } + + return "", false, fmt.Errorf("no verified email found") +} + +func (p *GitHubProvider) fetchOrganizations(client *http.Client) ([]string, error) { + resp, err := client.Get(p.apiBaseURL + "/user/orgs") + if err != nil { + return nil, fmt.Errorf("failed to get organizations: %w", err) + } + defer resp.Body.Close() + + if resp.StatusCode != http.StatusOK { + return nil, fmt.Errorf("failed to get organizations: status %d", resp.StatusCode) + } + + var orgs []githubOrgResponse + if err := json.NewDecoder(resp.Body).Decode(&orgs); err != nil { + return nil, fmt.Errorf("failed to decode organizations: %w", err) + } + + orgNames := make([]string, len(orgs)) + for i, org := range orgs { + orgNames[i] = org.Login + } + + return orgNames, nil +} diff --git a/internal/idp/github_test.go b/internal/idp/github_test.go new file mode 100644 index 0000000..ccf0c87 --- /dev/null +++ b/internal/idp/github_test.go @@ -0,0 +1,224 @@ +package idp + +import ( + "context" + "encoding/json" + "net/http" + "net/http/httptest" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "golang.org/x/oauth2" +) + +func TestGitHubProvider_Type(t *testing.T) { + provider := NewGitHubProvider("client-id", "client-secret", "https://example.com/callback", "", "", "") + assert.Equal(t, "github", provider.Type()) +} + +func TestGitHubProvider_AuthURL(t *testing.T) { + provider := NewGitHubProvider("client-id", "client-secret", "https://example.com/callback", "", "", "") + + authURL := provider.AuthURL("test-state") + + assert.Contains(t, authURL, "github.com") + assert.Contains(t, authURL, "state=test-state") + assert.Contains(t, authURL, "client_id=client-id") +} + +func TestGitHubProvider_UserInfo(t *testing.T) { + tests := []struct { + name string + userResp githubUserResponse + emailsResp []githubEmailResponse + orgsResp []githubOrgResponse + expectedEmail string + expectedEmailVerified bool + expectedDomain string + expectedOrgs []string + }{ + { + name: "user_with_public_email", + userResp: githubUserResponse{ + ID: 12345, + Login: "testuser", + Email: "user@company.com", + Name: "Test User", + AvatarURL: "https://github.com/avatar.jpg", + }, + orgsResp: []githubOrgResponse{{Login: "my-org"}}, + expectedEmail: "user@company.com", + expectedEmailVerified: true, + expectedDomain: "company.com", + expectedOrgs: []string{"my-org"}, + }, + { + name: "user_without_public_email_fetches_from_api", + userResp: githubUserResponse{ + ID: 12345, + Login: "testuser", + Name: "Test User", + }, + emailsResp: []githubEmailResponse{ + {Email: "secondary@other.com", Primary: false, Verified: true}, + {Email: "primary@company.com", Primary: true, Verified: true}, + }, + orgsResp: []githubOrgResponse{}, + expectedEmail: "primary@company.com", + expectedEmailVerified: true, + expectedDomain: "company.com", + expectedOrgs: []string{}, + }, + { + name: "user_with_unverified_primary_falls_back_to_verified", + userResp: githubUserResponse{ + ID: 12345, + Login: "testuser", + }, + emailsResp: []githubEmailResponse{ + {Email: "primary@company.com", Primary: true, Verified: false}, + {Email: "verified@company.com", Primary: false, Verified: true}, + }, + orgsResp: []githubOrgResponse{}, + expectedEmail: "verified@company.com", + expectedEmailVerified: true, + expectedDomain: "company.com", + expectedOrgs: []string{}, + }, + { + name: "orgs_always_populated", + userResp: githubUserResponse{ + ID: 12345, + Login: "testuser", + Email: "user@gmail.com", + }, + orgsResp: []githubOrgResponse{{Login: "org-a"}, {Login: "org-b"}}, + expectedEmail: "user@gmail.com", + expectedEmailVerified: true, + expectedDomain: "gmail.com", + expectedOrgs: []string{"org-a", "org-b"}, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "application/json") + + switch r.URL.Path { + case "/user": + err := json.NewEncoder(w).Encode(tt.userResp) + require.NoError(t, err) + case "/user/emails": + err := json.NewEncoder(w).Encode(tt.emailsResp) + require.NoError(t, err) + case "/user/orgs": + err := json.NewEncoder(w).Encode(tt.orgsResp) + require.NoError(t, err) + default: + w.WriteHeader(http.StatusNotFound) + } + })) + defer server.Close() + + provider := &GitHubProvider{ + config: oauth2.Config{ + ClientID: "test-client", + ClientSecret: "test-secret", + RedirectURL: "https://example.com/callback", + Scopes: []string{"user:email", "read:org"}, + Endpoint: oauth2.Endpoint{ + AuthURL: server.URL + "/authorize", + TokenURL: server.URL + "/token", + }, + }, + apiBaseURL: server.URL, + } + + token := &oauth2.Token{AccessToken: "test-token"} + identity, err := provider.UserInfo(context.Background(), token) + + require.NoError(t, err) + require.NotNil(t, identity) + assert.Equal(t, "github", identity.ProviderType) + assert.Equal(t, tt.expectedEmail, identity.Email) + assert.Equal(t, tt.expectedEmailVerified, identity.EmailVerified) + assert.Equal(t, tt.expectedDomain, identity.Domain) + assert.Equal(t, tt.expectedOrgs, identity.Organizations) + }) + } +} + +func TestGitHubProvider_UserInfo_APIErrors(t *testing.T) { + tests := []struct { + name string + userStatus int + errContains string + }{ + { + name: "user_api_error", + userStatus: http.StatusInternalServerError, + errContains: "status 500", + }, + { + name: "user_unauthorized", + userStatus: http.StatusUnauthorized, + errContains: "status 401", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(tt.userStatus) + })) + defer server.Close() + + provider := &GitHubProvider{ + config: oauth2.Config{ + ClientID: "test-client", + ClientSecret: "test-secret", + }, + apiBaseURL: server.URL, + } + + token := &oauth2.Token{AccessToken: "test-token"} + _, err := provider.UserInfo(context.Background(), token) + + require.Error(t, err) + assert.Contains(t, err.Error(), tt.errContains) + }) + } +} + +func TestGitHubProvider_UserInfo_NoVerifiedEmail(t *testing.T) { + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "application/json") + + switch r.URL.Path { + case "/user": + err := json.NewEncoder(w).Encode(githubUserResponse{ID: 123, Login: "test"}) + require.NoError(t, err) + case "/user/emails": + err := json.NewEncoder(w).Encode([]githubEmailResponse{ + {Email: "unverified@example.com", Primary: true, Verified: false}, + }) + require.NoError(t, err) + default: + w.WriteHeader(http.StatusNotFound) + } + })) + defer server.Close() + + provider := &GitHubProvider{ + config: oauth2.Config{ClientID: "test"}, + apiBaseURL: server.URL, + } + + token := &oauth2.Token{AccessToken: "test-token"} + _, err := provider.UserInfo(context.Background(), token) + + require.Error(t, err) + assert.Contains(t, err.Error(), "no verified email") +} diff --git a/internal/idp/google.go b/internal/idp/google.go new file mode 100644 index 0000000..82e86e2 --- /dev/null +++ b/internal/idp/google.go @@ -0,0 +1,113 @@ +package idp + +import ( + "context" + "encoding/json" + "fmt" + "net/http" + + emailutil "github.com/dgellow/mcp-front/internal/emailutil" + "golang.org/x/oauth2" + "golang.org/x/oauth2/google" +) + +// GoogleProvider implements the Provider interface for Google OAuth. +// Google has specific quirks like `hd` for hosted domain and `verified_email` field. +type GoogleProvider struct { + config oauth2.Config + userInfoURL string +} + +// googleUserInfoResponse represents Google's userinfo response. +// Note: Google uses `hd` for hosted domain and `verified_email` instead of OIDC standard `email_verified`. +type googleUserInfoResponse struct { + Sub string `json:"sub"` + Email string `json:"email"` + VerifiedEmail bool `json:"verified_email"` + Name string `json:"name"` + Picture string `json:"picture"` + HostedDomain string `json:"hd"` +} + +// NewGoogleProvider creates a new Google OAuth provider. +// Optional endpoint overrides (authorizationURL, tokenURL, userInfoURL) allow +// pointing at a non-Google server — useful for testing or corporate proxies. +func NewGoogleProvider(clientID, clientSecret, redirectURI, authorizationURL, tokenURL, userInfoURL string) *GoogleProvider { + endpoint := google.Endpoint + if authorizationURL != "" { + endpoint.AuthURL = authorizationURL + } + if tokenURL != "" { + endpoint.TokenURL = tokenURL + } + + uiURL := "https://www.googleapis.com/oauth2/v2/userinfo" + if userInfoURL != "" { + uiURL = userInfoURL + } + + return &GoogleProvider{ + config: oauth2.Config{ + ClientID: clientID, + ClientSecret: clientSecret, + RedirectURL: redirectURI, + Scopes: []string{"openid", "profile", "email"}, + Endpoint: endpoint, + }, + userInfoURL: uiURL, + } +} + +// Type returns the provider type. +func (p *GoogleProvider) Type() string { + return "google" +} + +// AuthURL generates the authorization URL. +func (p *GoogleProvider) AuthURL(state string) string { + return p.config.AuthCodeURL(state, + oauth2.AccessTypeOffline, + oauth2.ApprovalForce, + ) +} + +// ExchangeCode exchanges an authorization code for tokens. +func (p *GoogleProvider) ExchangeCode(ctx context.Context, code string) (*oauth2.Token, error) { + return p.config.Exchange(ctx, code) +} + +// UserInfo fetches user identity from Google's userinfo endpoint. +func (p *GoogleProvider) UserInfo(ctx context.Context, token *oauth2.Token) (*Identity, error) { + client := p.config.Client(ctx, token) + + resp, err := client.Get(p.userInfoURL) + if err != nil { + return nil, fmt.Errorf("failed to get user info: %w", err) + } + defer resp.Body.Close() + + if resp.StatusCode != http.StatusOK { + return nil, fmt.Errorf("failed to get user info: status %d", resp.StatusCode) + } + + var googleUser googleUserInfoResponse + if err := json.NewDecoder(resp.Body).Decode(&googleUser); err != nil { + return nil, fmt.Errorf("failed to decode user info: %w", err) + } + + // Use Google's hosted domain if available, otherwise derive from email + domain := googleUser.HostedDomain + if domain == "" { + domain = emailutil.ExtractDomain(googleUser.Email) + } + + return &Identity{ + ProviderType: "google", + Subject: googleUser.Sub, + Email: googleUser.Email, + EmailVerified: googleUser.VerifiedEmail, + Name: googleUser.Name, + Picture: googleUser.Picture, + Domain: domain, + }, nil +} diff --git a/internal/idp/google_test.go b/internal/idp/google_test.go new file mode 100644 index 0000000..662c194 --- /dev/null +++ b/internal/idp/google_test.go @@ -0,0 +1,121 @@ +package idp + +import ( + "context" + "encoding/json" + "net/http" + "net/http/httptest" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "golang.org/x/oauth2" +) + +func TestGoogleProvider_Type(t *testing.T) { + provider := NewGoogleProvider("client-id", "client-secret", "https://example.com/callback", "", "", "") + assert.Equal(t, "google", provider.Type()) +} + +func TestGoogleProvider_AuthURL(t *testing.T) { + provider := NewGoogleProvider("client-id", "client-secret", "https://example.com/callback", "", "", "") + + authURL := provider.AuthURL("test-state") + + assert.Contains(t, authURL, "accounts.google.com") + assert.Contains(t, authURL, "state=test-state") + assert.Contains(t, authURL, "client_id=client-id") + assert.Contains(t, authURL, "redirect_uri=") + assert.Contains(t, authURL, "access_type=offline") +} + +func TestGoogleProvider_UserInfo(t *testing.T) { + tests := []struct { + name string + userInfoResp googleUserInfoResponse + expectedDomain string + expectedSubject string + }{ + { + name: "user_with_hosted_domain", + userInfoResp: googleUserInfoResponse{ + Sub: "12345", + Email: "user@company.com", + VerifiedEmail: true, + Name: "Test User", + Picture: "https://example.com/photo.jpg", + HostedDomain: "company.com", + }, + expectedDomain: "company.com", + expectedSubject: "12345", + }, + { + name: "user_without_hosted_domain_derives_from_email", + userInfoResp: googleUserInfoResponse{ + Sub: "12345", + Email: "user@gmail.com", + VerifiedEmail: true, + Name: "Test User", + }, + expectedDomain: "gmail.com", + expectedSubject: "12345", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "application/json") + err := json.NewEncoder(w).Encode(tt.userInfoResp) + require.NoError(t, err) + })) + defer server.Close() + + provider := &GoogleProvider{ + config: oauth2.Config{ + ClientID: "test-client", + ClientSecret: "test-secret", + RedirectURL: "https://example.com/callback", + Scopes: []string{"openid", "profile", "email"}, + Endpoint: oauth2.Endpoint{ + AuthURL: server.URL + "/authorize", + TokenURL: server.URL + "/token", + }, + }, + userInfoURL: server.URL, + } + token := &oauth2.Token{AccessToken: "test-token"} + + identity, err := provider.UserInfo(context.Background(), token) + + require.NoError(t, err) + require.NotNil(t, identity) + assert.Equal(t, "google", identity.ProviderType) + assert.Equal(t, tt.expectedSubject, identity.Subject) + assert.Equal(t, tt.expectedDomain, identity.Domain) + assert.Equal(t, tt.userInfoResp.Email, identity.Email) + assert.Equal(t, tt.userInfoResp.VerifiedEmail, identity.EmailVerified) + }) + } +} + +func TestGoogleProvider_UserInfo_ServerError(t *testing.T) { + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusInternalServerError) + })) + defer server.Close() + + provider := &GoogleProvider{ + config: oauth2.Config{ + ClientID: "test-client", + ClientSecret: "test-secret", + }, + userInfoURL: server.URL, + } + token := &oauth2.Token{AccessToken: "test-token"} + + _, err := provider.UserInfo(context.Background(), token) + + require.Error(t, err) + assert.Contains(t, err.Error(), "status 500") +} diff --git a/internal/idp/oidc.go b/internal/idp/oidc.go new file mode 100644 index 0000000..807ac05 --- /dev/null +++ b/internal/idp/oidc.go @@ -0,0 +1,178 @@ +package idp + +import ( + "context" + "encoding/json" + "fmt" + "net/http" + "time" + + emailutil "github.com/dgellow/mcp-front/internal/emailutil" + "golang.org/x/oauth2" +) + +// OIDCConfig configures a generic OIDC provider. +type OIDCConfig struct { + // ProviderType identifies this provider (e.g., "oidc", "google", "azure"). + ProviderType string + + // Discovery URL for OIDC discovery (optional if endpoints are provided directly). + DiscoveryURL string + + // Direct endpoint configuration (used if DiscoveryURL is not set). + AuthorizationURL string + TokenURL string + UserInfoURL string + + // OAuth client configuration. + ClientID string + ClientSecret string + RedirectURI string + Scopes []string +} + +// OIDCProvider implements the Provider interface for OIDC-compliant identity providers. +type OIDCProvider struct { + providerType string + config oauth2.Config + userInfoURL string +} + +// oidcDiscoveryDocument represents the OIDC discovery document. +type oidcDiscoveryDocument struct { + AuthorizationEndpoint string `json:"authorization_endpoint"` + TokenEndpoint string `json:"token_endpoint"` + UserInfoEndpoint string `json:"userinfo_endpoint"` + Issuer string `json:"issuer"` +} + +// oidcUserInfoResponse represents the standard OIDC userinfo response. +type oidcUserInfoResponse struct { + Sub string `json:"sub"` + Email string `json:"email"` + EmailVerified bool `json:"email_verified"` + Name string `json:"name"` + Picture string `json:"picture"` +} + +// NewOIDCProvider creates a new OIDC provider. +// TODO: Add OIDC discovery caching to avoid repeated network calls. +func NewOIDCProvider(cfg OIDCConfig) (*OIDCProvider, error) { + var authURL, tokenURL, userInfoURL string + + if cfg.DiscoveryURL != "" { + discovery, err := fetchOIDCDiscovery(cfg.DiscoveryURL) + if err != nil { + return nil, fmt.Errorf("failed to fetch OIDC discovery: %w", err) + } + authURL = discovery.AuthorizationEndpoint + tokenURL = discovery.TokenEndpoint + userInfoURL = discovery.UserInfoEndpoint + } else { + if cfg.AuthorizationURL == "" || cfg.TokenURL == "" || cfg.UserInfoURL == "" { + return nil, fmt.Errorf("either discoveryUrl or all endpoints (authorizationUrl, tokenUrl, userInfoUrl) must be provided") + } + authURL = cfg.AuthorizationURL + tokenURL = cfg.TokenURL + userInfoURL = cfg.UserInfoURL + } + + scopes := cfg.Scopes + if len(scopes) == 0 { + scopes = []string{"openid", "email", "profile"} + } + + providerType := cfg.ProviderType + if providerType == "" { + providerType = "oidc" + } + + return &OIDCProvider{ + providerType: providerType, + config: oauth2.Config{ + ClientID: cfg.ClientID, + ClientSecret: cfg.ClientSecret, + RedirectURL: cfg.RedirectURI, + Scopes: scopes, + Endpoint: oauth2.Endpoint{ + AuthURL: authURL, + TokenURL: tokenURL, + }, + }, + userInfoURL: userInfoURL, + }, nil +} + +func fetchOIDCDiscovery(discoveryURL string) (*oidcDiscoveryDocument, error) { + client := &http.Client{Timeout: 10 * time.Second} + resp, err := client.Get(discoveryURL) + if err != nil { + return nil, fmt.Errorf("failed to fetch discovery document: %w", err) + } + defer resp.Body.Close() + + if resp.StatusCode != http.StatusOK { + return nil, fmt.Errorf("discovery endpoint returned status %d", resp.StatusCode) + } + + var discovery oidcDiscoveryDocument + if err := json.NewDecoder(resp.Body).Decode(&discovery); err != nil { + return nil, fmt.Errorf("failed to decode discovery document: %w", err) + } + + if discovery.AuthorizationEndpoint == "" || discovery.TokenEndpoint == "" || discovery.UserInfoEndpoint == "" { + return nil, fmt.Errorf("discovery document missing required endpoints") + } + + return &discovery, nil +} + +// Type returns the provider type. +func (p *OIDCProvider) Type() string { + return p.providerType +} + +// AuthURL generates the authorization URL. +func (p *OIDCProvider) AuthURL(state string) string { + return p.config.AuthCodeURL(state, + oauth2.AccessTypeOffline, + oauth2.ApprovalForce, + ) +} + +// ExchangeCode exchanges an authorization code for tokens. +func (p *OIDCProvider) ExchangeCode(ctx context.Context, code string) (*oauth2.Token, error) { + return p.config.Exchange(ctx, code) +} + +// UserInfo fetches user identity from the OIDC userinfo endpoint. +// TODO: Add ID token validation as optimization (avoids network call). +func (p *OIDCProvider) UserInfo(ctx context.Context, token *oauth2.Token) (*Identity, error) { + client := p.config.Client(ctx, token) + resp, err := client.Get(p.userInfoURL) + if err != nil { + return nil, fmt.Errorf("failed to get user info: %w", err) + } + defer resp.Body.Close() + + if resp.StatusCode != http.StatusOK { + return nil, fmt.Errorf("failed to get user info: status %d", resp.StatusCode) + } + + var userInfoResp oidcUserInfoResponse + if err := json.NewDecoder(resp.Body).Decode(&userInfoResp); err != nil { + return nil, fmt.Errorf("failed to decode user info: %w", err) + } + + domain := emailutil.ExtractDomain(userInfoResp.Email) + + return &Identity{ + ProviderType: p.providerType, + Subject: userInfoResp.Sub, + Email: userInfoResp.Email, + EmailVerified: userInfoResp.EmailVerified, + Name: userInfoResp.Name, + Picture: userInfoResp.Picture, + Domain: domain, + }, nil +} diff --git a/internal/idp/oidc_test.go b/internal/idp/oidc_test.go new file mode 100644 index 0000000..ced5b9f --- /dev/null +++ b/internal/idp/oidc_test.go @@ -0,0 +1,189 @@ +package idp + +import ( + "context" + "encoding/json" + "net/http" + "net/http/httptest" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "golang.org/x/oauth2" +) + +func TestNewOIDCProvider_WithDirectEndpoints(t *testing.T) { + provider, err := NewOIDCProvider(OIDCConfig{ + ProviderType: "custom", + AuthorizationURL: "https://idp.example.com/authorize", + TokenURL: "https://idp.example.com/token", + UserInfoURL: "https://idp.example.com/userinfo", + ClientID: "client-id", + ClientSecret: "client-secret", + RedirectURI: "https://example.com/callback", + }) + + require.NoError(t, err) + require.NotNil(t, provider) + assert.Equal(t, "custom", provider.Type()) +} + +func TestNewOIDCProvider_WithDiscovery(t *testing.T) { + discovery := oidcDiscoveryDocument{ + Issuer: "https://idp.example.com", + AuthorizationEndpoint: "https://idp.example.com/authorize", + TokenEndpoint: "https://idp.example.com/token", + UserInfoEndpoint: "https://idp.example.com/userinfo", + } + + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "application/json") + err := json.NewEncoder(w).Encode(discovery) + require.NoError(t, err) + })) + defer server.Close() + + provider, err := NewOIDCProvider(OIDCConfig{ + DiscoveryURL: server.URL, + ClientID: "client-id", + ClientSecret: "client-secret", + RedirectURI: "https://example.com/callback", + }) + + require.NoError(t, err) + require.NotNil(t, provider) + assert.Equal(t, "oidc", provider.Type()) +} + +func TestNewOIDCProvider_MissingEndpoints(t *testing.T) { + _, err := NewOIDCProvider(OIDCConfig{ + ClientID: "client-id", + ClientSecret: "client-secret", + RedirectURI: "https://example.com/callback", + }) + + require.Error(t, err) + assert.Contains(t, err.Error(), "discoveryUrl or all endpoints") +} + +func TestNewOIDCProvider_PartialEndpoints(t *testing.T) { + _, err := NewOIDCProvider(OIDCConfig{ + AuthorizationURL: "https://idp.example.com/authorize", + ClientID: "client-id", + ClientSecret: "client-secret", + RedirectURI: "https://example.com/callback", + }) + + require.Error(t, err) + assert.Contains(t, err.Error(), "discoveryUrl or all endpoints") +} + +func TestNewOIDCProvider_DiscoveryMissingEndpoints(t *testing.T) { + discovery := oidcDiscoveryDocument{ + Issuer: "https://idp.example.com", + } + + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "application/json") + err := json.NewEncoder(w).Encode(discovery) + require.NoError(t, err) + })) + defer server.Close() + + _, err := NewOIDCProvider(OIDCConfig{ + DiscoveryURL: server.URL, + ClientID: "client-id", + ClientSecret: "client-secret", + RedirectURI: "https://example.com/callback", + }) + + require.Error(t, err) + assert.Contains(t, err.Error(), "missing required endpoints") +} + +func TestOIDCProvider_AuthURL(t *testing.T) { + provider, err := NewOIDCProvider(OIDCConfig{ + AuthorizationURL: "https://idp.example.com/authorize", + TokenURL: "https://idp.example.com/token", + UserInfoURL: "https://idp.example.com/userinfo", + ClientID: "client-id", + ClientSecret: "client-secret", + RedirectURI: "https://example.com/callback", + }) + require.NoError(t, err) + + authURL := provider.AuthURL("test-state") + + assert.Contains(t, authURL, "https://idp.example.com/authorize") + assert.Contains(t, authURL, "state=test-state") + assert.Contains(t, authURL, "client_id=client-id") +} + +func TestOIDCProvider_UserInfo(t *testing.T) { + userInfoServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "application/json") + resp := oidcUserInfoResponse{ + Sub: "12345", + Email: "user@example.com", + EmailVerified: true, + Name: "Test User", + Picture: "https://example.com/photo.jpg", + } + err := json.NewEncoder(w).Encode(resp) + require.NoError(t, err) + })) + defer userInfoServer.Close() + + provider, err := NewOIDCProvider(OIDCConfig{ + AuthorizationURL: "https://idp.example.com/authorize", + TokenURL: "https://idp.example.com/token", + UserInfoURL: userInfoServer.URL, + ClientID: "client-id", + ClientSecret: "client-secret", + RedirectURI: "https://example.com/callback", + }) + require.NoError(t, err) + + token := &oauth2.Token{AccessToken: "test-token"} + identity, err := provider.UserInfo(context.Background(), token) + + require.NoError(t, err) + require.NotNil(t, identity) + assert.Equal(t, "oidc", identity.ProviderType) + assert.Equal(t, "12345", identity.Subject) + assert.Equal(t, "user@example.com", identity.Email) + assert.Equal(t, "example.com", identity.Domain) + assert.True(t, identity.EmailVerified) +} + +func TestOIDCProvider_DefaultScopes(t *testing.T) { + provider, err := NewOIDCProvider(OIDCConfig{ + AuthorizationURL: "https://idp.example.com/authorize", + TokenURL: "https://idp.example.com/token", + UserInfoURL: "https://idp.example.com/userinfo", + ClientID: "client-id", + ClientSecret: "client-secret", + RedirectURI: "https://example.com/callback", + }) + require.NoError(t, err) + + authURL := provider.AuthURL("test-state") + assert.Contains(t, authURL, "scope=openid") +} + +func TestOIDCProvider_CustomScopes(t *testing.T) { + provider, err := NewOIDCProvider(OIDCConfig{ + AuthorizationURL: "https://idp.example.com/authorize", + TokenURL: "https://idp.example.com/token", + UserInfoURL: "https://idp.example.com/userinfo", + ClientID: "client-id", + ClientSecret: "client-secret", + RedirectURI: "https://example.com/callback", + Scopes: []string{"openid", "custom_scope"}, + }) + require.NoError(t, err) + + authURL := provider.AuthURL("test-state") + assert.Contains(t, authURL, "scope=openid") + assert.Contains(t, authURL, "custom_scope") +} diff --git a/internal/idp/provider.go b/internal/idp/provider.go new file mode 100644 index 0000000..e700112 --- /dev/null +++ b/internal/idp/provider.go @@ -0,0 +1,37 @@ +package idp + +import ( + "context" + + "golang.org/x/oauth2" +) + +// Identity represents user identity as reported by an identity provider. +// Providers populate this with identity information only — access control +// (domain, org checks) is handled by the authorization layer. +type Identity struct { + ProviderType string `json:"provider_type"` + Subject string `json:"sub"` + Email string `json:"email"` + EmailVerified bool `json:"email_verified"` + Name string `json:"name"` + Picture string `json:"picture"` + Domain string `json:"domain"` + Organizations []string `json:"organizations,omitempty"` +} + +// Provider abstracts identity provider operations. +type Provider interface { + // Type returns the provider type identifier (e.g., "google", "azure", "github", "oidc"). + Type() string + + // AuthURL generates the authorization URL for the OAuth flow. + AuthURL(state string) string + + // ExchangeCode exchanges an authorization code for tokens. + ExchangeCode(ctx context.Context, code string) (*oauth2.Token, error) + + // UserInfo fetches user identity from the provider. + // Returns identity information only — no access control validation. + UserInfo(ctx context.Context, token *oauth2.Token) (*Identity, error) +} diff --git a/internal/jsonrpc/errors.go b/internal/jsonrpc/errors.go index ee3196c..8fcfcf1 100644 --- a/internal/jsonrpc/errors.go +++ b/internal/jsonrpc/errors.go @@ -46,25 +46,3 @@ func NewStandardError(code int) *Error { Message: message, } } - -// GetErrorName returns a human-readable name for standard error codes -func GetErrorName(code int) string { - switch code { - case ParseError: - return "PARSE_ERROR" - case InvalidRequest: - return "INVALID_REQUEST" - case MethodNotFound: - return "METHOD_NOT_FOUND" - case InvalidParams: - return "INVALID_PARAMS" - case InternalError: - return "INTERNAL_ERROR" - default: - // Check if it's an MCP-specific error code - if code >= -32099 && code <= -32000 { - return "SERVER_ERROR" - } - return "UNKNOWN_ERROR" - } -} diff --git a/internal/mcpfront.go b/internal/mcpfront.go index e6f5b80..ec70773 100644 --- a/internal/mcpfront.go +++ b/internal/mcpfront.go @@ -14,6 +14,7 @@ import ( "github.com/dgellow/mcp-front/internal/client" "github.com/dgellow/mcp-front/internal/config" "github.com/dgellow/mcp-front/internal/crypto" + "github.com/dgellow/mcp-front/internal/idp" "github.com/dgellow/mcp-front/internal/inline" "github.com/dgellow/mcp-front/internal/log" "github.com/dgellow/mcp-front/internal/oauth" @@ -52,7 +53,7 @@ func NewMCPFront(ctx context.Context, cfg config.Config) (*MCPFront, error) { } // Setup authentication (OAuth components and service client) - oauthProvider, sessionEncryptor, authConfig, serviceOAuthClient, err := setupAuthentication(ctx, cfg, store) + oauthProvider, idpProvider, sessionEncryptor, authConfig, serviceOAuthClient, err := setupAuthentication(ctx, cfg, store) if err != nil { return nil, fmt.Errorf("failed to setup authentication: %w", err) } @@ -98,6 +99,7 @@ func NewMCPFront(ctx context.Context, cfg config.Config) (*MCPFront, error) { cfg, store, oauthProvider, + idpProvider, sessionEncryptor, authConfig, serviceOAuthClient, @@ -227,32 +229,42 @@ func setupStorage(ctx context.Context, cfg config.Config) (storage.Storage, erro } // setupAuthentication creates individual OAuth components using clean constructors -func setupAuthentication(ctx context.Context, cfg config.Config, store storage.Storage) (fosite.OAuth2Provider, crypto.Encryptor, config.OAuthAuthConfig, *auth.ServiceOAuthClient, error) { +func setupAuthentication(ctx context.Context, cfg config.Config, store storage.Storage) (fosite.OAuth2Provider, idp.Provider, crypto.Encryptor, config.OAuthAuthConfig, *auth.ServiceOAuthClient, error) { oauthAuth := cfg.Proxy.Auth if oauthAuth == nil { // OAuth not configured - return nil, nil, config.OAuthAuthConfig{}, nil, nil + return nil, nil, nil, config.OAuthAuthConfig{}, nil, nil } log.LogDebug("initializing OAuth components") + // Create identity provider + idpProvider, err := idp.NewProvider(oauthAuth.IDP) + if err != nil { + return nil, nil, nil, config.OAuthAuthConfig{}, nil, fmt.Errorf("failed to create identity provider: %w", err) + } + + log.LogInfoWithFields("mcpfront", "Identity provider configured", map[string]any{ + "type": idpProvider.Type(), + }) + // Generate or validate JWT secret using clean constructor jwtSecret, err := oauth.GenerateJWTSecret(string(oauthAuth.JWTSecret)) if err != nil { - return nil, nil, config.OAuthAuthConfig{}, nil, fmt.Errorf("failed to setup JWT secret: %w", err) + return nil, nil, nil, config.OAuthAuthConfig{}, nil, fmt.Errorf("failed to setup JWT secret: %w", err) } // Create session encryptor using clean constructor encryptionKey := []byte(oauthAuth.EncryptionKey) sessionEncryptor, err := oauth.NewSessionEncryptor(encryptionKey) if err != nil { - return nil, nil, config.OAuthAuthConfig{}, nil, fmt.Errorf("failed to create session encryptor: %w", err) + return nil, nil, nil, config.OAuthAuthConfig{}, nil, fmt.Errorf("failed to create session encryptor: %w", err) } // Create OAuth provider using clean constructor oauthProvider, err := oauth.NewOAuthProvider(*oauthAuth, store, jwtSecret) if err != nil { - return nil, nil, config.OAuthAuthConfig{}, nil, fmt.Errorf("failed to create OAuth provider: %w", err) + return nil, nil, nil, config.OAuthAuthConfig{}, nil, fmt.Errorf("failed to create OAuth provider: %w", err) } // Create OAuth client for service authentication and token refresh @@ -279,7 +291,7 @@ func setupAuthentication(ctx context.Context, cfg config.Config, store storage.S } } - return oauthProvider, sessionEncryptor, *oauthAuth, serviceOAuthClient, nil + return oauthProvider, idpProvider, sessionEncryptor, *oauthAuth, serviceOAuthClient, nil } // buildHTTPHandler creates the complete HTTP handler with all routing and middleware @@ -287,6 +299,7 @@ func buildHTTPHandler( cfg config.Config, storage storage.Storage, oauthProvider fosite.OAuth2Provider, + idpProvider idp.Provider, sessionEncryptor crypto.Encryptor, authConfig config.OAuthAuthConfig, serviceOAuthClient *auth.ServiceOAuthClient, @@ -337,6 +350,7 @@ func buildHTTPHandler( authHandlers := server.NewAuthHandlers( oauthProvider, authConfig, + idpProvider, storage, sessionEncryptor, cfg.MCPServers, @@ -352,7 +366,7 @@ func buildHTTPHandler( // Returns 404 by default, or base issuer metadata if dangerouslyAcceptIssuerAudience is enabled mux.Handle(route("/.well-known/oauth-protected-resource"), server.ChainMiddleware(http.HandlerFunc(authHandlers.ProtectedResourceMetadataHandler), oauthMiddleware...)) mux.Handle(route("/authorize"), server.ChainMiddleware(http.HandlerFunc(authHandlers.AuthorizeHandler), oauthMiddleware...)) - mux.Handle(route("/oauth/callback"), server.ChainMiddleware(http.HandlerFunc(authHandlers.GoogleCallbackHandler), oauthMiddleware...)) + mux.Handle(route("/oauth/callback"), server.ChainMiddleware(http.HandlerFunc(authHandlers.IDPCallbackHandler), oauthMiddleware...)) mux.Handle(route("/token"), server.ChainMiddleware(http.HandlerFunc(authHandlers.TokenHandler), oauthMiddleware...)) mux.Handle(route("/register"), server.ChainMiddleware(http.HandlerFunc(authHandlers.RegisterHandler), oauthMiddleware...)) mux.Handle(route("/clients/{client_id}"), server.ChainMiddleware(http.HandlerFunc(authHandlers.ClientMetadataHandler), oauthMiddleware...)) @@ -361,12 +375,12 @@ func buildHTTPHandler( tokenMiddleware := []server.MiddlewareFunc{ corsMiddleware, tokenLogger, - server.NewBrowserSSOMiddleware(authConfig, sessionEncryptor, browserStateToken), + server.NewBrowserSSOMiddleware(authConfig, idpProvider, sessionEncryptor, browserStateToken), mcpRecover, } // Create token handlers - tokenHandlers := server.NewTokenHandlers(storage, cfg.MCPServers, true, serviceOAuthClient) + tokenHandlers := server.NewTokenHandlers(storage, cfg.MCPServers, true, serviceOAuthClient, []byte(authConfig.EncryptionKey)) // Token management UI endpoints mux.Handle(route("/my/tokens"), server.ChainMiddleware(http.HandlerFunc(tokenHandlers.ListTokensHandler), tokenMiddleware...)) @@ -407,7 +421,7 @@ func buildHTTPHandler( } } else { // For stdio/SSE servers - if isStdioServer(serverConfig) { + if serverConfig.IsStdio() { sseServer, mcpServer, err = buildStdioSSEServer(serverName, baseURL, sessionManager) if err != nil { return nil, fmt.Errorf("failed to create SSE server for %s: %w", serverName, err) @@ -475,7 +489,7 @@ func buildHTTPHandler( // Add browser SSO if OAuth is enabled if oauthProvider != nil { // Reuse the same browserStateToken created earlier for consistency - adminMiddleware = append(adminMiddleware, server.NewBrowserSSOMiddleware(authConfig, sessionEncryptor, browserStateToken)) + adminMiddleware = append(adminMiddleware, server.NewBrowserSSOMiddleware(authConfig, idpProvider, sessionEncryptor, browserStateToken)) } // Add admin check middleware @@ -591,8 +605,3 @@ func buildStdioSSEServer(serverName, baseURL string, sessionManager *client.Stdi return sseServer, mcpServer, nil } - -// isStdioServer checks if this is a stdio-based server -func isStdioServer(cfg *config.MCPClientConfig) bool { - return cfg.TransportType == config.MCPClientTypeStdio -} diff --git a/internal/oauth/client_registration.go b/internal/oauth/client_registration.go new file mode 100644 index 0000000..75510df --- /dev/null +++ b/internal/oauth/client_registration.go @@ -0,0 +1,32 @@ +package oauth + +import ( + "fmt" + "strings" +) + +// ParseClientRegistration parses MCP client registration metadata. +// This is provider-agnostic as it deals with MCP client registration, not IDP. +func ParseClientRegistration(metadata map[string]any) (redirectURIs []string, scopes []string, err error) { + redirectURIs = []string{} + if uris, ok := metadata["redirect_uris"].([]any); ok { + for _, uri := range uris { + if uriStr, ok := uri.(string); ok { + redirectURIs = append(redirectURIs, uriStr) + } + } + } + + if len(redirectURIs) == 0 { + return nil, nil, fmt.Errorf("no valid redirect URIs provided") + } + + scopes = []string{"read", "write"} + if clientScopes, ok := metadata["scope"].(string); ok { + if strings.TrimSpace(clientScopes) != "" { + scopes = strings.Fields(clientScopes) + } + } + + return redirectURIs, scopes, nil +} diff --git a/internal/oauth/client_registration_test.go b/internal/oauth/client_registration_test.go new file mode 100644 index 0000000..4aacd68 --- /dev/null +++ b/internal/oauth/client_registration_test.go @@ -0,0 +1,113 @@ +package oauth + +import ( + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestParseClientRegistration(t *testing.T) { + tests := []struct { + name string + metadata map[string]any + wantRedirectURIs []string + wantScopes []string + wantErr bool + errContains string + }{ + { + name: "valid_with_single_redirect_uri", + metadata: map[string]any{ + "redirect_uris": []any{"https://example.com/callback"}, + }, + wantRedirectURIs: []string{"https://example.com/callback"}, + wantScopes: []string{"read", "write"}, + wantErr: false, + }, + { + name: "valid_with_multiple_redirect_uris", + metadata: map[string]any{ + "redirect_uris": []any{ + "https://example.com/callback", + "https://example.com/callback2", + }, + }, + wantRedirectURIs: []string{ + "https://example.com/callback", + "https://example.com/callback2", + }, + wantScopes: []string{"read", "write"}, + wantErr: false, + }, + { + name: "valid_with_custom_scopes", + metadata: map[string]any{ + "redirect_uris": []any{"https://example.com/callback"}, + "scope": "openid profile email", + }, + wantRedirectURIs: []string{"https://example.com/callback"}, + wantScopes: []string{"openid", "profile", "email"}, + wantErr: false, + }, + { + name: "valid_with_empty_scope_uses_default", + metadata: map[string]any{ + "redirect_uris": []any{"https://example.com/callback"}, + "scope": " ", + }, + wantRedirectURIs: []string{"https://example.com/callback"}, + wantScopes: []string{"read", "write"}, + wantErr: false, + }, + { + name: "missing_redirect_uris", + metadata: map[string]any{}, + wantErr: true, + errContains: "no valid redirect URIs", + }, + { + name: "empty_redirect_uris", + metadata: map[string]any{ + "redirect_uris": []any{}, + }, + wantErr: true, + errContains: "no valid redirect URIs", + }, + { + name: "redirect_uris_wrong_type", + metadata: map[string]any{ + "redirect_uris": "https://example.com/callback", + }, + wantErr: true, + errContains: "no valid redirect URIs", + }, + { + name: "redirect_uri_non_string_elements_ignored", + metadata: map[string]any{ + "redirect_uris": []any{123, "https://example.com/callback", nil}, + }, + wantRedirectURIs: []string{"https://example.com/callback"}, + wantScopes: []string{"read", "write"}, + wantErr: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + redirectURIs, scopes, err := ParseClientRegistration(tt.metadata) + + if tt.wantErr { + require.Error(t, err) + if tt.errContains != "" { + assert.Contains(t, err.Error(), tt.errContains) + } + return + } + + require.NoError(t, err) + assert.Equal(t, tt.wantRedirectURIs, redirectURIs) + assert.Equal(t, tt.wantScopes, scopes) + }) + } +} diff --git a/internal/oauth/metadata.go b/internal/oauth/metadata.go index e750415..078906e 100644 --- a/internal/oauth/metadata.go +++ b/internal/oauth/metadata.go @@ -53,29 +53,9 @@ func AuthorizationServerMetadata(issuer string) (map[string]any, error) { }, nil } -// ProtectedResourceMetadata builds OAuth 2.0 Protected Resource Metadata per RFC 9728 -// https://datatracker.ietf.org/doc/html/rfc9728 -// -// Deprecated: Use ServiceProtectedResourceMetadata for per-service metadata endpoints. -// This function returns the base issuer as the resource, which doesn't support -// per-service audience validation required by RFC 8707. -func ProtectedResourceMetadata(issuer string) (map[string]any, error) { - authzServerURL, err := urlutil.JoinPath(issuer, ".well-known", "oauth-authorization-server") - if err != nil { - return nil, err - } - - return map[string]any{ - "resource": issuer, - "authorization_servers": []string{ - issuer, - }, - "_links": map[string]any{ - "oauth-authorization-server": map[string]string{ - "href": authzServerURL, - }, - }, - }, nil +// AuthorizationServerMetadataURI returns the well-known URI for the authorization server metadata. +func AuthorizationServerMetadataURI(issuer string) (string, error) { + return urlutil.JoinPath(issuer, ".well-known", "oauth-authorization-server") } // ServiceProtectedResourceMetadata builds OAuth 2.0 Protected Resource Metadata per RFC 9728 @@ -92,7 +72,7 @@ func ServiceProtectedResourceMetadata(issuer string, serviceName string) (map[st return nil, err } - authzServerURL, err := urlutil.JoinPath(issuer, ".well-known", "oauth-authorization-server") + authzServerURL, err := AuthorizationServerMetadataURI(issuer) if err != nil { return nil, err } diff --git a/internal/oauth/metadata_test.go b/internal/oauth/metadata_test.go index db9e8df..4367c09 100644 --- a/internal/oauth/metadata_test.go +++ b/internal/oauth/metadata_test.go @@ -83,74 +83,6 @@ func TestAuthorizationServerMetadata(t *testing.T) { } } -func TestProtectedResourceMetadata(t *testing.T) { - tests := []struct { - name string - issuer string - wantErr bool - }{ - { - name: "valid issuer", - issuer: "https://example.com", - wantErr: false, - }, - { - name: "issuer with path", - issuer: "https://example.com/oauth", - wantErr: false, - }, - { - name: "invalid issuer", - issuer: "://invalid", - wantErr: true, - }, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - metadata, err := ProtectedResourceMetadata(tt.issuer) - if (err != nil) != tt.wantErr { - t.Errorf("ProtectedResourceMetadata() error = %v, wantErr %v", err, tt.wantErr) - return - } - - if tt.wantErr { - return - } - - // Verify resource field - if metadata["resource"] != tt.issuer { - t.Errorf("resource = %v, want %v", metadata["resource"], tt.issuer) - } - - // Verify authorization_servers array - authzServers, ok := metadata["authorization_servers"].([]string) - if !ok || len(authzServers) == 0 { - t.Error("authorization_servers is missing or empty") - } - - if authzServers[0] != tt.issuer { - t.Errorf("authorization_servers[0] = %v, want %v", authzServers[0], tt.issuer) - } - - // Verify _links structure - links, ok := metadata["_links"].(map[string]any) - if !ok { - t.Error("_links is missing or wrong type") - } - - authzServerLink, ok := links["oauth-authorization-server"].(map[string]string) - if !ok { - t.Error("oauth-authorization-server link is missing or wrong type") - } - - if authzServerLink["href"] == "" { - t.Error("oauth-authorization-server href is empty") - } - }) - } -} - func TestServiceProtectedResourceMetadata(t *testing.T) { tests := []struct { name string diff --git a/internal/oauth/provider.go b/internal/oauth/provider.go index 8ceca50..09eb568 100644 --- a/internal/oauth/provider.go +++ b/internal/oauth/provider.go @@ -12,10 +12,9 @@ import ( "github.com/dgellow/mcp-front/internal/config" "github.com/dgellow/mcp-front/internal/crypto" - "github.com/dgellow/mcp-front/internal/envutil" jsonwriter "github.com/dgellow/mcp-front/internal/json" "github.com/dgellow/mcp-front/internal/log" - "github.com/dgellow/mcp-front/internal/oauthsession" + "github.com/dgellow/mcp-front/internal/session" "github.com/dgellow/mcp-front/internal/storage" "github.com/dgellow/mcp-front/internal/urlutil" "github.com/ory/fosite" @@ -48,8 +47,8 @@ func NewOAuthProvider(oauthConfig config.OAuthAuthConfig, store storage.Storage, // Determine min parameter entropy based on environment minEntropy := 8 // Production default - enforce secure state parameters (8+ chars) - log.Logf("OAuth provider initialization - MCP_FRONT_ENV=%s, isDevelopmentMode=%v", os.Getenv("MCP_FRONT_ENV"), envutil.IsDev()) - if envutil.IsDev() { + log.Logf("OAuth provider initialization - MCP_FRONT_ENV=%s, isDevelopmentMode=%v", os.Getenv("MCP_FRONT_ENV"), config.IsDev()) + if config.IsDev() { minEntropy = 0 // Development mode - allow empty state parameters log.LogWarn("Development mode enabled - OAuth security checks relaxed (state parameter entropy: %d)", minEntropy) } @@ -158,8 +157,8 @@ func NewValidateTokenMiddleware(provider fosite.OAuth2Provider, issuer string, a // - This is documented fosite behavior, not a bug // - The actual session data must be retrieved from the returned AccessRequester // See: https://github.com/ory/fosite/issues/256 - session := &oauthsession.Session{DefaultSession: &fosite.DefaultSession{}} - _, accessRequest, err := provider.IntrospectToken(ctx, token, fosite.AccessToken, session) + oauthSession := &session.OAuthSession{DefaultSession: &fosite.DefaultSession{}} + _, accessRequest, err := provider.IntrospectToken(ctx, token, fosite.AccessToken, oauthSession) if err != nil { jsonwriter.WriteUnauthorizedRFC9728(w, "Invalid or expired token", metadataURI) return @@ -180,9 +179,9 @@ func NewValidateTokenMiddleware(provider fosite.OAuth2Provider, issuer string, a // This is the correct way to retrieve session data after token introspection var userEmail string if accessRequest != nil { - if reqSession, ok := accessRequest.GetSession().(*oauthsession.Session); ok { - if reqSession.UserInfo.Email != "" { - userEmail = reqSession.UserInfo.Email + if reqSession, ok := accessRequest.GetSession().(*session.OAuthSession); ok { + if reqSession.Identity.Email != "" { + userEmail = reqSession.Identity.Email } } } diff --git a/internal/oauthsession/session.go b/internal/oauthsession/session.go deleted file mode 100644 index f8f093f..0000000 --- a/internal/oauthsession/session.go +++ /dev/null @@ -1,20 +0,0 @@ -package oauthsession - -import ( - "github.com/dgellow/mcp-front/internal/googleauth" - "github.com/ory/fosite" -) - -// Session extends DefaultSession with user information -type Session struct { - *fosite.DefaultSession - UserInfo googleauth.UserInfo `json:"user_info"` -} - -// Clone implements fosite.Session -func (s *Session) Clone() fosite.Session { - return &Session{ - DefaultSession: s.DefaultSession.Clone().(*fosite.DefaultSession), - UserInfo: s.UserInfo, - } -} diff --git a/internal/server/admin_handlers.go b/internal/server/admin_handlers.go index b76db11..4c1c20d 100644 --- a/internal/server/admin_handlers.go +++ b/internal/server/admin_handlers.go @@ -4,8 +4,6 @@ import ( "fmt" "net/http" "net/url" - "strconv" - "strings" "time" "github.com/dgellow/mcp-front/internal/adminauth" @@ -23,7 +21,7 @@ type AdminHandlers struct { storage storage.Storage config config.Config sessionManager *client.StdioSessionManager - encryptionKey []byte // For HMAC-based CSRF tokens + csrf crypto.CSRFProtection } // NewAdminHandlers creates a new admin handlers instance @@ -32,67 +30,10 @@ func NewAdminHandlers(storage storage.Storage, config config.Config, sessionMana storage: storage, config: config, sessionManager: sessionManager, - encryptionKey: []byte(encryptionKey), + csrf: crypto.NewCSRFProtection([]byte(encryptionKey), 15*time.Minute), } } -// generateCSRFToken creates a new HMAC-based CSRF token -func (h *AdminHandlers) generateCSRFToken() (string, error) { - // Generate random nonce - nonce := crypto.GenerateSecureToken() - if nonce == "" { - return "", fmt.Errorf("failed to generate nonce") - } - - // Add timestamp (Unix seconds) - timestamp := strconv.FormatInt(time.Now().Unix(), 10) - - // Create data to sign: nonce:timestamp - data := nonce + ":" + timestamp - - // Sign with HMAC - signature := crypto.SignData(data, h.encryptionKey) - - // Return format: nonce:timestamp:signature - return fmt.Sprintf("%s:%s:%s", nonce, timestamp, signature), nil -} - -// validateCSRFToken checks if a CSRF token is valid -func (h *AdminHandlers) validateCSRFToken(token string) bool { - // Parse token format: nonce:timestamp:signature - parts := strings.SplitN(token, ":", 3) - if len(parts) != 3 { - log.LogDebug("Invalid CSRF token format") - return false - } - - nonce := parts[0] - timestampStr := parts[1] - signature := parts[2] - - // Verify timestamp (15 minute expiry) - timestamp, err := strconv.ParseInt(timestampStr, 10, 64) - if err != nil { - log.LogDebug("Invalid CSRF token timestamp: %v", err) - return false - } - - now := time.Now().Unix() - if now-timestamp > 900 { // 15 minutes - log.LogDebug("CSRF token expired") - return false - } - - // Verify HMAC signature - data := nonce + ":" + timestampStr - if !crypto.ValidateSignedData(data, signature, h.encryptionKey) { - log.LogDebug("Invalid CSRF token signature") - return false - } - - return true -} - // DashboardHandler shows the admin dashboard func (h *AdminHandlers) DashboardHandler(w http.ResponseWriter, r *http.Request) { // Only accept GET @@ -152,7 +93,7 @@ func (h *AdminHandlers) DashboardHandler(w http.ResponseWriter, r *http.Request) currentLogLevel := log.GetLogLevel() // Generate CSRF token - csrfToken, err := h.generateCSRFToken() + csrfToken, err := h.csrf.Generate() if err != nil { log.LogErrorWithFields("admin", "Failed to generate CSRF token", map[string]any{ "error": err.Error(), @@ -209,7 +150,7 @@ func (h *AdminHandlers) UserActionHandler(w http.ResponseWriter, r *http.Request } // Validate CSRF - if !h.validateCSRFToken(r.FormValue("csrf_token")) { + if !h.csrf.Validate(r.FormValue("csrf_token")) { jsonwriter.WriteForbidden(w, "Invalid CSRF token") return } @@ -379,7 +320,7 @@ func (h *AdminHandlers) SessionActionHandler(w http.ResponseWriter, r *http.Requ } // Validate CSRF - if !h.validateCSRFToken(r.FormValue("csrf_token")) { + if !h.csrf.Validate(r.FormValue("csrf_token")) { jsonwriter.WriteForbidden(w, "Invalid CSRF token") return } @@ -472,7 +413,7 @@ func (h *AdminHandlers) LoggingActionHandler(w http.ResponseWriter, r *http.Requ } // Validate CSRF - if !h.validateCSRFToken(r.FormValue("csrf_token")) { + if !h.csrf.Validate(r.FormValue("csrf_token")) { jsonwriter.WriteForbidden(w, "Invalid CSRF token") return } diff --git a/internal/server/admin_handlers_test.go b/internal/server/admin_handlers_test.go index b3587fa..a2aa8ee 100644 --- a/internal/server/admin_handlers_test.go +++ b/internal/server/admin_handlers_test.go @@ -27,7 +27,7 @@ func TestAdminHandlers_CSRF(t *testing.T) { t.Run("generate and validate CSRF token", func(t *testing.T) { // Generate token - token, err := handlers.generateCSRFToken() + token, err := handlers.csrf.Generate() if err != nil { t.Fatalf("Failed to generate CSRF token: %v", err) } @@ -38,13 +38,13 @@ func TestAdminHandlers_CSRF(t *testing.T) { } // Token should be valid immediately - if !handlers.validateCSRFToken(token) { + if !handlers.csrf.Validate(token) { t.Error("Token should be valid immediately after generation") } // Token should not be valid twice (though with HMAC it actually can be) // With HMAC-based tokens, they can be validated multiple times - if !handlers.validateCSRFToken(token) { + if !handlers.csrf.Validate(token) { t.Error("HMAC token should be valid on second validation") } }) @@ -58,7 +58,7 @@ func TestAdminHandlers_CSRF(t *testing.T) { } for _, token := range invalidTokens { - if handlers.validateCSRFToken(token) { + if handlers.csrf.Validate(token) { t.Errorf("Token '%s' should be invalid", token) } } @@ -69,7 +69,7 @@ func TestAdminHandlers_CSRF(t *testing.T) { // An expired token would have a timestamp from > 15 minutes ago expiredToken := "test-nonce:0:invalid-signature" - if handlers.validateCSRFToken(expiredToken) { + if handlers.csrf.Validate(expiredToken) { t.Error("Expired token should be invalid") } }) @@ -79,18 +79,18 @@ func TestAdminHandlers_CSRF(t *testing.T) { handlers2 := NewAdminHandlers(storage, cfg, sessionManager, "different-encryption-key-32bytes") // Generate token with first handler - token1, err := handlers.generateCSRFToken() + token1, err := handlers.csrf.Generate() if err != nil { t.Fatalf("Failed to generate token: %v", err) } // Token from handler1 should not validate with handler2 - if handlers2.validateCSRFToken(token1) { + if handlers2.csrf.Validate(token1) { t.Error("Token should not validate with different encryption key") } // But should still validate with original handler - if !handlers.validateCSRFToken(token1) { + if !handlers.csrf.Validate(token1) { t.Error("Token should validate with original handler") } }) diff --git a/internal/server/auth_handlers.go b/internal/server/auth_handlers.go index 23a98d5..cdb94e8 100644 --- a/internal/server/auth_handlers.go +++ b/internal/server/auth_handlers.go @@ -7,19 +7,19 @@ import ( "fmt" "net/http" "net/url" + "slices" "strings" "time" "github.com/dgellow/mcp-front/internal/auth" - "github.com/dgellow/mcp-front/internal/browserauth" "github.com/dgellow/mcp-front/internal/config" + "github.com/dgellow/mcp-front/internal/cookie" "github.com/dgellow/mcp-front/internal/crypto" - "github.com/dgellow/mcp-front/internal/envutil" - "github.com/dgellow/mcp-front/internal/googleauth" + "github.com/dgellow/mcp-front/internal/idp" jsonwriter "github.com/dgellow/mcp-front/internal/json" "github.com/dgellow/mcp-front/internal/log" "github.com/dgellow/mcp-front/internal/oauth" - "github.com/dgellow/mcp-front/internal/oauthsession" + "github.com/dgellow/mcp-front/internal/session" "github.com/dgellow/mcp-front/internal/storage" "github.com/ory/fosite" ) @@ -28,6 +28,7 @@ import ( type AuthHandlers struct { oauthProvider fosite.OAuth2Provider authConfig config.OAuthAuthConfig + idpProvider idp.Provider storage storage.Storage sessionEncryptor crypto.Encryptor mcpServers map[string]*config.MCPClientConfig @@ -37,18 +38,19 @@ type AuthHandlers struct { // UpstreamOAuthState stores OAuth state during upstream authentication flow (MCP host → mcp-front) type UpstreamOAuthState struct { - UserInfo googleauth.UserInfo `json:"user_info"` - ClientID string `json:"client_id"` - RedirectURI string `json:"redirect_uri"` - Scopes []string `json:"scopes"` - State string `json:"state"` - ResponseType string `json:"response_type"` + Identity idp.Identity `json:"identity"` + ClientID string `json:"client_id"` + RedirectURI string `json:"redirect_uri"` + Scopes []string `json:"scopes"` + State string `json:"state"` + ResponseType string `json:"response_type"` } // NewAuthHandlers creates new auth handlers with dependency injection func NewAuthHandlers( oauthProvider fosite.OAuth2Provider, authConfig config.OAuthAuthConfig, + idpProvider idp.Provider, storage storage.Storage, sessionEncryptor crypto.Encryptor, mcpServers map[string]*config.MCPClientConfig, @@ -57,6 +59,7 @@ func NewAuthHandlers( return &AuthHandlers{ oauthProvider: oauthProvider, authConfig: authConfig, + idpProvider: idpProvider, storage: storage, sessionEncryptor: sessionEncryptor, mcpServers: mcpServers, @@ -98,18 +101,31 @@ func (h *AuthHandlers) ProtectedResourceMetadataHandler(w http.ResponseWriter, r return } - // Workaround mode: return base issuer metadata for broken clients + // Workaround mode: return base issuer metadata for broken clients. + // This intentionally uses the issuer as the resource (no per-service scoping) + // because broken clients don't implement RFC 8707 resource indicators. + issuer := h.authConfig.Issuer log.LogWarnWithFields("oauth", "Serving base protected resource metadata (dangerouslyAcceptIssuerAudience enabled)", map[string]any{ - "issuer": h.authConfig.Issuer, + "issuer": issuer, }) - metadata, err := oauth.ProtectedResourceMetadata(h.authConfig.Issuer) + authzServerURL, err := oauth.AuthorizationServerMetadataURI(issuer) if err != nil { log.LogError("Failed to build protected resource metadata: %v", err) jsonwriter.WriteInternalServerError(w, "Internal server error") return } + metadata := map[string]any{ + "resource": issuer, + "authorization_servers": []string{issuer}, + "_links": map[string]any{ + "oauth-authorization-server": map[string]string{ + "href": authzServerURL, + }, + }, + } + w.Header().Set("Content-Type", "application/json") if err := json.NewEncoder(w).Encode(metadata); err != nil { log.LogError("Failed to encode protected resource metadata: %v", err) @@ -207,7 +223,7 @@ func (h *AuthHandlers) AuthorizeHandler(w http.ResponseWriter, r *http.Request) // In development mode, generate a secure state parameter if missing // This works around bugs in OAuth clients that don't send state stateParam := r.URL.Query().Get("state") - if envutil.IsDev() && len(stateParam) == 0 { + if config.IsDev() && len(stateParam) == 0 { generatedState := crypto.GenerateSecureToken() log.LogWarn("Development mode: generating state parameter '%s' for buggy client", generatedState) q := r.URL.Query() @@ -259,14 +275,19 @@ func (h *AuthHandlers) AuthorizeHandler(w http.ResponseWriter, r *http.Request) } state := ar.GetState() - h.storage.StoreAuthorizeRequest(state, ar) + if err := h.storage.StoreAuthorizeRequest(state, ar); err != nil { + log.LogError("Failed to store authorize request: %v", err) + h.oauthProvider.WriteAuthorizeError(ctx, w, ar, fosite.ErrServerError.WithHint("Failed to store authorization request")) + return + } - authURL := googleauth.GoogleAuthURL(h.authConfig, state) + authURL := h.idpProvider.AuthURL(state) http.Redirect(w, r, authURL, http.StatusFound) } -// GoogleCallbackHandler handles the callback from Google OAuth -func (h *AuthHandlers) GoogleCallbackHandler(w http.ResponseWriter, r *http.Request) { +// IDPCallbackHandler handles the callback from the identity provider. +// It dispatches to handleBrowserCallback or handleOAuthClientCallback based on the flow type. +func (h *AuthHandlers) IDPCallbackHandler(w http.ResponseWriter, r *http.Request) { ctx := r.Context() state := r.URL.Query().Get("state") @@ -274,7 +295,7 @@ func (h *AuthHandlers) GoogleCallbackHandler(w http.ResponseWriter, r *http.Requ if errMsg := r.URL.Query().Get("error"); errMsg != "" { errDesc := r.URL.Query().Get("error_description") - log.LogError("Google OAuth error: %s - %s", errMsg, errDesc) + log.LogError("OAuth error: %s - %s", errMsg, errDesc) jsonwriter.WriteBadRequest(w, fmt.Sprintf("Authentication failed: %s", errMsg)) return } @@ -293,7 +314,7 @@ func (h *AuthHandlers) GoogleCallbackHandler(w http.ResponseWriter, r *http.Requ isBrowserFlow = true stateToken := strings.TrimPrefix(state, "browser:") - var browserState browserauth.AuthorizationState + var browserState session.AuthorizationState if err := h.oauthStateToken.Verify(stateToken, &browserState); err != nil { log.LogError("Invalid browser state: %v", err) jsonwriter.WriteBadRequest(w, "Invalid state parameter") @@ -301,7 +322,6 @@ func (h *AuthHandlers) GoogleCallbackHandler(w http.ResponseWriter, r *http.Requ } returnURL = browserState.ReturnURL } else { - // OAuth client flow - retrieve stored authorize request var found bool ar, found = h.storage.GetAuthorizeRequest(state) if !found { @@ -315,7 +335,7 @@ func (h *AuthHandlers) GoogleCallbackHandler(w http.ResponseWriter, r *http.Requ ctx, cancel := context.WithTimeout(ctx, 30*time.Second) defer cancel() - token, err := googleauth.ExchangeCodeForToken(ctx, h.authConfig, code) + token, err := h.idpProvider.ExchangeCode(ctx, code) if err != nil { log.LogError("Failed to exchange code: %v", err) if !isBrowserFlow && ar != nil { @@ -326,10 +346,19 @@ func (h *AuthHandlers) GoogleCallbackHandler(w http.ResponseWriter, r *http.Requ return } - // Validate user - userInfo, err := googleauth.ValidateUser(ctx, h.authConfig, token) + identity, err := h.idpProvider.UserInfo(ctx, token) if err != nil { - log.LogError("User validation failed: %v", err) + log.LogError("Failed to fetch user identity: %v", err) + if !isBrowserFlow && ar != nil { + h.oauthProvider.WriteAuthorizeError(ctx, w, ar, fosite.ErrServerError.WithHint("Failed to fetch user identity")) + } else { + jsonwriter.WriteInternalServerError(w, "Authentication failed") + } + return + } + + if err := h.validateAccess(identity); err != nil { + log.LogError("Access denied: %v", err) if !isBrowserFlow && ar != nil { h.oauthProvider.WriteAuthorizeError(ctx, w, ar, fosite.ErrAccessDenied.WithHint(err.Error())) } else { @@ -338,65 +367,61 @@ func (h *AuthHandlers) GoogleCallbackHandler(w http.ResponseWriter, r *http.Requ return } - log.Logf("User authenticated: %s", userInfo.Email) + log.Logf("User authenticated: %s", identity.Email) - // Store user in database - if err := h.storage.UpsertUser(ctx, userInfo.Email); err != nil { + if err := h.storage.UpsertUser(ctx, identity.Email); err != nil { log.LogWarnWithFields("auth", "Failed to track user", map[string]any{ - "email": userInfo.Email, + "email": identity.Email, "error": err.Error(), }) } if isBrowserFlow { - // Browser SSO flow - set encrypted session cookie - // Browser sessions should last longer than API tokens for better UX - sessionDuration := 24 * time.Hour - - sessionData := browserauth.SessionCookie{ - Email: userInfo.Email, - Expires: time.Now().Add(sessionDuration), - } - - // Marshal session data to JSON - jsonData, err := json.Marshal(sessionData) - if err != nil { - log.LogError("Failed to marshal session data: %v", err) - jsonwriter.WriteInternalServerError(w, "Failed to create session") - return - } + h.handleBrowserCallback(w, r, identity, returnURL) + } else { + h.handleOAuthClientCallback(ctx, w, r, ar, identity) + } +} - // Encrypt session data - encryptedData, err := h.sessionEncryptor.Encrypt(string(jsonData)) - if err != nil { - log.LogError("Failed to encrypt session: %v", err) - jsonwriter.WriteInternalServerError(w, "Failed to create session") - return - } +// handleBrowserCallback handles the browser SSO callback flow: creates an encrypted +// session cookie and redirects to the return URL. +func (h *AuthHandlers) handleBrowserCallback(w http.ResponseWriter, r *http.Request, identity *idp.Identity, returnURL string) { + sessionDuration := 24 * time.Hour - // Set secure session cookie - http.SetCookie(w, &http.Cookie{ - Name: "mcp_session", - Value: encryptedData, - Path: "/", - HttpOnly: true, - Secure: !envutil.IsDev(), - SameSite: http.SameSiteStrictMode, - MaxAge: int(sessionDuration.Seconds()), - }) + sessionData := session.BrowserCookie{ + Email: identity.Email, + Provider: identity.ProviderType, + Expires: time.Now().Add(sessionDuration), + } - log.LogInfoWithFields("auth", "Browser SSO session created", map[string]any{ - "user": userInfo.Email, - "duration": sessionDuration, - "returnURL": returnURL, - }) + jsonData, err := json.Marshal(sessionData) + if err != nil { + log.LogError("Failed to marshal session data: %v", err) + jsonwriter.WriteInternalServerError(w, "Failed to create session") + return + } - // Redirect to return URL - http.Redirect(w, r, returnURL, http.StatusFound) + encryptedData, err := h.sessionEncryptor.Encrypt(string(jsonData)) + if err != nil { + log.LogError("Failed to encrypt session: %v", err) + jsonwriter.WriteInternalServerError(w, "Failed to create session") return } - // OAuth client flow - check if any services need OAuth + cookie.SetSession(w, encryptedData, sessionDuration) + + log.LogInfoWithFields("auth", "Browser SSO session created", map[string]any{ + "user": identity.Email, + "duration": sessionDuration, + "returnURL": returnURL, + }) + + http.Redirect(w, r, returnURL, http.StatusFound) +} + +// handleOAuthClientCallback handles the OAuth client callback flow: checks for +// service auth needs, creates a fosite session, and issues the authorize response. +func (h *AuthHandlers) handleOAuthClientCallback(ctx context.Context, w http.ResponseWriter, r *http.Request, ar fosite.AuthorizeRequester, identity *idp.Identity) { needsServiceAuth := false for _, serverConfig := range h.mcpServers { if serverConfig.RequiresUserToken && @@ -408,7 +433,7 @@ func (h *AuthHandlers) GoogleCallbackHandler(w http.ResponseWriter, r *http.Requ } if needsServiceAuth { - stateData, err := h.signUpstreamOAuthState(ar, userInfo) + stateData, err := h.signUpstreamOAuthState(ar, *identity) if err != nil { log.LogError("Failed to sign OAuth state: %v", err) h.oauthProvider.WriteAuthorizeError(ctx, w, ar, fosite.ErrServerError.WithHint("Failed to prepare service authentication")) @@ -419,20 +444,16 @@ func (h *AuthHandlers) GoogleCallbackHandler(w http.ResponseWriter, r *http.Requ return } - // Create session for token issuance - // Note: Audience claims are stored in the authorize request (ar.GetGrantedAudience()) - // and will be automatically propagated to access tokens by fosite - session := &oauthsession.Session{ + session := &session.OAuthSession{ DefaultSession: &fosite.DefaultSession{ ExpiresAt: map[fosite.TokenType]time.Time{ fosite.AccessToken: time.Now().Add(h.authConfig.TokenTTL), fosite.RefreshToken: time.Now().Add(h.authConfig.RefreshTokenTTL), }, }, - UserInfo: userInfo, + Identity: *identity, } - // Accept the authorization request response, err := h.oauthProvider.NewAuthorizeResponse(ctx, ar, session) if err != nil { log.LogError("Authorize response error: %v", err) @@ -451,7 +472,7 @@ func (h *AuthHandlers) TokenHandler(w http.ResponseWriter, r *http.Request) { // Create session for the token exchange // Note: We create our custom Session type here, and fosite will populate it // with the session data from the authorization code during NewAccessRequest - session := &oauthsession.Session{DefaultSession: &fosite.DefaultSession{}} + session := &session.OAuthSession{DefaultSession: &fosite.DefaultSession{}} // Handle token request - this retrieves the session from the authorization code accessRequest, err := h.oauthProvider.NewAccessRequest(ctx, r, session) @@ -511,7 +532,7 @@ func (h *AuthHandlers) RegisterHandler(w http.ResponseWriter, r *http.Request) { } // Parse client request - redirectURIs, scopes, err := googleauth.ParseClientRequest(metadata) + redirectURIs, scopes, err := oauth.ParseClientRegistration(metadata) if err != nil { log.LogError("Client request parsing error: %v", err) jsonwriter.WriteBadRequest(w, err.Error()) @@ -560,9 +581,9 @@ func (h *AuthHandlers) RegisterHandler(w http.ResponseWriter, r *http.Request) { } // signUpstreamOAuthState signs upstream OAuth state for secure storage -func (h *AuthHandlers) signUpstreamOAuthState(ar fosite.AuthorizeRequester, userInfo googleauth.UserInfo) (string, error) { +func (h *AuthHandlers) signUpstreamOAuthState(ar fosite.AuthorizeRequester, identity idp.Identity) (string, error) { state := UpstreamOAuthState{ - UserInfo: userInfo, + Identity: identity, ClientID: ar.GetClient().GetID(), RedirectURI: ar.GetRedirectURI().String(), Scopes: ar.GetRequestedScopes(), @@ -582,6 +603,28 @@ func (h *AuthHandlers) verifyUpstreamOAuthState(signedState string) (*UpstreamOA return &state, nil } +// validateAccess checks whether an authenticated identity meets the configured access policy. +func (h *AuthHandlers) validateAccess(identity *idp.Identity) error { + if len(h.authConfig.AllowedDomains) > 0 && + !slices.Contains(h.authConfig.AllowedDomains, identity.Domain) { + return fmt.Errorf("domain '%s' is not allowed. Contact your administrator", identity.Domain) + } + if len(h.authConfig.IDP.AllowedOrgs) > 0 && + !hasOverlap(identity.Organizations, h.authConfig.IDP.AllowedOrgs) { + return fmt.Errorf("user is not a member of any allowed organization. Contact your administrator") + } + return nil +} + +func hasOverlap(a, b []string) bool { + for _, x := range a { + if slices.Contains(b, x) { + return true + } + } + return false +} + // ServiceSelectionHandler shows the interstitial page for selecting services to connect func (h *AuthHandlers) ServiceSelectionHandler(w http.ResponseWriter, r *http.Request) { if r.Method != http.MethodGet { @@ -602,7 +645,7 @@ func (h *AuthHandlers) ServiceSelectionHandler(w http.ResponseWriter, r *http.Re return } - userEmail := upstreamOAuthState.UserInfo.Email + userEmail := upstreamOAuthState.Identity.Email // Prepare template data returnURL := fmt.Sprintf("/oauth/services?state=%s", url.QueryEscape(signedState)) @@ -724,7 +767,7 @@ func (h *AuthHandlers) CompleteOAuthHandler(w http.ResponseWriter, r *http.Reque Client: client, RequestedScope: upstreamOAuthState.Scopes, GrantedScope: upstreamOAuthState.Scopes, - Session: &oauthsession.Session{DefaultSession: &fosite.DefaultSession{}}, + Session: &session.OAuthSession{DefaultSession: &fosite.DefaultSession{}}, }, } @@ -737,14 +780,14 @@ func (h *AuthHandlers) CompleteOAuthHandler(w http.ResponseWriter, r *http.Reque ar.RedirectURI = redirectURI // Create session with user info - session := &oauthsession.Session{ + session := &session.OAuthSession{ DefaultSession: &fosite.DefaultSession{ ExpiresAt: map[fosite.TokenType]time.Time{ fosite.AccessToken: time.Now().Add(h.authConfig.TokenTTL), fosite.RefreshToken: time.Now().Add(h.authConfig.RefreshTokenTTL), }, }, - UserInfo: upstreamOAuthState.UserInfo, + Identity: upstreamOAuthState.Identity, } ar.SetSession(session) diff --git a/internal/server/auth_handlers_test.go b/internal/server/auth_handlers_test.go index 6e2e98f..a8cd625 100644 --- a/internal/server/auth_handlers_test.go +++ b/internal/server/auth_handlers_test.go @@ -1,6 +1,7 @@ package server import ( + "context" "encoding/json" "net/http" "net/http/httptest" @@ -9,15 +10,43 @@ import ( "time" "github.com/dgellow/mcp-front/internal/auth" - "github.com/dgellow/mcp-front/internal/browserauth" "github.com/dgellow/mcp-front/internal/config" "github.com/dgellow/mcp-front/internal/crypto" + "github.com/dgellow/mcp-front/internal/idp" "github.com/dgellow/mcp-front/internal/oauth" + "github.com/dgellow/mcp-front/internal/session" "github.com/dgellow/mcp-front/internal/storage" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" + "golang.org/x/oauth2" ) +// mockIDPProvider is a mock IDP provider for testing +type mockIDPProvider struct{} + +func (m *mockIDPProvider) Type() string { + return "mock" +} + +func (m *mockIDPProvider) AuthURL(state string) string { + return "https://auth.example.com/authorize?state=" + state +} + +func (m *mockIDPProvider) ExchangeCode(ctx context.Context, code string) (*oauth2.Token, error) { + return &oauth2.Token{AccessToken: "test-token"}, nil +} + +func (m *mockIDPProvider) UserInfo(ctx context.Context, token *oauth2.Token) (*idp.Identity, error) { + return &idp.Identity{ + ProviderType: "mock", + Subject: "123", + Email: "test@example.com", + EmailVerified: true, + Name: "Test User", + Domain: "example.com", + }, nil +} + func TestAuthenticationBoundaries(t *testing.T) { tests := []struct { name string @@ -47,18 +76,21 @@ func TestAuthenticationBoundaries(t *testing.T) { // Setup test OAuth configuration oauthConfig := config.OAuthAuthConfig{ - Kind: config.AuthKindOAuth, - Issuer: "https://test.example.com", - GoogleClientID: "test-client-id", - GoogleClientSecret: config.Secret("test-client-secret"), - GoogleRedirectURI: "https://test.example.com/oauth/callback", - JWTSecret: config.Secret(strings.Repeat("a", 32)), - EncryptionKey: config.Secret(strings.Repeat("b", 32)), - TokenTTL: time.Hour, - RefreshTokenTTL: 30 * 24 * time.Hour, - Storage: "memory", - AllowedDomains: []string{"example.com"}, - AllowedOrigins: []string{"https://test.example.com"}, + Kind: config.AuthKindOAuth, + Issuer: "https://test.example.com", + IDP: config.IDPConfig{ + Provider: "google", + ClientID: "test-client-id", + ClientSecret: config.Secret("test-client-secret"), + RedirectURI: "https://test.example.com/oauth/callback", + }, + JWTSecret: config.Secret(strings.Repeat("a", 32)), + EncryptionKey: config.Secret(strings.Repeat("b", 32)), + TokenTTL: time.Hour, + RefreshTokenTTL: 30 * 24 * time.Hour, + Storage: "memory", + AllowedDomains: []string{"example.com"}, + AllowedOrigins: []string{"https://test.example.com"}, } // Create storage @@ -77,17 +109,21 @@ func TestAuthenticationBoundaries(t *testing.T) { // Create service OAuth client serviceOAuthClient := auth.NewServiceOAuthClient(store, "https://test.example.com", []byte(strings.Repeat("k", 32))) + // Create mock IDP provider for testing + mockIDP := &mockIDPProvider{} + // Create handlers authHandlers := NewAuthHandlers( oauthProvider, oauthConfig, + mockIDP, store, sessionEncryptor, map[string]*config.MCPClientConfig{}, serviceOAuthClient, ) - tokenHandlers := NewTokenHandlers(store, map[string]*config.MCPClientConfig{}, true, serviceOAuthClient) + tokenHandlers := NewTokenHandlers(store, map[string]*config.MCPClientConfig{}, true, serviceOAuthClient, []byte(oauthConfig.EncryptionKey)) // Build mux with middlewares mux := http.NewServeMux() @@ -103,7 +139,7 @@ func TestAuthenticationBoundaries(t *testing.T) { // Protected endpoints tokenMiddleware := []MiddlewareFunc{ corsMiddleware, - NewBrowserSSOMiddleware(oauthConfig, sessionEncryptor, &browserStateToken), + NewBrowserSSOMiddleware(oauthConfig, mockIDP, sessionEncryptor, &browserStateToken), } mux.Handle("/my/tokens", ChainMiddleware( @@ -154,9 +190,10 @@ func TestAuthenticationBoundaries(t *testing.T) { // Test with valid session cookie (if auth is expected) if tt.expectAuth { // Create session data - sessionData := browserauth.SessionCookie{ - Email: "test@example.com", - Expires: time.Now().Add(24 * time.Hour), + sessionData := session.BrowserCookie{ + Email: "test@example.com", + Provider: "mock", + Expires: time.Now().Add(24 * time.Hour), } jsonData, err := json.Marshal(sessionData) require.NoError(t, err) @@ -186,18 +223,21 @@ func TestAuthenticationBoundaries(t *testing.T) { func TestOAuthEndpointHandlers(t *testing.T) { oauthConfig := config.OAuthAuthConfig{ - Kind: config.AuthKindOAuth, - Issuer: "https://test.example.com", - GoogleClientID: "test-client-id", - GoogleClientSecret: config.Secret("test-client-secret"), - GoogleRedirectURI: "https://test.example.com/oauth/callback", - JWTSecret: config.Secret(strings.Repeat("a", 32)), - EncryptionKey: config.Secret(strings.Repeat("b", 32)), - TokenTTL: time.Hour, - RefreshTokenTTL: 30 * 24 * time.Hour, - Storage: "memory", - AllowedDomains: []string{"example.com"}, - AllowedOrigins: []string{"https://test.example.com"}, + Kind: config.AuthKindOAuth, + Issuer: "https://test.example.com", + IDP: config.IDPConfig{ + Provider: "google", + ClientID: "test-client-id", + ClientSecret: config.Secret("test-client-secret"), + RedirectURI: "https://test.example.com/oauth/callback", + }, + JWTSecret: config.Secret(strings.Repeat("a", 32)), + EncryptionKey: config.Secret(strings.Repeat("b", 32)), + TokenTTL: time.Hour, + RefreshTokenTTL: 30 * 24 * time.Hour, + Storage: "memory", + AllowedDomains: []string{"example.com"}, + AllowedOrigins: []string{"https://test.example.com"}, } store := storage.NewMemoryStorage() @@ -208,10 +248,12 @@ func TestOAuthEndpointHandlers(t *testing.T) { sessionEncryptor, err := oauth.NewSessionEncryptor([]byte(oauthConfig.EncryptionKey)) require.NoError(t, err) serviceOAuthClient := auth.NewServiceOAuthClient(store, "https://test.example.com", []byte(strings.Repeat("k", 32))) + mockIDP := &mockIDPProvider{} authHandlers := NewAuthHandlers( oauthProvider, oauthConfig, + mockIDP, store, sessionEncryptor, map[string]*config.MCPClientConfig{}, @@ -272,7 +314,7 @@ func TestOAuthEndpointHandlers(t *testing.T) { req := httptest.NewRequest(http.MethodGet, "/oauth/services", nil) // Add valid session cookie - sessionData := browserauth.SessionCookie{ + sessionData := session.BrowserCookie{ Email: "test@example.com", Expires: time.Now().Add(24 * time.Hour), } @@ -378,3 +420,88 @@ func TestBearerTokenAuth(t *testing.T) { }) } } + +func TestValidateAccess(t *testing.T) { + tests := []struct { + name string + allowedDomains []string + allowedOrgs []string + identity *idp.Identity + wantErr bool + errContains string + }{ + { + name: "no_restrictions", + allowedDomains: nil, + allowedOrgs: nil, + identity: &idp.Identity{Domain: "any.com", Organizations: []string{"any-org"}}, + wantErr: false, + }, + { + name: "domain_allowed", + allowedDomains: []string{"company.com"}, + identity: &idp.Identity{Domain: "company.com"}, + wantErr: false, + }, + { + name: "domain_rejected", + allowedDomains: []string{"company.com"}, + identity: &idp.Identity{Domain: "other.com"}, + wantErr: true, + errContains: "domain 'other.com' is not allowed", + }, + { + name: "org_allowed", + allowedOrgs: []string{"allowed-org"}, + identity: &idp.Identity{Domain: "any.com", Organizations: []string{"allowed-org", "other-org"}}, + wantErr: false, + }, + { + name: "org_rejected", + allowedOrgs: []string{"required-org"}, + identity: &idp.Identity{Domain: "any.com", Organizations: []string{"other-org"}}, + wantErr: true, + errContains: "not a member of any allowed organization", + }, + { + name: "domain_and_org_both_pass", + allowedDomains: []string{"company.com"}, + allowedOrgs: []string{"my-org"}, + identity: &idp.Identity{Domain: "company.com", Organizations: []string{"my-org"}}, + wantErr: false, + }, + { + name: "domain_fails_before_org_check", + allowedDomains: []string{"company.com"}, + allowedOrgs: []string{"my-org"}, + identity: &idp.Identity{Domain: "other.com", Organizations: []string{"my-org"}}, + wantErr: true, + errContains: "domain 'other.com' is not allowed", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + h := &AuthHandlers{ + authConfig: config.OAuthAuthConfig{ + AllowedDomains: tt.allowedDomains, + IDP: config.IDPConfig{ + AllowedOrgs: tt.allowedOrgs, + }, + }, + } + + err := h.validateAccess(tt.identity) + + if tt.wantErr { + require.Error(t, err) + if tt.errContains != "" { + assert.Contains(t, err.Error(), tt.errContains) + } + return + } + + require.NoError(t, err) + }) + } +} diff --git a/internal/server/http.go b/internal/server/http.go index 3bbe7dc..6a0dc68 100644 --- a/internal/server/http.go +++ b/internal/server/http.go @@ -4,86 +4,10 @@ import ( "context" "errors" "net/http" - "strings" - "time" - "github.com/dgellow/mcp-front/internal/auth" - "github.com/dgellow/mcp-front/internal/client" - "github.com/dgellow/mcp-front/internal/config" "github.com/dgellow/mcp-front/internal/log" - "github.com/dgellow/mcp-front/internal/storage" - mcpserver "github.com/mark3labs/mcp-go/server" ) -// UserTokenService handles user token retrieval and OAuth refresh -type UserTokenService struct { - storage storage.Storage - serviceOAuthClient *auth.ServiceOAuthClient -} - -// NewUserTokenService creates a new user token service -func NewUserTokenService(storage storage.Storage, serviceOAuthClient *auth.ServiceOAuthClient) *UserTokenService { - return &UserTokenService{ - storage: storage, - serviceOAuthClient: serviceOAuthClient, - } -} - -// GetUserToken retrieves and formats a user token for a service, handling OAuth refresh. -// -// Token refresh strategy: Optimistic continuation on failure. -// If refresh fails, we log a warning and continue with the current token. The external -// service will reject the expired token with 401, giving the user a clear error. -// This is acceptable because: (1) refresh failures are rare (network issues, revoked -// tokens), and (2) forcing users to re-auth is better than silently hiding auth issues. -func (uts *UserTokenService) GetUserToken(ctx context.Context, userEmail, serviceName string, serviceConfig *config.MCPClientConfig) (string, error) { - storedToken, err := uts.storage.GetUserToken(ctx, userEmail, serviceName) - if err != nil { - return "", err - } - - switch storedToken.Type { - case storage.TokenTypeManual: - // Token is already in storedToken.Value, formatUserToken will handle it - break - case storage.TokenTypeOAuth: - if storedToken.OAuthData != nil && uts.serviceOAuthClient != nil { - if err := uts.serviceOAuthClient.RefreshToken(ctx, userEmail, serviceName, serviceConfig); err != nil { - log.LogWarnWithFields("user_token", "Failed to refresh OAuth token", map[string]any{ - "service": serviceName, - "user": userEmail, - "error": err.Error(), - }) - // Continue with current token - the service will handle auth failure - } else { - // Re-fetch the updated token after refresh - refreshedToken, err := uts.storage.GetUserToken(ctx, userEmail, serviceName) - if err != nil { - log.LogErrorWithFields("user_token", "Failed to fetch token after successful refresh", map[string]any{ - "service": serviceName, - "user": userEmail, - "error": err.Error(), - }) - // Continue with original token - the service will handle auth failure - } else { - storedToken = refreshedToken - var expiresAt time.Time - if refreshedToken.OAuthData != nil { - expiresAt = refreshedToken.OAuthData.ExpiresAt - } - log.LogInfoWithFields("user_token", "OAuth token refreshed and updated", map[string]any{ - "service": serviceName, - "user": userEmail, - "expiresAt": expiresAt, - }) - } - } - } - } - - return formatUserToken(storedToken, serviceConfig.UserAuthentication), nil -} - // HTTPServer manages the HTTP server lifecycle type HTTPServer struct { server *http.Server @@ -99,8 +23,6 @@ func NewHTTPServer(handler http.Handler, addr string) *HTTPServer { } } -// Handler builders and mux assembly - // HealthHandler handles health check requests type HealthHandler struct{} @@ -143,184 +65,3 @@ func (h *HTTPServer) Stop(ctx context.Context) error { }) return nil } - -// isStdioServer checks if this is a stdio-based server -func isStdioServer(config *config.MCPClientConfig) bool { - return config.Command != "" -} - -// formatUserToken formats a stored token according to the user authentication configuration -func formatUserToken(storedToken *storage.StoredToken, auth *config.UserAuthentication) string { - if storedToken == nil { - return "" - } - - if storedToken.Type == storage.TokenTypeOAuth && storedToken.OAuthData != nil { - token := storedToken.OAuthData.AccessToken - if auth.TokenFormat != "" && auth.TokenFormat != "{{token}}" { - return strings.ReplaceAll(auth.TokenFormat, "{{token}}", token) - } - return token - } - - token := storedToken.Value - if auth != nil && auth.TokenFormat != "" && auth.TokenFormat != "{{token}}" { - return strings.ReplaceAll(auth.TokenFormat, "{{token}}", token) - } - return token -} - -// SessionHandlerKey is the context key for session handlers -type SessionHandlerKey struct{} - -// SessionRequestHandler handles session-specific logic for a request -type SessionRequestHandler struct { - h *MCPHandler - userEmail string - config *config.MCPClientConfig - mcpServer *mcpserver.MCPServer // The shared MCP server -} - -// NewSessionRequestHandler creates a new session request handler with all dependencies -func NewSessionRequestHandler(h *MCPHandler, userEmail string, config *config.MCPClientConfig, mcpServer *mcpserver.MCPServer) *SessionRequestHandler { - return &SessionRequestHandler{ - h: h, - userEmail: userEmail, - config: config, - mcpServer: mcpServer, - } -} - -// GetUserEmail returns the user email for this session -func (s *SessionRequestHandler) GetUserEmail() string { - return s.userEmail -} - -// GetServerName returns the server name for this session -func (s *SessionRequestHandler) GetServerName() string { - return s.h.serverName -} - -// GetStorage returns the storage interface -func (s *SessionRequestHandler) GetStorage() storage.Storage { - return s.h.storage -} - -// HandleSessionRegistration handles the registration of a new session -func HandleSessionRegistration( - sessionCtx context.Context, - session mcpserver.ClientSession, - handler *SessionRequestHandler, - sessionManager *client.StdioSessionManager, -) { - // Create stdio process for this session - key := client.SessionKey{ - UserEmail: handler.userEmail, - ServerName: handler.h.serverName, - SessionID: session.SessionID(), - } - - log.LogDebugWithFields("server", "Registering session", map[string]any{ - "sessionID": session.SessionID(), - "server": handler.h.serverName, - "user": handler.userEmail, - }) - - log.LogTraceWithFields("server", "Session registration started", map[string]any{ - "sessionID": session.SessionID(), - "server": handler.h.serverName, - "user": handler.userEmail, - "requiresUserToken": handler.config.RequiresUserToken, - "transportType": handler.config.TransportType, - "command": handler.config.Command, - }) - - var userToken string - if handler.config.RequiresUserToken && handler.userEmail != "" && handler.h.storage != nil { - storedToken, err := handler.h.storage.GetUserToken(sessionCtx, handler.userEmail, handler.h.serverName) - if err != nil { - log.LogDebugWithFields("server", "No user token found", map[string]any{ - "server": handler.h.serverName, - "user": handler.userEmail, - }) - } else if storedToken != nil { - if handler.config.UserAuthentication != nil { - userToken = formatUserToken(storedToken, handler.config.UserAuthentication) - } else { - userToken = storedToken.Value - } - } - } - - stdioSession, err := sessionManager.GetOrCreateSession( - sessionCtx, - key, - handler.config, - handler.h.info, - handler.h.setupBaseURL, - userToken, - ) - if err != nil { - log.LogErrorWithFields("server", "Failed to create stdio session", map[string]any{ - "error": err.Error(), - "sessionID": session.SessionID(), - "server": handler.h.serverName, - "user": handler.userEmail, - }) - return - } - - // Discover and register capabilities from the stdio process - if err := stdioSession.DiscoverAndRegisterCapabilities( - sessionCtx, - handler.mcpServer, - handler.userEmail, - handler.config.RequiresUserToken, - handler.h.storage, - handler.h.serverName, - handler.h.setupBaseURL, - handler.config.UserAuthentication, - session, - ); err != nil { - log.LogErrorWithFields("server", "Failed to discover and register capabilities", map[string]any{ - "error": err.Error(), - "sessionID": session.SessionID(), - "server": handler.h.serverName, - "user": handler.userEmail, - }) - if err := sessionManager.RemoveSession(key); err != nil { - log.LogErrorWithFields("server", "Failed to remove session on capability failure", map[string]any{ - "sessionID": session.SessionID(), - "server": handler.h.serverName, - "user": handler.userEmail, - "error": err.Error(), - }) - } - return - } - - if handler.userEmail != "" { - if handler.h.storage != nil { - activeSession := storage.ActiveSession{ - SessionID: session.SessionID(), - UserEmail: handler.userEmail, - ServerName: handler.h.serverName, - Created: time.Now(), - LastActive: time.Now(), - } - if err := handler.h.storage.TrackSession(sessionCtx, activeSession); err != nil { - log.LogWarnWithFields("server", "Failed to track session", map[string]any{ - "error": err.Error(), - "sessionID": session.SessionID(), - "user": handler.userEmail, - }) - } - } - } - - log.LogInfoWithFields("server", "Session successfully created and connected", map[string]any{ - "sessionID": session.SessionID(), - "server": handler.h.serverName, - "user": handler.userEmail, - }) -} diff --git a/internal/server/http_test.go b/internal/server/http_test.go index 822b13b..955c9da 100644 --- a/internal/server/http_test.go +++ b/internal/server/http_test.go @@ -37,17 +37,20 @@ func TestHealthEndpoint(t *testing.T) { func TestOAuthEndpointsCORS(t *testing.T) { // Setup OAuth config oauthConfig := config.OAuthAuthConfig{ - Kind: config.AuthKindOAuth, - Issuer: "https://test.example.com", - GoogleClientID: "test-client-id", - GoogleClientSecret: config.Secret("test-client-secret"), - GoogleRedirectURI: "https://test.example.com/oauth/callback", - JWTSecret: config.Secret(strings.Repeat("a", 32)), - EncryptionKey: config.Secret(strings.Repeat("b", 32)), - TokenTTL: time.Hour, - RefreshTokenTTL: 30 * 24 * time.Hour, - Storage: "memory", - AllowedOrigins: []string{"http://localhost:6274"}, + Kind: config.AuthKindOAuth, + Issuer: "https://test.example.com", + IDP: config.IDPConfig{ + Provider: "google", + ClientID: "test-client-id", + ClientSecret: config.Secret("test-client-secret"), + RedirectURI: "https://test.example.com/oauth/callback", + }, + JWTSecret: config.Secret(strings.Repeat("a", 32)), + EncryptionKey: config.Secret(strings.Repeat("b", 32)), + TokenTTL: time.Hour, + RefreshTokenTTL: 30 * 24 * time.Hour, + Storage: "memory", + AllowedOrigins: []string{"http://localhost:6274"}, } store := storage.NewMemoryStorage() @@ -58,10 +61,12 @@ func TestOAuthEndpointsCORS(t *testing.T) { sessionEncryptor, err := oauth.NewSessionEncryptor([]byte(oauthConfig.EncryptionKey)) require.NoError(t, err) serviceOAuthClient := auth.NewServiceOAuthClient(store, "https://test.example.com", []byte(strings.Repeat("k", 32))) + mockIDP := &mockIDPProvider{} authHandlers := NewAuthHandlers( oauthProvider, oauthConfig, + mockIDP, store, sessionEncryptor, map[string]*config.MCPClientConfig{}, diff --git a/internal/server/mcp_handler.go b/internal/server/mcp_handler.go index 1f517ec..c0bc2ee 100644 --- a/internal/server/mcp_handler.go +++ b/internal/server/mcp_handler.go @@ -122,7 +122,7 @@ func (h *MCPHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) { log.LogInfoWithFields("mcp", "Handling message request", map[string]any{ "path": r.URL.Path, "server": h.serverName, - "isStdio": isStdioServer(serverConfig), + "isStdio": serverConfig.IsStdio(), "user": userEmail, "remoteAddr": r.RemoteAddr, "contentLength": r.ContentLength, @@ -133,7 +133,7 @@ func (h *MCPHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) { log.LogInfoWithFields("mcp", "Handling SSE request", map[string]any{ "path": r.URL.Path, "server": h.serverName, - "isStdio": isStdioServer(serverConfig), + "isStdio": serverConfig.IsStdio(), "user": userEmail, "remoteAddr": r.RemoteAddr, "userAgent": r.UserAgent(), @@ -168,7 +168,7 @@ func (h *MCPHandler) trackUserAccess(ctx context.Context, userEmail string) { func (h *MCPHandler) handleSSERequest(ctx context.Context, w http.ResponseWriter, r *http.Request, userEmail string, config *config.MCPClientConfig) { h.trackUserAccess(ctx, userEmail) - if !isStdioServer(config) { + if !config.IsStdio() { // For non-stdio servers, handle normally h.handleNonStdioSSERequest(ctx, w, r, userEmail, config) return @@ -206,7 +206,7 @@ func (h *MCPHandler) handleSSERequest(ctx context.Context, w http.ResponseWriter func (h *MCPHandler) handleMessageRequest(ctx context.Context, w http.ResponseWriter, r *http.Request, userEmail string, config *config.MCPClientConfig) { h.trackUserAccess(ctx, userEmail) - if isStdioServer(config) { + if config.IsStdio() { sessionID := r.URL.Query().Get("sessionId") if sessionID == "" { jsonrpc.WriteError(w, nil, jsonrpc.InvalidParams, "missing sessionId") diff --git a/internal/server/middleware.go b/internal/server/middleware.go index bb04163..c207e44 100644 --- a/internal/server/middleware.go +++ b/internal/server/middleware.go @@ -10,15 +10,15 @@ import ( "time" "github.com/dgellow/mcp-front/internal/adminauth" - "github.com/dgellow/mcp-front/internal/browserauth" "github.com/dgellow/mcp-front/internal/config" "github.com/dgellow/mcp-front/internal/cookie" "github.com/dgellow/mcp-front/internal/crypto" - "github.com/dgellow/mcp-front/internal/googleauth" + "github.com/dgellow/mcp-front/internal/idp" jsonwriter "github.com/dgellow/mcp-front/internal/json" "github.com/dgellow/mcp-front/internal/log" "github.com/dgellow/mcp-front/internal/oauth" "github.com/dgellow/mcp-front/internal/servicecontext" + "github.com/dgellow/mcp-front/internal/session" "github.com/dgellow/mcp-front/internal/storage" "golang.org/x/crypto/bcrypt" ) @@ -289,7 +289,7 @@ func adminMiddleware(adminConfig *config.AdminConfig, store storage.Storage) Mid } // NewBrowserSSOMiddleware creates middleware for browser-based SSO authentication -func NewBrowserSSOMiddleware(authConfig config.OAuthAuthConfig, sessionEncryptor crypto.Encryptor, browserStateToken *crypto.TokenSigner) MiddlewareFunc { +func NewBrowserSSOMiddleware(authConfig config.OAuthAuthConfig, idpProvider idp.Provider, sessionEncryptor crypto.Encryptor, browserStateToken *crypto.TokenSigner) MiddlewareFunc { return func(next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { // Check for session cookie @@ -301,8 +301,8 @@ func NewBrowserSSOMiddleware(authConfig config.OAuthAuthConfig, sessionEncryptor jsonwriter.WriteInternalServerError(w, "Failed to generate authentication state") return } - googleURL := googleauth.GoogleAuthURL(authConfig, state) - http.Redirect(w, r, googleURL, http.StatusFound) + authURL := idpProvider.AuthURL(state) + http.Redirect(w, r, authURL, http.StatusFound) return } @@ -313,13 +313,13 @@ func NewBrowserSSOMiddleware(authConfig config.OAuthAuthConfig, sessionEncryptor log.LogDebug("Invalid session cookie: %v", err) cookie.ClearSession(w) // Clear bad cookie state := generateBrowserState(browserStateToken, r.URL.String()) - googleURL := googleauth.GoogleAuthURL(authConfig, state) - http.Redirect(w, r, googleURL, http.StatusFound) + authURL := idpProvider.AuthURL(state) + http.Redirect(w, r, authURL, http.StatusFound) return } // Parse session data - var sessionData browserauth.SessionCookie + var sessionData session.BrowserCookie if err := json.NewDecoder(strings.NewReader(decrypted)).Decode(&sessionData); err != nil { // Invalid format cookie.ClearSession(w) @@ -331,14 +331,14 @@ func NewBrowserSSOMiddleware(authConfig config.OAuthAuthConfig, sessionEncryptor if time.Now().After(sessionData.Expires) { log.LogDebug("Session expired for user %s", sessionData.Email) cookie.ClearSession(w) - // Redirect directly to Google OAuth + // Redirect directly to OAuth state := generateBrowserState(browserStateToken, r.URL.String()) if state == "" { jsonwriter.WriteInternalServerError(w, "Failed to generate authentication state") return } - googleURL := googleauth.GoogleAuthURL(authConfig, state) - http.Redirect(w, r, googleURL, http.StatusFound) + authURL := idpProvider.AuthURL(state) + http.Redirect(w, r, authURL, http.StatusFound) return } @@ -353,7 +353,7 @@ func NewBrowserSSOMiddleware(authConfig config.OAuthAuthConfig, sessionEncryptor // generateBrowserState creates a secure state parameter for browser SSO func generateBrowserState(browserStateToken *crypto.TokenSigner, returnURL string) string { - state := browserauth.AuthorizationState{ + state := session.AuthorizationState{ Nonce: crypto.GenerateSecureToken(), ReturnURL: returnURL, } diff --git a/internal/server/session_handler.go b/internal/server/session_handler.go new file mode 100644 index 0000000..58334f6 --- /dev/null +++ b/internal/server/session_handler.go @@ -0,0 +1,167 @@ +package server + +import ( + "context" + "time" + + "github.com/dgellow/mcp-front/internal/client" + "github.com/dgellow/mcp-front/internal/config" + "github.com/dgellow/mcp-front/internal/log" + "github.com/dgellow/mcp-front/internal/storage" + mcpserver "github.com/mark3labs/mcp-go/server" +) + +// SessionHandlerKey is the context key for session handlers +type SessionHandlerKey struct{} + +// SessionRequestHandler handles session-specific logic for a request +type SessionRequestHandler struct { + h *MCPHandler + userEmail string + config *config.MCPClientConfig + mcpServer *mcpserver.MCPServer // The shared MCP server +} + +// NewSessionRequestHandler creates a new session request handler with all dependencies +func NewSessionRequestHandler(h *MCPHandler, userEmail string, config *config.MCPClientConfig, mcpServer *mcpserver.MCPServer) *SessionRequestHandler { + return &SessionRequestHandler{ + h: h, + userEmail: userEmail, + config: config, + mcpServer: mcpServer, + } +} + +// GetUserEmail returns the user email for this session +func (s *SessionRequestHandler) GetUserEmail() string { + return s.userEmail +} + +// GetServerName returns the server name for this session +func (s *SessionRequestHandler) GetServerName() string { + return s.h.serverName +} + +// GetStorage returns the storage interface +func (s *SessionRequestHandler) GetStorage() storage.Storage { + return s.h.storage +} + +// HandleSessionRegistration handles the registration of a new session +func HandleSessionRegistration( + sessionCtx context.Context, + session mcpserver.ClientSession, + handler *SessionRequestHandler, + sessionManager *client.StdioSessionManager, +) { + // Create stdio process for this session + key := client.SessionKey{ + UserEmail: handler.userEmail, + ServerName: handler.h.serverName, + SessionID: session.SessionID(), + } + + log.LogDebugWithFields("server", "Registering session", map[string]any{ + "sessionID": session.SessionID(), + "server": handler.h.serverName, + "user": handler.userEmail, + }) + + log.LogTraceWithFields("server", "Session registration started", map[string]any{ + "sessionID": session.SessionID(), + "server": handler.h.serverName, + "user": handler.userEmail, + "requiresUserToken": handler.config.RequiresUserToken, + "transportType": handler.config.TransportType, + "command": handler.config.Command, + }) + + var userToken string + if handler.config.RequiresUserToken && handler.userEmail != "" && handler.h.storage != nil { + storedToken, err := handler.h.storage.GetUserToken(sessionCtx, handler.userEmail, handler.h.serverName) + if err != nil { + log.LogDebugWithFields("server", "No user token found", map[string]any{ + "server": handler.h.serverName, + "user": handler.userEmail, + }) + } else if storedToken != nil { + if handler.config.UserAuthentication != nil { + userToken = formatUserToken(storedToken, handler.config.UserAuthentication) + } else { + userToken = storedToken.Value + } + } + } + + stdioSession, err := sessionManager.GetOrCreateSession( + sessionCtx, + key, + handler.config, + handler.h.info, + handler.h.setupBaseURL, + userToken, + ) + if err != nil { + log.LogErrorWithFields("server", "Failed to create stdio session", map[string]any{ + "error": err.Error(), + "sessionID": session.SessionID(), + "server": handler.h.serverName, + "user": handler.userEmail, + }) + return + } + + // Discover and register capabilities from the stdio process + if err := stdioSession.DiscoverAndRegisterCapabilities( + sessionCtx, + handler.mcpServer, + handler.userEmail, + handler.config.RequiresUserToken, + handler.h.storage, + handler.h.serverName, + handler.h.setupBaseURL, + handler.config.UserAuthentication, + session, + ); err != nil { + log.LogErrorWithFields("server", "Failed to discover and register capabilities", map[string]any{ + "error": err.Error(), + "sessionID": session.SessionID(), + "server": handler.h.serverName, + "user": handler.userEmail, + }) + if err := sessionManager.RemoveSession(key); err != nil { + log.LogErrorWithFields("server", "Failed to remove session on capability failure", map[string]any{ + "sessionID": session.SessionID(), + "server": handler.h.serverName, + "user": handler.userEmail, + "error": err.Error(), + }) + } + return + } + + if handler.userEmail != "" { + if handler.h.storage != nil { + activeSession := storage.ActiveSession{ + SessionID: session.SessionID(), + UserEmail: handler.userEmail, + ServerName: handler.h.serverName, + Created: time.Now(), + LastActive: time.Now(), + } + if err := handler.h.storage.TrackSession(sessionCtx, activeSession); err != nil { + log.LogWarnWithFields("server", "Failed to track session", map[string]any{ + "error": err.Error(), + "sessionID": session.SessionID(), + "user": handler.userEmail, + }) + } + } + } + + log.LogInfoWithFields("server", "Session successfully created and connected", map[string]any{ + "sessionID": session.SessionID(), + "server": handler.h.serverName, + "user": handler.userEmail, + }) +} diff --git a/internal/server/token_handlers.go b/internal/server/token_handlers.go index b2cbdaa..cf33a00 100644 --- a/internal/server/token_handlers.go +++ b/internal/server/token_handlers.go @@ -4,7 +4,6 @@ import ( "fmt" "net/http" "strings" - "sync" "time" "github.com/dgellow/mcp-front/internal/auth" @@ -20,40 +19,22 @@ import ( type TokenHandlers struct { tokenStore storage.UserTokenStore mcpServers map[string]*config.MCPClientConfig - csrfTokens sync.Map // Thread-safe CSRF token storage + csrf crypto.CSRFProtection oauthEnabled bool serviceOAuthClient *auth.ServiceOAuthClient } // NewTokenHandlers creates a new token handlers instance -func NewTokenHandlers(tokenStore storage.UserTokenStore, mcpServers map[string]*config.MCPClientConfig, oauthEnabled bool, serviceOAuthClient *auth.ServiceOAuthClient) *TokenHandlers { +func NewTokenHandlers(tokenStore storage.UserTokenStore, mcpServers map[string]*config.MCPClientConfig, oauthEnabled bool, serviceOAuthClient *auth.ServiceOAuthClient, csrfKey []byte) *TokenHandlers { return &TokenHandlers{ tokenStore: tokenStore, mcpServers: mcpServers, oauthEnabled: oauthEnabled, serviceOAuthClient: serviceOAuthClient, + csrf: crypto.NewCSRFProtection(csrfKey, 15*time.Minute), } } -// generateCSRFToken creates a new CSRF token -func (h *TokenHandlers) generateCSRFToken() (string, error) { - token := crypto.GenerateSecureToken() - if token == "" { - return "", fmt.Errorf("failed to generate CSRF token") - } - h.csrfTokens.Store(token, true) - return token, nil -} - -// validateCSRFToken checks if a CSRF token is valid -func (h *TokenHandlers) validateCSRFToken(token string) bool { - if _, exists := h.csrfTokens.LoadAndDelete(token); exists { - // One-time use via LoadAndDelete - return true - } - return false -} - // ListTokensHandler shows the token management page func (h *TokenHandlers) ListTokensHandler(w http.ResponseWriter, r *http.Request) { // Only accept GET @@ -137,7 +118,7 @@ func (h *TokenHandlers) ListTokensHandler(w http.ResponseWriter, r *http.Request } // Generate CSRF token - csrfToken, err := h.generateCSRFToken() + csrfToken, err := h.csrf.Generate() if err != nil { log.LogErrorWithFields("token", "Failed to generate CSRF token", map[string]any{ "error": err.Error(), @@ -188,7 +169,7 @@ func (h *TokenHandlers) SetTokenHandler(w http.ResponseWriter, r *http.Request) // Validate CSRF token csrfToken := r.FormValue("csrf_token") - if !h.validateCSRFToken(csrfToken) { + if !h.csrf.Validate(csrfToken) { jsonwriter.WriteForbidden(w, "Invalid CSRF token") return } @@ -304,7 +285,7 @@ func (h *TokenHandlers) DeleteTokenHandler(w http.ResponseWriter, r *http.Reques // Validate CSRF token csrfToken := r.FormValue("csrf_token") - if !h.validateCSRFToken(csrfToken) { + if !h.csrf.Validate(csrfToken) { jsonwriter.WriteForbidden(w, "Invalid CSRF token") return } diff --git a/internal/server/user_token_service.go b/internal/server/user_token_service.go new file mode 100644 index 0000000..7e14d57 --- /dev/null +++ b/internal/server/user_token_service.go @@ -0,0 +1,102 @@ +package server + +import ( + "context" + "strings" + "time" + + "github.com/dgellow/mcp-front/internal/auth" + "github.com/dgellow/mcp-front/internal/config" + "github.com/dgellow/mcp-front/internal/log" + "github.com/dgellow/mcp-front/internal/storage" +) + +// UserTokenService handles user token retrieval and OAuth refresh +type UserTokenService struct { + storage storage.Storage + serviceOAuthClient *auth.ServiceOAuthClient +} + +// NewUserTokenService creates a new user token service +func NewUserTokenService(storage storage.Storage, serviceOAuthClient *auth.ServiceOAuthClient) *UserTokenService { + return &UserTokenService{ + storage: storage, + serviceOAuthClient: serviceOAuthClient, + } +} + +// GetUserToken retrieves and formats a user token for a service, handling OAuth refresh. +// +// Token refresh strategy: Optimistic continuation on failure. +// If refresh fails, we log a warning and continue with the current token. The external +// service will reject the expired token with 401, giving the user a clear error. +// This is acceptable because: (1) refresh failures are rare (network issues, revoked +// tokens), and (2) forcing users to re-auth is better than silently hiding auth issues. +func (uts *UserTokenService) GetUserToken(ctx context.Context, userEmail, serviceName string, serviceConfig *config.MCPClientConfig) (string, error) { + storedToken, err := uts.storage.GetUserToken(ctx, userEmail, serviceName) + if err != nil { + return "", err + } + + switch storedToken.Type { + case storage.TokenTypeManual: + // Token is already in storedToken.Value, formatUserToken will handle it + break + case storage.TokenTypeOAuth: + if storedToken.OAuthData != nil && uts.serviceOAuthClient != nil { + if err := uts.serviceOAuthClient.RefreshToken(ctx, userEmail, serviceName, serviceConfig); err != nil { + log.LogWarnWithFields("user_token", "Failed to refresh OAuth token", map[string]any{ + "service": serviceName, + "user": userEmail, + "error": err.Error(), + }) + // Continue with current token - the service will handle auth failure + } else { + // Re-fetch the updated token after refresh + refreshedToken, err := uts.storage.GetUserToken(ctx, userEmail, serviceName) + if err != nil { + log.LogErrorWithFields("user_token", "Failed to fetch token after successful refresh", map[string]any{ + "service": serviceName, + "user": userEmail, + "error": err.Error(), + }) + // Continue with original token - the service will handle auth failure + } else { + storedToken = refreshedToken + var expiresAt time.Time + if refreshedToken.OAuthData != nil { + expiresAt = refreshedToken.OAuthData.ExpiresAt + } + log.LogInfoWithFields("user_token", "OAuth token refreshed and updated", map[string]any{ + "service": serviceName, + "user": userEmail, + "expiresAt": expiresAt, + }) + } + } + } + } + + return formatUserToken(storedToken, serviceConfig.UserAuthentication), nil +} + +// formatUserToken formats a stored token according to the user authentication configuration +func formatUserToken(storedToken *storage.StoredToken, auth *config.UserAuthentication) string { + if storedToken == nil { + return "" + } + + if storedToken.Type == storage.TokenTypeOAuth && storedToken.OAuthData != nil { + token := storedToken.OAuthData.AccessToken + if auth.TokenFormat != "" && auth.TokenFormat != "{{token}}" { + return strings.ReplaceAll(auth.TokenFormat, "{{token}}", token) + } + return token + } + + token := storedToken.Value + if auth != nil && auth.TokenFormat != "" && auth.TokenFormat != "{{token}}" { + return strings.ReplaceAll(auth.TokenFormat, "{{token}}", token) + } + return token +} diff --git a/internal/session/session.go b/internal/session/session.go new file mode 100644 index 0000000..226b160 --- /dev/null +++ b/internal/session/session.go @@ -0,0 +1,35 @@ +package session + +import ( + "time" + + "github.com/dgellow/mcp-front/internal/idp" + "github.com/ory/fosite" +) + +// OAuthSession extends DefaultSession with user information for the OAuth flow +type OAuthSession struct { + *fosite.DefaultSession + Identity idp.Identity `json:"identity"` +} + +// Clone implements fosite.Session +func (s *OAuthSession) Clone() fosite.Session { + return &OAuthSession{ + DefaultSession: s.DefaultSession.Clone().(*fosite.DefaultSession), + Identity: s.Identity, + } +} + +// BrowserCookie represents the data stored in encrypted browser session cookies +type BrowserCookie struct { + Email string `json:"email"` + Provider string `json:"provider"` // IDP that authenticated this user (e.g., "google", "azure", "github") + Expires time.Time `json:"expires"` +} + +// AuthorizationState represents the OAuth authorization code flow state parameter +type AuthorizationState struct { + Nonce string `json:"nonce"` + ReturnURL string `json:"return_url"` +} diff --git a/internal/browserauth/session_test.go b/internal/session/session_test.go similarity index 56% rename from internal/browserauth/session_test.go rename to internal/session/session_test.go index f7e15de..0294040 100644 --- a/internal/browserauth/session_test.go +++ b/internal/session/session_test.go @@ -1,4 +1,4 @@ -package browserauth +package session import ( "encoding/json" @@ -9,43 +9,42 @@ import ( "github.com/stretchr/testify/require" ) -func TestSessionCookie_MarshalUnmarshal(t *testing.T) { - original := SessionCookie{ - Email: "user@example.com", - Expires: time.Now().Add(24 * time.Hour).Truncate(time.Second), +func TestBrowserCookie_MarshalUnmarshal(t *testing.T) { + original := BrowserCookie{ + Email: "user@example.com", + Provider: "google", + Expires: time.Now().Add(24 * time.Hour).Truncate(time.Second), } - // Marshal to JSON data, err := json.Marshal(original) require.NoError(t, err) - // Unmarshal back - var unmarshaled SessionCookie + var unmarshaled BrowserCookie err = json.Unmarshal(data, &unmarshaled) require.NoError(t, err) - // Truncate for comparison (JSON time serialization) assert.Equal(t, original.Email, unmarshaled.Email) + assert.Equal(t, original.Provider, unmarshaled.Provider) assert.WithinDuration(t, original.Expires, unmarshaled.Expires, time.Second) } -func TestSessionCookie_Expiry(t *testing.T) { +func TestBrowserCookie_Expiry(t *testing.T) { t.Run("not expired", func(t *testing.T) { - session := SessionCookie{ - Email: "user@example.com", - Expires: time.Now().Add(1 * time.Hour), + s := BrowserCookie{ + Email: "user@example.com", + Provider: "google", + Expires: time.Now().Add(1 * time.Hour), } - - assert.True(t, session.Expires.After(time.Now())) + assert.True(t, s.Expires.After(time.Now())) }) t.Run("expired", func(t *testing.T) { - session := SessionCookie{ - Email: "user@example.com", - Expires: time.Now().Add(-1 * time.Hour), + s := BrowserCookie{ + Email: "user@example.com", + Provider: "google", + Expires: time.Now().Add(-1 * time.Hour), } - - assert.True(t, session.Expires.Before(time.Now())) + assert.True(t, s.Expires.Before(time.Now())) }) } @@ -55,11 +54,9 @@ func TestAuthorizationState_MarshalUnmarshal(t *testing.T) { ReturnURL: "/my/tokens", } - // Marshal to JSON data, err := json.Marshal(original) require.NoError(t, err) - // Unmarshal back var unmarshaled AuthorizationState err = json.Unmarshal(data, &unmarshaled) require.NoError(t, err) diff --git a/internal/storage/firestore.go b/internal/storage/firestore.go index 1311ee6..0b49554 100644 --- a/internal/storage/firestore.go +++ b/internal/storage/firestore.go @@ -16,6 +16,11 @@ import ( "google.golang.org/grpc/status" ) +const ( + usersCollection = "mcp_front_users" + sessionsCollection = "mcp_front_sessions" +) + // FirestoreStorage implements OAuth client storage using Google Cloud Firestore. // // Error handling strategy: @@ -195,8 +200,9 @@ func (s *FirestoreStorage) loadClientsFromFirestore(ctx context.Context) error { } // StoreAuthorizeRequest stores an authorize request with state (in memory only - short-lived) -func (s *FirestoreStorage) StoreAuthorizeRequest(state string, req fosite.AuthorizeRequester) { +func (s *FirestoreStorage) StoreAuthorizeRequest(state string, req fosite.AuthorizeRequester) error { s.stateCache.Store(state, req) + return nil } // GetAuthorizeRequest retrieves an authorize request by state (one-time use) @@ -533,7 +539,7 @@ func (s *FirestoreStorage) UpsertUser(ctx context.Context, email string) error { } // Try to get existing user first - doc, err := s.client.Collection("mcp_front_users").Doc(email).Get(ctx) + doc, err := s.client.Collection(usersCollection).Doc(email).Get(ctx) if err == nil { // User exists, update LastSeen _, err = doc.Ref.Update(ctx, []firestore.Update{ @@ -547,16 +553,35 @@ func (s *FirestoreStorage) UpsertUser(ctx context.Context, email string) error { userDoc.FirstSeen = time.Now() userDoc.Enabled = true userDoc.IsAdmin = false - _, err = s.client.Collection("mcp_front_users").Doc(email).Set(ctx, userDoc) + _, err = s.client.Collection(usersCollection).Doc(email).Set(ctx, userDoc) return err } return err } +// GetUser returns a single user by email +func (s *FirestoreStorage) GetUser(ctx context.Context, email string) (*UserInfo, error) { + doc, err := s.client.Collection(usersCollection).Doc(email).Get(ctx) + if err != nil { + if status.Code(err) == codes.NotFound { + return nil, ErrUserNotFound + } + return nil, fmt.Errorf("failed to get user from Firestore: %w", err) + } + + var userDoc UserDoc + if err := doc.DataTo(&userDoc); err != nil { + return nil, fmt.Errorf("failed to unmarshal user: %w", err) + } + + user := UserInfo(userDoc) + return &user, nil +} + // GetAllUsers returns all users func (s *FirestoreStorage) GetAllUsers(ctx context.Context) ([]UserInfo, error) { - iter := s.client.Collection("mcp_front_users").Documents(ctx) + iter := s.client.Collection(usersCollection).Documents(ctx) defer iter.Stop() var users []UserInfo @@ -583,7 +608,7 @@ func (s *FirestoreStorage) GetAllUsers(ctx context.Context) ([]UserInfo, error) // UpdateUserStatus updates a user's enabled status func (s *FirestoreStorage) UpdateUserStatus(ctx context.Context, email string, enabled bool) error { - _, err := s.client.Collection("mcp_front_users").Doc(email).Update(ctx, []firestore.Update{ + _, err := s.client.Collection(usersCollection).Doc(email).Update(ctx, []firestore.Update{ {Path: "enabled", Value: enabled}, }) if status.Code(err) == codes.NotFound { @@ -595,7 +620,7 @@ func (s *FirestoreStorage) UpdateUserStatus(ctx context.Context, email string, e // DeleteUser removes a user from storage func (s *FirestoreStorage) DeleteUser(ctx context.Context, email string) error { // Delete user document - _, err := s.client.Collection("mcp_front_users").Doc(email).Delete(ctx) + _, err := s.client.Collection(usersCollection).Doc(email).Delete(ctx) if err != nil && status.Code(err) != codes.NotFound { return err } @@ -625,7 +650,7 @@ func (s *FirestoreStorage) DeleteUser(ctx context.Context, email string) error { // SetUserAdmin updates a user's admin status func (s *FirestoreStorage) SetUserAdmin(ctx context.Context, email string, isAdmin bool) error { - _, err := s.client.Collection("mcp_front_users").Doc(email).Update(ctx, []firestore.Update{ + _, err := s.client.Collection(usersCollection).Doc(email).Update(ctx, []firestore.Update{ {Path: "is_admin", Value: isAdmin}, }) if status.Code(err) == codes.NotFound { @@ -645,7 +670,7 @@ func (s *FirestoreStorage) TrackSession(ctx context.Context, session ActiveSessi } // Check if session exists - doc, err := s.client.Collection("mcp_front_sessions").Doc(session.SessionID).Get(ctx) + doc, err := s.client.Collection(sessionsCollection).Doc(session.SessionID).Get(ctx) if err == nil { // Session exists, update LastActive _, err = doc.Ref.Update(ctx, []firestore.Update{ @@ -657,7 +682,7 @@ func (s *FirestoreStorage) TrackSession(ctx context.Context, session ActiveSessi // Session doesn't exist, create new if status.Code(err) == codes.NotFound { sessionDoc.Created = time.Now() - _, err = s.client.Collection("mcp_front_sessions").Doc(session.SessionID).Set(ctx, sessionDoc) + _, err = s.client.Collection(sessionsCollection).Doc(session.SessionID).Set(ctx, sessionDoc) return err } @@ -666,7 +691,7 @@ func (s *FirestoreStorage) TrackSession(ctx context.Context, session ActiveSessi // GetActiveSessions returns all active sessions func (s *FirestoreStorage) GetActiveSessions(ctx context.Context) ([]ActiveSession, error) { - iter := s.client.Collection("mcp_front_sessions").Documents(ctx) + iter := s.client.Collection(sessionsCollection).Documents(ctx) defer iter.Stop() var sessions []ActiveSession @@ -693,7 +718,7 @@ func (s *FirestoreStorage) GetActiveSessions(ctx context.Context) ([]ActiveSessi // RevokeSession removes a session func (s *FirestoreStorage) RevokeSession(ctx context.Context, sessionID string) error { - _, err := s.client.Collection("mcp_front_sessions").Doc(sessionID).Delete(ctx) + _, err := s.client.Collection(sessionsCollection).Doc(sessionID).Delete(ctx) if err != nil && status.Code(err) != codes.NotFound { return err } diff --git a/internal/storage/memory.go b/internal/storage/memory.go index 6a1dc79..63c2c54 100644 --- a/internal/storage/memory.go +++ b/internal/storage/memory.go @@ -43,8 +43,9 @@ func NewMemoryStorage() *MemoryStorage { } // StoreAuthorizeRequest stores an authorize request with state -func (s *MemoryStorage) StoreAuthorizeRequest(state string, req fosite.AuthorizeRequester) { +func (s *MemoryStorage) StoreAuthorizeRequest(state string, req fosite.AuthorizeRequester) error { s.stateCache.Store(state, req) + return nil } // GetAuthorizeRequest retrieves an authorize request by state (one-time use) @@ -204,6 +205,19 @@ func (s *MemoryStorage) UpsertUser(ctx context.Context, email string) error { return nil } +// GetUser returns a single user by email +func (s *MemoryStorage) GetUser(ctx context.Context, email string) (*UserInfo, error) { + s.usersMutex.RLock() + defer s.usersMutex.RUnlock() + + user, exists := s.users[email] + if !exists { + return nil, ErrUserNotFound + } + userCopy := *user + return &userCopy, nil +} + // GetAllUsers returns all users func (s *MemoryStorage) GetAllUsers(ctx context.Context) ([]UserInfo, error) { s.usersMutex.RLock() diff --git a/internal/storage/storage.go b/internal/storage/storage.go index 1f9a780..e44cb02 100644 --- a/internal/storage/storage.go +++ b/internal/storage/storage.go @@ -76,7 +76,7 @@ type Storage interface { fosite.Storage // OAuth state management - StoreAuthorizeRequest(state string, req fosite.AuthorizeRequester) + StoreAuthorizeRequest(state string, req fosite.AuthorizeRequester) error GetAuthorizeRequest(state string) (fosite.AuthorizeRequester, bool) // OAuth client management @@ -89,6 +89,7 @@ type Storage interface { // User tracking (upserted when users access MCP endpoints) UpsertUser(ctx context.Context, email string) error + GetUser(ctx context.Context, email string) (*UserInfo, error) GetAllUsers(ctx context.Context) ([]UserInfo, error) UpdateUserStatus(ctx context.Context, email string, enabled bool) error DeleteUser(ctx context.Context, email string) error