diff --git a/CLAUDE.md b/CLAUDE.md
index ec5900f..e0c17dd 100644
--- a/CLAUDE.md
+++ b/CLAUDE.md
@@ -245,15 +245,16 @@ When refactoring for better design:
```bash
# Build everything
-make build
+./scripts/build
# Format everything
-make format
+./scripts/format
# Lint everything
-make lint
+./scripts/lint
# Test mcp-front
+./scripts/test
go test ./internal/... -v
go test ./integration -v
@@ -261,7 +262,7 @@ go test ./integration -v
./mcp-front -config config.json
# Start docs dev server
-make doc
+./scripts/docs
```
## Documentation Site Guidelines
diff --git a/cmd/mcp-front/main.go b/cmd/mcp-front/main.go
index 7fd26a3..2cdb86d 100644
--- a/cmd/mcp-front/main.go
+++ b/cmd/mcp-front/main.go
@@ -22,17 +22,20 @@ func generateDefaultConfig(path string) error {
"addr": ":8080",
"name": "mcp-front",
"auth": map[string]any{
- "kind": "oauth",
- "issuer": "https://mcp.yourcompany.com",
- "allowedDomains": []string{"yourcompany.com"},
- "allowedOrigins": []string{"https://claude.ai"},
- "tokenTtl": "24h",
- "storage": "memory",
- "googleClientId": map[string]string{"$env": "GOOGLE_CLIENT_ID"},
- "googleClientSecret": map[string]string{"$env": "GOOGLE_CLIENT_SECRET"},
- "googleRedirectUri": "https://mcp.yourcompany.com/oauth/callback",
- "jwtSecret": map[string]string{"$env": "JWT_SECRET"},
- "encryptionKey": map[string]string{"$env": "ENCRYPTION_KEY"},
+ "kind": "oauth",
+ "issuer": "https://mcp.yourcompany.com",
+ "allowedDomains": []string{"yourcompany.com"},
+ "allowedOrigins": []string{"https://claude.ai"},
+ "tokenTtl": "24h",
+ "storage": "memory",
+ "idp": map[string]any{
+ "provider": "google",
+ "clientId": map[string]string{"$env": "GOOGLE_CLIENT_ID"},
+ "clientSecret": map[string]string{"$env": "GOOGLE_CLIENT_SECRET"},
+ "redirectUri": "https://mcp.yourcompany.com/oauth/callback",
+ },
+ "jwtSecret": map[string]string{"$env": "JWT_SECRET"},
+ "encryptionKey": map[string]string{"$env": "ENCRYPTION_KEY"},
},
},
"mcpServers": map[string]any{
@@ -40,9 +43,21 @@ func generateDefaultConfig(path string) error {
"transportType": "stdio",
"command": "docker",
"args": []any{
- "run", "--rm", "-i",
- "mcp/postgres:latest",
- map[string]string{"$env": "POSTGRES_URL"},
+ "run", "--rm", "-i", "--network", "host",
+ "-e", "POSTGRES_HOST",
+ "-e", "POSTGRES_PORT",
+ "-e", "POSTGRES_DATABASE",
+ "-e", "POSTGRES_USER",
+ "-e", "POSTGRES_PASSWORD",
+ "us-central1-docker.pkg.dev/database-toolbox/toolbox/toolbox:latest",
+ "--stdio", "--prebuilt", "postgres",
+ },
+ "env": map[string]any{
+ "POSTGRES_HOST": map[string]string{"$env": "POSTGRES_HOST"},
+ "POSTGRES_PORT": map[string]string{"$env": "POSTGRES_PORT"},
+ "POSTGRES_DATABASE": map[string]string{"$env": "POSTGRES_DATABASE"},
+ "POSTGRES_USER": map[string]string{"$env": "POSTGRES_USER"},
+ "POSTGRES_PASSWORD": map[string]string{"$env": "POSTGRES_PASSWORD"},
},
},
},
diff --git a/config-admin-example.json b/config-admin-example.json
index e0fd169..b35934c 100644
--- a/config-admin-example.json
+++ b/config-admin-example.json
@@ -7,13 +7,16 @@
"auth": {
"kind": "oauth",
"issuer": "https://mcp.example.com",
+ "idp": {
+ "provider": "google",
+ "clientId": {"$env": "GOOGLE_CLIENT_ID"},
+ "clientSecret": {"$env": "GOOGLE_CLIENT_SECRET"},
+ "redirectUri": "https://mcp.example.com/oauth/callback"
+ },
"allowedDomains": ["example.com"],
"allowedOrigins": ["https://claude.ai"],
"tokenTtl": "1h",
"storage": "memory",
- "googleClientId": {"$env": "GOOGLE_CLIENT_ID"},
- "googleClientSecret": {"$env": "GOOGLE_CLIENT_SECRET"},
- "googleRedirectUri": "https://mcp.example.com/oauth/callback",
"jwtSecret": {"$env": "JWT_SECRET"},
"encryptionKey": {"$env": "ENCRYPTION_KEY"}
},
diff --git a/config-inline-example.json b/config-inline-example.json
index 44038cd..cfa2466 100644
--- a/config-inline-example.json
+++ b/config-inline-example.json
@@ -7,13 +7,16 @@
"auth": {
"kind": "oauth",
"issuer": "https://mcp.example.com",
+ "idp": {
+ "provider": "google",
+ "clientId": {"$env": "GOOGLE_CLIENT_ID"},
+ "clientSecret": {"$env": "GOOGLE_CLIENT_SECRET"},
+ "redirectUri": "https://mcp.example.com/oauth/callback"
+ },
"allowedDomains": ["example.com"],
"allowedOrigins": ["https://claude.ai"],
"tokenTtl": "1h",
"storage": "memory",
- "googleClientId": {"$env": "GOOGLE_CLIENT_ID"},
- "googleClientSecret": {"$env": "GOOGLE_CLIENT_SECRET"},
- "googleRedirectUri": "https://mcp.example.com/oauth/callback",
"jwtSecret": {"$env": "JWT_SECRET"},
"encryptionKey": {"$env": "ENCRYPTION_KEY"}
}
diff --git a/config-inline-test.json b/config-inline-test.json
index 1d005ba..9f7ac5a 100644
--- a/config-inline-test.json
+++ b/config-inline-test.json
@@ -7,13 +7,16 @@
"auth": {
"kind": "oauth",
"issuer": "http://localhost:8080",
+ "idp": {
+ "provider": "google",
+ "clientId": {"$env": "GOOGLE_CLIENT_ID"},
+ "clientSecret": {"$env": "GOOGLE_CLIENT_SECRET"},
+ "redirectUri": "http://localhost:8080/oauth/callback"
+ },
"allowedDomains": ["gmail.com"],
"allowedOrigins": ["https://claude.ai"],
"tokenTtl": "1h",
"storage": "memory",
- "googleClientId": {"$env": "GOOGLE_CLIENT_ID"},
- "googleClientSecret": {"$env": "GOOGLE_CLIENT_SECRET"},
- "googleRedirectUri": "http://localhost:8080/oauth/callback",
"jwtSecret": {"$env": "JWT_SECRET"},
"encryptionKey": {"$env": "ENCRYPTION_KEY"}
}
diff --git a/config-oauth-firestore.example.json b/config-oauth-firestore.example.json
index 72697a6..aaa7d4d 100644
--- a/config-oauth-firestore.example.json
+++ b/config-oauth-firestore.example.json
@@ -8,13 +8,16 @@
"kind": "oauth",
"issuer": {"$env": "OAUTH_ISSUER"},
"gcpProject": {"$env": "GCP_PROJECT"},
+ "idp": {
+ "provider": "google",
+ "clientId": {"$env": "GOOGLE_CLIENT_ID"},
+ "clientSecret": {"$env": "GOOGLE_CLIENT_SECRET"},
+ "redirectUri": {"$env": "GOOGLE_REDIRECT_URI"}
+ },
"allowedDomains": ["yourcompany.com", "contractors.yourcompany.com"],
"allowedOrigins": ["https://claude.ai", "https://yourcompany.com"],
"tokenTtl": "24h",
"storage": "firestore",
- "googleClientId": {"$env": "GOOGLE_CLIENT_ID"},
- "googleClientSecret": {"$env": "GOOGLE_CLIENT_SECRET"},
- "googleRedirectUri": {"$env": "GOOGLE_REDIRECT_URI"},
"jwtSecret": {"$env": "JWT_SECRET"},
"encryptionKey": {"$env": "ENCRYPTION_KEY"}
}
@@ -25,11 +28,20 @@
"command": "docker",
"args": [
"run", "--rm", "-i", "--network", "host",
- "mcp/postgres",
- {"$env": "DATABASE_URL"}
+ "-e", "POSTGRES_HOST",
+ "-e", "POSTGRES_PORT",
+ "-e", "POSTGRES_DATABASE",
+ "-e", "POSTGRES_USER",
+ "-e", "POSTGRES_PASSWORD",
+ "us-central1-docker.pkg.dev/database-toolbox/toolbox/toolbox:latest",
+ "--stdio", "--prebuilt", "postgres"
],
"env": {
- "PGPASSWORD": {"$env": "POSTGRES_PASSWORD"}
+ "POSTGRES_HOST": {"$env": "POSTGRES_HOST"},
+ "POSTGRES_PORT": {"$env": "POSTGRES_PORT"},
+ "POSTGRES_DATABASE": {"$env": "POSTGRES_DATABASE"},
+ "POSTGRES_USER": {"$env": "POSTGRES_USER"},
+ "POSTGRES_PASSWORD": {"$env": "POSTGRES_PASSWORD"}
}
},
"notion": {
diff --git a/config-oauth.example.json b/config-oauth.example.json
index 8b2b1ff..80b522d 100644
--- a/config-oauth.example.json
+++ b/config-oauth.example.json
@@ -8,13 +8,16 @@
"kind": "oauth",
"issuer": {"$env": "OAUTH_ISSUER"},
"gcpProject": {"$env": "GCP_PROJECT"},
+ "idp": {
+ "provider": "google",
+ "clientId": {"$env": "GOOGLE_CLIENT_ID"},
+ "clientSecret": {"$env": "GOOGLE_CLIENT_SECRET"},
+ "redirectUri": {"$env": "GOOGLE_REDIRECT_URI"}
+ },
"allowedDomains": ["yourcompany.com", "contractors.yourcompany.com"],
"allowedOrigins": ["https://claude.ai", "https://yourcompany.com"],
"tokenTtl": "24h",
"storage": "memory",
- "googleClientId": {"$env": "GOOGLE_CLIENT_ID"},
- "googleClientSecret": {"$env": "GOOGLE_CLIENT_SECRET"},
- "googleRedirectUri": {"$env": "GOOGLE_REDIRECT_URI"},
"jwtSecret": {"$env": "JWT_SECRET"},
"encryptionKey": {"$env": "ENCRYPTION_KEY"}
}
@@ -25,11 +28,20 @@
"command": "docker",
"args": [
"run", "--rm", "-i", "--network", "host",
- "mcp/postgres",
- {"$env": "DATABASE_URL"}
+ "-e", "POSTGRES_HOST",
+ "-e", "POSTGRES_PORT",
+ "-e", "POSTGRES_DATABASE",
+ "-e", "POSTGRES_USER",
+ "-e", "POSTGRES_PASSWORD",
+ "us-central1-docker.pkg.dev/database-toolbox/toolbox/toolbox:latest",
+ "--stdio", "--prebuilt", "postgres"
],
"env": {
- "PGPASSWORD": {"$env": "POSTGRES_PASSWORD"}
+ "POSTGRES_HOST": {"$env": "POSTGRES_HOST"},
+ "POSTGRES_PORT": {"$env": "POSTGRES_PORT"},
+ "POSTGRES_DATABASE": {"$env": "POSTGRES_DATABASE"},
+ "POSTGRES_USER": {"$env": "POSTGRES_USER"},
+ "POSTGRES_PASSWORD": {"$env": "POSTGRES_PASSWORD"}
}
},
"notion": {
diff --git a/config-oauth.json b/config-oauth.json
index dec7a3d..a49ba37 100644
--- a/config-oauth.json
+++ b/config-oauth.json
@@ -8,13 +8,16 @@
"kind": "oauth",
"issuer": "https://mcp-internal.yourcompany.org",
"gcpProject": {"$env": "GCP_PROJECT"},
+ "idp": {
+ "provider": "google",
+ "clientId": {"$env": "GOOGLE_CLIENT_ID"},
+ "clientSecret": {"$env": "GOOGLE_CLIENT_SECRET"},
+ "redirectUri": "https://mcp-internal.yourcompany.org/oauth/callback"
+ },
"allowedDomains": ["yourcompany.com"],
"allowedOrigins": ["https://claude.ai"],
"tokenTtl": "24h",
"storage": "memory",
- "googleClientId": {"$env": "GOOGLE_CLIENT_ID"},
- "googleClientSecret": {"$env": "GOOGLE_CLIENT_SECRET"},
- "googleRedirectUri": "https://mcp-internal.yourcompany.org/oauth/callback",
"jwtSecret": {"$env": "JWT_SECRET"},
"encryptionKey": {"$env": "ENCRYPTION_KEY"}
}
@@ -22,12 +25,24 @@
"mcpServers": {
"postgres": {
"transportType": "stdio",
- "command": "docker",
+ "command": "docker",
"args": [
- "run", "--rm", "-i",
- "mcp/postgres:latest",
- "postgresql://user:password@localhost:5432/database"
- ]
+ "run", "--rm", "-i", "--network", "host",
+ "-e", "POSTGRES_HOST",
+ "-e", "POSTGRES_PORT",
+ "-e", "POSTGRES_DATABASE",
+ "-e", "POSTGRES_USER",
+ "-e", "POSTGRES_PASSWORD",
+ "us-central1-docker.pkg.dev/database-toolbox/toolbox/toolbox:latest",
+ "--stdio", "--prebuilt", "postgres"
+ ],
+ "env": {
+ "POSTGRES_HOST": {"$env": "POSTGRES_HOST"},
+ "POSTGRES_PORT": {"$env": "POSTGRES_PORT"},
+ "POSTGRES_DATABASE": {"$env": "POSTGRES_DATABASE"},
+ "POSTGRES_USER": {"$env": "POSTGRES_USER"},
+ "POSTGRES_PASSWORD": {"$env": "POSTGRES_PASSWORD"}
+ }
},
"notion": {
"transportType": "stdio",
diff --git a/config-token.example.json b/config-token.example.json
index b13702e..fc07d09 100644
--- a/config-token.example.json
+++ b/config-token.example.json
@@ -11,8 +11,13 @@
"command": "docker",
"args": [
"run", "--rm", "-i", "--network", "host",
- "mcp/postgres",
- "postgresql://testuser:testpass@localhost:5432/testdb"
+ "-e", "POSTGRES_HOST=localhost",
+ "-e", "POSTGRES_PORT=5432",
+ "-e", "POSTGRES_DATABASE=testdb",
+ "-e", "POSTGRES_USER=testuser",
+ "-e", "POSTGRES_PASSWORD=testpass",
+ "us-central1-docker.pkg.dev/database-toolbox/toolbox/toolbox:latest",
+ "--stdio", "--prebuilt", "postgres"
],
"serviceAuths": [
{
diff --git a/config-user-tokens-example.json b/config-user-tokens-example.json
index bd93f36..8abab60 100644
--- a/config-user-tokens-example.json
+++ b/config-user-tokens-example.json
@@ -8,13 +8,16 @@
"kind": "oauth",
"issuer": {"$env": "OAUTH_ISSUER"},
"gcpProject": {"$env": "GCP_PROJECT"},
+ "idp": {
+ "provider": "google",
+ "clientId": {"$env": "GOOGLE_CLIENT_ID"},
+ "clientSecret": {"$env": "GOOGLE_CLIENT_SECRET"},
+ "redirectUri": {"$env": "GOOGLE_REDIRECT_URI"}
+ },
"allowedDomains": ["yourcompany.com"],
"allowedOrigins": ["https://claude.ai"],
"tokenTtl": "24h",
"storage": "memory",
- "googleClientId": {"$env": "GOOGLE_CLIENT_ID"},
- "googleClientSecret": {"$env": "GOOGLE_CLIENT_SECRET"},
- "googleRedirectUri": {"$env": "GOOGLE_REDIRECT_URI"},
"jwtSecret": {"$env": "JWT_SECRET"},
"encryptionKey": {"$env": "ENCRYPTION_KEY"}
}
@@ -57,11 +60,20 @@
"command": "docker",
"args": [
"run", "--rm", "-i", "--network", "host",
- "mcp/postgres",
- {"$env": "DATABASE_URL"}
+ "-e", "POSTGRES_HOST",
+ "-e", "POSTGRES_PORT",
+ "-e", "POSTGRES_DATABASE",
+ "-e", "POSTGRES_USER",
+ "-e", "POSTGRES_PASSWORD",
+ "us-central1-docker.pkg.dev/database-toolbox/toolbox/toolbox:latest",
+ "--stdio", "--prebuilt", "postgres"
],
"env": {
- "PGPASSWORD": {"$env": "POSTGRES_PASSWORD"}
+ "POSTGRES_HOST": {"$env": "POSTGRES_HOST"},
+ "POSTGRES_PORT": {"$env": "POSTGRES_PORT"},
+ "POSTGRES_DATABASE": {"$env": "POSTGRES_DATABASE"},
+ "POSTGRES_USER": {"$env": "POSTGRES_USER"},
+ "POSTGRES_PASSWORD": {"$env": "POSTGRES_PASSWORD"}
}
}
}
diff --git a/docs-site/src/content/docs/configuration.md b/docs-site/src/content/docs/configuration.md
index 1351707..7a2f908 100644
--- a/docs-site/src/content/docs/configuration.md
+++ b/docs-site/src/content/docs/configuration.md
@@ -49,13 +49,16 @@ For production, use OAuth with Google. Claude redirects users to Google for auth
"auth": {
"kind": "oauth",
"issuer": "https://mcp.company.com",
+ "idp": {
+ "provider": "google",
+ "clientId": { "$env": "GOOGLE_CLIENT_ID" },
+ "clientSecret": { "$env": "GOOGLE_CLIENT_SECRET" },
+ "redirectUri": "https://mcp.company.com/oauth/callback"
+ },
"allowedDomains": ["company.com"],
"allowedOrigins": ["https://claude.ai"],
"tokenTtl": "24h",
"storage": "memory",
- "googleClientId": { "$env": "GOOGLE_CLIENT_ID" },
- "googleClientSecret": { "$env": "GOOGLE_CLIENT_SECRET" },
- "googleRedirectUri": "https://mcp.company.com/oauth/callback",
"jwtSecret": { "$env": "JWT_SECRET" },
"encryptionKey": { "$env": "ENCRYPTION_KEY" }
}
@@ -66,7 +69,7 @@ The `issuer` should match your `baseURL`. `allowedDomains` restricts access to s
`tokenTtl` controls how long JWT tokens are valid. Shorter times are more secure but require more frequent logins.
-Security requirements: `googleClientSecret`, `jwtSecret`, and `encryptionKey` must be environment variables. The JWT secret must be at least 32 bytes. The encryption key must be exactly 32 bytes.
+Security requirements: `idp.clientSecret`, `jwtSecret`, and `encryptionKey` must be environment variables. The JWT secret must be at least 32 bytes. The encryption key must be exactly 32 bytes.
For production, set `storage` to "firestore" and add `gcpProject`, `firestoreDatabase`, and `firestoreCollection` fields.
@@ -345,13 +348,16 @@ Set `MCP_FRONT_ENV=development` when testing OAuth locally. It allows http:// UR
"auth": {
"kind": "oauth",
"issuer": "https://mcp.company.com",
+ "idp": {
+ "provider": "google",
+ "clientId": { "$env": "GOOGLE_CLIENT_ID" },
+ "clientSecret": { "$env": "GOOGLE_CLIENT_SECRET" },
+ "redirectUri": "https://mcp.company.com/oauth/callback"
+ },
"allowedDomains": ["company.com"],
"allowedOrigins": ["https://claude.ai"],
"tokenTtl": "4h",
"storage": "firestore",
- "googleClientId": { "$env": "GOOGLE_CLIENT_ID" },
- "googleClientSecret": { "$env": "GOOGLE_CLIENT_SECRET" },
- "googleRedirectUri": "https://mcp.company.com/oauth/callback",
"jwtSecret": { "$env": "JWT_SECRET" },
"encryptionKey": { "$env": "ENCRYPTION_KEY" },
"gcpProject": { "$env": "GOOGLE_CLOUD_PROJECT" },
diff --git a/docs-site/src/content/docs/examples/oauth-google.md b/docs-site/src/content/docs/examples/oauth-google.md
index d4598eb..d35aea1 100644
--- a/docs-site/src/content/docs/examples/oauth-google.md
+++ b/docs-site/src/content/docs/examples/oauth-google.md
@@ -39,15 +39,26 @@ Create `config.json`:
```json
{
- "version": "1.0",
+ "version": "v0.0.1-DEV_EDITION_EXPECT_CHANGES",
"proxy": {
"name": "Company MCP Proxy",
- "baseUrl": "https://mcp.company.com",
+ "baseURL": "https://mcp.company.com",
"addr": ":8080",
"auth": {
"kind": "oauth",
"issuer": "https://mcp.company.com",
- "allowedDomains": ["company.com"]
+ "idp": {
+ "provider": "google",
+ "clientId": { "$env": "GOOGLE_CLIENT_ID" },
+ "clientSecret": { "$env": "GOOGLE_CLIENT_SECRET" },
+ "redirectUri": "https://mcp.company.com/oauth/callback"
+ },
+ "allowedDomains": ["company.com"],
+ "allowedOrigins": ["https://claude.ai"],
+ "tokenTtl": "24h",
+ "storage": "memory",
+ "jwtSecret": { "$env": "JWT_SECRET" },
+ "encryptionKey": { "$env": "ENCRYPTION_KEY" }
}
},
"mcpServers": {
@@ -80,6 +91,7 @@ Create `config.json`:
export GOOGLE_CLIENT_ID="your-client-id.apps.googleusercontent.com"
export GOOGLE_CLIENT_SECRET="your-client-secret"
export JWT_SECRET=$(openssl rand -base64 32)
+export ENCRYPTION_KEY=$(openssl rand -base64 32)
```
## 4. Run with Docker
@@ -89,6 +101,7 @@ docker run -p 8080:8080 \
-e GOOGLE_CLIENT_ID \
-e GOOGLE_CLIENT_SECRET \
-e JWT_SECRET \
+ -e ENCRYPTION_KEY \
-v $(pwd)/config.json:/config.json \
ghcr.io/dgellow/mcp-front:latest
```
diff --git a/docs-site/src/content/docs/index.mdx b/docs-site/src/content/docs/index.mdx
index c7275ee..f1d79fb 100644
--- a/docs-site/src/content/docs/index.mdx
+++ b/docs-site/src/content/docs/index.mdx
@@ -41,9 +41,13 @@ Claude redirects users to Google for authentication, and MCP Front validates the
"auth": {
"kind": "oauth",
"issuer": "https://mcp.company.com",
+ "idp": {
+ "provider": "google",
+ "clientId": { "$env": "GOOGLE_CLIENT_ID" },
+ "clientSecret": { "$env": "GOOGLE_CLIENT_SECRET" },
+ "redirectUri": "https://mcp.company.com/oauth/callback"
+ },
"allowedDomains": ["company.com"],
- "googleClientId": { "$env": "GOOGLE_CLIENT_ID" },
- "googleClientSecret": { "$env": "GOOGLE_CLIENT_SECRET" },
"jwtSecret": { "$env": "JWT_SECRET" },
"encryptionKey": { "$env": "ENCRYPTION_KEY" }
}
diff --git a/integration/base_path_test.go b/integration/base_path_test.go
index 88b8e13..ca0ec3c 100644
--- a/integration/base_path_test.go
+++ b/integration/base_path_test.go
@@ -11,7 +11,12 @@ import (
func TestBasePathRouting(t *testing.T) {
waitForDB(t)
- startMCPFront(t, "config/config.base-path-test.json")
+ cfg := buildTestConfig(
+ "http://localhost:8080/mcp-api", "mcp-front-base-path-test",
+ nil,
+ map[string]any{"postgres": testPostgresServer(withBearerTokens("test-token"))},
+ )
+ startMCPFront(t, writeTestConfig(t, cfg))
waitForMCPFront(t)
initialContainers := getMCPContainers()
diff --git a/integration/basic_auth_test.go b/integration/basic_auth_test.go
index 1f31a91..4368f1a 100644
--- a/integration/basic_auth_test.go
+++ b/integration/basic_auth_test.go
@@ -10,8 +10,11 @@ import (
)
func TestBasicAuth(t *testing.T) {
- // Start mcp-front with basic auth config
- startMCPFront(t, "config/config.basic-auth-test.json",
+ cfg := buildTestConfig("http://localhost:8080", "mcp-front-basic-auth-test",
+ nil,
+ map[string]any{"postgres": testPostgresServer(withBasicAuth("admin", "ADMIN_PASSWORD"), withBasicAuth("user", "USER_PASSWORD"))},
+ )
+ startMCPFront(t, writeTestConfig(t, cfg),
"ADMIN_PASSWORD=adminpass123",
"USER_PASSWORD=userpass456",
)
diff --git a/integration/config/config.base-path-test.json b/integration/config/config.base-path-test.json
deleted file mode 100644
index 5b60230..0000000
--- a/integration/config/config.base-path-test.json
+++ /dev/null
@@ -1,28 +0,0 @@
-{
- "version": "v0.0.1-DEV_EDITION_EXPECT_CHANGES",
- "proxy": {
- "baseURL": "http://localhost:8080/mcp-api",
- "addr": ":8080",
- "name": "mcp-front-base-path-test"
- },
- "mcpServers": {
- "postgres": {
- "transportType": "stdio",
- "command": "docker",
- "args": [
- "run",
- "-i",
- "--network",
- "host",
- "mcp/postgres",
- "postgresql://testuser:testpass@localhost:15432/testdb"
- ],
- "serviceAuths": [
- {
- "type": "bearer",
- "tokens": ["test-token"]
- }
- ]
- }
- }
-}
diff --git a/integration/config/config.basic-auth-test.json b/integration/config/config.basic-auth-test.json
deleted file mode 100644
index 8ef2e53..0000000
--- a/integration/config/config.basic-auth-test.json
+++ /dev/null
@@ -1,34 +0,0 @@
-{
- "version": "v0.0.1-DEV_EDITION_EXPECT_CHANGES",
- "proxy": {
- "baseURL": "http://localhost:8080",
- "addr": ":8080",
- "name": "mcp-front-basic-auth-test"
- },
- "mcpServers": {
- "postgres": {
- "transportType": "stdio",
- "command": "docker",
- "args": [
- "run",
- "-i",
- "--network",
- "host",
- "mcp/postgres",
- "postgresql://testuser:testpass@localhost:15432/testdb"
- ],
- "serviceAuths": [
- {
- "type": "basic",
- "username": "admin",
- "password": {"$env": "ADMIN_PASSWORD"}
- },
- {
- "type": "basic",
- "username": "user",
- "password": {"$env": "USER_PASSWORD"}
- }
- ]
- }
- }
-}
\ No newline at end of file
diff --git a/integration/config/config.demo-token.json b/integration/config/config.demo-token.json
index a9b07b5..7610c11 100644
--- a/integration/config/config.demo-token.json
+++ b/integration/config/config.demo-token.json
@@ -11,8 +11,13 @@
"command": "docker",
"args": [
"run", "--rm", "-i", "--network", "host",
- "mcp/postgres",
- "postgresql://testuser:testpass@localhost:15432/testdb"
+ "-e", "POSTGRES_HOST=localhost",
+ "-e", "POSTGRES_PORT=15432",
+ "-e", "POSTGRES_DATABASE=testdb",
+ "-e", "POSTGRES_USER=testuser",
+ "-e", "POSTGRES_PASSWORD=testpass",
+ "us-central1-docker.pkg.dev/database-toolbox/toolbox/toolbox:latest",
+ "--stdio", "--prebuilt", "postgres"
],
"options": {
"logEnabled": true
diff --git a/integration/config/config.oauth-integration-test.json b/integration/config/config.oauth-integration-test.json
deleted file mode 100644
index 0a40a0f..0000000
--- a/integration/config/config.oauth-integration-test.json
+++ /dev/null
@@ -1,41 +0,0 @@
-{
- "version": "v0.0.1-DEV_EDITION_EXPECT_CHANGES",
- "proxy": {
- "baseURL": "http://localhost:8080",
- "addr": ":8080",
- "name": "mcp-front-oauth-test",
- "auth": {
- "kind": "oauth",
- "issuer": "http://localhost:8080",
- "gcpProject": "test-project",
- "allowedDomains": [
- "test.com"
- ],
- "allowedOrigins": [
- "https://claude.ai"
- ],
- "tokenTtl": "1h",
- "storage": "memory",
- "googleClientId": "test-client-id",
- "googleClientSecret": "test-client-secret-for-integration-testing",
- "googleRedirectUri": "http://localhost:8080/oauth/callback",
- "jwtSecret": "test-jwt-secret-for-integration-testing-32-chars-long",
- "encryptionKey": "test-encryption-key-32-bytes-aes"
- }
- },
- "mcpServers": {
- "postgres": {
- "transportType": "stdio",
- "command": "docker",
- "args": [
- "run",
- "--rm",
- "-i",
- "--network",
- "host",
- "mcp/postgres",
- "postgresql://testuser:testpass@localhost:15432/testdb"
- ]
- }
- }
-}
\ No newline at end of file
diff --git a/integration/config/config.oauth-rfc8707-test.json b/integration/config/config.oauth-rfc8707-test.json
index f169a64..75d1162 100644
--- a/integration/config/config.oauth-rfc8707-test.json
+++ b/integration/config/config.oauth-rfc8707-test.json
@@ -8,13 +8,19 @@
"kind": "oauth",
"issuer": "http://localhost:8080",
"gcpProject": "test-project",
+ "idp": {
+ "provider": "google",
+ "clientId": {"$env": "GOOGLE_CLIENT_ID"},
+ "clientSecret": {"$env": "GOOGLE_CLIENT_SECRET"},
+ "redirectUri": "http://localhost:8080/oauth/callback",
+ "authorizationUrl": "http://localhost:9090/auth",
+ "tokenUrl": "http://localhost:9090/token",
+ "userInfoUrl": "http://localhost:9090/userinfo"
+ },
"allowedDomains": ["test.com"],
"allowedOrigins": ["https://claude.ai"],
"tokenTtl": "1h",
"storage": "memory",
- "googleClientId": {"$env": "GOOGLE_CLIENT_ID"},
- "googleClientSecret": {"$env": "GOOGLE_CLIENT_SECRET"},
- "googleRedirectUri": "http://localhost:8080/oauth/callback",
"jwtSecret": {"$env": "JWT_SECRET"},
"encryptionKey": {"$env": "ENCRYPTION_KEY"}
}
diff --git a/integration/config/config.oauth-service-integration-test.json b/integration/config/config.oauth-service-integration-test.json
index e46b23c..e1c4e61 100644
--- a/integration/config/config.oauth-service-integration-test.json
+++ b/integration/config/config.oauth-service-integration-test.json
@@ -8,13 +8,19 @@
"kind": "oauth",
"issuer": "http://localhost:8080",
"gcpProject": "test-project",
+ "idp": {
+ "provider": "google",
+ "clientId": {"$env": "GOOGLE_CLIENT_ID"},
+ "clientSecret": {"$env": "GOOGLE_CLIENT_SECRET"},
+ "redirectUri": "http://localhost:8080/oauth/callback",
+ "authorizationUrl": "http://localhost:9090/auth",
+ "tokenUrl": "http://localhost:9090/token",
+ "userInfoUrl": "http://localhost:9090/userinfo"
+ },
"allowedDomains": ["test.com"],
"allowedOrigins": ["https://claude.ai"],
"tokenTtl": "1h",
"storage": "memory",
- "googleClientId": {"$env": "GOOGLE_CLIENT_ID"},
- "googleClientSecret": {"$env": "GOOGLE_CLIENT_SECRET"},
- "googleRedirectUri": "http://localhost:8080/oauth/callback",
"jwtSecret": {"$env": "JWT_SECRET"},
"encryptionKey": {"$env": "ENCRYPTION_KEY"}
}
diff --git a/integration/config/config.oauth-service-test.json b/integration/config/config.oauth-service-test.json
deleted file mode 100644
index 9d2f6d4..0000000
--- a/integration/config/config.oauth-service-test.json
+++ /dev/null
@@ -1,47 +0,0 @@
-{
- "version": "v0.0.1-DEV_EDITION_EXPECT_CHANGES",
- "proxy": {
- "baseURL": "http://localhost:8080",
- "addr": ":8080",
- "name": "mcp-front-oauth-usertoken-test",
- "auth": {
- "kind": "oauth",
- "issuer": "http://localhost:8080",
- "gcpProject": "test-project",
- "allowedDomains": ["test.com"],
- "allowedOrigins": ["https://claude.ai"],
- "tokenTtl": "1h",
- "storage": "memory",
- "googleClientId": {"$env": "GOOGLE_CLIENT_ID"},
- "googleClientSecret": {"$env": "GOOGLE_CLIENT_SECRET"},
- "googleRedirectUri": "http://localhost:8080/oauth/callback",
- "jwtSecret": {"$env": "JWT_SECRET"},
- "encryptionKey": {"$env": "ENCRYPTION_KEY"}
- }
- },
- "mcpServers": {
- "postgres": {
- "transportType": "stdio",
- "command": "docker",
- "args": [
- "run",
- "--rm",
- "-i",
- "--network",
- "host",
- "mcp/postgres",
- "postgresql://testuser:testpass@localhost:15432/testdb"
- ],
- "env": {
- "USER_TOKEN": {"$userToken": "{{token}}"}
- },
- "requiresUserToken": true,
- "userAuthentication": {
- "type": "manual",
- "displayName": "Test Service",
- "instructions": "Enter your test token",
- "helpUrl": "https://example.com/help"
- }
- }
- }
-}
\ No newline at end of file
diff --git a/integration/config/config.oauth-test.json b/integration/config/config.oauth-test.json
deleted file mode 100644
index d08faca..0000000
--- a/integration/config/config.oauth-test.json
+++ /dev/null
@@ -1,33 +0,0 @@
-{
- "version": "v0.0.1-DEV_EDITION_EXPECT_CHANGES",
- "proxy": {
- "baseURL": "http://localhost:8080",
- "addr": ":8080",
- "name": "mcp-front-oauth-test",
- "auth": {
- "kind": "oauth",
- "issuer": "http://localhost:8080",
- "gcpProject": "test-project",
- "allowedDomains": ["test.com", "stainless.com", "claude.ai"],
- "allowedOrigins": ["https://claude.ai"],
- "tokenTtl": "1h",
- "storage": "memory",
- "googleClientId": {"$env": "GOOGLE_CLIENT_ID"},
- "googleClientSecret": {"$env": "GOOGLE_CLIENT_SECRET"},
- "googleRedirectUri": "http://localhost:8080/oauth/callback",
- "jwtSecret": {"$env": "JWT_SECRET"},
- "encryptionKey": {"$env": "ENCRYPTION_KEY"}
- }
- },
- "mcpServers": {
- "postgres": {
- "transportType": "stdio",
- "command": "docker",
- "args": [
- "run", "--rm", "-i", "--network", "host",
- "mcp/postgres",
- "postgresql://testuser:testpass@localhost:15432/testdb"
- ]
- }
- }
-}
\ No newline at end of file
diff --git a/integration/config/config.oauth-token-test.json b/integration/config/config.oauth-token-test.json
index 5d39ce4..f3874b7 100644
--- a/integration/config/config.oauth-token-test.json
+++ b/integration/config/config.oauth-token-test.json
@@ -8,13 +8,19 @@
"kind": "oauth",
"issuer": "http://localhost:8080",
"gcpProject": "test-project",
+ "idp": {
+ "provider": "google",
+ "clientId": {"$env": "GOOGLE_CLIENT_ID"},
+ "clientSecret": {"$env": "GOOGLE_CLIENT_SECRET"},
+ "redirectUri": "http://localhost:8080/oauth/callback",
+ "authorizationUrl": "http://localhost:9090/auth",
+ "tokenUrl": "http://localhost:9090/token",
+ "userInfoUrl": "http://localhost:9090/userinfo"
+ },
"allowedDomains": ["test.com"],
"allowedOrigins": ["https://claude.ai"],
"tokenTtl": "1h",
"storage": "memory",
- "googleClientId": {"$env": "GOOGLE_CLIENT_ID"},
- "googleClientSecret": {"$env": "GOOGLE_CLIENT_SECRET"},
- "googleRedirectUri": "http://localhost:8080/oauth/callback",
"jwtSecret": {"$env": "JWT_SECRET"},
"encryptionKey": {"$env": "ENCRYPTION_KEY"}
}
diff --git a/integration/config/config.oauth-usertoken-tools-test.json b/integration/config/config.oauth-usertoken-tools-test.json
deleted file mode 100644
index 9d2f6d4..0000000
--- a/integration/config/config.oauth-usertoken-tools-test.json
+++ /dev/null
@@ -1,47 +0,0 @@
-{
- "version": "v0.0.1-DEV_EDITION_EXPECT_CHANGES",
- "proxy": {
- "baseURL": "http://localhost:8080",
- "addr": ":8080",
- "name": "mcp-front-oauth-usertoken-test",
- "auth": {
- "kind": "oauth",
- "issuer": "http://localhost:8080",
- "gcpProject": "test-project",
- "allowedDomains": ["test.com"],
- "allowedOrigins": ["https://claude.ai"],
- "tokenTtl": "1h",
- "storage": "memory",
- "googleClientId": {"$env": "GOOGLE_CLIENT_ID"},
- "googleClientSecret": {"$env": "GOOGLE_CLIENT_SECRET"},
- "googleRedirectUri": "http://localhost:8080/oauth/callback",
- "jwtSecret": {"$env": "JWT_SECRET"},
- "encryptionKey": {"$env": "ENCRYPTION_KEY"}
- }
- },
- "mcpServers": {
- "postgres": {
- "transportType": "stdio",
- "command": "docker",
- "args": [
- "run",
- "--rm",
- "-i",
- "--network",
- "host",
- "mcp/postgres",
- "postgresql://testuser:testpass@localhost:15432/testdb"
- ],
- "env": {
- "USER_TOKEN": {"$userToken": "{{token}}"}
- },
- "requiresUserToken": true,
- "userAuthentication": {
- "type": "manual",
- "displayName": "Test Service",
- "instructions": "Enter your test token",
- "helpUrl": "https://example.com/help"
- }
- }
- }
-}
\ No newline at end of file
diff --git a/integration/config/config.test.json b/integration/config/config.test.json
deleted file mode 100644
index 773d78c..0000000
--- a/integration/config/config.test.json
+++ /dev/null
@@ -1,31 +0,0 @@
-{
- "version": "v0.0.1-DEV_EDITION_EXPECT_CHANGES",
- "proxy": {
- "baseURL": "http://localhost:8080",
- "addr": ":8080",
- "name": "mcp-front-test"
- },
- "mcpServers": {
- "postgres": {
- "transportType": "stdio",
- "command": "docker",
- "args": [
- "run",
- "-i",
- "--network",
- "host",
- "mcp/postgres",
- "postgresql://testuser:testpass@localhost:15432/testdb"
- ],
- "options": {
- "logEnabled": true
- },
- "serviceAuths": [
- {
- "type": "bearer",
- "tokens": ["test-token", "alt-test-token"]
- }
- ]
- }
- }
-}
\ No newline at end of file
diff --git a/integration/integration_test.go b/integration/integration_test.go
index 72ff8b4..2bd6473 100644
--- a/integration/integration_test.go
+++ b/integration/integration_test.go
@@ -15,11 +15,11 @@ func TestIntegration(t *testing.T) {
waitForDB(t)
trace(t, "Starting mcp-front")
- startMCPFront(t, "config/config.test.json",
- "GOOGLE_OAUTH_AUTH_URL=http://localhost:9090/auth",
- "GOOGLE_OAUTH_TOKEN_URL=http://localhost:9090/token",
- "GOOGLE_USERINFO_URL=http://localhost:9090/userinfo",
+ cfg := buildTestConfig("http://localhost:8080", "mcp-front-test",
+ nil,
+ map[string]any{"postgres": testPostgresServer(withBearerTokens("test-token", "alt-test-token"), withLogEnabled())},
)
+ startMCPFront(t, writeTestConfig(t, cfg))
waitForMCPFront(t)
trace(t, "mcp-front is ready")
@@ -49,7 +49,7 @@ func TestIntegration(t *testing.T) {
t.Log("Connected to MCP server with session")
queryParams := map[string]any{
- "name": "query",
+ "name": "execute_sql",
"arguments": map[string]any{
"sql": "SELECT COUNT(*) as user_count FROM users",
},
@@ -75,23 +75,31 @@ func TestIntegration(t *testing.T) {
assert.NotEmpty(t, content, "Query result missing content")
t.Log("Query executed successfully")
- // Test resources list
- resourcesResult, err := client.SendMCPRequest("resources/list", map[string]any{})
- require.NoError(t, err, "Failed to list resources")
+ // Test tools list
+ toolsResult, err := client.SendMCPRequest("tools/list", map[string]any{})
+ require.NoError(t, err, "Failed to list tools")
- t.Logf("Resources response: %+v", resourcesResult)
+ t.Logf("Tools response: %+v", toolsResult)
- // Check for error in resources response
- errorMap, hasError = resourcesResult["error"].(map[string]any)
- assert.False(t, hasError, "Resources list returned error: %v", errorMap)
+ errorMap, hasError = toolsResult["error"].(map[string]any)
+ assert.False(t, hasError, "Tools list returned error: %v", errorMap)
- // Verify we got resources
- resultMap, ok = resourcesResult["result"].(map[string]any)
- require.True(t, ok, "Expected result in resources response")
+ resultMap, ok = toolsResult["result"].(map[string]any)
+ require.True(t, ok, "Expected result in tools response")
- resources, ok := resultMap["resources"].([]any)
- require.True(t, ok, "Expected resources array in result")
- assert.NotEmpty(t, resources, "Expected at least one resource")
- t.Logf("Found %d resources", len(resources))
+ tools, ok := resultMap["tools"].([]any)
+ require.True(t, ok, "Expected tools array in result")
+ assert.NotEmpty(t, tools, "Expected at least one tool")
+ t.Logf("Found %d tools", len(tools))
+ // Verify execute_sql tool is present
+ var toolNames []string
+ for _, tool := range tools {
+ if toolMap, ok := tool.(map[string]any); ok {
+ if name, ok := toolMap["name"].(string); ok {
+ toolNames = append(toolNames, name)
+ }
+ }
+ }
+ assert.Contains(t, toolNames, "execute_sql", "Should have execute_sql tool")
}
diff --git a/integration/isolation_test.go b/integration/isolation_test.go
index 41d5d0f..1036e91 100644
--- a/integration/isolation_test.go
+++ b/integration/isolation_test.go
@@ -18,12 +18,16 @@ func TestMultiUserSessionIsolation(t *testing.T) {
// Start mcp-front with bearer token auth
trace(t, "Starting mcp-front")
- startMCPFront(t, "config/config.test.json")
+ cfg := buildTestConfig("http://localhost:8080", "mcp-front-test",
+ nil,
+ map[string]any{"postgres": testPostgresServer(withBearerTokens("test-token", "alt-test-token"), withLogEnabled())},
+ )
+ startMCPFront(t, writeTestConfig(t, cfg))
waitForMCPFront(t)
// Get initial container count
initialContainers := getMCPContainers()
- t.Logf("Initial mcp/postgres containers: %d", len(initialContainers))
+ t.Logf("Initial toolbox containers: %d", len(initialContainers))
// Create two clients with different auth tokens
client1 := NewMCPSSEClient("http://localhost:8080")
@@ -69,7 +73,7 @@ func TestMultiUserSessionIsolation(t *testing.T) {
}
query1Result, err := client1.SendMCPRequest("tools/call", map[string]any{
- "name": "query",
+ "name": "execute_sql",
"arguments": map[string]any{
"sql": "SELECT 'user1-query1' as test_id, COUNT(*) as count FROM users",
},
@@ -116,13 +120,13 @@ func TestMultiUserSessionIsolation(t *testing.T) {
// Verify that client1 and client2 have different containers
if client1Container != "" && client2Container != "" && client1Container == client2Container {
t.Errorf("CRITICAL: Both users are using the same Docker container! Container ID: %s", client1Container)
- t.Error("This indicates session isolation is NOT working - users are sharing the same mcp/postgres instance")
+ t.Error("This indicates session isolation is NOT working - users are sharing the same toolbox instance")
} else if client1Container != "" && client2Container != "" {
t.Logf("Confirmed different stdio processes: User1 container=%s, User2 container=%s", client1Container, client2Container)
}
query2Result, err := client2.SendMCPRequest("tools/call", map[string]any{
- "name": "query",
+ "name": "execute_sql",
"arguments": map[string]any{
"sql": "SELECT 'user2-query1' as test_id, COUNT(*) as count FROM orders",
},
@@ -135,7 +139,7 @@ func TestMultiUserSessionIsolation(t *testing.T) {
// Step 3: First user sends another query
t.Log("\nStep 3: First user sends another query")
query3Result, err := client1.SendMCPRequest("tools/call", map[string]any{
- "name": "query",
+ "name": "execute_sql",
"arguments": map[string]any{
"sql": "SELECT 'user1-query2' as test_id, current_timestamp as ts",
},
@@ -148,7 +152,7 @@ func TestMultiUserSessionIsolation(t *testing.T) {
// Step 4: First user sends another query
t.Log("\nStep 4: First user sends another query")
query4Result, err := client1.SendMCPRequest("tools/call", map[string]any{
- "name": "query",
+ "name": "execute_sql",
"arguments": map[string]any{
"sql": "SELECT 'user1-query3' as test_id, version() as db_version",
},
@@ -161,7 +165,7 @@ func TestMultiUserSessionIsolation(t *testing.T) {
// Step 5: Second user sends a query
t.Log("\nStep 5: Second user sends a query")
query5Result, err := client2.SendMCPRequest("tools/call", map[string]any{
- "name": "query",
+ "name": "execute_sql",
"arguments": map[string]any{
"sql": "SELECT 'user2-query2' as test_id, current_database() as db_name",
},
@@ -208,12 +212,16 @@ func TestSessionCleanupAfterTimeout(t *testing.T) {
// Start mcp-front with test timeout configuration
trace(t, "Starting mcp-front with test session timeout")
- startMCPFront(t, "config/config.test.json")
+ cfg := buildTestConfig("http://localhost:8080", "mcp-front-test",
+ nil,
+ map[string]any{"postgres": testPostgresServer(withBearerTokens("test-token", "alt-test-token"), withLogEnabled())},
+ )
+ startMCPFront(t, writeTestConfig(t, cfg))
waitForMCPFront(t)
// Get initial container count
initialContainers := getMCPContainers()
- t.Logf("Initial mcp/postgres containers: %d", len(initialContainers))
+ t.Logf("Initial toolbox containers: %d", len(initialContainers))
// Create a client and connect
client := NewMCPSSEClient("http://localhost:8080")
@@ -237,7 +245,7 @@ func TestSessionCleanupAfterTimeout(t *testing.T) {
// Send a query to ensure session is active
_, err = client.SendMCPRequest("tools/call", map[string]any{
- "name": "query",
+ "name": "execute_sql",
"arguments": map[string]any{
"sql": "SELECT 'test' as test_id",
},
@@ -277,12 +285,16 @@ func TestSessionTimerReset(t *testing.T) {
// Start mcp-front with test timeout configuration
trace(t, "Starting mcp-front with test session timeout")
- startMCPFront(t, "config/config.test.json")
+ cfg := buildTestConfig("http://localhost:8080", "mcp-front-test",
+ nil,
+ map[string]any{"postgres": testPostgresServer(withBearerTokens("test-token", "alt-test-token"), withLogEnabled())},
+ )
+ startMCPFront(t, writeTestConfig(t, cfg))
waitForMCPFront(t)
// Get initial container count
initialContainers := getMCPContainers()
- t.Logf("Initial mcp/postgres containers: %d", len(initialContainers))
+ t.Logf("Initial toolbox containers: %d", len(initialContainers))
// Create a client and connect
client := NewMCPSSEClient("http://localhost:8080")
@@ -308,7 +320,7 @@ func TestSessionTimerReset(t *testing.T) {
for i := range 3 {
t.Logf("Sending keepalive query %d/3...", i+1)
_, err := client.SendMCPRequest("tools/call", map[string]any{
- "name": "query",
+ "name": "execute_sql",
"arguments": map[string]any{
"sql": "SELECT 'keepalive' as status, NOW() as timestamp",
},
@@ -357,12 +369,16 @@ func TestMultiUserTimerIndependence(t *testing.T) {
// Start mcp-front with test timeout configuration
trace(t, "Starting mcp-front with test session timeout")
- startMCPFront(t, "config/config.test.json")
+ cfg := buildTestConfig("http://localhost:8080", "mcp-front-test",
+ nil,
+ map[string]any{"postgres": testPostgresServer(withBearerTokens("test-token", "alt-test-token"), withLogEnabled())},
+ )
+ startMCPFront(t, writeTestConfig(t, cfg))
waitForMCPFront(t)
// Get initial container count
initialContainers := getMCPContainers()
- t.Logf("Initial mcp/postgres containers: %d", len(initialContainers))
+ t.Logf("Initial toolbox containers: %d", len(initialContainers))
// Create two clients
client1 := NewMCPSSEClient("http://localhost:8080")
@@ -411,7 +427,7 @@ func TestMultiUserTimerIndependence(t *testing.T) {
for i := range 4 {
time.Sleep(4 * time.Second)
_, err := client2.SendMCPRequest("tools/call", map[string]any{
- "name": "query",
+ "name": "execute_sql",
"arguments": map[string]any{
"sql": "SELECT 'client2-keepalive' as status",
},
diff --git a/integration/main_test.go b/integration/main_test.go
index 49eb876..39527d7 100644
--- a/integration/main_test.go
+++ b/integration/main_test.go
@@ -22,6 +22,15 @@ func TestMain(m *testing.M) {
os.Exit(1)
}
+ // Pull the toolbox image so the first test doesn't timeout on image pull
+ fmt.Println("Pulling toolbox image...")
+ pullCmd := exec.Command("docker", "pull", ToolboxImage)
+ pullCmd.Stdout = os.Stdout
+ pullCmd.Stderr = os.Stderr
+ if err := pullCmd.Run(); err != nil {
+ fmt.Printf("Warning: failed to pull toolbox image: %v\n", err)
+ }
+
// Set up local log file for mcp-front output
logFile := "mcp-front-test.log"
os.Setenv("MCP_LOG_FILE", logFile)
@@ -53,7 +62,7 @@ func TestMain(m *testing.M) {
os.Exit(exitCode)
}()
- // Start fake GCP server for OAuth
+ // Start fake GCP server for OAuth (port 9090)
fakeGCP := NewFakeGCPServer("9090")
err := fakeGCP.Start()
if err != nil {
@@ -65,6 +74,28 @@ func TestMain(m *testing.M) {
_ = fakeGCP.Stop()
}()
+ // Start fake GitHub server (port 9092)
+ fakeGitHub := NewFakeGitHubServer("9092", []string{"test-org", "another-org"})
+ if err := fakeGitHub.Start(); err != nil {
+ fmt.Printf("Failed to start fake GitHub server: %v\n", err)
+ exitCode = 1
+ return
+ }
+ defer func() {
+ _ = fakeGitHub.Stop()
+ }()
+
+ // Start fake OIDC server (port 9093) — used for both OIDC and Azure tests
+ fakeOIDC := NewFakeOIDCServer("9093")
+ if err := fakeOIDC.Start(); err != nil {
+ fmt.Printf("Failed to start fake OIDC server: %v\n", err)
+ exitCode = 1
+ return
+ }
+ defer func() {
+ _ = fakeOIDC.Stop()
+ }()
+
// Wait for database to be ready
fmt.Println("Waiting for database to be ready...")
for i := range 30 { // Wait up to 30 seconds
diff --git a/integration/oauth_flow_test.go b/integration/oauth_flow_test.go
new file mode 100644
index 0000000..c6e8217
--- /dev/null
+++ b/integration/oauth_flow_test.go
@@ -0,0 +1,789 @@
+package integration
+
+import (
+ "bufio"
+ "bytes"
+ "encoding/json"
+ "fmt"
+ "io"
+ "net/http"
+ "net/url"
+ "os"
+ "os/exec"
+ "regexp"
+ "strings"
+ "testing"
+ "time"
+
+ "github.com/stretchr/testify/assert"
+ "github.com/stretchr/testify/require"
+)
+
+// TestBasicOAuthFlow tests the basic OAuth server functionality
+func TestBasicOAuthFlow(t *testing.T) {
+ cfg := buildTestConfig("http://localhost:8080", "mcp-front-oauth-test",
+ testOAuthConfigFromEnv(),
+ map[string]any{"postgres": testPostgresServer()},
+ )
+ startMCPFront(t, writeTestConfig(t, cfg),
+ "JWT_SECRET=test-jwt-secret-32-bytes-exactly!",
+ "ENCRYPTION_KEY=test-encryption-key-32-bytes-ok!",
+ "GOOGLE_CLIENT_ID=test-client-id-for-oauth",
+ "GOOGLE_CLIENT_SECRET=test-client-secret-for-oauth",
+ "MCP_FRONT_ENV=development",
+ )
+
+ // Wait for startup
+ waitForMCPFront(t)
+
+ // Test OAuth discovery
+ resp, err := http.Get("http://localhost:8080/.well-known/oauth-authorization-server")
+ require.NoError(t, err, "Failed to get OAuth discovery")
+ defer resp.Body.Close()
+
+ assert.Equal(t, 200, resp.StatusCode, "OAuth discovery failed")
+
+ var discovery map[string]any
+ err = json.NewDecoder(resp.Body).Decode(&discovery)
+ require.NoError(t, err, "Failed to decode discovery")
+
+ // Verify required endpoints
+ requiredEndpoints := []string{
+ "issuer",
+ "authorization_endpoint",
+ "token_endpoint",
+ "registration_endpoint",
+ }
+
+ for _, endpoint := range requiredEndpoints {
+ _, ok := discovery[endpoint]
+ assert.True(t, ok, "Missing required endpoint: %s", endpoint)
+ }
+
+ // Verify client_secret_post is advertised
+ authMethods, ok := discovery["token_endpoint_auth_methods_supported"].([]any)
+ assert.True(t, ok, "token_endpoint_auth_methods_supported should be present")
+
+ var hasNone, hasClientSecretPost bool
+ for _, method := range authMethods {
+ if method == "none" {
+ hasNone = true
+ }
+ if method == "client_secret_post" {
+ hasClientSecretPost = true
+ }
+ }
+ assert.True(t, hasNone, "Should support 'none' auth method for public clients")
+ assert.True(t, hasClientSecretPost, "Should support 'client_secret_post' auth method for confidential clients")
+}
+
+// TestJWTSecretValidation tests JWT secret length requirements
+func TestJWTSecretValidation(t *testing.T) {
+ oauthCfg := buildTestConfig("http://localhost:8080", "mcp-front-oauth-test",
+ testOAuthConfigFromEnv(),
+ map[string]any{"postgres": testPostgresServer()},
+ )
+ configPath := writeTestConfig(t, oauthCfg)
+
+ tests := []struct {
+ name string
+ secret string
+ shouldFail bool
+ }{
+ {"Short 3-byte secret", "123", true},
+ {"Short 16-byte secret", "sixteen-byte-key", true},
+ {"Valid 32-byte secret", "demo-jwt-secret-32-bytes-exactly!", false},
+ {"Long 64-byte secret", "demo-jwt-secret-32-bytes-exactly!demo-jwt-secret-32-bytes-exactly!", false},
+ }
+
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+
+ mcpCmd := exec.Command("../cmd/mcp-front/mcp-front", "-config", configPath)
+ mcpCmd.Env = []string{
+ "PATH=" + os.Getenv("PATH"),
+ "JWT_SECRET=" + tt.secret,
+ "ENCRYPTION_KEY=test-encryption-key-32-bytes-ok!",
+ "GOOGLE_CLIENT_ID=test-client-id",
+ "GOOGLE_CLIENT_SECRET=test-client-secret",
+ "MCP_FRONT_ENV=development",
+ }
+
+ // Capture stderr
+ stderrPipe, _ := mcpCmd.StderrPipe()
+ scanner := bufio.NewScanner(stderrPipe)
+
+ if err := mcpCmd.Start(); err != nil {
+ t.Fatalf("Failed to start mcp-front: %v", err)
+ }
+
+ // Read stderr to check for errors
+ errorFound := false
+ go func() {
+ for scanner.Scan() {
+ line := scanner.Text()
+ if contains(line, "JWT secret must be at least") {
+ errorFound = true
+ }
+ }
+ }()
+
+ // Give it time to start or fail
+ time.Sleep(2 * time.Second)
+
+ // Check if it's running
+ healthy := checkHealth()
+
+ // Clean up
+ if mcpCmd.Process != nil {
+ _ = mcpCmd.Process.Kill()
+ _ = mcpCmd.Wait()
+ }
+
+ if tt.shouldFail {
+ assert.False(t, healthy && !errorFound, "Expected failure with short JWT secret but server started successfully")
+ } else {
+ assert.True(t, healthy, "Expected success with valid JWT secret but server failed to start")
+ }
+ })
+ }
+}
+
+// TestClientRegistration tests dynamic client registration (RFC 7591)
+func TestClientRegistration(t *testing.T) {
+ // Start OAuth server
+ mcpCmd := startOAuthServer(t, map[string]string{
+ "MCP_FRONT_ENV": "development",
+ })
+ defer stopServer(mcpCmd)
+
+ if !waitForHealthCheck(30) {
+ t.Fatal("OAuth server failed to start")
+ }
+
+ t.Run("PublicClientRegistration", func(t *testing.T) {
+ // Register a public client (no secret)
+ clientReq := map[string]any{
+ "redirect_uris": []string{"http://127.0.0.1:6274/oauth/callback/debug"},
+ "scope": "read write",
+ }
+
+ body, _ := json.Marshal(clientReq)
+ resp, err := http.Post(
+ "http://localhost:8080/register",
+ "application/json",
+ bytes.NewBuffer(body),
+ )
+ if err != nil {
+ t.Fatalf("Failed to register client: %v", err)
+ }
+ defer resp.Body.Close()
+
+ if resp.StatusCode != 201 {
+ body, _ := io.ReadAll(resp.Body)
+ t.Fatalf("Client registration failed with status %d: %s", resp.StatusCode, string(body))
+ }
+
+ var clientResp map[string]any
+ if err := json.NewDecoder(resp.Body).Decode(&clientResp); err != nil {
+ t.Fatalf("Failed to decode response: %v", err)
+ }
+
+ // Verify response
+ if clientResp["client_id"] == "" {
+ t.Error("Client ID should not be empty")
+ }
+ if clientResp["client_secret"] != nil {
+ t.Error("Public client should not have a secret")
+ }
+ if scope, ok := clientResp["scope"].(string); !ok || scope != "read write" {
+ t.Errorf("Expected scope 'read write' as string, got: %v", clientResp["scope"])
+ }
+ })
+
+ t.Run("MultipleRegistrations", func(t *testing.T) {
+ // Register multiple clients and verify they get different IDs
+ var clientIDs []string
+
+ for i := range 3 {
+ clientReq := map[string]any{
+ "redirect_uris": []string{fmt.Sprintf("http://example.com/callback%d", i)},
+ "scope": "read",
+ }
+
+ body, _ := json.Marshal(clientReq)
+ resp, err := http.Post(
+ "http://localhost:8080/register",
+ "application/json",
+ bytes.NewBuffer(body),
+ )
+ if err != nil {
+ t.Fatalf("Failed to register client %d: %v", i, err)
+ }
+ defer resp.Body.Close()
+
+ var clientResp map[string]any
+ _ = json.NewDecoder(resp.Body).Decode(&clientResp)
+ clientIDs = append(clientIDs, clientResp["client_id"].(string))
+ }
+
+ // Verify all IDs are unique
+ for i := 0; i < len(clientIDs); i++ {
+ for j := i + 1; j < len(clientIDs); j++ {
+ if clientIDs[i] == clientIDs[j] {
+ t.Errorf("Client IDs should be unique, but got duplicate: %s", clientIDs[i])
+ }
+ }
+ }
+
+ })
+
+ t.Run("ConfidentialClientRegistration", func(t *testing.T) {
+ // Register a confidential client with client_secret_post
+ clientReq := map[string]any{
+ "redirect_uris": []string{"https://example.com/callback"},
+ "scope": "read write",
+ "token_endpoint_auth_method": "client_secret_post",
+ }
+
+ body, _ := json.Marshal(clientReq)
+ resp, err := http.Post(
+ "http://localhost:8080/register",
+ "application/json",
+ bytes.NewBuffer(body),
+ )
+ if err != nil {
+ t.Fatalf("Failed to register confidential client: %v", err)
+ }
+ defer resp.Body.Close()
+
+ if resp.StatusCode != 201 {
+ body, _ := io.ReadAll(resp.Body)
+ t.Fatalf("Confidential client registration failed with status %d: %s", resp.StatusCode, string(body))
+ }
+
+ var clientResp map[string]any
+ if err := json.NewDecoder(resp.Body).Decode(&clientResp); err != nil {
+ t.Fatalf("Failed to decode response: %v", err)
+ }
+
+ // Verify response includes client_secret
+ if clientResp["client_id"] == "" {
+ t.Error("Client ID should not be empty")
+ }
+ clientSecret, ok := clientResp["client_secret"].(string)
+ if !ok || clientSecret == "" {
+ t.Error("Confidential client should receive a client_secret")
+ }
+ // Verify secret has reasonable length (base64 of 32 bytes)
+ if len(clientSecret) < 40 {
+ t.Errorf("Client secret seems too short: %d chars", len(clientSecret))
+ }
+
+ tokenAuthMethod, ok := clientResp["token_endpoint_auth_method"].(string)
+ if !ok || tokenAuthMethod != "client_secret_post" {
+ t.Errorf("Expected token_endpoint_auth_method 'client_secret_post', got: %v", clientResp["token_endpoint_auth_method"])
+ }
+
+ // Verify scope is returned as string
+ if scope, ok := clientResp["scope"].(string); !ok || scope != "read write" {
+ t.Errorf("Expected scope 'read write' as string, got: %v", clientResp["scope"])
+ }
+ })
+
+ t.Run("PublicVsConfidentialClients", func(t *testing.T) {
+ // Test that public clients don't get secrets and confidential ones do
+
+ // First, create a public client
+ publicReq := map[string]any{
+ "redirect_uris": []string{"https://public.example.com/callback"},
+ "scope": "read",
+ // No token_endpoint_auth_method specified - defaults to "none"
+ }
+
+ body, _ := json.Marshal(publicReq)
+ resp, err := http.Post(
+ "http://localhost:8080/register",
+ "application/json",
+ bytes.NewBuffer(body),
+ )
+ require.NoError(t, err)
+ defer resp.Body.Close()
+
+ var publicResp map[string]any
+ _ = json.NewDecoder(resp.Body).Decode(&publicResp)
+
+ // Verify public client has no secret
+ if _, hasSecret := publicResp["client_secret"]; hasSecret {
+ t.Error("Public client should not have a secret")
+ }
+ if authMethod := publicResp["token_endpoint_auth_method"]; authMethod != "none" {
+ t.Errorf("Public client should have auth method 'none', got: %v", authMethod)
+ }
+
+ // Now create a confidential client
+ confidentialReq := map[string]any{
+ "redirect_uris": []string{"https://confidential.example.com/callback"},
+ "scope": "read write",
+ "token_endpoint_auth_method": "client_secret_post",
+ }
+
+ body, _ = json.Marshal(confidentialReq)
+ resp, err = http.Post(
+ "http://localhost:8080/register",
+ "application/json",
+ bytes.NewBuffer(body),
+ )
+ require.NoError(t, err)
+ defer resp.Body.Close()
+
+ var confResp map[string]any
+ _ = json.NewDecoder(resp.Body).Decode(&confResp)
+
+ // Verify confidential client has a secret
+ if secret, ok := confResp["client_secret"].(string); !ok || secret == "" {
+ t.Error("Confidential client should have a secret")
+ }
+ if authMethod := confResp["token_endpoint_auth_method"]; authMethod != "client_secret_post" {
+ t.Errorf("Confidential client should have auth method 'client_secret_post', got: %v", authMethod)
+ }
+ })
+}
+
+// TestStateParameterHandling tests OAuth state parameter requirements
+func TestStateParameterHandling(t *testing.T) {
+ tests := []struct {
+ name string
+ environment string
+ state string
+ expectError bool
+ }{
+ {"Production without state", "production", "", true},
+ {"Production with state", "production", "secure-random-state", false},
+ {"Development without state", "development", "", false}, // Should auto-generate
+ {"Development with state", "development", "test-state", false},
+ }
+
+ for _, tt := range tests {
+ // capture range variable
+ t.Run(tt.name, func(t *testing.T) {
+ // Start server with specific environment
+ mcpCmd := startOAuthServer(t, map[string]string{
+ "MCP_FRONT_ENV": tt.environment,
+ })
+ defer stopServer(mcpCmd)
+
+ if !waitForHealthCheck(10) {
+ t.Fatal("Server failed to start")
+ }
+
+ // Register a client first
+ clientID := registerTestClient(t)
+
+ // Create authorization request
+ params := url.Values{
+ "response_type": {"code"},
+ "client_id": {clientID},
+ "redirect_uri": {"http://127.0.0.1:6274/oauth/callback"},
+ "code_challenge": {"test-challenge"},
+ "code_challenge_method": {"S256"},
+ "scope": {"read write"},
+ }
+ if tt.state != "" {
+ params.Set("state", tt.state)
+ }
+
+ authURL := fmt.Sprintf("http://localhost:8080/authorize?%s", params.Encode())
+
+ // Use a client that doesn't follow redirects
+ client := &http.Client{
+ CheckRedirect: func(req *http.Request, via []*http.Request) error {
+ return http.ErrUseLastResponse
+ },
+ }
+ resp, err := client.Get(authURL)
+ if err != nil {
+ t.Fatalf("Authorization request failed: %v", err)
+ }
+ defer resp.Body.Close()
+
+ if tt.expectError {
+ // OAuth errors are returned as redirects with error parameters
+ if resp.StatusCode == 302 || resp.StatusCode == 303 {
+ location := resp.Header.Get("Location")
+ if strings.Contains(location, "error=") {
+ } else {
+ t.Errorf("Expected error redirect for %s, got redirect without error", tt.name)
+ }
+ } else if resp.StatusCode >= 400 {
+ } else {
+ t.Errorf("Expected error for %s, got status %d", tt.name, resp.StatusCode)
+ }
+ } else {
+ if resp.StatusCode == 302 || resp.StatusCode == 303 {
+ location := resp.Header.Get("Location")
+ if strings.Contains(location, "error=") {
+ t.Errorf("Unexpected error redirect for %s: %s", tt.name, location)
+ }
+ } else if resp.StatusCode < 400 {
+ } else {
+ body, _ := io.ReadAll(resp.Body)
+ t.Errorf("Expected success for %s, got status %d: %s", tt.name, resp.StatusCode, string(body))
+ }
+ }
+ })
+ }
+}
+
+// TestEnvironmentModes tests development vs production mode differences
+func TestEnvironmentModes(t *testing.T) {
+ t.Run("DevelopmentMode", func(t *testing.T) {
+ mcpCmd := startOAuthServer(t, map[string]string{
+ "MCP_FRONT_ENV": "development",
+ })
+ defer stopServer(mcpCmd)
+
+ if !waitForHealthCheck(30) {
+ t.Fatal("Server failed to start")
+ }
+
+ // In development mode, missing state should be auto-generated
+ clientID := registerTestClient(t)
+
+ params := url.Values{
+ "response_type": {"code"},
+ "client_id": {clientID},
+ "redirect_uri": {"http://127.0.0.1:6274/oauth/callback"},
+ "code_challenge": {"test-challenge"},
+ "code_challenge_method": {"S256"},
+ "scope": {"read"},
+ // Intentionally omitting state parameter
+ }
+
+ // Use a client that doesn't follow redirects
+ client := &http.Client{
+ CheckRedirect: func(req *http.Request, via []*http.Request) error {
+ return http.ErrUseLastResponse
+ },
+ }
+ resp, err := client.Get("http://localhost:8080/authorize?" + params.Encode())
+ if err != nil {
+ t.Fatalf("Failed to make auth request: %v", err)
+ }
+ defer resp.Body.Close()
+
+ // Should redirect (302) not error
+ if resp.StatusCode >= 400 && resp.StatusCode != 302 {
+ t.Errorf("Development mode should handle missing state, got status %d", resp.StatusCode)
+ }
+ })
+
+ t.Run("ProductionMode", func(t *testing.T) {
+ mcpCmd := startOAuthServer(t, map[string]string{
+ "MCP_FRONT_ENV": "production",
+ })
+ defer stopServer(mcpCmd)
+
+ if !waitForHealthCheck(30) {
+ t.Fatal("Server failed to start")
+ }
+
+ // In production mode, state should be required
+ clientID := registerTestClient(t)
+
+ params := url.Values{
+ "response_type": {"code"},
+ "client_id": {clientID},
+ "redirect_uri": {"http://127.0.0.1:6274/oauth/callback"},
+ "code_challenge": {"test-challenge"},
+ "code_challenge_method": {"S256"},
+ "scope": {"read"},
+ // Intentionally omitting state parameter
+ }
+
+ // Use a client that doesn't follow redirects
+ client := &http.Client{
+ CheckRedirect: func(req *http.Request, via []*http.Request) error {
+ return http.ErrUseLastResponse
+ },
+ }
+ resp, err := client.Get("http://localhost:8080/authorize?" + params.Encode())
+ if err != nil {
+ t.Fatalf("Failed to make auth request: %v", err)
+ }
+ defer resp.Body.Close()
+
+ // Should error - OAuth errors are returned as redirects
+ if resp.StatusCode == 302 || resp.StatusCode == 303 {
+ location := resp.Header.Get("Location")
+ if strings.Contains(location, "error=") {
+ } else {
+ t.Errorf("Expected error redirect in production mode, got redirect without error")
+ }
+ } else if resp.StatusCode >= 400 {
+ } else {
+ t.Errorf("Production mode should require state parameter, got status %d", resp.StatusCode)
+ }
+ })
+}
+
+// TestOAuthEndpoints tests all OAuth endpoints comprehensively
+func TestOAuthEndpoints(t *testing.T) {
+ mcpCmd := startOAuthServer(t, map[string]string{
+ "MCP_FRONT_ENV": "development",
+ })
+ defer stopServer(mcpCmd)
+
+ if !waitForHealthCheck(10) {
+ t.Fatal("Server failed to start")
+ }
+
+ t.Run("Discovery", func(t *testing.T) {
+ resp, err := http.Get("http://localhost:8080/.well-known/oauth-authorization-server")
+ if err != nil {
+ t.Fatalf("Discovery request failed: %v", err)
+ }
+ defer resp.Body.Close()
+
+ if resp.StatusCode != 200 {
+ t.Fatalf("Discovery failed with status %d", resp.StatusCode)
+ }
+
+ var discovery map[string]any
+ if err := json.NewDecoder(resp.Body).Decode(&discovery); err != nil {
+ t.Fatalf("Failed to decode discovery response: %v", err)
+ }
+
+ // Verify all required fields
+ required := []string{
+ "issuer",
+ "authorization_endpoint",
+ "token_endpoint",
+ "registration_endpoint",
+ "response_types_supported",
+ "grant_types_supported",
+ "code_challenge_methods_supported",
+ }
+
+ for _, field := range required {
+ if _, ok := discovery[field]; !ok {
+ t.Errorf("Missing required discovery field: %s", field)
+ }
+ }
+
+ })
+
+ t.Run("HealthCheck", func(t *testing.T) {
+ resp, err := http.Get("http://localhost:8080/health")
+ if err != nil {
+ t.Fatalf("Health check failed: %v", err)
+ }
+ defer resp.Body.Close()
+
+ if resp.StatusCode != 200 {
+ t.Errorf("Health check should return 200, got %d", resp.StatusCode)
+ }
+
+ var health map[string]string
+ if err := json.NewDecoder(resp.Body).Decode(&health); err != nil {
+ t.Fatalf("Failed to decode health response: %v", err)
+ }
+ if health["status"] != "ok" {
+ t.Errorf("Expected status 'ok', got '%s'", health["status"])
+ }
+
+ })
+}
+
+// TestCORSHeaders tests CORS headers for Claude.ai compatibility
+func TestCORSHeaders(t *testing.T) {
+ mcpCmd := startOAuthServer(t, map[string]string{
+ "MCP_FRONT_ENV": "development",
+ })
+ defer stopServer(mcpCmd)
+
+ if !waitForHealthCheck(10) {
+ t.Fatal("Server failed to start")
+ }
+
+ // Test preflight request
+ req, _ := http.NewRequest("OPTIONS", "http://localhost:8080/register", nil)
+ req.Header.Set("Origin", "https://claude.ai")
+ req.Header.Set("Access-Control-Request-Method", "POST")
+ req.Header.Set("Access-Control-Request-Headers", "content-type")
+
+ resp, err := http.DefaultClient.Do(req)
+ if err != nil {
+ t.Fatalf("Preflight request failed: %v", err)
+ }
+ defer resp.Body.Close()
+
+ if resp.StatusCode != 200 {
+ t.Errorf("Preflight should return 200, got %d", resp.StatusCode)
+ }
+
+ // Check CORS headers
+ expectedHeaders := map[string]string{
+ "Access-Control-Allow-Origin": "https://claude.ai",
+ "Access-Control-Allow-Methods": "GET, POST, OPTIONS",
+ "Access-Control-Allow-Headers": "Content-Type, Authorization, Cache-Control, mcp-protocol-version",
+ }
+
+ for header, expected := range expectedHeaders {
+ actual := resp.Header.Get(header)
+ if actual != expected {
+ t.Errorf("Expected %s: '%s', got '%s'", header, expected, actual)
+ }
+ }
+
+}
+
+// Shared helpers for OAuth tests
+
+func startOAuthServer(t *testing.T, env map[string]string) *exec.Cmd {
+ cfg := buildTestConfig("http://localhost:8080", "mcp-front-oauth-test",
+ testOAuthConfigFromEnv(),
+ map[string]any{"postgres": testPostgresServer()},
+ )
+ configPath := writeTestConfig(t, cfg)
+
+ mcpCmd := exec.Command("../cmd/mcp-front/mcp-front", "-config", configPath)
+
+ // Set default environment
+ mcpCmd.Env = []string{
+ "PATH=" + os.Getenv("PATH"),
+ "JWT_SECRET=demo-jwt-secret-32-bytes-exactly!",
+ "ENCRYPTION_KEY=test-encryption-key-32-bytes-ok!",
+ "GOOGLE_CLIENT_ID=test-client-id-oauth",
+ "GOOGLE_CLIENT_SECRET=test-client-secret-oauth",
+ }
+
+ // Override with provided env
+ for key, value := range env {
+ mcpCmd.Env = append(mcpCmd.Env, fmt.Sprintf("%s=%s", key, value))
+ }
+
+ // Capture stderr for debugging and also output to test log
+ var stderr bytes.Buffer
+ mcpCmd.Stderr = io.MultiWriter(&stderr, os.Stderr)
+
+ if err := mcpCmd.Start(); err != nil {
+ t.Fatalf("Failed to start OAuth server: %v", err)
+ }
+
+ // Give a moment for immediate failures
+ time.Sleep(100 * time.Millisecond)
+
+ // Check if process died immediately
+ if mcpCmd.ProcessState != nil {
+ t.Fatalf("OAuth server died immediately: %s", stderr.String())
+ }
+
+ return mcpCmd
+}
+
+// startOAuthServerWithTokenConfig starts the OAuth server with user token configuration
+func startOAuthServerWithTokenConfig(t *testing.T) *exec.Cmd {
+ // Start with user token config
+ mcpCmd := exec.Command("../cmd/mcp-front/mcp-front", "-config", "config/config.oauth-token-test.json")
+
+ // Set default environment
+ mcpCmd.Env = []string{
+ "PATH=" + os.Getenv("PATH"),
+ "JWT_SECRET=demo-jwt-secret-32-bytes-exactly!",
+ "ENCRYPTION_KEY=test-encryption-key-32-bytes-ok!",
+ "GOOGLE_CLIENT_ID=test-client-id-oauth",
+ "GOOGLE_CLIENT_SECRET=test-client-secret-oauth",
+ "MCP_FRONT_ENV=development",
+ }
+
+ // Capture stderr for debugging and also output to test log
+ var stderr bytes.Buffer
+ mcpCmd.Stderr = io.MultiWriter(&stderr, os.Stderr)
+
+ if err := mcpCmd.Start(); err != nil {
+ t.Fatalf("Failed to start OAuth server: %v", err)
+ }
+
+ // Give a moment for immediate failures
+ time.Sleep(100 * time.Millisecond)
+
+ // Check if process died immediately
+ if mcpCmd.ProcessState != nil {
+ t.Fatalf("OAuth server died immediately: %s", stderr.String())
+ }
+
+ return mcpCmd
+}
+
+func stopServer(cmd *exec.Cmd) {
+ if cmd != nil && cmd.Process != nil {
+ _ = cmd.Process.Kill()
+ _ = cmd.Wait()
+ // Give the OS time to release the port
+ time.Sleep(100 * time.Millisecond)
+ }
+}
+
+func waitForHealthCheck(seconds int) bool {
+ for range seconds {
+ if checkHealth() {
+ return true
+ }
+ time.Sleep(1 * time.Second)
+ }
+ return false
+}
+
+func checkHealth() bool {
+ resp, err := http.Get("http://localhost:8080/health")
+ if err == nil && resp.StatusCode == 200 {
+ resp.Body.Close()
+ return true
+ }
+ if resp != nil {
+ resp.Body.Close()
+ }
+ return false
+}
+
+func registerTestClient(t *testing.T) string {
+ clientReq := map[string]any{
+ "redirect_uris": []string{"http://127.0.0.1:6274/oauth/callback"},
+ "scope": "openid email profile read write",
+ }
+
+ body, _ := json.Marshal(clientReq)
+ resp, err := http.Post(
+ "http://localhost:8080/register",
+ "application/json",
+ bytes.NewBuffer(body),
+ )
+ if err != nil {
+ t.Fatalf("Failed to register client: %v", err)
+ }
+ defer resp.Body.Close()
+
+ if resp.StatusCode != 201 {
+ body, _ := io.ReadAll(resp.Body)
+ t.Fatalf("Client registration failed: %d - %s", resp.StatusCode, string(body))
+ }
+
+ var clientResp map[string]any
+ _ = json.NewDecoder(resp.Body).Decode(&clientResp)
+ return clientResp["client_id"].(string)
+}
+
+// extractCSRFToken extracts the CSRF token from the HTML response
+func extractCSRFToken(t *testing.T, html string) string {
+ // Look for
+ re := regexp.MustCompile(`]+name="csrf_token"[^>]+value="([^"]+)"`)
+ matches := re.FindStringSubmatch(html)
+ require.GreaterOrEqual(t, len(matches), 2, "CSRF token not found in response")
+ return matches[1]
+}
+
+// contains is a simple helper to check if string contains substring
+func contains(s, substr string) bool {
+ return strings.Contains(s, substr)
+}
diff --git a/integration/oauth_idp_test.go b/integration/oauth_idp_test.go
new file mode 100644
index 0000000..86b2171
--- /dev/null
+++ b/integration/oauth_idp_test.go
@@ -0,0 +1,256 @@
+package integration
+
+import (
+ "io"
+ "net/http"
+ "testing"
+
+ "github.com/stretchr/testify/assert"
+ "github.com/stretchr/testify/require"
+)
+
+// idpEnvGitHub is the set of env vars needed for GitHub IDP integration tests.
+var idpEnvGitHub = []string{
+ "JWT_SECRET=test-jwt-secret-32-bytes-exactly!",
+ "ENCRYPTION_KEY=test-encryption-key-32-bytes-ok!",
+ "GITHUB_CLIENT_SECRET=test-github-client-secret",
+ "MCP_FRONT_ENV=development",
+}
+
+// idpEnvOIDC is the set of env vars needed for OIDC IDP integration tests.
+var idpEnvOIDC = []string{
+ "JWT_SECRET=test-jwt-secret-32-bytes-exactly!",
+ "ENCRYPTION_KEY=test-encryption-key-32-bytes-ok!",
+ "OIDC_CLIENT_SECRET=test-oidc-client-secret",
+ "MCP_FRONT_ENV=development",
+}
+
+// idpEnvAzure is the set of env vars needed for Azure IDP integration tests.
+var idpEnvAzure = []string{
+ "JWT_SECRET=test-jwt-secret-32-bytes-exactly!",
+ "ENCRYPTION_KEY=test-encryption-key-32-bytes-ok!",
+ "AZURE_CLIENT_SECRET=test-azure-client-secret",
+ "MCP_FRONT_ENV=development",
+}
+
+// TestGitHubOAuthFlow tests the full OAuth flow using the GitHub IDP.
+func TestGitHubOAuthFlow(t *testing.T) {
+ cfg := buildTestConfig("http://localhost:8080", "mcp-front-github-test",
+ testGitHubOAuthConfig("test-org"),
+ map[string]any{"postgres": testPostgresServer()},
+ )
+ startMCPFront(t, writeTestConfig(t, cfg), idpEnvGitHub...)
+ waitForMCPFront(t)
+
+ accessToken := getOAuthAccessTokenForIDP(t, "http://localhost:8080/postgres", "localhost:9092")
+
+ mcpClient := NewMCPSSEClient("http://localhost:8080")
+ mcpClient.SetAuthToken(accessToken)
+
+ err := mcpClient.Connect()
+ require.NoError(t, err, "Should connect to postgres with GitHub-issued token")
+ defer mcpClient.Close()
+
+ toolsResp, err := mcpClient.SendMCPRequest("tools/list", map[string]any{})
+ require.NoError(t, err, "Should list tools with GitHub-issued token")
+
+ resultMap, ok := toolsResp["result"].(map[string]any)
+ require.True(t, ok, "Expected result in tools response")
+ tools, ok := resultMap["tools"].([]any)
+ require.True(t, ok, "Expected tools array")
+ assert.NotEmpty(t, tools, "Should have tools")
+}
+
+// TestGitHubOrgDenial verifies that users not in allowed orgs are denied.
+func TestGitHubOrgDenial(t *testing.T) {
+ cfg := buildTestConfig("http://localhost:8080", "mcp-front-github-org-deny",
+ testGitHubOAuthConfig("org-that-doesnt-match"),
+ map[string]any{"postgres": testPostgresServer()},
+ )
+ startMCPFront(t, writeTestConfig(t, cfg), idpEnvGitHub...)
+ waitForMCPFront(t)
+
+ // Register client and start auth flow
+ clientID := registerTestClient(t)
+
+ client := &http.Client{
+ CheckRedirect: func(req *http.Request, via []*http.Request) error {
+ return http.ErrUseLastResponse
+ },
+ }
+
+ // Step 1: Authorize
+ authResp, err := client.Get("http://localhost:8080/authorize?response_type=code&client_id=" + clientID +
+ "&redirect_uri=http://127.0.0.1:6274/oauth/callback&code_challenge=test-challenge&code_challenge_method=S256" +
+ "&scope=openid+email+profile&state=test-state&resource=http://localhost:8080/postgres")
+ require.NoError(t, err)
+ defer authResp.Body.Close()
+ require.Contains(t, []int{302, 303}, authResp.StatusCode)
+
+ // Step 2: Follow to GitHub fake
+ location := authResp.Header.Get("Location")
+ idpResp, err := client.Get(location)
+ require.NoError(t, err)
+ defer idpResp.Body.Close()
+
+ // Step 3: Follow callback — should get an error (access denied), not a code
+ callbackLocation := idpResp.Header.Get("Location")
+ callbackResp, err := client.Get(callbackLocation)
+ require.NoError(t, err)
+ defer callbackResp.Body.Close()
+
+ // The callback should either:
+ // - Return a 403 directly, or
+ // - Redirect to the redirect_uri with an error parameter
+ if callbackResp.StatusCode == 302 || callbackResp.StatusCode == 303 {
+ finalLocation := callbackResp.Header.Get("Location")
+ assert.Contains(t, finalLocation, "error=", "Should redirect with error for org denial")
+ } else {
+ // Direct error response
+ body, _ := io.ReadAll(callbackResp.Body)
+ assert.Contains(t, string(body), "denied", "Should contain denial message")
+ }
+}
+
+// TestOIDCOAuthFlow tests the full OAuth flow using a generic OIDC provider.
+func TestOIDCOAuthFlow(t *testing.T) {
+ cfg := buildTestConfig("http://localhost:8080", "mcp-front-oidc-test",
+ testOIDCOAuthConfig(),
+ map[string]any{"postgres": testPostgresServer()},
+ )
+ startMCPFront(t, writeTestConfig(t, cfg), idpEnvOIDC...)
+ waitForMCPFront(t)
+
+ accessToken := getOAuthAccessTokenForIDP(t, "http://localhost:8080/postgres", "localhost:9093")
+
+ mcpClient := NewMCPSSEClient("http://localhost:8080")
+ mcpClient.SetAuthToken(accessToken)
+
+ err := mcpClient.Connect()
+ require.NoError(t, err, "Should connect to postgres with OIDC-issued token")
+ defer mcpClient.Close()
+
+ toolsResp, err := mcpClient.SendMCPRequest("tools/list", map[string]any{})
+ require.NoError(t, err, "Should list tools with OIDC-issued token")
+
+ resultMap, ok := toolsResp["result"].(map[string]any)
+ require.True(t, ok, "Expected result in tools response")
+ tools, ok := resultMap["tools"].([]any)
+ require.True(t, ok, "Expected tools array")
+ assert.NotEmpty(t, tools, "Should have tools")
+}
+
+// TestAzureOAuthFlow tests the full OAuth flow using the Azure IDP (backed by the OIDC fake server).
+func TestAzureOAuthFlow(t *testing.T) {
+ cfg := buildTestConfig("http://localhost:8080", "mcp-front-azure-test",
+ testAzureOAuthConfig(),
+ map[string]any{"postgres": testPostgresServer()},
+ )
+ startMCPFront(t, writeTestConfig(t, cfg), idpEnvAzure...)
+ waitForMCPFront(t)
+
+ accessToken := getOAuthAccessTokenForIDP(t, "http://localhost:8080/postgres", "localhost:9093")
+
+ mcpClient := NewMCPSSEClient("http://localhost:8080")
+ mcpClient.SetAuthToken(accessToken)
+
+ err := mcpClient.Connect()
+ require.NoError(t, err, "Should connect to postgres with Azure-issued token")
+ defer mcpClient.Close()
+
+ toolsResp, err := mcpClient.SendMCPRequest("tools/list", map[string]any{})
+ require.NoError(t, err, "Should list tools with Azure-issued token")
+
+ resultMap, ok := toolsResp["result"].(map[string]any)
+ require.True(t, ok, "Expected result in tools response")
+ tools, ok := resultMap["tools"].([]any)
+ require.True(t, ok, "Expected tools array")
+ assert.NotEmpty(t, tools, "Should have tools")
+}
+
+// TestIDPDomainDenial verifies that domain restrictions are enforced across providers.
+func TestIDPDomainDenial(t *testing.T) {
+ tests := []struct {
+ name string
+ authConfig map[string]any
+ env []string
+ idpHost string
+ }{
+ {
+ name: "GitHub",
+ authConfig: func() map[string]any {
+ cfg := testGitHubOAuthConfig("test-org")
+ cfg["allowedDomains"] = []string{"wrong-domain.com"}
+ return cfg
+ }(),
+ env: idpEnvGitHub,
+ idpHost: "localhost:9092",
+ },
+ {
+ name: "OIDC",
+ authConfig: func() map[string]any {
+ cfg := testOIDCOAuthConfig()
+ cfg["allowedDomains"] = []string{"wrong-domain.com"}
+ return cfg
+ }(),
+ env: idpEnvOIDC,
+ idpHost: "localhost:9093",
+ },
+ {
+ name: "Azure",
+ authConfig: func() map[string]any {
+ cfg := testAzureOAuthConfig()
+ cfg["allowedDomains"] = []string{"wrong-domain.com"}
+ return cfg
+ }(),
+ env: idpEnvAzure,
+ idpHost: "localhost:9093",
+ },
+ }
+
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ cfg := buildTestConfig("http://localhost:8080", "mcp-front-domain-deny-"+tt.name,
+ tt.authConfig,
+ map[string]any{"postgres": testPostgresServer()},
+ )
+ startMCPFront(t, writeTestConfig(t, cfg), tt.env...)
+ waitForMCPFront(t)
+
+ clientID := registerTestClient(t)
+
+ client := &http.Client{
+ CheckRedirect: func(req *http.Request, via []*http.Request) error {
+ return http.ErrUseLastResponse
+ },
+ }
+
+ authResp, err := client.Get("http://localhost:8080/authorize?response_type=code&client_id=" + clientID +
+ "&redirect_uri=http://127.0.0.1:6274/oauth/callback&code_challenge=test-challenge&code_challenge_method=S256" +
+ "&scope=openid+email+profile&state=test-state&resource=http://localhost:8080/postgres")
+ require.NoError(t, err)
+ defer authResp.Body.Close()
+ require.Contains(t, []int{302, 303}, authResp.StatusCode)
+
+ location := authResp.Header.Get("Location")
+ require.Contains(t, location, tt.idpHost, "Should redirect to expected IDP")
+
+ idpResp, err := client.Get(location)
+ require.NoError(t, err)
+ defer idpResp.Body.Close()
+
+ callbackLocation := idpResp.Header.Get("Location")
+ callbackResp, err := client.Get(callbackLocation)
+ require.NoError(t, err)
+ defer callbackResp.Body.Close()
+
+ if callbackResp.StatusCode == 302 || callbackResp.StatusCode == 303 {
+ finalLocation := callbackResp.Header.Get("Location")
+ assert.Contains(t, finalLocation, "error=", "Should redirect with error for domain denial")
+ } else {
+ body, _ := io.ReadAll(callbackResp.Body)
+ assert.Contains(t, string(body), "denied", "Should contain denial message")
+ }
+ })
+ }
+}
diff --git a/integration/oauth_rfc8707_test.go b/integration/oauth_rfc8707_test.go
new file mode 100644
index 0000000..94c01f2
--- /dev/null
+++ b/integration/oauth_rfc8707_test.go
@@ -0,0 +1,198 @@
+package integration
+
+import (
+ "crypto/sha256"
+ "encoding/base64"
+ "encoding/json"
+ "net/http"
+ "net/url"
+ "testing"
+
+ "github.com/stretchr/testify/assert"
+ "github.com/stretchr/testify/require"
+)
+
+// TestRFC8707ResourceIndicators validates RFC 8707 resource indicator functionality
+func TestRFC8707ResourceIndicators(t *testing.T) {
+ startMCPFront(t, "config/config.oauth-rfc8707-test.json",
+ "JWT_SECRET=test-jwt-secret-32-bytes-exactly!",
+ "ENCRYPTION_KEY=test-encryption-key-32-bytes-ok!",
+ "GOOGLE_CLIENT_ID=test-client-id-for-oauth",
+ "GOOGLE_CLIENT_SECRET=test-client-secret-for-oauth",
+ "MCP_FRONT_ENV=development",
+ )
+
+ waitForMCPFront(t)
+
+ t.Run("BaseProtectedResourceMetadataReturns404", func(t *testing.T) {
+ // Base metadata endpoint should return 404, directing clients to per-service endpoints
+ resp, err := http.Get("http://localhost:8080/.well-known/oauth-protected-resource")
+ require.NoError(t, err)
+ defer resp.Body.Close()
+
+ assert.Equal(t, 404, resp.StatusCode, "Base protected resource metadata endpoint should return 404")
+
+ var errResp map[string]any
+ err = json.NewDecoder(resp.Body).Decode(&errResp)
+ require.NoError(t, err)
+
+ assert.Contains(t, errResp["message"], "per-service", "Error message should direct to per-service endpoints")
+ })
+
+ t.Run("PerServiceProtectedResourceMetadataEndpoint", func(t *testing.T) {
+ // Per-service metadata endpoint should return service-specific resource URI
+ resp, err := http.Get("http://localhost:8080/.well-known/oauth-protected-resource/test-sse")
+ require.NoError(t, err)
+ defer resp.Body.Close()
+
+ assert.Equal(t, 200, resp.StatusCode, "Per-service protected resource metadata endpoint should exist")
+
+ var metadata map[string]any
+ err = json.NewDecoder(resp.Body).Decode(&metadata)
+ require.NoError(t, err)
+
+ // Resource should be service-specific, not base URL
+ assert.Equal(t, "http://localhost:8080/test-sse", metadata["resource"],
+ "Resource should be service-specific URL")
+
+ authzServers, ok := metadata["authorization_servers"].([]any)
+ require.True(t, ok, "Should have authorization_servers array")
+ require.NotEmpty(t, authzServers)
+ assert.Equal(t, "http://localhost:8080", authzServers[0],
+ "Authorization server should be base issuer")
+ })
+
+ t.Run("UnknownServiceReturns404", func(t *testing.T) {
+ resp, err := http.Get("http://localhost:8080/.well-known/oauth-protected-resource/nonexistent-service")
+ require.NoError(t, err)
+ defer resp.Body.Close()
+
+ assert.Equal(t, 404, resp.StatusCode, "Unknown service should return 404")
+ })
+
+ t.Run("TokenWithResourceParameter", func(t *testing.T) {
+ clientID := registerTestClient(t)
+
+ codeVerifier := "test-code-verifier-that-is-at-least-43-characters-long"
+ h := sha256.New()
+ h.Write([]byte(codeVerifier))
+ codeChallenge := base64.RawURLEncoding.EncodeToString(h.Sum(nil))
+
+ authParams := url.Values{
+ "response_type": {"code"},
+ "client_id": {clientID},
+ "redirect_uri": {"http://127.0.0.1:6274/oauth/callback"},
+ "code_challenge": {codeChallenge},
+ "code_challenge_method": {"S256"},
+ "scope": {"openid email profile"},
+ "state": {"test-state"},
+ "resource": {"http://localhost:8080/test-sse"},
+ }
+
+ client := &http.Client{
+ CheckRedirect: func(req *http.Request, via []*http.Request) error {
+ return http.ErrUseLastResponse
+ },
+ }
+
+ authResp, err := client.Get("http://localhost:8080/authorize?" + authParams.Encode())
+ require.NoError(t, err)
+ defer authResp.Body.Close()
+
+ assert.Contains(t, []int{302, 303}, authResp.StatusCode, "Should redirect to Google OAuth")
+
+ location := authResp.Header.Get("Location")
+ googleResp, err := client.Get(location)
+ require.NoError(t, err)
+ defer googleResp.Body.Close()
+
+ callbackLocation := googleResp.Header.Get("Location")
+ callbackResp, err := client.Get(callbackLocation)
+ require.NoError(t, err)
+ defer callbackResp.Body.Close()
+
+ finalURL, err := url.Parse(callbackResp.Header.Get("Location"))
+ require.NoError(t, err)
+ authCode := finalURL.Query().Get("code")
+ require.NotEmpty(t, authCode, "Should have authorization code")
+
+ tokenParams := url.Values{
+ "grant_type": {"authorization_code"},
+ "code": {authCode},
+ "redirect_uri": {"http://127.0.0.1:6274/oauth/callback"},
+ "client_id": {clientID},
+ "code_verifier": {codeVerifier},
+ }
+
+ tokenResp, err := http.PostForm("http://localhost:8080/token", tokenParams)
+ require.NoError(t, err)
+ defer tokenResp.Body.Close()
+
+ require.Equal(t, 200, tokenResp.StatusCode, "Token exchange should succeed")
+
+ var tokenData map[string]any
+ err = json.NewDecoder(tokenResp.Body).Decode(&tokenData)
+ require.NoError(t, err)
+
+ testSSEToken := tokenData["access_token"].(string)
+ require.NotEmpty(t, testSSEToken, "Should have access token")
+
+ t.Logf("Got token with test-sse audience: %s", testSSEToken[:20]+"...")
+
+ // Verify token works for test-sse (matching audience)
+ req, _ := http.NewRequest("GET", "http://localhost:8080/test-sse/sse", nil)
+ req.Header.Set("Authorization", "Bearer "+testSSEToken)
+ req.Header.Set("Accept", "text/event-stream")
+
+ sseResp, err := client.Do(req)
+ require.NoError(t, err)
+ defer sseResp.Body.Close()
+
+ assert.Equal(t, 200, sseResp.StatusCode,
+ "Token with test-sse audience should access /test-sse/sse")
+
+ // Verify token does NOT work for test-streamable (wrong audience)
+ req, _ = http.NewRequest("GET", "http://localhost:8080/test-streamable/sse", nil)
+ req.Header.Set("Authorization", "Bearer "+testSSEToken)
+ req.Header.Set("Accept", "text/event-stream")
+
+ streamableResp, err := client.Do(req)
+ require.NoError(t, err)
+ defer streamableResp.Body.Close()
+
+ assert.Equal(t, 401, streamableResp.StatusCode,
+ "Token with test-sse audience should NOT access /test-streamable/sse")
+
+ wwwAuth := streamableResp.Header.Get("WWW-Authenticate")
+ assert.Contains(t, wwwAuth, "Bearer resource_metadata=",
+ "401 response should include RFC 9728 WWW-Authenticate header")
+ // Per RFC 9728 Section 5.2, the metadata URI should be service-specific
+ assert.Contains(t, wwwAuth, "/.well-known/oauth-protected-resource/test-streamable",
+ "401 response should point to per-service metadata endpoint")
+ })
+
+ t.Run("401ResponseIncludesServiceSpecificMetadataURI", func(t *testing.T) {
+ // Request to a protected endpoint without token should get 401
+ // with service-specific metadata URI in WWW-Authenticate header
+ client := &http.Client{
+ CheckRedirect: func(req *http.Request, via []*http.Request) error {
+ return http.ErrUseLastResponse
+ },
+ }
+
+ req, _ := http.NewRequest("GET", "http://localhost:8080/test-sse/sse", nil)
+ req.Header.Set("Accept", "text/event-stream")
+
+ resp, err := client.Do(req)
+ require.NoError(t, err)
+ defer resp.Body.Close()
+
+ assert.Equal(t, 401, resp.StatusCode, "Request without token should return 401")
+
+ wwwAuth := resp.Header.Get("WWW-Authenticate")
+ assert.Contains(t, wwwAuth, "Bearer resource_metadata=",
+ "401 response should include RFC 9728 WWW-Authenticate header")
+ assert.Contains(t, wwwAuth, "/.well-known/oauth-protected-resource/test-sse",
+ "401 response should point to test-sse specific metadata endpoint")
+ })
+}
diff --git a/integration/oauth_service_test.go b/integration/oauth_service_test.go
new file mode 100644
index 0000000..20a8ca1
--- /dev/null
+++ b/integration/oauth_service_test.go
@@ -0,0 +1,88 @@
+package integration
+
+import (
+ "io"
+ "net/http"
+ "net/http/cookiejar"
+ "testing"
+
+ "github.com/stretchr/testify/assert"
+ "github.com/stretchr/testify/require"
+)
+
+// TestServiceOAuthIntegration validates the complete OAuth flow for external services
+func TestServiceOAuthIntegration(t *testing.T) {
+ // Start fake service OAuth provider on port 9091
+ fakeService := NewFakeServiceOAuthServer("9091")
+ err := fakeService.Start()
+ require.NoError(t, err)
+ defer func() { _ = fakeService.Stop() }()
+
+ // Start mcp-front with OAuth service configuration
+ startMCPFront(t, "config/config.oauth-service-integration-test.json",
+ "JWT_SECRET=demo-jwt-secret-32-bytes-exactly!",
+ "ENCRYPTION_KEY=test-encryption-key-32-bytes-ok!",
+ "GOOGLE_CLIENT_ID=test-client-id-oauth",
+ "GOOGLE_CLIENT_SECRET=test-client-secret-oauth",
+ "TEST_SERVICE_CLIENT_ID=service-client-id",
+ "TEST_SERVICE_CLIENT_SECRET=service-client-secret",
+ "MCP_FRONT_ENV=development",
+ )
+
+ if !waitForHealthCheck(30) {
+ t.Fatal("Server failed to start")
+ }
+
+ // For this test, we use browser SSO instead of OAuth client flow
+ // This simulates a user in the browser connecting services
+ jar, _ := cookiejar.New(nil)
+ client := &http.Client{Jar: jar}
+
+ // Complete Google OAuth to get browser session
+ // Access /my/tokens which triggers SSO flow
+ resp, err := client.Get("http://localhost:8080/my/tokens")
+ require.NoError(t, err)
+ defer resp.Body.Close()
+
+ // Should have completed SSO and landed on /my/tokens
+ require.Equal(t, http.StatusOK, resp.StatusCode)
+
+ t.Run("ServiceOAuthConnectFlow", func(t *testing.T) {
+ // User clicks "Connect" for the service
+ req, _ := http.NewRequest("GET", "http://localhost:8080/oauth/connect?service=test-service", nil)
+
+ resp, err := client.Do(req)
+ require.NoError(t, err)
+ defer resp.Body.Close()
+
+ // Should complete OAuth flow and redirect back with success
+ // The http.Client automatically follows redirects:
+ // 1. /oauth/connect → redirects to localhost:9091/oauth/authorize
+ // 2. Fake service → redirects to /oauth/callback/test-service?code=...
+ // 3. Callback → stores token, redirects to /my/tokens with success message
+
+ body, _ := io.ReadAll(resp.Body)
+ bodyStr := string(body)
+
+ // Final page should show success
+ assert.Contains(t, bodyStr, "Successfully connected", "Should show success message after OAuth flow")
+ assert.Contains(t, bodyStr, "Test OAuth Service", "Should mention service name")
+ })
+
+ t.Run("ConnectedServiceShownOnTokenPage", func(t *testing.T) {
+ // After OAuth connection, service should appear as connected
+ req, _ := http.NewRequest("GET", "http://localhost:8080/my/tokens", nil)
+
+ resp, err := client.Do(req)
+ require.NoError(t, err)
+ defer resp.Body.Close()
+
+ body, _ := io.ReadAll(resp.Body)
+ bodyStr := string(body)
+
+ // Should show the service with connected status
+ assert.Contains(t, bodyStr, "Test OAuth Service")
+ // OAuth-connected services show disconnect button, not connect
+ assert.Contains(t, bodyStr, "Disconnect", "OAuth-connected service should show Disconnect button")
+ })
+}
diff --git a/integration/oauth_test.go b/integration/oauth_test.go
deleted file mode 100644
index 0b2949f..0000000
--- a/integration/oauth_test.go
+++ /dev/null
@@ -1,1577 +0,0 @@
-package integration
-
-import (
- "bufio"
- "bytes"
- "context"
- "crypto/sha256"
- "encoding/base64"
- "encoding/json"
- "fmt"
- "io"
- "net/http"
- "net/http/cookiejar"
- "net/url"
- "os"
- "os/exec"
- "regexp"
- "strings"
- "testing"
- "time"
-
- "github.com/stretchr/testify/assert"
- "github.com/stretchr/testify/mock"
- "github.com/stretchr/testify/require"
-)
-
-// TestBasicOAuthFlow tests the basic OAuth server functionality
-func TestBasicOAuthFlow(t *testing.T) {
- // Start mcp-front with OAuth config
- startMCPFront(t, "config/config.oauth-test.json",
- "JWT_SECRET=test-jwt-secret-32-bytes-exactly!",
- "ENCRYPTION_KEY=test-encryption-key-32-bytes-ok!",
- "GOOGLE_CLIENT_ID=test-client-id-for-oauth",
- "GOOGLE_CLIENT_SECRET=test-client-secret-for-oauth",
- "MCP_FRONT_ENV=development",
- "GOOGLE_OAUTH_AUTH_URL=http://localhost:9090/auth",
- "GOOGLE_OAUTH_TOKEN_URL=http://localhost:9090/token",
- "GOOGLE_USERINFO_URL=http://localhost:9090/userinfo",
- )
-
- // Wait for startup
- waitForMCPFront(t)
-
- // Test OAuth discovery
- resp, err := http.Get("http://localhost:8080/.well-known/oauth-authorization-server")
- require.NoError(t, err, "Failed to get OAuth discovery")
- defer resp.Body.Close()
-
- assert.Equal(t, 200, resp.StatusCode, "OAuth discovery failed")
-
- var discovery map[string]any
- err = json.NewDecoder(resp.Body).Decode(&discovery)
- require.NoError(t, err, "Failed to decode discovery")
-
- // Verify required endpoints
- requiredEndpoints := []string{
- "issuer",
- "authorization_endpoint",
- "token_endpoint",
- "registration_endpoint",
- }
-
- for _, endpoint := range requiredEndpoints {
- _, ok := discovery[endpoint]
- assert.True(t, ok, "Missing required endpoint: %s", endpoint)
- }
-
- // Verify client_secret_post is advertised
- authMethods, ok := discovery["token_endpoint_auth_methods_supported"].([]any)
- assert.True(t, ok, "token_endpoint_auth_methods_supported should be present")
-
- var hasNone, hasClientSecretPost bool
- for _, method := range authMethods {
- if method == "none" {
- hasNone = true
- }
- if method == "client_secret_post" {
- hasClientSecretPost = true
- }
- }
- assert.True(t, hasNone, "Should support 'none' auth method for public clients")
- assert.True(t, hasClientSecretPost, "Should support 'client_secret_post' auth method for confidential clients")
-}
-
-// TestJWTSecretValidation tests JWT secret length requirements
-func TestJWTSecretValidation(t *testing.T) {
- tests := []struct {
- name string
- secret string
- shouldFail bool
- }{
- {"Short 3-byte secret", "123", true},
- {"Short 16-byte secret", "sixteen-byte-key", true},
- {"Valid 32-byte secret", "demo-jwt-secret-32-bytes-exactly!", false},
- {"Long 64-byte secret", "demo-jwt-secret-32-bytes-exactly!demo-jwt-secret-32-bytes-exactly!", false},
- }
-
- for _, tt := range tests {
- t.Run(tt.name, func(t *testing.T) {
-
- // Start mcp-front with specific JWT secret
- mcpCmd := exec.Command("../cmd/mcp-front/mcp-front", "-config", "config/config.oauth-test.json")
- mcpCmd.Env = []string{
- "PATH=" + os.Getenv("PATH"),
- "JWT_SECRET=" + tt.secret,
- "ENCRYPTION_KEY=test-encryption-key-32-bytes-ok!",
- "GOOGLE_CLIENT_ID=test-client-id",
- "GOOGLE_CLIENT_SECRET=test-client-secret",
- "MCP_FRONT_ENV=development",
- }
-
- // Capture stderr
- stderrPipe, _ := mcpCmd.StderrPipe()
- scanner := bufio.NewScanner(stderrPipe)
-
- if err := mcpCmd.Start(); err != nil {
- t.Fatalf("Failed to start mcp-front: %v", err)
- }
-
- // Read stderr to check for errors
- errorFound := false
- go func() {
- for scanner.Scan() {
- line := scanner.Text()
- if contains(line, "JWT secret must be at least") {
- errorFound = true
- }
- }
- }()
-
- // Give it time to start or fail
- time.Sleep(2 * time.Second)
-
- // Check if it's running
- healthy := checkHealth()
-
- // Clean up
- if mcpCmd.Process != nil {
- _ = mcpCmd.Process.Kill()
- _ = mcpCmd.Wait()
- }
-
- if tt.shouldFail {
- assert.False(t, healthy && !errorFound, "Expected failure with short JWT secret but server started successfully")
- } else {
- assert.True(t, healthy, "Expected success with valid JWT secret but server failed to start")
- }
- })
- }
-}
-
-// TestClientRegistration tests dynamic client registration (RFC 7591)
-func TestClientRegistration(t *testing.T) {
- // Start OAuth server
- mcpCmd := startOAuthServer(t, map[string]string{
- "MCP_FRONT_ENV": "development",
- })
- defer stopServer(mcpCmd)
-
- if !waitForHealthCheck(30) {
- t.Fatal("OAuth server failed to start")
- }
-
- t.Run("PublicClientRegistration", func(t *testing.T) {
- // Register a public client (no secret)
- clientReq := map[string]any{
- "redirect_uris": []string{"http://127.0.0.1:6274/oauth/callback/debug"},
- "scope": "read write",
- }
-
- body, _ := json.Marshal(clientReq)
- resp, err := http.Post(
- "http://localhost:8080/register",
- "application/json",
- bytes.NewBuffer(body),
- )
- if err != nil {
- t.Fatalf("Failed to register client: %v", err)
- }
- defer resp.Body.Close()
-
- if resp.StatusCode != 201 {
- body, _ := io.ReadAll(resp.Body)
- t.Fatalf("Client registration failed with status %d: %s", resp.StatusCode, string(body))
- }
-
- var clientResp map[string]any
- if err := json.NewDecoder(resp.Body).Decode(&clientResp); err != nil {
- t.Fatalf("Failed to decode response: %v", err)
- }
-
- // Verify response
- if clientResp["client_id"] == "" {
- t.Error("Client ID should not be empty")
- }
- if clientResp["client_secret"] != nil {
- t.Error("Public client should not have a secret")
- }
- if scope, ok := clientResp["scope"].(string); !ok || scope != "read write" {
- t.Errorf("Expected scope 'read write' as string, got: %v", clientResp["scope"])
- }
- })
-
- t.Run("MultipleRegistrations", func(t *testing.T) {
- // Register multiple clients and verify they get different IDs
- var clientIDs []string
-
- for i := range 3 {
- clientReq := map[string]any{
- "redirect_uris": []string{fmt.Sprintf("http://example.com/callback%d", i)},
- "scope": "read",
- }
-
- body, _ := json.Marshal(clientReq)
- resp, err := http.Post(
- "http://localhost:8080/register",
- "application/json",
- bytes.NewBuffer(body),
- )
- if err != nil {
- t.Fatalf("Failed to register client %d: %v", i, err)
- }
- defer resp.Body.Close()
-
- var clientResp map[string]any
- _ = json.NewDecoder(resp.Body).Decode(&clientResp)
- clientIDs = append(clientIDs, clientResp["client_id"].(string))
- }
-
- // Verify all IDs are unique
- for i := 0; i < len(clientIDs); i++ {
- for j := i + 1; j < len(clientIDs); j++ {
- if clientIDs[i] == clientIDs[j] {
- t.Errorf("Client IDs should be unique, but got duplicate: %s", clientIDs[i])
- }
- }
- }
-
- })
-
- t.Run("ConfidentialClientRegistration", func(t *testing.T) {
- // Register a confidential client with client_secret_post
- clientReq := map[string]any{
- "redirect_uris": []string{"https://example.com/callback"},
- "scope": "read write",
- "token_endpoint_auth_method": "client_secret_post",
- }
-
- body, _ := json.Marshal(clientReq)
- resp, err := http.Post(
- "http://localhost:8080/register",
- "application/json",
- bytes.NewBuffer(body),
- )
- if err != nil {
- t.Fatalf("Failed to register confidential client: %v", err)
- }
- defer resp.Body.Close()
-
- if resp.StatusCode != 201 {
- body, _ := io.ReadAll(resp.Body)
- t.Fatalf("Confidential client registration failed with status %d: %s", resp.StatusCode, string(body))
- }
-
- var clientResp map[string]any
- if err := json.NewDecoder(resp.Body).Decode(&clientResp); err != nil {
- t.Fatalf("Failed to decode response: %v", err)
- }
-
- // Verify response includes client_secret
- if clientResp["client_id"] == "" {
- t.Error("Client ID should not be empty")
- }
- clientSecret, ok := clientResp["client_secret"].(string)
- if !ok || clientSecret == "" {
- t.Error("Confidential client should receive a client_secret")
- }
- // Verify secret has reasonable length (base64 of 32 bytes)
- if len(clientSecret) < 40 {
- t.Errorf("Client secret seems too short: %d chars", len(clientSecret))
- }
-
- tokenAuthMethod, ok := clientResp["token_endpoint_auth_method"].(string)
- if !ok || tokenAuthMethod != "client_secret_post" {
- t.Errorf("Expected token_endpoint_auth_method 'client_secret_post', got: %v", clientResp["token_endpoint_auth_method"])
- }
-
- // Verify scope is returned as string
- if scope, ok := clientResp["scope"].(string); !ok || scope != "read write" {
- t.Errorf("Expected scope 'read write' as string, got: %v", clientResp["scope"])
- }
- })
-
- t.Run("PublicVsConfidentialClients", func(t *testing.T) {
- // Test that public clients don't get secrets and confidential ones do
-
- // First, create a public client
- publicReq := map[string]any{
- "redirect_uris": []string{"https://public.example.com/callback"},
- "scope": "read",
- // No token_endpoint_auth_method specified - defaults to "none"
- }
-
- body, _ := json.Marshal(publicReq)
- resp, err := http.Post(
- "http://localhost:8080/register",
- "application/json",
- bytes.NewBuffer(body),
- )
- require.NoError(t, err)
- defer resp.Body.Close()
-
- var publicResp map[string]any
- _ = json.NewDecoder(resp.Body).Decode(&publicResp)
-
- // Verify public client has no secret
- if _, hasSecret := publicResp["client_secret"]; hasSecret {
- t.Error("Public client should not have a secret")
- }
- if authMethod := publicResp["token_endpoint_auth_method"]; authMethod != "none" {
- t.Errorf("Public client should have auth method 'none', got: %v", authMethod)
- }
-
- // Now create a confidential client
- confidentialReq := map[string]any{
- "redirect_uris": []string{"https://confidential.example.com/callback"},
- "scope": "read write",
- "token_endpoint_auth_method": "client_secret_post",
- }
-
- body, _ = json.Marshal(confidentialReq)
- resp, err = http.Post(
- "http://localhost:8080/register",
- "application/json",
- bytes.NewBuffer(body),
- )
- require.NoError(t, err)
- defer resp.Body.Close()
-
- var confResp map[string]any
- _ = json.NewDecoder(resp.Body).Decode(&confResp)
-
- // Verify confidential client has a secret
- if secret, ok := confResp["client_secret"].(string); !ok || secret == "" {
- t.Error("Confidential client should have a secret")
- }
- if authMethod := confResp["token_endpoint_auth_method"]; authMethod != "client_secret_post" {
- t.Errorf("Confidential client should have auth method 'client_secret_post', got: %v", authMethod)
- }
- })
-}
-
-// TestUserTokenFlow tests the user token management functionality with browser-based SSO
-// This test expects the /my/* routes to work with Google SSO (session-based auth),
-// not Bearer token auth.
-func TestUserTokenFlow(t *testing.T) {
- // Start OAuth server with user token configuration
- mcpCmd := startOAuthServerWithTokenConfig(t)
- defer stopServer(mcpCmd)
-
- if !waitForHealthCheck(30) {
- t.Fatal("Server failed to start")
- }
-
- // Create a client with cookie jar to simulate browser behavior
- jar, _ := cookiejar.New(nil)
- client := &http.Client{
- Jar: jar,
- CheckRedirect: func(req *http.Request, via []*http.Request) error {
- // Allow up to 10 redirects
- if len(via) >= 10 {
- return fmt.Errorf("too many redirects")
- }
- return nil
- },
- }
-
- t.Run("UnauthenticatedRedirectsToSSO", func(t *testing.T) {
- // Create a client that doesn't follow redirects to test the initial redirect
- noRedirectClient := &http.Client{
- CheckRedirect: func(req *http.Request, via []*http.Request) error {
- return http.ErrUseLastResponse
- },
- }
-
- // Try to access /my/tokens without authentication
- resp, err := noRedirectClient.Get("http://localhost:8080/my/tokens")
- require.NoError(t, err)
- defer resp.Body.Close()
-
- // Should get a redirect response
- assert.Equal(t, http.StatusFound, resp.StatusCode, "Should get redirect status")
-
- // Check the redirect location
- location := resp.Header.Get("Location")
- assert.Contains(t, location, "localhost:9090/auth", "Should redirect to Google OAuth")
- assert.Contains(t, location, "client_id=", "Should include client_id")
- assert.Contains(t, location, "redirect_uri=", "Should include redirect_uri")
- // Extract and validate the state parameter
- parsedURL, err := url.Parse(location)
- require.NoError(t, err)
- stateParam := parsedURL.Query().Get("state")
- require.NotEmpty(t, stateParam, "State parameter should be present")
-
- // State format: "browser:" prefix followed by signed token
- // We verify structure but not internal format (that's implementation detail)
- assert.True(t, strings.HasPrefix(stateParam, "browser:"), "State should start with browser:")
- assert.Greater(t, len(stateParam), len("browser:"), "State should have content after prefix")
-
- // Verify state contains signature (has dot separator indicating signed data)
- stateContent := strings.TrimPrefix(stateParam, "browser:")
- assert.Contains(t, stateContent, ".", "Signed state should contain signature separator")
- })
-
- t.Run("AuthenticatedUserCanAccessTokens", func(t *testing.T) {
- // The client with cookie jar will automatically follow the full SSO flow:
- // 1. GET /my/tokens -> redirect to Google OAuth
- // 2. Google OAuth redirects to /oauth/callback with code
- // 3. Callback sets session cookie and redirects to /my/tokens
- // 4. Client follows redirect with cookie and gets the page
-
- resp, err := client.Get("http://localhost:8080/my/tokens")
- require.NoError(t, err)
- defer resp.Body.Close()
-
- // After following all redirects, we should be at /my/tokens with 200 OK
- assert.Equal(t, http.StatusOK, resp.StatusCode, "Should access /my/tokens after SSO")
- finalURL := resp.Request.URL.String()
- assert.Contains(t, finalURL, "/my/tokens", "Should end up at /my/tokens after SSO")
-
- // Read response body
- body, err := io.ReadAll(resp.Body)
- require.NoError(t, err)
- bodyStr := string(body)
-
- // Should show both services without tokens
- assert.Contains(t, bodyStr, "Notion", "Expected Notion service in response")
- assert.Contains(t, bodyStr, "GitHub", "Expected GitHub service in response")
- })
-
- t.Run("SetTokenWithValidation", func(t *testing.T) {
- // Assume we're already authenticated from previous test
- // Get CSRF token first
- resp, err := client.Get("http://localhost:8080/my/tokens")
- require.NoError(t, err)
- defer resp.Body.Close()
-
- body, err := io.ReadAll(resp.Body)
- require.NoError(t, err)
-
- // Extract CSRF token from response
- csrfToken := extractCSRFToken(t, string(body))
-
- // Try to set invalid Notion token
- form := url.Values{
- "service": []string{"notion"},
- "token": []string{"invalid-token"},
- "csrf_token": []string{csrfToken},
- }
-
- req, _ := http.NewRequest("POST", "http://localhost:8080/my/tokens/set", strings.NewReader(form.Encode()))
- req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
-
- // Use custom client that doesn't follow redirects for this test
- noRedirectClient := &http.Client{
- Jar: jar, // Use same cookie jar
- CheckRedirect: func(req *http.Request, via []*http.Request) error {
- return http.ErrUseLastResponse
- },
- }
-
- resp, err = noRedirectClient.Do(req)
- require.NoError(t, err)
- defer resp.Body.Close()
-
- // Should redirect with error
- assert.Equal(t, http.StatusSeeOther, resp.StatusCode, "Expected redirect")
- location := resp.Header.Get("Location")
- assert.Contains(t, location, "error", "Expected error in redirect")
-
- // Get new CSRF token
- resp, err = client.Get("http://localhost:8080/my/tokens")
- require.NoError(t, err)
- defer resp.Body.Close()
-
- body, err = io.ReadAll(resp.Body)
- require.NoError(t, err)
- csrfToken = extractCSRFToken(t, string(body))
-
- // Set valid Notion token (regex expects exactly 43 chars after "secret_")
- form = url.Values{
- "service": {"notion"},
- "token": {"secret_1234567890123456789012345678901234567890123"},
- "csrf_token": {csrfToken},
- }
-
- req, _ = http.NewRequest("POST", "http://localhost:8080/my/tokens/set", strings.NewReader(form.Encode()))
- req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
-
- resp, err = noRedirectClient.Do(req)
- require.NoError(t, err)
- defer resp.Body.Close()
-
- // Should redirect with success
- assert.Equal(t, http.StatusSeeOther, resp.StatusCode, "Expected redirect")
- location = resp.Header.Get("Location")
- assert.Contains(t, location, "success", "Expected success in redirect")
- })
-}
-
-// TestStateParameterHandling tests OAuth state parameter requirements
-func TestStateParameterHandling(t *testing.T) {
- tests := []struct {
- name string
- environment string
- state string
- expectError bool
- }{
- {"Production without state", "production", "", true},
- {"Production with state", "production", "secure-random-state", false},
- {"Development without state", "development", "", false}, // Should auto-generate
- {"Development with state", "development", "test-state", false},
- }
-
- for _, tt := range tests {
- // capture range variable
- t.Run(tt.name, func(t *testing.T) {
- // Start server with specific environment
- mcpCmd := startOAuthServer(t, map[string]string{
- "MCP_FRONT_ENV": tt.environment,
- })
- defer stopServer(mcpCmd)
-
- if !waitForHealthCheck(10) {
- t.Fatal("Server failed to start")
- }
-
- // Register a client first
- clientID := registerTestClient(t)
-
- // Create authorization request
- params := url.Values{
- "response_type": {"code"},
- "client_id": {clientID},
- "redirect_uri": {"http://127.0.0.1:6274/oauth/callback"},
- "code_challenge": {"test-challenge"},
- "code_challenge_method": {"S256"},
- "scope": {"read write"},
- }
- if tt.state != "" {
- params.Set("state", tt.state)
- }
-
- authURL := fmt.Sprintf("http://localhost:8080/authorize?%s", params.Encode())
-
- // Use a client that doesn't follow redirects
- client := &http.Client{
- CheckRedirect: func(req *http.Request, via []*http.Request) error {
- return http.ErrUseLastResponse
- },
- }
- resp, err := client.Get(authURL)
- if err != nil {
- t.Fatalf("Authorization request failed: %v", err)
- }
- defer resp.Body.Close()
-
- if tt.expectError {
- // OAuth errors are returned as redirects with error parameters
- if resp.StatusCode == 302 || resp.StatusCode == 303 {
- location := resp.Header.Get("Location")
- if strings.Contains(location, "error=") {
- } else {
- t.Errorf("Expected error redirect for %s, got redirect without error", tt.name)
- }
- } else if resp.StatusCode >= 400 {
- } else {
- t.Errorf("Expected error for %s, got status %d", tt.name, resp.StatusCode)
- }
- } else {
- if resp.StatusCode == 302 || resp.StatusCode == 303 {
- location := resp.Header.Get("Location")
- if strings.Contains(location, "error=") {
- t.Errorf("Unexpected error redirect for %s: %s", tt.name, location)
- }
- } else if resp.StatusCode < 400 {
- } else {
- body, _ := io.ReadAll(resp.Body)
- t.Errorf("Expected success for %s, got status %d: %s", tt.name, resp.StatusCode, string(body))
- }
- }
- })
- }
-}
-
-// TestEnvironmentModes tests development vs production mode differences
-func TestEnvironmentModes(t *testing.T) {
- t.Run("DevelopmentMode", func(t *testing.T) {
- mcpCmd := startOAuthServer(t, map[string]string{
- "MCP_FRONT_ENV": "development",
- })
- defer stopServer(mcpCmd)
-
- if !waitForHealthCheck(30) {
- t.Fatal("Server failed to start")
- }
-
- // In development mode, missing state should be auto-generated
- clientID := registerTestClient(t)
-
- params := url.Values{
- "response_type": {"code"},
- "client_id": {clientID},
- "redirect_uri": {"http://127.0.0.1:6274/oauth/callback"},
- "code_challenge": {"test-challenge"},
- "code_challenge_method": {"S256"},
- "scope": {"read"},
- // Intentionally omitting state parameter
- }
-
- // Use a client that doesn't follow redirects
- client := &http.Client{
- CheckRedirect: func(req *http.Request, via []*http.Request) error {
- return http.ErrUseLastResponse
- },
- }
- resp, err := client.Get("http://localhost:8080/authorize?" + params.Encode())
- if err != nil {
- t.Fatalf("Failed to make auth request: %v", err)
- }
- defer resp.Body.Close()
-
- // Should redirect (302) not error
- if resp.StatusCode >= 400 && resp.StatusCode != 302 {
- t.Errorf("Development mode should handle missing state, got status %d", resp.StatusCode)
- }
- })
-
- t.Run("ProductionMode", func(t *testing.T) {
- mcpCmd := startOAuthServer(t, map[string]string{
- "MCP_FRONT_ENV": "production",
- })
- defer stopServer(mcpCmd)
-
- if !waitForHealthCheck(30) {
- t.Fatal("Server failed to start")
- }
-
- // In production mode, state should be required
- clientID := registerTestClient(t)
-
- params := url.Values{
- "response_type": {"code"},
- "client_id": {clientID},
- "redirect_uri": {"http://127.0.0.1:6274/oauth/callback"},
- "code_challenge": {"test-challenge"},
- "code_challenge_method": {"S256"},
- "scope": {"read"},
- // Intentionally omitting state parameter
- }
-
- // Use a client that doesn't follow redirects
- client := &http.Client{
- CheckRedirect: func(req *http.Request, via []*http.Request) error {
- return http.ErrUseLastResponse
- },
- }
- resp, err := client.Get("http://localhost:8080/authorize?" + params.Encode())
- if err != nil {
- t.Fatalf("Failed to make auth request: %v", err)
- }
- defer resp.Body.Close()
-
- // Should error - OAuth errors are returned as redirects
- if resp.StatusCode == 302 || resp.StatusCode == 303 {
- location := resp.Header.Get("Location")
- if strings.Contains(location, "error=") {
- } else {
- t.Errorf("Expected error redirect in production mode, got redirect without error")
- }
- } else if resp.StatusCode >= 400 {
- } else {
- t.Errorf("Production mode should require state parameter, got status %d", resp.StatusCode)
- }
- })
-}
-
-// TestOAuthEndpoints tests all OAuth endpoints comprehensively
-func TestOAuthEndpoints(t *testing.T) {
- mcpCmd := startOAuthServer(t, map[string]string{
- "MCP_FRONT_ENV": "development",
- })
- defer stopServer(mcpCmd)
-
- if !waitForHealthCheck(10) {
- t.Fatal("Server failed to start")
- }
-
- t.Run("Discovery", func(t *testing.T) {
- resp, err := http.Get("http://localhost:8080/.well-known/oauth-authorization-server")
- if err != nil {
- t.Fatalf("Discovery request failed: %v", err)
- }
- defer resp.Body.Close()
-
- if resp.StatusCode != 200 {
- t.Fatalf("Discovery failed with status %d", resp.StatusCode)
- }
-
- var discovery map[string]any
- if err := json.NewDecoder(resp.Body).Decode(&discovery); err != nil {
- t.Fatalf("Failed to decode discovery response: %v", err)
- }
-
- // Verify all required fields
- required := []string{
- "issuer",
- "authorization_endpoint",
- "token_endpoint",
- "registration_endpoint",
- "response_types_supported",
- "grant_types_supported",
- "code_challenge_methods_supported",
- }
-
- for _, field := range required {
- if _, ok := discovery[field]; !ok {
- t.Errorf("Missing required discovery field: %s", field)
- }
- }
-
- })
-
- t.Run("HealthCheck", func(t *testing.T) {
- resp, err := http.Get("http://localhost:8080/health")
- if err != nil {
- t.Fatalf("Health check failed: %v", err)
- }
- defer resp.Body.Close()
-
- if resp.StatusCode != 200 {
- t.Errorf("Health check should return 200, got %d", resp.StatusCode)
- }
-
- var health map[string]string
- if err := json.NewDecoder(resp.Body).Decode(&health); err != nil {
- t.Fatalf("Failed to decode health response: %v", err)
- }
- if health["status"] != "ok" {
- t.Errorf("Expected status 'ok', got '%s'", health["status"])
- }
-
- })
-}
-
-// TestCORSHeaders tests CORS headers for Claude.ai compatibility
-func TestCORSHeaders(t *testing.T) {
- mcpCmd := startOAuthServer(t, map[string]string{
- "MCP_FRONT_ENV": "development",
- })
- defer stopServer(mcpCmd)
-
- if !waitForHealthCheck(10) {
- t.Fatal("Server failed to start")
- }
-
- // Test preflight request
- req, _ := http.NewRequest("OPTIONS", "http://localhost:8080/register", nil)
- req.Header.Set("Origin", "https://claude.ai")
- req.Header.Set("Access-Control-Request-Method", "POST")
- req.Header.Set("Access-Control-Request-Headers", "content-type")
-
- resp, err := http.DefaultClient.Do(req)
- if err != nil {
- t.Fatalf("Preflight request failed: %v", err)
- }
- defer resp.Body.Close()
-
- if resp.StatusCode != 200 {
- t.Errorf("Preflight should return 200, got %d", resp.StatusCode)
- }
-
- // Check CORS headers
- expectedHeaders := map[string]string{
- "Access-Control-Allow-Origin": "https://claude.ai",
- "Access-Control-Allow-Methods": "GET, POST, OPTIONS",
- "Access-Control-Allow-Headers": "Content-Type, Authorization, Cache-Control, mcp-protocol-version",
- }
-
- for header, expected := range expectedHeaders {
- actual := resp.Header.Get(header)
- if actual != expected {
- t.Errorf("Expected %s: '%s', got '%s'", header, expected, actual)
- }
- }
-
-}
-
-// TestToolAdvertisementWithUserTokens tests that tools are advertised even without user tokens
-// but fail gracefully when invoked without the required token, and succeed with the token
-func TestToolAdvertisementWithUserTokens(t *testing.T) {
- // Start OAuth server with user token configuration
- startMCPFront(t, "config/config.oauth-usertoken-tools-test.json",
- "JWT_SECRET=demo-jwt-secret-32-bytes-exactly!",
- "ENCRYPTION_KEY=test-encryption-key-32-bytes-ok!",
- "GOOGLE_CLIENT_ID=test-client-id-oauth",
- "GOOGLE_CLIENT_SECRET=test-client-secret-oauth",
- "GOOGLE_OAUTH_AUTH_URL=http://localhost:9090/auth",
- "GOOGLE_OAUTH_TOKEN_URL=http://localhost:9090/token",
- "GOOGLE_USERINFO_URL=http://localhost:9090/userinfo",
- "MCP_FRONT_ENV=development",
- "LOG_LEVEL=debug",
- )
-
- if !waitForHealthCheck(30) {
- t.Fatal("Server failed to start")
- }
-
- // Complete OAuth flow to get a valid access token
- accessToken := getOAuthAccessToken(t, "http://localhost:8080/postgres")
-
- t.Run("ToolsAdvertisedWithoutToken", func(t *testing.T) {
- // Create MCP client with OAuth token
- mcpClient := NewMCPSSEClient("http://localhost:8080")
- mcpClient.SetAuthToken(accessToken)
-
- // Connect to postgres SSE endpoint
- err := mcpClient.Connect()
- require.NoError(t, err, "Should connect to postgres SSE endpoint without user token")
- defer mcpClient.Close()
-
- // Request tools list
- toolsResp, err := mcpClient.SendMCPRequest("tools/list", map[string]any{})
- require.NoError(t, err, "Should list tools without user token")
-
- // Verify we got tools
- resultMap, ok := toolsResp["result"].(map[string]any)
- require.True(t, ok, "Expected result in tools response")
-
- tools, ok := resultMap["tools"].([]any)
- require.True(t, ok, "Expected tools array in result")
- assert.NotEmpty(t, tools, "Should have tools advertised")
-
- // Check for common postgres tools
- var toolNames []string
- for _, tool := range tools {
- if toolMap, ok := tool.(map[string]any); ok {
- if name, ok := toolMap["name"].(string); ok {
- toolNames = append(toolNames, name)
- }
- }
- }
-
- assert.Contains(t, toolNames, "query", "Should have query tool")
- t.Logf("Successfully advertised tools without user token: %v", toolNames)
- })
-
- t.Run("ToolInvocationFailsWithoutToken", func(t *testing.T) {
- // Create MCP client with OAuth token
- mcpClient := NewMCPSSEClient("http://localhost:8080")
- mcpClient.SetAuthToken(accessToken)
-
- // Connect to postgres SSE endpoint
- err := mcpClient.Connect()
- require.NoError(t, err)
- defer mcpClient.Close()
-
- // Try to invoke a tool without user token
- queryParams := map[string]any{
- "name": "query",
- "arguments": map[string]any{
- "sql": "SELECT 1",
- },
- }
-
- result, err := mcpClient.SendMCPRequest("tools/call", queryParams)
- require.NoError(t, err, "Should get response even without token")
-
- // MCP protocol returns errors as successful responses with error content
- require.NotNil(t, result["result"], "Should have result in response")
-
- resultMap := result["result"].(map[string]any)
- content := resultMap["content"].([]any)
- require.NotEmpty(t, content, "Should have content in result")
-
- contentItem := content[0].(map[string]any)
- errorJSON := contentItem["text"].(string)
-
- // Parse the error JSON
- var errorData map[string]any
- err = json.Unmarshal([]byte(errorJSON), &errorData)
- require.NoError(t, err, "Error should be valid JSON")
-
- // Verify error structure
- errorInfo := errorData["error"].(map[string]any)
- assert.Equal(t, "token_required", errorInfo["code"], "Error code should be token_required")
-
- errorMessage := errorInfo["message"].(string)
- assert.Contains(t, errorMessage, "token required", "Error should mention token required")
- assert.Contains(t, errorMessage, "/my/tokens", "Error should mention token setup URL")
- assert.Contains(t, errorMessage, "Test Service", "Error should mention service name")
-
- // Verify error data
- errData := errorInfo["data"].(map[string]any)
- assert.Equal(t, "postgres", errData["service"], "Should identify the service")
- assert.Contains(t, errData["tokenSetupUrl"].(string), "/my/tokens", "Should include token setup URL")
-
- // Verify instructions
- instructions := errData["instructions"].(map[string]any)
- assert.Contains(t, instructions["ai"].(string), "CRITICAL", "Should have AI instructions")
- assert.Contains(t, instructions["human"].(string), "token required", "Should have human instructions")
- })
-
- t.Run("ToolInvocationSucceedsWithUserToken", func(t *testing.T) {
- // Step 1: GET /my/tokens to extract CSRF token
- jar, err := cookiejar.New(nil)
- require.NoError(t, err)
- client := &http.Client{
- Jar: jar, // Need cookie jar for CSRF
- CheckRedirect: func(req *http.Request, via []*http.Request) error {
- return http.ErrUseLastResponse // Don't follow redirects
- },
- }
-
- req, err := http.NewRequest("GET", "http://localhost:8080/my/tokens", nil)
- require.NoError(t, err)
- req.Header.Set("Authorization", "Bearer "+accessToken)
-
- resp, err := client.Do(req)
- require.NoError(t, err)
- defer resp.Body.Close()
-
- // Check if we got the page or a redirect
- if resp.StatusCode == 302 || resp.StatusCode == 303 {
- // Follow the redirect
- location := resp.Header.Get("Location")
- t.Logf("Got redirect to: %s", location)
-
- // Allow redirects for this request
- client = &http.Client{
- Jar: jar,
- }
-
- req, err = http.NewRequest("GET", "http://localhost:8080/my/tokens", nil)
- require.NoError(t, err)
- req.Header.Set("Authorization", "Bearer "+accessToken)
-
- resp, err = client.Do(req)
- require.NoError(t, err)
- defer resp.Body.Close()
- }
-
- require.Equal(t, 200, resp.StatusCode, "Should be able to access token page")
-
- // Extract CSRF token from HTML
- body, err := io.ReadAll(resp.Body)
- require.NoError(t, err)
-
- // Look for the CSRF token in the form
- csrfRegex := regexp.MustCompile(`name="csrf_token" value="([^"]+)"`)
- matches := csrfRegex.FindSubmatch(body)
- require.Len(t, matches, 2, "Should find CSRF token in form")
- csrfToken := string(matches[1])
-
- // Step 2: POST to /my/tokens/set with test token
- formData := url.Values{
- "service": {"postgres"},
- "token": {"test-user-token-12345"},
- "csrf_token": {csrfToken},
- }
-
- req, err = http.NewRequest("POST", "http://localhost:8080/my/tokens/set", strings.NewReader(formData.Encode()))
- require.NoError(t, err)
- req.Header.Set("Authorization", "Bearer "+accessToken)
- req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
-
- resp, err = client.Do(req)
- require.NoError(t, err)
- defer resp.Body.Close()
-
- // Check the response - it might be 200 if following redirects
- switch resp.StatusCode {
- case 200:
- // That's fine, it means the token was set and we got the page back
- t.Log("Token set successfully, got page response")
- case 302, 303:
- // Also fine, redirect means success
- t.Log("Token set successfully, got redirect")
- default:
- body, _ := io.ReadAll(resp.Body)
- t.Fatalf("Unexpected response setting token: status=%d, body=%s", resp.StatusCode, string(body))
- }
-
- // Step 3: Now test tool invocation with the token
- mcpClient := NewMCPSSEClient("http://localhost:8080")
- mcpClient.SetAuthToken(accessToken)
-
- err = mcpClient.Connect()
- require.NoError(t, err, "Should connect to postgres SSE endpoint")
- defer mcpClient.Close()
-
- // Call the query tool with a simple query
- queryParams := map[string]any{
- "name": "query",
- "arguments": map[string]any{
- "sql": "SELECT 1 as test",
- },
- }
-
- result, err := mcpClient.SendMCPRequest("tools/call", queryParams)
- require.NoError(t, err, "Should successfully call tool with token")
-
- // Verify we got a successful result, not an error
- require.NotNil(t, result["result"], "Should have result in response")
-
- resultMap := result["result"].(map[string]any)
- content := resultMap["content"].([]any)
- require.NotEmpty(t, content, "Should have content in result")
-
- contentItem := content[0].(map[string]any)
- resultText := contentItem["text"].(string)
-
- // The result should contain actual query results, not an error
- assert.NotContains(t, resultText, "token_required", "Should not have token error")
- assert.NotContains(t, resultText, "Token Required", "Should not have token error message")
-
- // Postgres query result should contain our test value
- assert.Contains(t, resultText, "1", "Should contain query result")
-
- t.Log("Successfully invoked tool with user token")
- })
-}
-
-// Helper functions
-
-func startOAuthServer(t *testing.T, env map[string]string) *exec.Cmd {
- // Start with OAuth config
- mcpCmd := exec.Command("../cmd/mcp-front/mcp-front", "-config", "config/config.oauth-test.json")
-
- // Set default environment
- mcpCmd.Env = []string{
- "PATH=" + os.Getenv("PATH"),
- "JWT_SECRET=demo-jwt-secret-32-bytes-exactly!",
- "ENCRYPTION_KEY=test-encryption-key-32-bytes-ok!",
- "GOOGLE_CLIENT_ID=test-client-id-oauth",
- "GOOGLE_CLIENT_SECRET=test-client-secret-oauth",
- "GOOGLE_OAUTH_AUTH_URL=http://localhost:9090/auth",
- "GOOGLE_OAUTH_TOKEN_URL=http://localhost:9090/token",
- "GOOGLE_USERINFO_URL=http://localhost:9090/userinfo",
- }
-
- // Override with provided env
- for key, value := range env {
- mcpCmd.Env = append(mcpCmd.Env, fmt.Sprintf("%s=%s", key, value))
- }
-
- // Capture stderr for debugging and also output to test log
- var stderr bytes.Buffer
- mcpCmd.Stderr = io.MultiWriter(&stderr, os.Stderr)
-
- if err := mcpCmd.Start(); err != nil {
- t.Fatalf("Failed to start OAuth server: %v", err)
- }
-
- // Give a moment for immediate failures
- time.Sleep(100 * time.Millisecond)
-
- // Check if process died immediately
- if mcpCmd.ProcessState != nil {
- t.Fatalf("OAuth server died immediately: %s", stderr.String())
- }
-
- return mcpCmd
-}
-
-// startOAuthServerWithTokenConfig starts the OAuth server with user token configuration
-func startOAuthServerWithTokenConfig(t *testing.T) *exec.Cmd {
- // Start with user token config
- mcpCmd := exec.Command("../cmd/mcp-front/mcp-front", "-config", "config/config.oauth-token-test.json")
-
- // Set default environment
- mcpCmd.Env = []string{
- "PATH=" + os.Getenv("PATH"),
- "JWT_SECRET=demo-jwt-secret-32-bytes-exactly!",
- "ENCRYPTION_KEY=test-encryption-key-32-bytes-ok!",
- "GOOGLE_CLIENT_ID=test-client-id-oauth",
- "GOOGLE_CLIENT_SECRET=test-client-secret-oauth",
- "GOOGLE_OAUTH_AUTH_URL=http://localhost:9090/auth",
- "GOOGLE_OAUTH_TOKEN_URL=http://localhost:9090/token",
- "GOOGLE_USERINFO_URL=http://localhost:9090/userinfo",
- "MCP_FRONT_ENV=development",
- }
-
- // Capture stderr for debugging and also output to test log
- var stderr bytes.Buffer
- mcpCmd.Stderr = io.MultiWriter(&stderr, os.Stderr)
-
- if err := mcpCmd.Start(); err != nil {
- t.Fatalf("Failed to start OAuth server: %v", err)
- }
-
- // Give a moment for immediate failures
- time.Sleep(100 * time.Millisecond)
-
- // Check if process died immediately
- if mcpCmd.ProcessState != nil {
- t.Fatalf("OAuth server died immediately: %s", stderr.String())
- }
-
- return mcpCmd
-}
-
-func stopServer(cmd *exec.Cmd) {
- if cmd != nil && cmd.Process != nil {
- _ = cmd.Process.Kill()
- _ = cmd.Wait()
- // Give the OS time to release the port
- time.Sleep(100 * time.Millisecond)
- }
-}
-
-func waitForHealthCheck(seconds int) bool {
- for range seconds {
- if checkHealth() {
- return true
- }
- time.Sleep(1 * time.Second)
- }
- return false
-}
-
-func checkHealth() bool {
- resp, err := http.Get("http://localhost:8080/health")
- if err == nil && resp.StatusCode == 200 {
- resp.Body.Close()
- return true
- }
- if resp != nil {
- resp.Body.Close()
- }
- return false
-}
-
-func registerTestClient(t *testing.T) string {
- clientReq := map[string]any{
- "redirect_uris": []string{"http://127.0.0.1:6274/oauth/callback"},
- "scope": "openid email profile read write",
- }
-
- body, _ := json.Marshal(clientReq)
- resp, err := http.Post(
- "http://localhost:8080/register",
- "application/json",
- bytes.NewBuffer(body),
- )
- if err != nil {
- t.Fatalf("Failed to register client: %v", err)
- }
- defer resp.Body.Close()
-
- if resp.StatusCode != 201 {
- body, _ := io.ReadAll(resp.Body)
- t.Fatalf("Client registration failed: %d - %s", resp.StatusCode, string(body))
- }
-
- var clientResp map[string]any
- _ = json.NewDecoder(resp.Body).Decode(&clientResp)
- return clientResp["client_id"].(string)
-}
-
-// extractCSRFToken extracts the CSRF token from the HTML response
-func extractCSRFToken(t *testing.T, html string) string {
- // Look for
- re := regexp.MustCompile(`]+name="csrf_token"[^>]+value="([^"]+)"`)
- matches := re.FindStringSubmatch(html)
- require.GreaterOrEqual(t, len(matches), 2, "CSRF token not found in response")
- return matches[1]
-}
-
-// contains is a simple helper to check if string contains substring
-func contains(s, substr string) bool {
- return strings.Contains(s, substr)
-}
-
-// TestServiceOAuthIntegration validates the complete OAuth flow for external services
-func TestServiceOAuthIntegration(t *testing.T) {
- // Start fake service OAuth provider on port 9091
- fakeService := NewFakeServiceOAuthServer("9091")
- err := fakeService.Start()
- require.NoError(t, err)
- defer func() { _ = fakeService.Stop() }()
-
- // Start mcp-front with OAuth service configuration
- startMCPFront(t, "config/config.oauth-service-integration-test.json",
- "JWT_SECRET=demo-jwt-secret-32-bytes-exactly!",
- "ENCRYPTION_KEY=test-encryption-key-32-bytes-ok!",
- "GOOGLE_CLIENT_ID=test-client-id-oauth",
- "GOOGLE_CLIENT_SECRET=test-client-secret-oauth",
- "TEST_SERVICE_CLIENT_ID=service-client-id",
- "TEST_SERVICE_CLIENT_SECRET=service-client-secret",
- "GOOGLE_OAUTH_AUTH_URL=http://localhost:9090/auth",
- "GOOGLE_OAUTH_TOKEN_URL=http://localhost:9090/token",
- "GOOGLE_USERINFO_URL=http://localhost:9090/userinfo",
- "MCP_FRONT_ENV=development",
- )
-
- if !waitForHealthCheck(30) {
- t.Fatal("Server failed to start")
- }
-
- // For this test, we use browser SSO instead of OAuth client flow
- // This simulates a user in the browser connecting services
- jar, _ := cookiejar.New(nil)
- client := &http.Client{Jar: jar}
-
- // Complete Google OAuth to get browser session
- // Access /my/tokens which triggers SSO flow
- resp, err := client.Get("http://localhost:8080/my/tokens")
- require.NoError(t, err)
- defer resp.Body.Close()
-
- // Should have completed SSO and landed on /my/tokens
- require.Equal(t, http.StatusOK, resp.StatusCode)
-
- t.Run("ServiceOAuthConnectFlow", func(t *testing.T) {
- // User clicks "Connect" for the service
- req, _ := http.NewRequest("GET", "http://localhost:8080/oauth/connect?service=test-service", nil)
-
- resp, err := client.Do(req)
- require.NoError(t, err)
- defer resp.Body.Close()
-
- // Should complete OAuth flow and redirect back with success
- // The http.Client automatically follows redirects:
- // 1. /oauth/connect → redirects to localhost:9091/oauth/authorize
- // 2. Fake service → redirects to /oauth/callback/test-service?code=...
- // 3. Callback → stores token, redirects to /my/tokens with success message
-
- body, _ := io.ReadAll(resp.Body)
- bodyStr := string(body)
-
- // Final page should show success
- assert.Contains(t, bodyStr, "Successfully connected", "Should show success message after OAuth flow")
- assert.Contains(t, bodyStr, "Test OAuth Service", "Should mention service name")
- })
-
- t.Run("ConnectedServiceShownOnTokenPage", func(t *testing.T) {
- // After OAuth connection, service should appear as connected
- req, _ := http.NewRequest("GET", "http://localhost:8080/my/tokens", nil)
-
- resp, err := client.Do(req)
- require.NoError(t, err)
- defer resp.Body.Close()
-
- body, _ := io.ReadAll(resp.Body)
- bodyStr := string(body)
-
- // Should show the service with connected status
- assert.Contains(t, bodyStr, "Test OAuth Service")
- // OAuth-connected services show disconnect button, not connect
- assert.Contains(t, bodyStr, "Disconnect", "OAuth-connected service should show Disconnect button")
- })
-}
-
-// getOAuthAccessToken completes the OAuth flow and returns a valid access token
-func getOAuthAccessToken(t *testing.T, resource string) string {
- // Register a test client
- clientID := registerTestClient(t)
-
- // Generate PKCE challenge (must be at least 43 characters)
- codeVerifier := "test-code-verifier-that-is-at-least-43-characters-long"
- h := sha256.New()
- h.Write([]byte(codeVerifier))
- codeChallenge := base64.RawURLEncoding.EncodeToString(h.Sum(nil))
-
- // Step 1: Authorization request
- authParams := url.Values{
- "response_type": {"code"},
- "client_id": {clientID},
- "redirect_uri": {"http://127.0.0.1:6274/oauth/callback"},
- "code_challenge": {codeChallenge},
- "code_challenge_method": {"S256"},
- "scope": {"openid email profile"},
- "state": {"test-state"},
- "resource": {resource},
- }
-
- client := &http.Client{
- CheckRedirect: func(req *http.Request, via []*http.Request) error {
- return http.ErrUseLastResponse
- },
- }
-
- authResp, err := client.Get("http://localhost:8080/authorize?" + authParams.Encode())
- require.NoError(t, err)
- defer authResp.Body.Close()
-
- // Should redirect to Google OAuth (our mock)
- require.Contains(t, []int{302, 303}, authResp.StatusCode, "Should redirect to Google OAuth")
- location := authResp.Header.Get("Location")
- require.Contains(t, location, "http://localhost:9090/auth", "Should redirect to mock Google OAuth")
-
- // Parse the redirect to get the state parameter
- redirectURL, err := url.Parse(location)
- require.NoError(t, err)
- _ = redirectURL.Query().Get("state") // state is included in the redirect but not needed for this test
-
- // Step 2: Follow redirect to mock Google OAuth (which immediately redirects back)
- googleResp, err := client.Get(location)
- require.NoError(t, err)
- defer googleResp.Body.Close()
-
- // Mock Google OAuth redirects back to callback
- require.Contains(t, []int{302, 303}, googleResp.StatusCode, "Mock Google should redirect back")
- callbackLocation := googleResp.Header.Get("Location")
- require.Contains(t, callbackLocation, "/oauth/callback", "Should redirect to callback")
-
- // Step 3: Follow callback redirect
- callbackResp, err := client.Get(callbackLocation)
- require.NoError(t, err)
- defer callbackResp.Body.Close()
-
- // Should redirect to the original redirect_uri with authorization code
- require.Contains(t, []int{302, 303}, callbackResp.StatusCode, "Callback should redirect with code")
- finalLocation := callbackResp.Header.Get("Location")
-
- // Parse authorization code from final redirect
- finalURL, err := url.Parse(finalLocation)
- require.NoError(t, err)
- authCode := finalURL.Query().Get("code")
- require.NotEmpty(t, authCode, "Should have authorization code")
-
- // Step 4: Exchange code for token
- tokenParams := url.Values{
- "grant_type": {"authorization_code"},
- "code": {authCode},
- "redirect_uri": {"http://127.0.0.1:6274/oauth/callback"},
- "client_id": {clientID},
- "code_verifier": {codeVerifier},
- }
-
- tokenResp, err := http.PostForm("http://localhost:8080/token", tokenParams)
- require.NoError(t, err)
- defer tokenResp.Body.Close()
-
- if tokenResp.StatusCode != 200 {
- body, _ := io.ReadAll(tokenResp.Body)
- t.Logf("Token exchange failed with status %d: %s", tokenResp.StatusCode, string(body))
- }
-
- require.Equal(t, 200, tokenResp.StatusCode, "Token exchange should succeed")
-
- var tokenData map[string]any
- err = json.NewDecoder(tokenResp.Body).Decode(&tokenData)
- require.NoError(t, err)
-
- accessToken := tokenData["access_token"].(string)
- require.NotEmpty(t, accessToken, "Should have access token")
-
- return accessToken
-}
-
-// MockUserTokenStore mocks the UserTokenStore interface for testing
-type MockUserTokenStore struct {
- mock.Mock
-}
-
-func (m *MockUserTokenStore) GetUserToken(ctx context.Context, email, service string) (string, error) {
- args := m.Called(ctx, email, service)
- return args.String(0), args.Error(1)
-}
-
-func (m *MockUserTokenStore) SetUserToken(ctx context.Context, email, service, token string) error {
- args := m.Called(ctx, email, service, token)
- return args.Error(0)
-}
-
-func (m *MockUserTokenStore) DeleteUserToken(ctx context.Context, email, service string) error {
- args := m.Called(ctx, email, service)
- return args.Error(0)
-}
-
-func (m *MockUserTokenStore) ListUserServices(ctx context.Context, email string) ([]string, error) {
- args := m.Called(ctx, email)
- if args.Get(0) == nil {
- return nil, args.Error(1)
- }
- return args.Get(0).([]string), args.Error(1)
-}
-
-// TestRFC8707ResourceIndicators validates RFC 8707 resource indicator functionality
-func TestRFC8707ResourceIndicators(t *testing.T) {
- startMCPFront(t, "config/config.oauth-rfc8707-test.json",
- "JWT_SECRET=test-jwt-secret-32-bytes-exactly!",
- "ENCRYPTION_KEY=test-encryption-key-32-bytes-ok!",
- "GOOGLE_CLIENT_ID=test-client-id-for-oauth",
- "GOOGLE_CLIENT_SECRET=test-client-secret-for-oauth",
- "MCP_FRONT_ENV=development",
- "GOOGLE_OAUTH_AUTH_URL=http://localhost:9090/auth",
- "GOOGLE_OAUTH_TOKEN_URL=http://localhost:9090/token",
- "GOOGLE_USERINFO_URL=http://localhost:9090/userinfo",
- )
-
- waitForMCPFront(t)
-
- t.Run("BaseProtectedResourceMetadataReturns404", func(t *testing.T) {
- // Base metadata endpoint should return 404, directing clients to per-service endpoints
- resp, err := http.Get("http://localhost:8080/.well-known/oauth-protected-resource")
- require.NoError(t, err)
- defer resp.Body.Close()
-
- assert.Equal(t, 404, resp.StatusCode, "Base protected resource metadata endpoint should return 404")
-
- var errResp map[string]any
- err = json.NewDecoder(resp.Body).Decode(&errResp)
- require.NoError(t, err)
-
- assert.Contains(t, errResp["message"], "per-service", "Error message should direct to per-service endpoints")
- })
-
- t.Run("PerServiceProtectedResourceMetadataEndpoint", func(t *testing.T) {
- // Per-service metadata endpoint should return service-specific resource URI
- resp, err := http.Get("http://localhost:8080/.well-known/oauth-protected-resource/test-sse")
- require.NoError(t, err)
- defer resp.Body.Close()
-
- assert.Equal(t, 200, resp.StatusCode, "Per-service protected resource metadata endpoint should exist")
-
- var metadata map[string]any
- err = json.NewDecoder(resp.Body).Decode(&metadata)
- require.NoError(t, err)
-
- // Resource should be service-specific, not base URL
- assert.Equal(t, "http://localhost:8080/test-sse", metadata["resource"],
- "Resource should be service-specific URL")
-
- authzServers, ok := metadata["authorization_servers"].([]any)
- require.True(t, ok, "Should have authorization_servers array")
- require.NotEmpty(t, authzServers)
- assert.Equal(t, "http://localhost:8080", authzServers[0],
- "Authorization server should be base issuer")
- })
-
- t.Run("UnknownServiceReturns404", func(t *testing.T) {
- resp, err := http.Get("http://localhost:8080/.well-known/oauth-protected-resource/nonexistent-service")
- require.NoError(t, err)
- defer resp.Body.Close()
-
- assert.Equal(t, 404, resp.StatusCode, "Unknown service should return 404")
- })
-
- t.Run("TokenWithResourceParameter", func(t *testing.T) {
- clientID := registerTestClient(t)
-
- codeVerifier := "test-code-verifier-that-is-at-least-43-characters-long"
- h := sha256.New()
- h.Write([]byte(codeVerifier))
- codeChallenge := base64.RawURLEncoding.EncodeToString(h.Sum(nil))
-
- authParams := url.Values{
- "response_type": {"code"},
- "client_id": {clientID},
- "redirect_uri": {"http://127.0.0.1:6274/oauth/callback"},
- "code_challenge": {codeChallenge},
- "code_challenge_method": {"S256"},
- "scope": {"openid email profile"},
- "state": {"test-state"},
- "resource": {"http://localhost:8080/test-sse"},
- }
-
- client := &http.Client{
- CheckRedirect: func(req *http.Request, via []*http.Request) error {
- return http.ErrUseLastResponse
- },
- }
-
- authResp, err := client.Get("http://localhost:8080/authorize?" + authParams.Encode())
- require.NoError(t, err)
- defer authResp.Body.Close()
-
- assert.Contains(t, []int{302, 303}, authResp.StatusCode, "Should redirect to Google OAuth")
-
- location := authResp.Header.Get("Location")
- googleResp, err := client.Get(location)
- require.NoError(t, err)
- defer googleResp.Body.Close()
-
- callbackLocation := googleResp.Header.Get("Location")
- callbackResp, err := client.Get(callbackLocation)
- require.NoError(t, err)
- defer callbackResp.Body.Close()
-
- finalURL, err := url.Parse(callbackResp.Header.Get("Location"))
- require.NoError(t, err)
- authCode := finalURL.Query().Get("code")
- require.NotEmpty(t, authCode, "Should have authorization code")
-
- tokenParams := url.Values{
- "grant_type": {"authorization_code"},
- "code": {authCode},
- "redirect_uri": {"http://127.0.0.1:6274/oauth/callback"},
- "client_id": {clientID},
- "code_verifier": {codeVerifier},
- }
-
- tokenResp, err := http.PostForm("http://localhost:8080/token", tokenParams)
- require.NoError(t, err)
- defer tokenResp.Body.Close()
-
- require.Equal(t, 200, tokenResp.StatusCode, "Token exchange should succeed")
-
- var tokenData map[string]any
- err = json.NewDecoder(tokenResp.Body).Decode(&tokenData)
- require.NoError(t, err)
-
- testSSEToken := tokenData["access_token"].(string)
- require.NotEmpty(t, testSSEToken, "Should have access token")
-
- t.Logf("Got token with test-sse audience: %s", testSSEToken[:20]+"...")
-
- // Verify token works for test-sse (matching audience)
- req, _ := http.NewRequest("GET", "http://localhost:8080/test-sse/sse", nil)
- req.Header.Set("Authorization", "Bearer "+testSSEToken)
- req.Header.Set("Accept", "text/event-stream")
-
- sseResp, err := client.Do(req)
- require.NoError(t, err)
- defer sseResp.Body.Close()
-
- assert.Equal(t, 200, sseResp.StatusCode,
- "Token with test-sse audience should access /test-sse/sse")
-
- // Verify token does NOT work for test-streamable (wrong audience)
- req, _ = http.NewRequest("GET", "http://localhost:8080/test-streamable/sse", nil)
- req.Header.Set("Authorization", "Bearer "+testSSEToken)
- req.Header.Set("Accept", "text/event-stream")
-
- streamableResp, err := client.Do(req)
- require.NoError(t, err)
- defer streamableResp.Body.Close()
-
- assert.Equal(t, 401, streamableResp.StatusCode,
- "Token with test-sse audience should NOT access /test-streamable/sse")
-
- wwwAuth := streamableResp.Header.Get("WWW-Authenticate")
- assert.Contains(t, wwwAuth, "Bearer resource_metadata=",
- "401 response should include RFC 9728 WWW-Authenticate header")
- // Per RFC 9728 Section 5.2, the metadata URI should be service-specific
- assert.Contains(t, wwwAuth, "/.well-known/oauth-protected-resource/test-streamable",
- "401 response should point to per-service metadata endpoint")
- })
-
- t.Run("401ResponseIncludesServiceSpecificMetadataURI", func(t *testing.T) {
- // Request to a protected endpoint without token should get 401
- // with service-specific metadata URI in WWW-Authenticate header
- client := &http.Client{
- CheckRedirect: func(req *http.Request, via []*http.Request) error {
- return http.ErrUseLastResponse
- },
- }
-
- req, _ := http.NewRequest("GET", "http://localhost:8080/test-sse/sse", nil)
- req.Header.Set("Accept", "text/event-stream")
-
- resp, err := client.Do(req)
- require.NoError(t, err)
- defer resp.Body.Close()
-
- assert.Equal(t, 401, resp.StatusCode, "Request without token should return 401")
-
- wwwAuth := resp.Header.Get("WWW-Authenticate")
- assert.Contains(t, wwwAuth, "Bearer resource_metadata=",
- "401 response should include RFC 9728 WWW-Authenticate header")
- assert.Contains(t, wwwAuth, "/.well-known/oauth-protected-resource/test-sse",
- "401 response should point to test-sse specific metadata endpoint")
- })
-}
diff --git a/integration/oauth_user_tokens_test.go b/integration/oauth_user_tokens_test.go
new file mode 100644
index 0000000..6612d9d
--- /dev/null
+++ b/integration/oauth_user_tokens_test.go
@@ -0,0 +1,536 @@
+package integration
+
+import (
+ "context"
+ "crypto/sha256"
+ "encoding/base64"
+ "encoding/json"
+ "fmt"
+ "io"
+ "net/http"
+ "net/http/cookiejar"
+ "net/url"
+ "regexp"
+ "strings"
+ "testing"
+
+ "github.com/stretchr/testify/assert"
+ "github.com/stretchr/testify/mock"
+ "github.com/stretchr/testify/require"
+)
+
+// TestUserTokenFlow tests the user token management functionality with browser-based SSO
+// This test expects the /my/* routes to work with Google SSO (session-based auth),
+// not Bearer token auth.
+func TestUserTokenFlow(t *testing.T) {
+ // Start OAuth server with user token configuration
+ mcpCmd := startOAuthServerWithTokenConfig(t)
+ defer stopServer(mcpCmd)
+
+ if !waitForHealthCheck(30) {
+ t.Fatal("Server failed to start")
+ }
+
+ // Create a client with cookie jar to simulate browser behavior
+ jar, _ := cookiejar.New(nil)
+ client := &http.Client{
+ Jar: jar,
+ CheckRedirect: func(req *http.Request, via []*http.Request) error {
+ // Allow up to 10 redirects
+ if len(via) >= 10 {
+ return fmt.Errorf("too many redirects")
+ }
+ return nil
+ },
+ }
+
+ t.Run("UnauthenticatedRedirectsToSSO", func(t *testing.T) {
+ // Create a client that doesn't follow redirects to test the initial redirect
+ noRedirectClient := &http.Client{
+ CheckRedirect: func(req *http.Request, via []*http.Request) error {
+ return http.ErrUseLastResponse
+ },
+ }
+
+ // Try to access /my/tokens without authentication
+ resp, err := noRedirectClient.Get("http://localhost:8080/my/tokens")
+ require.NoError(t, err)
+ defer resp.Body.Close()
+
+ // Should get a redirect response
+ assert.Equal(t, http.StatusFound, resp.StatusCode, "Should get redirect status")
+
+ // Check the redirect location
+ location := resp.Header.Get("Location")
+ assert.Contains(t, location, "localhost:9090/auth", "Should redirect to Google OAuth")
+ assert.Contains(t, location, "client_id=", "Should include client_id")
+ assert.Contains(t, location, "redirect_uri=", "Should include redirect_uri")
+ // Extract and validate the state parameter
+ parsedURL, err := url.Parse(location)
+ require.NoError(t, err)
+ stateParam := parsedURL.Query().Get("state")
+ require.NotEmpty(t, stateParam, "State parameter should be present")
+
+ // State format: "browser:" prefix followed by signed token
+ // We verify structure but not internal format (that's implementation detail)
+ assert.True(t, strings.HasPrefix(stateParam, "browser:"), "State should start with browser:")
+ assert.Greater(t, len(stateParam), len("browser:"), "State should have content after prefix")
+
+ // Verify state contains signature (has dot separator indicating signed data)
+ stateContent := strings.TrimPrefix(stateParam, "browser:")
+ assert.Contains(t, stateContent, ".", "Signed state should contain signature separator")
+ })
+
+ t.Run("AuthenticatedUserCanAccessTokens", func(t *testing.T) {
+ // The client with cookie jar will automatically follow the full SSO flow:
+ // 1. GET /my/tokens -> redirect to Google OAuth
+ // 2. Google OAuth redirects to /oauth/callback with code
+ // 3. Callback sets session cookie and redirects to /my/tokens
+ // 4. Client follows redirect with cookie and gets the page
+
+ resp, err := client.Get("http://localhost:8080/my/tokens")
+ require.NoError(t, err)
+ defer resp.Body.Close()
+
+ // After following all redirects, we should be at /my/tokens with 200 OK
+ assert.Equal(t, http.StatusOK, resp.StatusCode, "Should access /my/tokens after SSO")
+ finalURL := resp.Request.URL.String()
+ assert.Contains(t, finalURL, "/my/tokens", "Should end up at /my/tokens after SSO")
+
+ // Read response body
+ body, err := io.ReadAll(resp.Body)
+ require.NoError(t, err)
+ bodyStr := string(body)
+
+ // Should show both services without tokens
+ assert.Contains(t, bodyStr, "Notion", "Expected Notion service in response")
+ assert.Contains(t, bodyStr, "GitHub", "Expected GitHub service in response")
+ })
+
+ t.Run("SetTokenWithValidation", func(t *testing.T) {
+ // Assume we're already authenticated from previous test
+ // Get CSRF token first
+ resp, err := client.Get("http://localhost:8080/my/tokens")
+ require.NoError(t, err)
+ defer resp.Body.Close()
+
+ body, err := io.ReadAll(resp.Body)
+ require.NoError(t, err)
+
+ // Extract CSRF token from response
+ csrfToken := extractCSRFToken(t, string(body))
+
+ // Try to set invalid Notion token
+ form := url.Values{
+ "service": []string{"notion"},
+ "token": []string{"invalid-token"},
+ "csrf_token": []string{csrfToken},
+ }
+
+ req, _ := http.NewRequest("POST", "http://localhost:8080/my/tokens/set", strings.NewReader(form.Encode()))
+ req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
+
+ // Use custom client that doesn't follow redirects for this test
+ noRedirectClient := &http.Client{
+ Jar: jar, // Use same cookie jar
+ CheckRedirect: func(req *http.Request, via []*http.Request) error {
+ return http.ErrUseLastResponse
+ },
+ }
+
+ resp, err = noRedirectClient.Do(req)
+ require.NoError(t, err)
+ defer resp.Body.Close()
+
+ // Should redirect with error
+ assert.Equal(t, http.StatusSeeOther, resp.StatusCode, "Expected redirect")
+ location := resp.Header.Get("Location")
+ assert.Contains(t, location, "error", "Expected error in redirect")
+
+ // Get new CSRF token
+ resp, err = client.Get("http://localhost:8080/my/tokens")
+ require.NoError(t, err)
+ defer resp.Body.Close()
+
+ body, err = io.ReadAll(resp.Body)
+ require.NoError(t, err)
+ csrfToken = extractCSRFToken(t, string(body))
+
+ // Set valid Notion token (regex expects exactly 43 chars after "secret_")
+ form = url.Values{
+ "service": {"notion"},
+ "token": {"secret_1234567890123456789012345678901234567890123"},
+ "csrf_token": {csrfToken},
+ }
+
+ req, _ = http.NewRequest("POST", "http://localhost:8080/my/tokens/set", strings.NewReader(form.Encode()))
+ req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
+
+ resp, err = noRedirectClient.Do(req)
+ require.NoError(t, err)
+ defer resp.Body.Close()
+
+ // Should redirect with success
+ assert.Equal(t, http.StatusSeeOther, resp.StatusCode, "Expected redirect")
+ location = resp.Header.Get("Location")
+ assert.Contains(t, location, "success", "Expected success in redirect")
+ })
+}
+
+// TestToolAdvertisementWithUserTokens tests that tools are advertised even without user tokens
+// but fail gracefully when invoked without the required token, and succeed with the token
+func TestToolAdvertisementWithUserTokens(t *testing.T) {
+ cfg := buildTestConfig("http://localhost:8080", "mcp-front-oauth-usertoken-test",
+ testOAuthConfigFromEnv(),
+ map[string]any{"postgres": testPostgresServer(withUserToken())},
+ )
+ startMCPFront(t, writeTestConfig(t, cfg),
+ "JWT_SECRET=demo-jwt-secret-32-bytes-exactly!",
+ "ENCRYPTION_KEY=test-encryption-key-32-bytes-ok!",
+ "GOOGLE_CLIENT_ID=test-client-id-oauth",
+ "GOOGLE_CLIENT_SECRET=test-client-secret-oauth",
+ "MCP_FRONT_ENV=development",
+ "LOG_LEVEL=debug",
+ )
+
+ if !waitForHealthCheck(30) {
+ t.Fatal("Server failed to start")
+ }
+
+ // Complete OAuth flow to get a valid access token
+ accessToken := getOAuthAccessToken(t, "http://localhost:8080/postgres")
+
+ t.Run("ToolsAdvertisedWithoutToken", func(t *testing.T) {
+ // Create MCP client with OAuth token
+ mcpClient := NewMCPSSEClient("http://localhost:8080")
+ mcpClient.SetAuthToken(accessToken)
+
+ // Connect to postgres SSE endpoint
+ err := mcpClient.Connect()
+ require.NoError(t, err, "Should connect to postgres SSE endpoint without user token")
+ defer mcpClient.Close()
+
+ // Request tools list
+ toolsResp, err := mcpClient.SendMCPRequest("tools/list", map[string]any{})
+ require.NoError(t, err, "Should list tools without user token")
+
+ // Verify we got tools
+ resultMap, ok := toolsResp["result"].(map[string]any)
+ require.True(t, ok, "Expected result in tools response")
+
+ tools, ok := resultMap["tools"].([]any)
+ require.True(t, ok, "Expected tools array in result")
+ assert.NotEmpty(t, tools, "Should have tools advertised")
+
+ // Check for common postgres tools
+ var toolNames []string
+ for _, tool := range tools {
+ if toolMap, ok := tool.(map[string]any); ok {
+ if name, ok := toolMap["name"].(string); ok {
+ toolNames = append(toolNames, name)
+ }
+ }
+ }
+
+ assert.Contains(t, toolNames, "execute_sql", "Should have execute_sql tool")
+ t.Logf("Successfully advertised tools without user token: %v", toolNames)
+ })
+
+ t.Run("ToolInvocationFailsWithoutToken", func(t *testing.T) {
+ // Create MCP client with OAuth token
+ mcpClient := NewMCPSSEClient("http://localhost:8080")
+ mcpClient.SetAuthToken(accessToken)
+
+ // Connect to postgres SSE endpoint
+ err := mcpClient.Connect()
+ require.NoError(t, err)
+ defer mcpClient.Close()
+
+ // Try to invoke a tool without user token
+ queryParams := map[string]any{
+ "name": "execute_sql",
+ "arguments": map[string]any{
+ "sql": "SELECT 1",
+ },
+ }
+
+ result, err := mcpClient.SendMCPRequest("tools/call", queryParams)
+ require.NoError(t, err, "Should get response even without token")
+
+ // MCP protocol returns errors as successful responses with error content
+ require.NotNil(t, result["result"], "Should have result in response")
+
+ resultMap := result["result"].(map[string]any)
+ content := resultMap["content"].([]any)
+ require.NotEmpty(t, content, "Should have content in result")
+
+ contentItem := content[0].(map[string]any)
+ errorJSON := contentItem["text"].(string)
+
+ // Parse the error JSON
+ var errorData map[string]any
+ err = json.Unmarshal([]byte(errorJSON), &errorData)
+ require.NoError(t, err, "Error should be valid JSON")
+
+ // Verify error structure
+ errorInfo := errorData["error"].(map[string]any)
+ assert.Equal(t, "token_required", errorInfo["code"], "Error code should be token_required")
+
+ errorMessage := errorInfo["message"].(string)
+ assert.Contains(t, errorMessage, "token required", "Error should mention token required")
+ assert.Contains(t, errorMessage, "/my/tokens", "Error should mention token setup URL")
+ assert.Contains(t, errorMessage, "Test Service", "Error should mention service name")
+
+ // Verify error data
+ errData := errorInfo["data"].(map[string]any)
+ assert.Equal(t, "postgres", errData["service"], "Should identify the service")
+ assert.Contains(t, errData["tokenSetupUrl"].(string), "/my/tokens", "Should include token setup URL")
+
+ // Verify instructions
+ instructions := errData["instructions"].(map[string]any)
+ assert.Contains(t, instructions["ai"].(string), "CRITICAL", "Should have AI instructions")
+ assert.Contains(t, instructions["human"].(string), "token required", "Should have human instructions")
+ })
+
+ t.Run("ToolInvocationSucceedsWithUserToken", func(t *testing.T) {
+ // Step 1: GET /my/tokens to extract CSRF token
+ jar, err := cookiejar.New(nil)
+ require.NoError(t, err)
+ client := &http.Client{
+ Jar: jar, // Need cookie jar for CSRF
+ CheckRedirect: func(req *http.Request, via []*http.Request) error {
+ return http.ErrUseLastResponse // Don't follow redirects
+ },
+ }
+
+ req, err := http.NewRequest("GET", "http://localhost:8080/my/tokens", nil)
+ require.NoError(t, err)
+ req.Header.Set("Authorization", "Bearer "+accessToken)
+
+ resp, err := client.Do(req)
+ require.NoError(t, err)
+ defer resp.Body.Close()
+
+ // Check if we got the page or a redirect
+ if resp.StatusCode == 302 || resp.StatusCode == 303 {
+ // Follow the redirect
+ location := resp.Header.Get("Location")
+ t.Logf("Got redirect to: %s", location)
+
+ // Allow redirects for this request
+ client = &http.Client{
+ Jar: jar,
+ }
+
+ req, err = http.NewRequest("GET", "http://localhost:8080/my/tokens", nil)
+ require.NoError(t, err)
+ req.Header.Set("Authorization", "Bearer "+accessToken)
+
+ resp, err = client.Do(req)
+ require.NoError(t, err)
+ defer resp.Body.Close()
+ }
+
+ require.Equal(t, 200, resp.StatusCode, "Should be able to access token page")
+
+ // Extract CSRF token from HTML
+ body, err := io.ReadAll(resp.Body)
+ require.NoError(t, err)
+
+ // Look for the CSRF token in the form
+ csrfRegex := regexp.MustCompile(`name="csrf_token" value="([^"]+)"`)
+ matches := csrfRegex.FindSubmatch(body)
+ require.Len(t, matches, 2, "Should find CSRF token in form")
+ csrfToken := string(matches[1])
+
+ // Step 2: POST to /my/tokens/set with test token
+ formData := url.Values{
+ "service": {"postgres"},
+ "token": {"test-user-token-12345"},
+ "csrf_token": {csrfToken},
+ }
+
+ req, err = http.NewRequest("POST", "http://localhost:8080/my/tokens/set", strings.NewReader(formData.Encode()))
+ require.NoError(t, err)
+ req.Header.Set("Authorization", "Bearer "+accessToken)
+ req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
+
+ resp, err = client.Do(req)
+ require.NoError(t, err)
+ defer resp.Body.Close()
+
+ // Check the response - it might be 200 if following redirects
+ switch resp.StatusCode {
+ case 200:
+ // That's fine, it means the token was set and we got the page back
+ t.Log("Token set successfully, got page response")
+ case 302, 303:
+ // Also fine, redirect means success
+ t.Log("Token set successfully, got redirect")
+ default:
+ body, _ := io.ReadAll(resp.Body)
+ t.Fatalf("Unexpected response setting token: status=%d, body=%s", resp.StatusCode, string(body))
+ }
+
+ // Step 3: Now test tool invocation with the token
+ mcpClient := NewMCPSSEClient("http://localhost:8080")
+ mcpClient.SetAuthToken(accessToken)
+
+ err = mcpClient.Connect()
+ require.NoError(t, err, "Should connect to postgres SSE endpoint")
+ defer mcpClient.Close()
+
+ // Call the query tool with a simple query
+ queryParams := map[string]any{
+ "name": "execute_sql",
+ "arguments": map[string]any{
+ "sql": "SELECT 1 as test",
+ },
+ }
+
+ result, err := mcpClient.SendMCPRequest("tools/call", queryParams)
+ require.NoError(t, err, "Should successfully call tool with token")
+
+ // Verify we got a successful result, not an error
+ require.NotNil(t, result["result"], "Should have result in response")
+
+ resultMap := result["result"].(map[string]any)
+ content := resultMap["content"].([]any)
+ require.NotEmpty(t, content, "Should have content in result")
+
+ contentItem := content[0].(map[string]any)
+ resultText := contentItem["text"].(string)
+
+ // The result should contain actual query results, not an error
+ assert.NotContains(t, resultText, "token_required", "Should not have token error")
+ assert.NotContains(t, resultText, "Token Required", "Should not have token error message")
+
+ // Postgres query result should contain our test value
+ assert.Contains(t, resultText, "1", "Should contain query result")
+
+ t.Log("Successfully invoked tool with user token")
+ })
+}
+
+// getOAuthAccessTokenForIDP completes the OAuth flow and returns a valid access token.
+// expectedIDPHost is the host:port of the expected IDP redirect target (e.g., "localhost:9090" for Google).
+func getOAuthAccessTokenForIDP(t *testing.T, resource, expectedIDPHost string) string {
+ t.Helper()
+
+ clientID := registerTestClient(t)
+
+ codeVerifier := "test-code-verifier-that-is-at-least-43-characters-long"
+ h := sha256.New()
+ h.Write([]byte(codeVerifier))
+ codeChallenge := base64.RawURLEncoding.EncodeToString(h.Sum(nil))
+
+ authParams := url.Values{
+ "response_type": {"code"},
+ "client_id": {clientID},
+ "redirect_uri": {"http://127.0.0.1:6274/oauth/callback"},
+ "code_challenge": {codeChallenge},
+ "code_challenge_method": {"S256"},
+ "scope": {"openid email profile"},
+ "state": {"test-state"},
+ "resource": {resource},
+ }
+
+ client := &http.Client{
+ CheckRedirect: func(req *http.Request, via []*http.Request) error {
+ return http.ErrUseLastResponse
+ },
+ }
+
+ // Step 1: Authorization request
+ authResp, err := client.Get("http://localhost:8080/authorize?" + authParams.Encode())
+ require.NoError(t, err)
+ defer authResp.Body.Close()
+
+ require.Contains(t, []int{302, 303}, authResp.StatusCode, "Should redirect to IDP")
+ location := authResp.Header.Get("Location")
+ require.Contains(t, location, expectedIDPHost, "Should redirect to expected IDP")
+
+ // Step 2: Follow redirect to IDP (which immediately redirects back)
+ idpResp, err := client.Get(location)
+ require.NoError(t, err)
+ defer idpResp.Body.Close()
+
+ require.Contains(t, []int{302, 303}, idpResp.StatusCode, "IDP should redirect back")
+ callbackLocation := idpResp.Header.Get("Location")
+ require.Contains(t, callbackLocation, "/oauth/callback", "Should redirect to callback")
+
+ // Step 3: Follow callback redirect
+ callbackResp, err := client.Get(callbackLocation)
+ require.NoError(t, err)
+ defer callbackResp.Body.Close()
+
+ require.Contains(t, []int{302, 303}, callbackResp.StatusCode, "Callback should redirect with code")
+ finalLocation := callbackResp.Header.Get("Location")
+
+ finalURL, err := url.Parse(finalLocation)
+ require.NoError(t, err)
+ authCode := finalURL.Query().Get("code")
+ require.NotEmpty(t, authCode, "Should have authorization code")
+
+ // Step 4: Exchange code for token
+ tokenParams := url.Values{
+ "grant_type": {"authorization_code"},
+ "code": {authCode},
+ "redirect_uri": {"http://127.0.0.1:6274/oauth/callback"},
+ "client_id": {clientID},
+ "code_verifier": {codeVerifier},
+ }
+
+ tokenResp, err := http.PostForm("http://localhost:8080/token", tokenParams)
+ require.NoError(t, err)
+ defer tokenResp.Body.Close()
+
+ if tokenResp.StatusCode != 200 {
+ body, _ := io.ReadAll(tokenResp.Body)
+ t.Logf("Token exchange failed with status %d: %s", tokenResp.StatusCode, string(body))
+ }
+
+ require.Equal(t, 200, tokenResp.StatusCode, "Token exchange should succeed")
+
+ var tokenData map[string]any
+ err = json.NewDecoder(tokenResp.Body).Decode(&tokenData)
+ require.NoError(t, err)
+
+ accessToken := tokenData["access_token"].(string)
+ require.NotEmpty(t, accessToken, "Should have access token")
+
+ return accessToken
+}
+
+// getOAuthAccessToken completes the OAuth flow using the Google IDP and returns a valid access token.
+func getOAuthAccessToken(t *testing.T, resource string) string {
+ return getOAuthAccessTokenForIDP(t, resource, "localhost:9090")
+}
+
+// MockUserTokenStore mocks the UserTokenStore interface for testing
+type MockUserTokenStore struct {
+ mock.Mock
+}
+
+func (m *MockUserTokenStore) GetUserToken(ctx context.Context, email, service string) (string, error) {
+ args := m.Called(ctx, email, service)
+ return args.String(0), args.Error(1)
+}
+
+func (m *MockUserTokenStore) SetUserToken(ctx context.Context, email, service, token string) error {
+ args := m.Called(ctx, email, service, token)
+ return args.Error(0)
+}
+
+func (m *MockUserTokenStore) DeleteUserToken(ctx context.Context, email, service string) error {
+ args := m.Called(ctx, email, service)
+ return args.Error(0)
+}
+
+func (m *MockUserTokenStore) ListUserServices(ctx context.Context, email string) ([]string, error) {
+ args := m.Called(ctx, email)
+ if args.Get(0) == nil {
+ return nil, args.Error(1)
+ }
+ return args.Get(0).([]string), args.Error(1)
+}
diff --git a/integration/security_test.go b/integration/security_test.go
index 2305c92..653ca14 100644
--- a/integration/security_test.go
+++ b/integration/security_test.go
@@ -15,7 +15,11 @@ func TestSecurityScenarios(t *testing.T) {
waitForDB(t)
// Start mcp-front
- startMCPFront(t, "config/config.test.json")
+ cfg := buildTestConfig("http://localhost:8080", "mcp-front-test",
+ nil,
+ map[string]any{"postgres": testPostgresServer(withBearerTokens("test-token", "alt-test-token"), withLogEnabled())},
+ )
+ startMCPFront(t, writeTestConfig(t, cfg))
// Wait for server to be ready
waitForMCPFront(t)
@@ -78,7 +82,7 @@ func TestSecurityScenarios(t *testing.T) {
})
t.Run("SQLInjectionAttempts", func(t *testing.T) {
- t.Skip("Skipping SQL injection tests, it's not a responsibility of mcp-front to guard mcp/postgres")
+ t.Skip("Skipping SQL injection tests, it's not a responsibility of mcp-front to guard toolbox")
client := NewMCPSSEClient("http://localhost:8080")
_ = client.Authenticate()
@@ -101,7 +105,7 @@ func TestSecurityScenarios(t *testing.T) {
// Try to inject via the query parameter
_, err := client.SendMCPRequest("tools/call", map[string]any{
- "name": "query",
+ "name": "execute_sql",
"arguments": map[string]any{
"query": payload,
},
@@ -297,7 +301,11 @@ func TestFailureScenarios(t *testing.T) {
// Database is already started by TestMain, just wait for readiness
waitForDB(t)
- startMCPFront(t, "config/config.test.json")
+ cfg := buildTestConfig("http://localhost:8080", "mcp-front-test",
+ nil,
+ map[string]any{"postgres": testPostgresServer(withBearerTokens("test-token", "alt-test-token"), withLogEnabled())},
+ )
+ startMCPFront(t, writeTestConfig(t, cfg))
// Wait for server to be ready
waitForMCPFront(t)
diff --git a/integration/streamable_client.go b/integration/streamable_client.go
deleted file mode 100644
index 76ab1b7..0000000
--- a/integration/streamable_client.go
+++ /dev/null
@@ -1 +0,0 @@
-package integration
diff --git a/integration/test_clients.go b/integration/test_clients.go
new file mode 100644
index 0000000..a5cb357
--- /dev/null
+++ b/integration/test_clients.go
@@ -0,0 +1,442 @@
+package integration
+
+import (
+ "bufio"
+ "bytes"
+ "encoding/json"
+ "fmt"
+ "io"
+ "net/http"
+ "net/url"
+ "strings"
+ "sync"
+ "time"
+)
+
+// MCPSSEClient simulates an MCP client for testing
+type MCPSSEClient struct {
+ baseURL string
+ token string
+ sseConn io.ReadCloser
+ messageEndpoint string
+ sseScanner *bufio.Scanner
+ sessionID string
+}
+
+// NewMCPSSEClient creates a new MCP client for testing
+func NewMCPSSEClient(baseURL string) *MCPSSEClient {
+ return &MCPSSEClient{
+ baseURL: baseURL,
+ }
+}
+
+// Authenticate sets up authentication for the client
+func (c *MCPSSEClient) Authenticate() error {
+ c.token = "test-token"
+ return nil
+}
+
+// SetAuthToken sets a specific auth token for the client
+func (c *MCPSSEClient) SetAuthToken(token string) {
+ c.token = token
+}
+
+// Connect establishes an SSE connection and retrieves the message endpoint
+func (c *MCPSSEClient) Connect() error {
+ return c.ConnectToServer("postgres")
+}
+
+// ConnectToServer establishes an SSE connection to a specific server
+func (c *MCPSSEClient) ConnectToServer(serverName string) error {
+ // Close any existing connection
+ if c.sseConn != nil {
+ c.sseConn.Close()
+ c.sseConn = nil
+ c.messageEndpoint = ""
+ }
+
+ sseURL := c.baseURL + "/" + serverName + "/sse"
+ tracef("ConnectToServer: requesting %s", sseURL)
+
+ req, err := http.NewRequest("GET", sseURL, nil)
+ if err != nil {
+ return fmt.Errorf("failed to create SSE request: %v", err)
+ }
+
+ req.Header.Set("Accept", "text/event-stream")
+ req.Header.Set("Authorization", "Bearer "+c.token)
+ req.Header.Set("Cache-Control", "no-cache")
+ tracef("ConnectToServer: headers set, making request")
+
+ // Don't use a timeout on the client for SSE
+ client := &http.Client{}
+ resp, err := client.Do(req)
+ if err != nil {
+ return fmt.Errorf("SSE connection failed: %v", err)
+ }
+
+ tracef("ConnectToServer: got response status %d", resp.StatusCode)
+ if resp.StatusCode != 200 {
+ body, _ := io.ReadAll(resp.Body)
+ resp.Body.Close()
+ return fmt.Errorf("SSE connection returned %d: %s", resp.StatusCode, string(body))
+ }
+
+ // Store the connection
+ c.sseConn = resp.Body
+ c.sseScanner = bufio.NewScanner(resp.Body)
+
+ // Read initial SSE messages to get the endpoint
+ // For inline servers, we don't get a message endpoint - we use the server path directly
+ gotEndpointMessage := false
+ for c.sseScanner.Scan() {
+ line := c.sseScanner.Text()
+ tracef("ConnectToServer: SSE line: %s", line)
+
+ // Look for data lines
+ if after, ok := strings.CutPrefix(line, "data: "); ok {
+ data := after
+
+ // Check if it's an endpoint message (for inline servers)
+ if strings.Contains(data, `"type":"endpoint"`) {
+ gotEndpointMessage = true
+ // For inline servers, construct the message endpoint
+ c.messageEndpoint = c.baseURL + "/" + serverName + "/message"
+ tracef("ConnectToServer: inline server detected, using endpoint: %s", c.messageEndpoint)
+ break
+ }
+
+ // Check if it's a message endpoint URL (for stdio servers)
+ if strings.Contains(data, "http://") || strings.Contains(data, "https://") {
+ c.messageEndpoint = data
+
+ // Extract session ID from endpoint URL
+ if u, err := url.Parse(data); err == nil {
+ c.sessionID = u.Query().Get("sessionId")
+ }
+
+ tracef("ConnectToServer: found endpoint: %s", c.messageEndpoint)
+ break
+ }
+ }
+ }
+
+ if c.messageEndpoint == "" && !gotEndpointMessage {
+ c.sseConn.Close()
+ c.sseConn = nil
+ return fmt.Errorf("no message endpoint received")
+ }
+
+ tracef("Connect: successfully connected to MCP server")
+ return nil
+}
+
+// ValidateBackendConnectivity checks if we can connect to the MCP server
+func (c *MCPSSEClient) ValidateBackendConnectivity() error {
+ return c.Connect()
+}
+
+// Close closes the SSE connection
+func (c *MCPSSEClient) Close() {
+ if c.sseConn != nil {
+ c.sseConn.Close()
+ c.sseConn = nil
+ c.messageEndpoint = ""
+ c.sseScanner = nil
+ }
+}
+
+// SendMCPRequest sends an MCP JSON-RPC request and returns the response
+func (c *MCPSSEClient) SendMCPRequest(method string, params any) (map[string]any, error) {
+ // Ensure we have a connection
+ if c.messageEndpoint == "" {
+ if err := c.Connect(); err != nil {
+ return nil, fmt.Errorf("failed to connect: %v", err)
+ }
+ }
+
+ // Send MCP request to the message endpoint
+ request := map[string]any{
+ "jsonrpc": "2.0",
+ "id": 1,
+ "method": method,
+ "params": params,
+ }
+
+ reqBody, err := json.Marshal(request)
+ if err != nil {
+ return nil, err
+ }
+
+ msgReq, err := http.NewRequest("POST", c.messageEndpoint, bytes.NewBuffer(reqBody))
+ if err != nil {
+ return nil, err
+ }
+
+ msgReq.Header.Set("Content-Type", "application/json")
+ msgReq.Header.Set("Authorization", "Bearer "+c.token)
+
+ client := &http.Client{Timeout: 30 * time.Second}
+ msgResp, err := client.Do(msgReq)
+ if err != nil {
+ return nil, err
+ }
+ defer msgResp.Body.Close()
+
+ respBody, err := io.ReadAll(msgResp.Body)
+ if err != nil {
+ return nil, err
+ }
+
+ if msgResp.StatusCode != 200 && msgResp.StatusCode != 202 {
+ return nil, fmt.Errorf("MCP request failed: %d - %s", msgResp.StatusCode, string(respBody))
+ }
+
+ // Handle 202 and empty responses - read response from SSE stream
+ if msgResp.StatusCode == 202 || len(respBody) == 0 {
+ // Read response from SSE stream
+ for c.sseScanner.Scan() {
+ line := c.sseScanner.Text()
+
+ if after, ok := strings.CutPrefix(line, "data: "); ok {
+ data := after
+ // Try to parse as JSON
+ var msg map[string]any
+ if err := json.Unmarshal([]byte(data), &msg); err == nil {
+ // Check if this is our response (matching ID)
+ if id, ok := msg["id"]; ok && id == float64(1) {
+ return msg, nil
+ }
+ }
+ }
+ }
+
+ if err := c.sseScanner.Err(); err != nil {
+ return nil, fmt.Errorf("SSE scanner error: %v", err)
+ }
+
+ return nil, fmt.Errorf("no response received from SSE stream")
+ }
+
+ var result map[string]any
+ if err := json.Unmarshal(respBody, &result); err != nil {
+ return nil, fmt.Errorf("failed to parse response: %v - %s", err, string(respBody))
+ }
+
+ return result, nil
+}
+
+// MCPStreamableClient is a test client for HTTP-Streamable MCP servers
+type MCPStreamableClient struct {
+ baseURL string
+ serverName string
+ token string
+ httpClient *http.Client
+
+ // For GET SSE streaming
+ sseConn io.ReadCloser
+ sseScanner *bufio.Scanner
+ sseCancel chan struct{}
+
+ mu sync.Mutex
+}
+
+// NewMCPStreamableClient creates a new streamable-http test client
+func NewMCPStreamableClient(baseURL string) *MCPStreamableClient {
+ return &MCPStreamableClient{
+ baseURL: baseURL,
+ httpClient: &http.Client{
+ Timeout: 30 * time.Second,
+ },
+ }
+}
+
+// SetAuthToken sets the authentication token
+func (c *MCPStreamableClient) SetAuthToken(token string) {
+ c.token = token
+}
+
+// ConnectToServer establishes connection to a streamable-http server
+func (c *MCPStreamableClient) ConnectToServer(serverName string) error {
+ c.mu.Lock()
+ defer c.mu.Unlock()
+
+ // Close any existing connection
+ c.close()
+
+ c.serverName = serverName
+
+ // For streamable-http, we can optionally open a GET SSE stream for server-initiated messages
+ // But it's not required for basic request/response
+ return c.openSSEStream()
+}
+
+// openSSEStream opens a GET SSE connection for receiving server-initiated messages
+func (c *MCPStreamableClient) openSSEStream() error {
+ url := c.baseURL + "/" + c.serverName + "/sse"
+
+ req, err := http.NewRequest("GET", url, nil)
+ if err != nil {
+ return fmt.Errorf("failed to create GET request: %v", err)
+ }
+
+ req.Header.Set("Accept", "text/event-stream")
+ req.Header.Set("Authorization", "Bearer "+c.token)
+ req.Header.Set("Cache-Control", "no-cache")
+
+ // Use a client without timeout for SSE
+ sseClient := &http.Client{}
+ resp, err := sseClient.Do(req)
+ if err != nil {
+ return fmt.Errorf("SSE connection failed: %v", err)
+ }
+
+ if resp.StatusCode != 200 {
+ body, _ := io.ReadAll(resp.Body)
+ resp.Body.Close()
+ return fmt.Errorf("SSE connection returned %d: %s", resp.StatusCode, string(body))
+ }
+
+ c.sseConn = resp.Body
+ c.sseScanner = bufio.NewScanner(resp.Body)
+ c.sseCancel = make(chan struct{})
+
+ // Start reading SSE messages in background
+ go c.readSSEMessages()
+
+ return nil
+}
+
+// readSSEMessages reads server-initiated messages from the SSE stream
+func (c *MCPStreamableClient) readSSEMessages() {
+ for {
+ select {
+ case <-c.sseCancel:
+ return
+ default:
+ if c.sseScanner.Scan() {
+ line := c.sseScanner.Text()
+ if after, ok := strings.CutPrefix(line, "data: "); ok {
+ data := after
+ // In a real implementation, we'd process server-initiated messages here
+ tracef("StreamableClient: received SSE message: %s", data)
+ }
+ } else {
+ // Scanner stopped - connection closed or error
+ return
+ }
+ }
+ }
+}
+
+// SendMCPRequest sends a JSON-RPC request via POST
+func (c *MCPStreamableClient) SendMCPRequest(method string, params any) (map[string]any, error) {
+ c.mu.Lock()
+ serverName := c.serverName
+ c.mu.Unlock()
+
+ if serverName == "" {
+ return nil, fmt.Errorf("not connected to any server")
+ }
+
+ // For streamable-http, we POST to the server endpoint
+ url := c.baseURL + "/" + serverName + "/sse"
+
+ // Construct JSON-RPC request
+ request := map[string]any{
+ "jsonrpc": "2.0",
+ "id": 1,
+ "method": method,
+ "params": params,
+ }
+
+ body, err := json.Marshal(request)
+ if err != nil {
+ return nil, fmt.Errorf("failed to marshal request: %v", err)
+ }
+
+ req, err := http.NewRequest("POST", url, bytes.NewReader(body))
+ if err != nil {
+ return nil, fmt.Errorf("failed to create POST request: %v", err)
+ }
+
+ req.Header.Set("Content-Type", "application/json")
+ req.Header.Set("Authorization", "Bearer "+c.token)
+ // Accept both JSON and SSE responses
+ req.Header.Set("Accept", "application/json, text/event-stream")
+
+ resp, err := c.httpClient.Do(req)
+ if err != nil {
+ return nil, fmt.Errorf("request failed: %v", err)
+ }
+ defer resp.Body.Close()
+
+ // Check content type to determine response format
+ contentType := resp.Header.Get("Content-Type")
+
+ if strings.HasPrefix(contentType, "text/event-stream") {
+ // Handle SSE response
+ return c.handleSSEResponse(resp.Body)
+ } else {
+ // Handle JSON response
+ return c.handleJSONResponse(resp.Body)
+ }
+}
+
+// handleJSONResponse processes a regular JSON response
+func (c *MCPStreamableClient) handleJSONResponse(body io.Reader) (map[string]any, error) {
+ var response map[string]any
+ if err := json.NewDecoder(body).Decode(&response); err != nil {
+ return nil, fmt.Errorf("failed to decode JSON response: %v", err)
+ }
+ return response, nil
+}
+
+// handleSSEResponse processes an SSE stream response from a POST
+func (c *MCPStreamableClient) handleSSEResponse(body io.Reader) (map[string]any, error) {
+ scanner := bufio.NewScanner(body)
+ var lastResponse map[string]any
+
+ for scanner.Scan() {
+ line := scanner.Text()
+ if after, ok := strings.CutPrefix(line, "data: "); ok {
+ data := after
+ var msg map[string]any
+ if err := json.Unmarshal([]byte(data), &msg); err == nil {
+ // Keep the last response with an ID (not a notification)
+ if _, hasID := msg["id"]; hasID {
+ lastResponse = msg
+ }
+ }
+ }
+ }
+
+ if lastResponse == nil {
+ return nil, fmt.Errorf("no response received in SSE stream")
+ }
+
+ return lastResponse, nil
+}
+
+// Close closes all connections
+func (c *MCPStreamableClient) Close() {
+ c.mu.Lock()
+ defer c.mu.Unlock()
+ c.close()
+}
+
+// close is the internal close method (must be called with lock held)
+func (c *MCPStreamableClient) close() {
+ if c.sseCancel != nil {
+ close(c.sseCancel)
+ c.sseCancel = nil
+ }
+
+ if c.sseConn != nil {
+ c.sseConn.Close()
+ c.sseConn = nil
+ c.sseScanner = nil
+ }
+
+ c.serverName = ""
+}
diff --git a/integration/test_fakes.go b/integration/test_fakes.go
new file mode 100644
index 0000000..019b8c0
--- /dev/null
+++ b/integration/test_fakes.go
@@ -0,0 +1,355 @@
+package integration
+
+import (
+ "context"
+ "encoding/json"
+ "fmt"
+ "net/http"
+ "time"
+)
+
+// FakeGCPServer provides a fake GCP OAuth server for testing
+type FakeGCPServer struct {
+ server *http.Server
+ port string
+}
+
+// NewFakeGCPServer creates a new fake GCP server
+func NewFakeGCPServer(port string) *FakeGCPServer {
+ mux := http.NewServeMux()
+
+ mux.HandleFunc("/auth", func(w http.ResponseWriter, r *http.Request) {
+ redirectURI := r.URL.Query().Get("redirect_uri")
+ state := r.URL.Query().Get("state")
+ http.Redirect(w, r, fmt.Sprintf("%s?code=test-auth-code&state=%s", redirectURI, state), http.StatusFound)
+ })
+
+ mux.HandleFunc("/token", func(w http.ResponseWriter, r *http.Request) {
+ // Parse the form data
+ if err := r.ParseForm(); err != nil {
+ http.Error(w, "Invalid request", http.StatusBadRequest)
+ return
+ }
+
+ // Check the authorization code
+ code := r.FormValue("code")
+ if code != "test-auth-code" {
+ w.Header().Set("Content-Type", "application/json")
+ w.WriteHeader(http.StatusBadRequest)
+ _ = json.NewEncoder(w).Encode(map[string]any{
+ "error": "invalid_grant",
+ "error_description": "Invalid authorization code",
+ })
+ return
+ }
+
+ w.Header().Set("Content-Type", "application/json")
+ _ = json.NewEncoder(w).Encode(map[string]any{
+ "access_token": "test-access-token",
+ "token_type": "Bearer",
+ "expires_in": 3600,
+ })
+ })
+
+ mux.HandleFunc("/userinfo", func(w http.ResponseWriter, r *http.Request) {
+ w.Header().Set("Content-Type", "application/json")
+ _ = json.NewEncoder(w).Encode(map[string]any{
+ "email": "test@test.com",
+ "hd": "test.com",
+ })
+ })
+
+ server := &http.Server{
+ Addr: ":" + port,
+ Handler: mux,
+ }
+
+ return &FakeGCPServer{
+ server: server,
+ port: port,
+ }
+}
+
+// Start starts the fake GCP server
+func (m *FakeGCPServer) Start() error {
+ go func() {
+ if err := m.server.ListenAndServe(); err != nil && err != http.ErrServerClosed {
+ panic(err)
+ }
+ }()
+
+ time.Sleep(100 * time.Millisecond)
+ return nil
+}
+
+// Stop stops the fake GCP server
+func (m *FakeGCPServer) Stop() error {
+ ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
+ defer cancel()
+ return m.server.Shutdown(ctx)
+}
+
+// FakeServiceOAuthServer provides a fake OAuth server for external services (like Linear, GitHub)
+type FakeServiceOAuthServer struct {
+ server *http.Server
+ port string
+}
+
+// NewFakeServiceOAuthServer creates a new fake service OAuth server
+func NewFakeServiceOAuthServer(port string) *FakeServiceOAuthServer {
+ mux := http.NewServeMux()
+
+ mux.HandleFunc("/oauth/authorize", func(w http.ResponseWriter, r *http.Request) {
+ redirectURI := r.URL.Query().Get("redirect_uri")
+ state := r.URL.Query().Get("state")
+ http.Redirect(w, r, fmt.Sprintf("%s?code=service-auth-code&state=%s", redirectURI, state), http.StatusFound)
+ })
+
+ mux.HandleFunc("/oauth/token", func(w http.ResponseWriter, r *http.Request) {
+ if err := r.ParseForm(); err != nil {
+ http.Error(w, "Invalid request", http.StatusBadRequest)
+ return
+ }
+
+ code := r.FormValue("code")
+ if code != "service-auth-code" {
+ w.Header().Set("Content-Type", "application/json")
+ w.WriteHeader(http.StatusBadRequest)
+ _ = json.NewEncoder(w).Encode(map[string]any{
+ "error": "invalid_grant",
+ "error_description": "Invalid authorization code",
+ })
+ return
+ }
+
+ w.Header().Set("Content-Type", "application/json")
+ _ = json.NewEncoder(w).Encode(map[string]any{
+ "access_token": "service-oauth-access-token",
+ "refresh_token": "service-oauth-refresh-token",
+ "token_type": "Bearer",
+ "expires_in": 3600,
+ })
+ })
+
+ server := &http.Server{
+ Addr: ":" + port,
+ Handler: mux,
+ }
+
+ return &FakeServiceOAuthServer{
+ server: server,
+ port: port,
+ }
+}
+
+// Start starts the fake service OAuth server
+func (s *FakeServiceOAuthServer) Start() error {
+ go func() {
+ if err := s.server.ListenAndServe(); err != nil && err != http.ErrServerClosed {
+ panic(err)
+ }
+ }()
+
+ time.Sleep(100 * time.Millisecond)
+ return nil
+}
+
+// Stop stops the fake service OAuth server
+func (s *FakeServiceOAuthServer) Stop() error {
+ ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
+ defer cancel()
+ return s.server.Shutdown(ctx)
+}
+
+// FakeGitHubServer simulates GitHub's OAuth and API endpoints for integration testing.
+type FakeGitHubServer struct {
+ server *http.Server
+ port string
+}
+
+// NewFakeGitHubServer creates a new fake GitHub server.
+// orgs controls what organizations the /user/orgs endpoint returns.
+func NewFakeGitHubServer(port string, orgs []string) *FakeGitHubServer {
+ mux := http.NewServeMux()
+
+ mux.HandleFunc("/login/oauth/authorize", func(w http.ResponseWriter, r *http.Request) {
+ redirectURI := r.URL.Query().Get("redirect_uri")
+ state := r.URL.Query().Get("state")
+ http.Redirect(w, r, fmt.Sprintf("%s?code=github-test-code&state=%s", redirectURI, state), http.StatusFound)
+ })
+
+ mux.HandleFunc("/login/oauth/access_token", func(w http.ResponseWriter, r *http.Request) {
+ if err := r.ParseForm(); err != nil {
+ http.Error(w, "Invalid request", http.StatusBadRequest)
+ return
+ }
+
+ code := r.FormValue("code")
+ if code != "github-test-code" {
+ w.Header().Set("Content-Type", "application/json")
+ w.WriteHeader(http.StatusBadRequest)
+ _ = json.NewEncoder(w).Encode(map[string]any{
+ "error": "bad_verification_code",
+ "error_description": "Invalid authorization code",
+ })
+ return
+ }
+
+ w.Header().Set("Content-Type", "application/json")
+ _ = json.NewEncoder(w).Encode(map[string]any{
+ "access_token": "github-test-token",
+ "token_type": "bearer",
+ "scope": "user:email,read:org",
+ })
+ })
+
+ mux.HandleFunc("/user", func(w http.ResponseWriter, r *http.Request) {
+ w.Header().Set("Content-Type", "application/json")
+ _ = json.NewEncoder(w).Encode(map[string]any{
+ "id": 12345,
+ "login": "testuser",
+ "email": "test@test.com",
+ "name": "Test User",
+ "avatar_url": "https://github.com/avatar.jpg",
+ })
+ })
+
+ mux.HandleFunc("/user/emails", func(w http.ResponseWriter, r *http.Request) {
+ w.Header().Set("Content-Type", "application/json")
+ _ = json.NewEncoder(w).Encode([]map[string]any{
+ {"email": "test@test.com", "primary": true, "verified": true},
+ })
+ })
+
+ mux.HandleFunc("/user/orgs", func(w http.ResponseWriter, r *http.Request) {
+ orgList := make([]map[string]any, len(orgs))
+ for i, org := range orgs {
+ orgList[i] = map[string]any{"login": org}
+ }
+ w.Header().Set("Content-Type", "application/json")
+ _ = json.NewEncoder(w).Encode(orgList)
+ })
+
+ server := &http.Server{
+ Addr: ":" + port,
+ Handler: mux,
+ }
+
+ return &FakeGitHubServer{
+ server: server,
+ port: port,
+ }
+}
+
+// Start starts the fake GitHub server
+func (s *FakeGitHubServer) Start() error {
+ go func() {
+ if err := s.server.ListenAndServe(); err != nil && err != http.ErrServerClosed {
+ panic(err)
+ }
+ }()
+
+ time.Sleep(100 * time.Millisecond)
+ return nil
+}
+
+// Stop stops the fake GitHub server
+func (s *FakeGitHubServer) Stop() error {
+ ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
+ defer cancel()
+ return s.server.Shutdown(ctx)
+}
+
+// FakeOIDCServer simulates a generic OIDC provider for integration testing.
+// Used for both generic OIDC and Azure tests (Azure is OIDC-compliant).
+type FakeOIDCServer struct {
+ server *http.Server
+ port string
+}
+
+// NewFakeOIDCServer creates a new fake OIDC server.
+func NewFakeOIDCServer(port string) *FakeOIDCServer {
+ mux := http.NewServeMux()
+
+ baseURL := "http://localhost:" + port
+
+ mux.HandleFunc("/.well-known/openid-configuration", func(w http.ResponseWriter, r *http.Request) {
+ w.Header().Set("Content-Type", "application/json")
+ _ = json.NewEncoder(w).Encode(map[string]any{
+ "issuer": baseURL,
+ "authorization_endpoint": baseURL + "/authorize",
+ "token_endpoint": baseURL + "/token",
+ "userinfo_endpoint": baseURL + "/userinfo",
+ })
+ })
+
+ mux.HandleFunc("/authorize", func(w http.ResponseWriter, r *http.Request) {
+ redirectURI := r.URL.Query().Get("redirect_uri")
+ state := r.URL.Query().Get("state")
+ http.Redirect(w, r, fmt.Sprintf("%s?code=oidc-test-code&state=%s", redirectURI, state), http.StatusFound)
+ })
+
+ mux.HandleFunc("/token", func(w http.ResponseWriter, r *http.Request) {
+ if err := r.ParseForm(); err != nil {
+ http.Error(w, "Invalid request", http.StatusBadRequest)
+ return
+ }
+
+ code := r.FormValue("code")
+ if code != "oidc-test-code" {
+ w.Header().Set("Content-Type", "application/json")
+ w.WriteHeader(http.StatusBadRequest)
+ _ = json.NewEncoder(w).Encode(map[string]any{
+ "error": "invalid_grant",
+ "error_description": "Invalid authorization code",
+ })
+ return
+ }
+
+ w.Header().Set("Content-Type", "application/json")
+ _ = json.NewEncoder(w).Encode(map[string]any{
+ "access_token": "oidc-test-token",
+ "token_type": "Bearer",
+ "expires_in": 3600,
+ })
+ })
+
+ mux.HandleFunc("/userinfo", func(w http.ResponseWriter, r *http.Request) {
+ w.Header().Set("Content-Type", "application/json")
+ _ = json.NewEncoder(w).Encode(map[string]any{
+ "sub": "oidc-12345",
+ "email": "test@oidc-test.com",
+ "email_verified": true,
+ "name": "OIDC User",
+ })
+ })
+
+ server := &http.Server{
+ Addr: ":" + port,
+ Handler: mux,
+ }
+
+ return &FakeOIDCServer{
+ server: server,
+ port: port,
+ }
+}
+
+// Start starts the fake OIDC server
+func (s *FakeOIDCServer) Start() error {
+ go func() {
+ if err := s.server.ListenAndServe(); err != nil && err != http.ErrServerClosed {
+ panic(err)
+ }
+ }()
+
+ time.Sleep(100 * time.Millisecond)
+ return nil
+}
+
+// Stop stops the fake OIDC server
+func (s *FakeOIDCServer) Stop() error {
+ ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
+ defer cancel()
+ return s.server.Shutdown(ctx)
+}
diff --git a/integration/test_helpers.go b/integration/test_helpers.go
new file mode 100644
index 0000000..1b8715f
--- /dev/null
+++ b/integration/test_helpers.go
@@ -0,0 +1,463 @@
+package integration
+
+import (
+ "encoding/json"
+ "fmt"
+ "net/http"
+ "os"
+ "os/exec"
+ "slices"
+ "strings"
+ "syscall"
+ "testing"
+ "time"
+)
+
+// ToolboxImage is the Docker image for the MCP Toolbox for Databases.
+// Used as the MCP server backing integration tests. All test configs
+// that reference a postgres MCP server should use this image.
+const ToolboxImage = "us-central1-docker.pkg.dev/database-toolbox/toolbox/toolbox:latest"
+
+// testPostgresDockerArgs returns the Docker args for running the toolbox
+// as a stdio MCP server against the test postgres database.
+func testPostgresDockerArgs() []string {
+ return []string{
+ "run", "--rm", "-i", "--network", "host",
+ "-e", "POSTGRES_HOST=localhost",
+ "-e", "POSTGRES_PORT=15432",
+ "-e", "POSTGRES_DATABASE=testdb",
+ "-e", "POSTGRES_USER=testuser",
+ "-e", "POSTGRES_PASSWORD=testpass",
+ ToolboxImage,
+ "--stdio", "--prebuilt", "postgres",
+ }
+}
+
+// testPostgresServer returns an MCP server config for the test postgres database.
+// Options can customize auth, logging, etc.
+func testPostgresServer(opts ...serverOption) map[string]any {
+ args := make([]any, len(testPostgresDockerArgs()))
+ for i, a := range testPostgresDockerArgs() {
+ args[i] = a
+ }
+ s := map[string]any{
+ "transportType": "stdio",
+ "command": "docker",
+ "args": args,
+ }
+ for _, opt := range opts {
+ opt(s)
+ }
+ return s
+}
+
+type serverOption func(map[string]any)
+
+func withBearerTokens(tokens ...string) serverOption {
+ return func(s map[string]any) {
+ s["serviceAuths"] = []map[string]any{
+ {"type": "bearer", "tokens": tokens},
+ }
+ }
+}
+
+func withBasicAuth(username, passwordEnvVar string) serverOption {
+ return func(s map[string]any) {
+ auths, _ := s["serviceAuths"].([]map[string]any)
+ auths = append(auths, map[string]any{
+ "type": "basic",
+ "username": username,
+ "password": map[string]string{"$env": passwordEnvVar},
+ })
+ s["serviceAuths"] = auths
+ }
+}
+
+func withLogEnabled() serverOption {
+ return func(s map[string]any) {
+ s["options"] = map[string]any{"logEnabled": true}
+ }
+}
+
+func withUserToken() serverOption {
+ return func(s map[string]any) {
+ s["env"] = map[string]any{
+ "USER_TOKEN": map[string]string{"$userToken": "{{token}}"},
+ }
+ s["requiresUserToken"] = true
+ s["userAuthentication"] = map[string]any{
+ "type": "manual",
+ "displayName": "Test Service",
+ "instructions": "Enter your test token",
+ "helpUrl": "https://example.com/help",
+ }
+ }
+}
+
+// testOAuthConfigFromEnv returns an OAuth auth config that reads secrets from env vars.
+func testOAuthConfigFromEnv() map[string]any {
+ return map[string]any{
+ "kind": "oauth",
+ "issuer": "http://localhost:8080",
+ "gcpProject": "test-project",
+ "idp": map[string]any{
+ "provider": "google",
+ "clientId": map[string]string{"$env": "GOOGLE_CLIENT_ID"},
+ "clientSecret": map[string]string{"$env": "GOOGLE_CLIENT_SECRET"},
+ "redirectUri": "http://localhost:8080/oauth/callback",
+ "authorizationUrl": "http://localhost:9090/auth",
+ "tokenUrl": "http://localhost:9090/token",
+ "userInfoUrl": "http://localhost:9090/userinfo",
+ },
+ "allowedDomains": []string{"test.com", "stainless.com", "claude.ai"},
+ "allowedOrigins": []string{"https://claude.ai"},
+ "tokenTtl": "1h",
+ "storage": "memory",
+ "jwtSecret": map[string]string{"$env": "JWT_SECRET"},
+ "encryptionKey": map[string]string{"$env": "ENCRYPTION_KEY"},
+ }
+}
+
+// testGitHubOAuthConfig returns an OAuth auth config for GitHub IDP testing.
+// Secrets use $env references; callers must pass JWT_SECRET, ENCRYPTION_KEY,
+// and GITHUB_CLIENT_SECRET env vars.
+func testGitHubOAuthConfig(allowedOrgs ...string) map[string]any {
+ idpCfg := map[string]any{
+ "provider": "github",
+ "clientId": "test-github-client-id",
+ "clientSecret": map[string]string{"$env": "GITHUB_CLIENT_SECRET"},
+ "redirectUri": "http://localhost:8080/oauth/callback",
+ "authorizationUrl": "http://localhost:9092/login/oauth/authorize",
+ "tokenUrl": "http://localhost:9092/login/oauth/access_token",
+ "userInfoUrl": "http://localhost:9092",
+ }
+ if len(allowedOrgs) > 0 {
+ idpCfg["allowedOrgs"] = allowedOrgs
+ }
+ return map[string]any{
+ "kind": "oauth",
+ "issuer": "http://localhost:8080",
+ "gcpProject": "test-project",
+ "idp": idpCfg,
+ "allowedDomains": []string{"test.com"},
+ "allowedOrigins": []string{"https://claude.ai"},
+ "tokenTtl": "1h",
+ "storage": "memory",
+ "jwtSecret": map[string]string{"$env": "JWT_SECRET"},
+ "encryptionKey": map[string]string{"$env": "ENCRYPTION_KEY"},
+ }
+}
+
+// testOIDCOAuthConfig returns an OAuth auth config for generic OIDC IDP testing.
+// Secrets use $env references; callers must pass JWT_SECRET, ENCRYPTION_KEY,
+// and OIDC_CLIENT_SECRET env vars.
+func testOIDCOAuthConfig() map[string]any {
+ return map[string]any{
+ "kind": "oauth",
+ "issuer": "http://localhost:8080",
+ "gcpProject": "test-project",
+ "idp": map[string]any{
+ "provider": "oidc",
+ "clientId": "test-oidc-client-id",
+ "clientSecret": map[string]string{"$env": "OIDC_CLIENT_SECRET"},
+ "redirectUri": "http://localhost:8080/oauth/callback",
+ "authorizationUrl": "http://localhost:9093/authorize",
+ "tokenUrl": "http://localhost:9093/token",
+ "userInfoUrl": "http://localhost:9093/userinfo",
+ },
+ "allowedDomains": []string{"oidc-test.com"},
+ "allowedOrigins": []string{"https://claude.ai"},
+ "tokenTtl": "1h",
+ "storage": "memory",
+ "jwtSecret": map[string]string{"$env": "JWT_SECRET"},
+ "encryptionKey": map[string]string{"$env": "ENCRYPTION_KEY"},
+ }
+}
+
+// testAzureOAuthConfig returns an OAuth auth config for Azure IDP testing.
+// Secrets use $env references; callers must pass JWT_SECRET, ENCRYPTION_KEY,
+// and AZURE_CLIENT_SECRET env vars.
+func testAzureOAuthConfig() map[string]any {
+ return map[string]any{
+ "kind": "oauth",
+ "issuer": "http://localhost:8080",
+ "gcpProject": "test-project",
+ "idp": map[string]any{
+ "provider": "azure",
+ "tenantId": "test-tenant",
+ "clientId": "test-azure-client-id",
+ "clientSecret": map[string]string{"$env": "AZURE_CLIENT_SECRET"},
+ "redirectUri": "http://localhost:8080/oauth/callback",
+ "authorizationUrl": "http://localhost:9093/authorize",
+ "tokenUrl": "http://localhost:9093/token",
+ "userInfoUrl": "http://localhost:9093/userinfo",
+ },
+ "allowedDomains": []string{"oidc-test.com"},
+ "allowedOrigins": []string{"https://claude.ai"},
+ "tokenTtl": "1h",
+ "storage": "memory",
+ "jwtSecret": map[string]string{"$env": "JWT_SECRET"},
+ "encryptionKey": map[string]string{"$env": "ENCRYPTION_KEY"},
+ }
+}
+
+// writeTestConfig writes a config map to a temporary JSON file and returns its path.
+// The file is automatically cleaned up when the test finishes.
+func writeTestConfig(t *testing.T, cfg map[string]any) string {
+ t.Helper()
+ data, err := json.MarshalIndent(cfg, "", " ")
+ if err != nil {
+ t.Fatalf("Failed to marshal test config: %v", err)
+ }
+ f, err := os.CreateTemp(t.TempDir(), "config-*.json")
+ if err != nil {
+ t.Fatalf("Failed to create temp config file: %v", err)
+ }
+ if _, err := f.Write(data); err != nil {
+ t.Fatalf("Failed to write temp config: %v", err)
+ }
+ if err := f.Close(); err != nil {
+ t.Fatalf("Failed to close temp config: %v", err)
+ }
+ return f.Name()
+}
+
+// buildTestConfig builds a complete mcp-front config map.
+func buildTestConfig(baseURL, name string, auth map[string]any, mcpServers map[string]any) map[string]any {
+ proxy := map[string]any{
+ "baseURL": baseURL,
+ "addr": ":8080",
+ "name": name,
+ }
+ if auth != nil {
+ proxy["auth"] = auth
+ }
+ return map[string]any{
+ "version": "v0.0.1-DEV_EDITION_EXPECT_CHANGES",
+ "proxy": proxy,
+ "mcpServers": mcpServers,
+ }
+}
+
+// TestConfig holds all timeout configurations for integration tests
+type TestConfig struct {
+ SessionTimeout string
+ CleanupInterval string
+ CleanupWaitTime string
+ TimerResetWaitTime string
+ MultiUserWaitTime string
+}
+
+// GetTestConfig returns test configuration from environment variables or defaults
+func GetTestConfig() TestConfig {
+ c := TestConfig{
+ SessionTimeout: "10s",
+ CleanupInterval: "2s",
+ CleanupWaitTime: "15s",
+ TimerResetWaitTime: "12s",
+ MultiUserWaitTime: "15s",
+ }
+
+ // Override from environment if set
+ if v := os.Getenv("SESSION_TIMEOUT"); v != "" {
+ c.SessionTimeout = v
+ }
+ if v := os.Getenv("SESSION_CLEANUP_INTERVAL"); v != "" {
+ c.CleanupInterval = v
+ }
+ if v := os.Getenv("TEST_CLEANUP_WAIT_TIME"); v != "" {
+ c.CleanupWaitTime = v
+ }
+ if v := os.Getenv("TEST_TIMER_RESET_WAIT_TIME"); v != "" {
+ c.TimerResetWaitTime = v
+ }
+ if v := os.Getenv("TEST_MULTI_USER_WAIT_TIME"); v != "" {
+ c.MultiUserWaitTime = v
+ }
+
+ return c
+}
+
+func waitForDB(t *testing.T) {
+ waitForSec := 5
+ for range waitForSec {
+ // Check if container is running
+ psCmd := exec.Command("docker", "compose", "ps", "-q", "test-postgres")
+ if output, err := psCmd.Output(); err != nil || len(output) == 0 {
+ time.Sleep(1 * time.Second)
+ continue
+ }
+
+ // Check if database is ready
+ checkCmd := exec.Command("docker", "compose", "exec", "-T", "test-postgres", "pg_isready", "-U", "testuser", "-d", "testdb")
+ if err := checkCmd.Run(); err == nil {
+ return
+ }
+ time.Sleep(1 * time.Second)
+ }
+
+ t.Fatalf("Database failed to become ready after %d seconds", waitForSec)
+}
+
+// trace logs a message if TRACE environment variable is set
+func trace(t *testing.T, format string, args ...any) {
+ if os.Getenv("TRACE") == "1" {
+ t.Logf("TRACE: "+format, args...)
+ }
+}
+
+// tracef logs a formatted message to stdout if TRACE is set (for use outside tests)
+func tracef(format string, args ...any) {
+ if os.Getenv("TRACE") == "1" {
+ fmt.Printf("TRACE: "+format+"\n", args...)
+ }
+}
+
+// startMCPFront starts the mcp-front server with the given config
+func startMCPFront(t *testing.T, configPath string, extraEnv ...string) {
+ mcpCmd := exec.Command("../cmd/mcp-front/mcp-front", "-config", configPath)
+
+ // Get test config for session timeouts
+ testConfig := GetTestConfig()
+
+ // Build default environment with test timeouts
+ defaultEnv := []string{
+ "SESSION_TIMEOUT=" + testConfig.SessionTimeout,
+ "SESSION_CLEANUP_INTERVAL=" + testConfig.CleanupInterval,
+ }
+
+ // Start with system environment
+ mcpCmd.Env = os.Environ()
+
+ // Apply defaults first
+ mcpCmd.Env = append(mcpCmd.Env, defaultEnv...)
+
+ // Apply extra env (can override defaults)
+ mcpCmd.Env = append(mcpCmd.Env, extraEnv...)
+
+ // Pass through LOG_LEVEL and LOG_FORMAT if set
+ if logLevel := os.Getenv("LOG_LEVEL"); logLevel != "" {
+ mcpCmd.Env = append(mcpCmd.Env, "LOG_LEVEL="+logLevel)
+ }
+ if logFormat := os.Getenv("LOG_FORMAT"); logFormat != "" {
+ mcpCmd.Env = append(mcpCmd.Env, "LOG_FORMAT="+logFormat)
+ }
+
+ // Capture output to log file if MCP_LOG_FILE is set
+ if logFile := os.Getenv("MCP_LOG_FILE"); logFile != "" {
+ f, err := os.OpenFile(logFile, os.O_CREATE|os.O_WRONLY|os.O_APPEND, 0644)
+ if err == nil {
+ mcpCmd.Stderr = f
+ mcpCmd.Stdout = f
+ t.Cleanup(func() { f.Close() })
+ }
+ }
+
+ if err := mcpCmd.Start(); err != nil {
+ t.Fatalf("Failed to start mcp-front: %v", err)
+ }
+
+ // Register cleanup that runs even if test is killed
+ t.Cleanup(func() {
+ stopMCPFront(mcpCmd)
+ })
+}
+
+// stopMCPFront stops the mcp-front server gracefully
+func stopMCPFront(cmd *exec.Cmd) {
+ if cmd == nil || cmd.Process == nil {
+ return
+ }
+
+ // Try graceful shutdown first (SIGINT)
+ if err := cmd.Process.Signal(syscall.SIGINT); err != nil {
+ // If SIGINT fails, force kill immediately
+ _ = cmd.Process.Kill()
+ _ = cmd.Wait()
+ return
+ }
+
+ // Wait up to 5 seconds for graceful shutdown
+ done := make(chan error, 1)
+ go func() {
+ done <- cmd.Wait()
+ }()
+
+ select {
+ case <-done:
+ // Graceful shutdown completed
+ return
+ case <-time.After(5 * time.Second):
+ // Timeout, force kill
+ _ = cmd.Process.Kill()
+ _ = cmd.Wait()
+ }
+}
+
+// waitForMCPFront waits for the mcp-front server to be ready
+func waitForMCPFront(t *testing.T) {
+ t.Helper()
+ for range 10 {
+ resp, err := http.Get("http://localhost:8080/health")
+ if err == nil && resp.StatusCode == 200 {
+ resp.Body.Close()
+ return
+ }
+ if resp != nil {
+ resp.Body.Close()
+ }
+ time.Sleep(1 * time.Second)
+ }
+ t.Fatal("mcp-front failed to become ready after 10 seconds")
+}
+
+// getMCPContainers returns a list of running toolbox container IDs
+func getMCPContainers() []string {
+ cmd := exec.Command("docker", "ps", "--format", "{{.ID}}", "--filter", "ancestor="+ToolboxImage)
+ output, err := cmd.Output()
+ if err != nil {
+ return nil
+ }
+
+ var containers []string
+ for line := range strings.SplitSeq(strings.TrimSpace(string(output)), "\n") {
+ if line != "" {
+ containers = append(containers, line)
+ }
+ }
+ return containers
+}
+
+// cleanupContainers forces cleanup of containers that weren't in the initial set
+func cleanupContainers(t *testing.T, initialContainers []string) {
+ time.Sleep(2 * time.Second)
+ containers := getMCPContainers()
+ for _, container := range containers {
+ isInitial := slices.Contains(initialContainers, container)
+ if !isInitial {
+ t.Logf("Force stopping container: %s...", container)
+ if err := exec.Command("docker", "stop", container).Run(); err != nil {
+ t.Logf("Failed to stop container %s: %v", container, err)
+ } else {
+ t.Logf("Stopped container: %s", container)
+ }
+ }
+ }
+}
+
+// TestQuickSmoke provides a fast validation test
+func TestQuickSmoke(t *testing.T) {
+ t.Log("Running quick smoke test...")
+
+ // Just verify the test infrastructure works
+ client := NewMCPSSEClient("http://localhost:8080")
+ if client == nil {
+ t.Fatal("Failed to create client")
+ }
+
+ if err := client.Authenticate(); err != nil {
+ t.Fatal("Failed to set up authentication")
+ }
+
+ t.Log("Quick smoke test passed - test infrastructure is working")
+}
diff --git a/integration/test_utils.go b/integration/test_utils.go
deleted file mode 100644
index 7d21f61..0000000
--- a/integration/test_utils.go
+++ /dev/null
@@ -1,917 +0,0 @@
-package integration
-
-import (
- "bufio"
- "bytes"
- "context"
- "encoding/json"
- "fmt"
- "io"
- "net/http"
- "net/url"
- "os"
- "os/exec"
- "slices"
- "strings"
- "sync"
- "syscall"
- "testing"
- "time"
-)
-
-// getDockerComposeCommand returns the appropriate docker compose command
-func getDockerComposeCommand() string {
- // Check if docker compose v2 is available
- cmd := exec.Command("docker", "compose", "version")
- if err := cmd.Run(); err == nil {
- return "docker compose"
- }
- return "docker-compose"
-}
-
-// execDockerCompose executes docker compose with the given arguments
-func execDockerCompose(args ...string) *exec.Cmd {
- dcCmd := getDockerComposeCommand()
- if dcCmd == "docker compose" {
- allArgs := append([]string{"compose"}, args...)
- return exec.Command("docker", allArgs...)
- }
- return exec.Command("docker-compose", args...)
-}
-
-// MCPSSEClient simulates an MCP client for testing
-type MCPSSEClient struct {
- baseURL string
- token string
- sseConn io.ReadCloser
- messageEndpoint string
- sseScanner *bufio.Scanner
- sessionID string
-}
-
-// NewMCPSSEClient creates a new MCP client for testing
-func NewMCPSSEClient(baseURL string) *MCPSSEClient {
- return &MCPSSEClient{
- baseURL: baseURL,
- }
-}
-
-// Authenticate sets up authentication for the client
-func (c *MCPSSEClient) Authenticate() error {
- c.token = "test-token"
- return nil
-}
-
-// SetAuthToken sets a specific auth token for the client
-func (c *MCPSSEClient) SetAuthToken(token string) {
- c.token = token
-}
-
-// Connect establishes an SSE connection and retrieves the message endpoint
-func (c *MCPSSEClient) Connect() error {
- return c.ConnectToServer("postgres")
-}
-
-// ConnectToServer establishes an SSE connection to a specific server
-func (c *MCPSSEClient) ConnectToServer(serverName string) error {
- // Close any existing connection
- if c.sseConn != nil {
- c.sseConn.Close()
- c.sseConn = nil
- c.messageEndpoint = ""
- }
-
- sseURL := c.baseURL + "/" + serverName + "/sse"
- tracef("ConnectToServer: requesting %s", sseURL)
-
- req, err := http.NewRequest("GET", sseURL, nil)
- if err != nil {
- return fmt.Errorf("failed to create SSE request: %v", err)
- }
-
- req.Header.Set("Accept", "text/event-stream")
- req.Header.Set("Authorization", "Bearer "+c.token)
- req.Header.Set("Cache-Control", "no-cache")
- tracef("ConnectToServer: headers set, making request")
-
- // Don't use a timeout on the client for SSE
- client := &http.Client{}
- resp, err := client.Do(req)
- if err != nil {
- return fmt.Errorf("SSE connection failed: %v", err)
- }
-
- tracef("ConnectToServer: got response status %d", resp.StatusCode)
- if resp.StatusCode != 200 {
- body, _ := io.ReadAll(resp.Body)
- resp.Body.Close()
- return fmt.Errorf("SSE connection returned %d: %s", resp.StatusCode, string(body))
- }
-
- // Store the connection
- c.sseConn = resp.Body
- c.sseScanner = bufio.NewScanner(resp.Body)
-
- // Read initial SSE messages to get the endpoint
- // For inline servers, we don't get a message endpoint - we use the server path directly
- gotEndpointMessage := false
- for c.sseScanner.Scan() {
- line := c.sseScanner.Text()
- tracef("ConnectToServer: SSE line: %s", line)
-
- // Look for data lines
- if after, ok := strings.CutPrefix(line, "data: "); ok {
- data := after
-
- // Check if it's an endpoint message (for inline servers)
- if strings.Contains(data, `"type":"endpoint"`) {
- gotEndpointMessage = true
- // For inline servers, construct the message endpoint
- c.messageEndpoint = c.baseURL + "/" + serverName + "/message"
- tracef("ConnectToServer: inline server detected, using endpoint: %s", c.messageEndpoint)
- break
- }
-
- // Check if it's a message endpoint URL (for stdio servers)
- if strings.Contains(data, "http://") || strings.Contains(data, "https://") {
- c.messageEndpoint = data
-
- // Extract session ID from endpoint URL
- if u, err := url.Parse(data); err == nil {
- c.sessionID = u.Query().Get("sessionId")
- }
-
- tracef("ConnectToServer: found endpoint: %s", c.messageEndpoint)
- break
- }
- }
- }
-
- if c.messageEndpoint == "" && !gotEndpointMessage {
- c.sseConn.Close()
- c.sseConn = nil
- return fmt.Errorf("no message endpoint received")
- }
-
- tracef("Connect: successfully connected to MCP server")
- return nil
-}
-
-// ValidateBackendConnectivity checks if we can connect to the MCP server
-func (c *MCPSSEClient) ValidateBackendConnectivity() error {
- return c.Connect()
-}
-
-// Close closes the SSE connection
-func (c *MCPSSEClient) Close() {
- if c.sseConn != nil {
- c.sseConn.Close()
- c.sseConn = nil
- c.messageEndpoint = ""
- c.sseScanner = nil
- }
-}
-
-// SendMCPRequest sends an MCP JSON-RPC request and returns the response
-func (c *MCPSSEClient) SendMCPRequest(method string, params any) (map[string]any, error) {
- // Ensure we have a connection
- if c.messageEndpoint == "" {
- if err := c.Connect(); err != nil {
- return nil, fmt.Errorf("failed to connect: %v", err)
- }
- }
-
- // Send MCP request to the message endpoint
- request := map[string]any{
- "jsonrpc": "2.0",
- "id": 1,
- "method": method,
- "params": params,
- }
-
- reqBody, err := json.Marshal(request)
- if err != nil {
- return nil, err
- }
-
- msgReq, err := http.NewRequest("POST", c.messageEndpoint, bytes.NewBuffer(reqBody))
- if err != nil {
- return nil, err
- }
-
- msgReq.Header.Set("Content-Type", "application/json")
- msgReq.Header.Set("Authorization", "Bearer "+c.token)
-
- client := &http.Client{Timeout: 30 * time.Second}
- msgResp, err := client.Do(msgReq)
- if err != nil {
- return nil, err
- }
- defer msgResp.Body.Close()
-
- respBody, err := io.ReadAll(msgResp.Body)
- if err != nil {
- return nil, err
- }
-
- if msgResp.StatusCode != 200 && msgResp.StatusCode != 202 {
- return nil, fmt.Errorf("MCP request failed: %d - %s", msgResp.StatusCode, string(respBody))
- }
-
- // Handle 202 and empty responses - read response from SSE stream
- if msgResp.StatusCode == 202 || len(respBody) == 0 {
- // Read response from SSE stream
- for c.sseScanner.Scan() {
- line := c.sseScanner.Text()
-
- if after, ok := strings.CutPrefix(line, "data: "); ok {
- data := after
- // Try to parse as JSON
- var msg map[string]any
- if err := json.Unmarshal([]byte(data), &msg); err == nil {
- // Check if this is our response (matching ID)
- if id, ok := msg["id"]; ok && id == float64(1) {
- return msg, nil
- }
- }
- }
- }
-
- if err := c.sseScanner.Err(); err != nil {
- return nil, fmt.Errorf("SSE scanner error: %v", err)
- }
-
- return nil, fmt.Errorf("no response received from SSE stream")
- }
-
- var result map[string]any
- if err := json.Unmarshal(respBody, &result); err != nil {
- return nil, fmt.Errorf("failed to parse response: %v - %s", err, string(respBody))
- }
-
- return result, nil
-}
-
-// MCPStreamableClient is a test client for HTTP-Streamable MCP servers
-type MCPStreamableClient struct {
- baseURL string
- serverName string
- token string
- httpClient *http.Client
-
- // For GET SSE streaming
- sseConn io.ReadCloser
- sseScanner *bufio.Scanner
- sseCancel chan struct{}
-
- mu sync.Mutex
-}
-
-// NewMCPStreamableClient creates a new streamable-http test client
-func NewMCPStreamableClient(baseURL string) *MCPStreamableClient {
- return &MCPStreamableClient{
- baseURL: baseURL,
- httpClient: &http.Client{
- Timeout: 30 * time.Second,
- },
- }
-}
-
-// SetAuthToken sets the authentication token
-func (c *MCPStreamableClient) SetAuthToken(token string) {
- c.token = token
-}
-
-// ConnectToServer establishes connection to a streamable-http server
-func (c *MCPStreamableClient) ConnectToServer(serverName string) error {
- c.mu.Lock()
- defer c.mu.Unlock()
-
- // Close any existing connection
- c.close()
-
- c.serverName = serverName
-
- // For streamable-http, we can optionally open a GET SSE stream for server-initiated messages
- // But it's not required for basic request/response
- return c.openSSEStream()
-}
-
-// openSSEStream opens a GET SSE connection for receiving server-initiated messages
-func (c *MCPStreamableClient) openSSEStream() error {
- url := c.baseURL + "/" + c.serverName + "/sse"
-
- req, err := http.NewRequest("GET", url, nil)
- if err != nil {
- return fmt.Errorf("failed to create GET request: %v", err)
- }
-
- req.Header.Set("Accept", "text/event-stream")
- req.Header.Set("Authorization", "Bearer "+c.token)
- req.Header.Set("Cache-Control", "no-cache")
-
- // Use a client without timeout for SSE
- sseClient := &http.Client{}
- resp, err := sseClient.Do(req)
- if err != nil {
- return fmt.Errorf("SSE connection failed: %v", err)
- }
-
- if resp.StatusCode != 200 {
- body, _ := io.ReadAll(resp.Body)
- resp.Body.Close()
- return fmt.Errorf("SSE connection returned %d: %s", resp.StatusCode, string(body))
- }
-
- c.sseConn = resp.Body
- c.sseScanner = bufio.NewScanner(resp.Body)
- c.sseCancel = make(chan struct{})
-
- // Start reading SSE messages in background
- go c.readSSEMessages()
-
- return nil
-}
-
-// readSSEMessages reads server-initiated messages from the SSE stream
-func (c *MCPStreamableClient) readSSEMessages() {
- for {
- select {
- case <-c.sseCancel:
- return
- default:
- if c.sseScanner.Scan() {
- line := c.sseScanner.Text()
- if after, ok := strings.CutPrefix(line, "data: "); ok {
- data := after
- // In a real implementation, we'd process server-initiated messages here
- tracef("StreamableClient: received SSE message: %s", data)
- }
- } else {
- // Scanner stopped - connection closed or error
- return
- }
- }
- }
-}
-
-// SendMCPRequest sends a JSON-RPC request via POST
-func (c *MCPStreamableClient) SendMCPRequest(method string, params any) (map[string]any, error) {
- c.mu.Lock()
- serverName := c.serverName
- c.mu.Unlock()
-
- if serverName == "" {
- return nil, fmt.Errorf("not connected to any server")
- }
-
- // For streamable-http, we POST to the server endpoint
- url := c.baseURL + "/" + serverName + "/sse"
-
- // Construct JSON-RPC request
- request := map[string]any{
- "jsonrpc": "2.0",
- "id": 1,
- "method": method,
- "params": params,
- }
-
- body, err := json.Marshal(request)
- if err != nil {
- return nil, fmt.Errorf("failed to marshal request: %v", err)
- }
-
- req, err := http.NewRequest("POST", url, bytes.NewReader(body))
- if err != nil {
- return nil, fmt.Errorf("failed to create POST request: %v", err)
- }
-
- req.Header.Set("Content-Type", "application/json")
- req.Header.Set("Authorization", "Bearer "+c.token)
- // Accept both JSON and SSE responses
- req.Header.Set("Accept", "application/json, text/event-stream")
-
- resp, err := c.httpClient.Do(req)
- if err != nil {
- return nil, fmt.Errorf("request failed: %v", err)
- }
- defer resp.Body.Close()
-
- // Check content type to determine response format
- contentType := resp.Header.Get("Content-Type")
-
- if strings.HasPrefix(contentType, "text/event-stream") {
- // Handle SSE response
- return c.handleSSEResponse(resp.Body)
- } else {
- // Handle JSON response
- return c.handleJSONResponse(resp.Body)
- }
-}
-
-// handleJSONResponse processes a regular JSON response
-func (c *MCPStreamableClient) handleJSONResponse(body io.Reader) (map[string]any, error) {
- var response map[string]any
- if err := json.NewDecoder(body).Decode(&response); err != nil {
- return nil, fmt.Errorf("failed to decode JSON response: %v", err)
- }
- return response, nil
-}
-
-// handleSSEResponse processes an SSE stream response from a POST
-func (c *MCPStreamableClient) handleSSEResponse(body io.Reader) (map[string]any, error) {
- scanner := bufio.NewScanner(body)
- var lastResponse map[string]any
-
- for scanner.Scan() {
- line := scanner.Text()
- if after, ok := strings.CutPrefix(line, "data: "); ok {
- data := after
- var msg map[string]any
- if err := json.Unmarshal([]byte(data), &msg); err == nil {
- // Keep the last response with an ID (not a notification)
- if _, hasID := msg["id"]; hasID {
- lastResponse = msg
- }
- }
- }
- }
-
- if lastResponse == nil {
- return nil, fmt.Errorf("no response received in SSE stream")
- }
-
- return lastResponse, nil
-}
-
-// Close closes all connections
-func (c *MCPStreamableClient) Close() {
- c.mu.Lock()
- defer c.mu.Unlock()
- c.close()
-}
-
-// close is the internal close method (must be called with lock held)
-func (c *MCPStreamableClient) close() {
- if c.sseCancel != nil {
- close(c.sseCancel)
- c.sseCancel = nil
- }
-
- if c.sseConn != nil {
- c.sseConn.Close()
- c.sseConn = nil
- c.sseScanner = nil
- }
-
- c.serverName = ""
-}
-
-// FakeGCPServer provides a fake GCP OAuth server for testing
-type FakeGCPServer struct {
- server *http.Server
- port string
-}
-
-// NewFakeGCPServer creates a new fake GCP server
-func NewFakeGCPServer(port string) *FakeGCPServer {
- mux := http.NewServeMux()
-
- mux.HandleFunc("/auth", func(w http.ResponseWriter, r *http.Request) {
- redirectURI := r.URL.Query().Get("redirect_uri")
- state := r.URL.Query().Get("state")
- http.Redirect(w, r, fmt.Sprintf("%s?code=test-auth-code&state=%s", redirectURI, state), http.StatusFound)
- })
-
- mux.HandleFunc("/token", func(w http.ResponseWriter, r *http.Request) {
- // Parse the form data
- if err := r.ParseForm(); err != nil {
- http.Error(w, "Invalid request", http.StatusBadRequest)
- return
- }
-
- // Check the authorization code
- code := r.FormValue("code")
- if code != "test-auth-code" {
- w.Header().Set("Content-Type", "application/json")
- w.WriteHeader(http.StatusBadRequest)
- _ = json.NewEncoder(w).Encode(map[string]any{
- "error": "invalid_grant",
- "error_description": "Invalid authorization code",
- })
- return
- }
-
- w.Header().Set("Content-Type", "application/json")
- _ = json.NewEncoder(w).Encode(map[string]any{
- "access_token": "test-access-token",
- "token_type": "Bearer",
- "expires_in": 3600,
- })
- })
-
- mux.HandleFunc("/userinfo", func(w http.ResponseWriter, r *http.Request) {
- w.Header().Set("Content-Type", "application/json")
- _ = json.NewEncoder(w).Encode(map[string]any{
- "email": "test@test.com",
- "hd": "test.com",
- })
- })
-
- server := &http.Server{
- Addr: ":" + port,
- Handler: mux,
- }
-
- return &FakeGCPServer{
- server: server,
- port: port,
- }
-}
-
-// Start starts the fake GCP server
-func (m *FakeGCPServer) Start() error {
- go func() {
- if err := m.server.ListenAndServe(); err != nil && err != http.ErrServerClosed {
- panic(err)
- }
- }()
-
- time.Sleep(100 * time.Millisecond)
- return nil
-}
-
-// Stop stops the fake GCP server
-func (m *FakeGCPServer) Stop() error {
- ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
- defer cancel()
- return m.server.Shutdown(ctx)
-}
-
-// FakeServiceOAuthServer provides a fake OAuth server for external services (like Linear, GitHub)
-type FakeServiceOAuthServer struct {
- server *http.Server
- port string
-}
-
-// NewFakeServiceOAuthServer creates a new fake service OAuth server
-func NewFakeServiceOAuthServer(port string) *FakeServiceOAuthServer {
- mux := http.NewServeMux()
-
- mux.HandleFunc("/oauth/authorize", func(w http.ResponseWriter, r *http.Request) {
- redirectURI := r.URL.Query().Get("redirect_uri")
- state := r.URL.Query().Get("state")
- http.Redirect(w, r, fmt.Sprintf("%s?code=service-auth-code&state=%s", redirectURI, state), http.StatusFound)
- })
-
- mux.HandleFunc("/oauth/token", func(w http.ResponseWriter, r *http.Request) {
- if err := r.ParseForm(); err != nil {
- http.Error(w, "Invalid request", http.StatusBadRequest)
- return
- }
-
- code := r.FormValue("code")
- if code != "service-auth-code" {
- w.Header().Set("Content-Type", "application/json")
- w.WriteHeader(http.StatusBadRequest)
- _ = json.NewEncoder(w).Encode(map[string]any{
- "error": "invalid_grant",
- "error_description": "Invalid authorization code",
- })
- return
- }
-
- w.Header().Set("Content-Type", "application/json")
- _ = json.NewEncoder(w).Encode(map[string]any{
- "access_token": "service-oauth-access-token",
- "refresh_token": "service-oauth-refresh-token",
- "token_type": "Bearer",
- "expires_in": 3600,
- })
- })
-
- server := &http.Server{
- Addr: ":" + port,
- Handler: mux,
- }
-
- return &FakeServiceOAuthServer{
- server: server,
- port: port,
- }
-}
-
-// Start starts the fake service OAuth server
-func (s *FakeServiceOAuthServer) Start() error {
- go func() {
- if err := s.server.ListenAndServe(); err != nil && err != http.ErrServerClosed {
- panic(err)
- }
- }()
-
- time.Sleep(100 * time.Millisecond)
- return nil
-}
-
-// Stop stops the fake service OAuth server
-func (s *FakeServiceOAuthServer) Stop() error {
- ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
- defer cancel()
- return s.server.Shutdown(ctx)
-}
-
-// TestEnvironment manages the complete test environment
-type TestEnvironment struct {
- dbCmd *exec.Cmd
- mcpCmd *exec.Cmd
- fakeGCP *FakeGCPServer
- client *MCPSSEClient
-}
-
-// SetupTestEnvironment creates and starts all components needed for testing
-func SetupTestEnvironment(t *testing.T) *TestEnvironment {
- env := &TestEnvironment{}
-
- // Start test database
- t.Log("🚀 Starting test database...")
- env.dbCmd = execDockerCompose("-f", "config/docker-compose.test.yml", "up", "-d")
- if err := env.dbCmd.Run(); err != nil {
- t.Fatalf("Failed to start test database: %v", err)
- }
-
- time.Sleep(10 * time.Second)
-
- // Start mock GCP server
- t.Log("🚀 Starting mock GCP server...")
- env.fakeGCP = NewFakeGCPServer("9090")
- if err := env.fakeGCP.Start(); err != nil {
- t.Fatalf("Failed to start mock GCP server: %v", err)
- }
-
- // Start mcp-front
- t.Log("🚀 Starting mcp-front...")
- env.mcpCmd = exec.Command("../cmd/mcp-front/mcp-front", "-config", "config/config.test.json")
-
- // Capture stderr to log file if MCP_LOG_FILE is set
- if logFile := os.Getenv("MCP_LOG_FILE"); logFile != "" {
- f, err := os.OpenFile(logFile, os.O_CREATE|os.O_WRONLY|os.O_APPEND, 0644)
- if err == nil {
- env.mcpCmd.Stderr = f
- env.mcpCmd.Stdout = f
- t.Cleanup(func() { f.Close() })
- }
- }
-
- if err := env.mcpCmd.Start(); err != nil {
- t.Fatalf("Failed to start mcp-front: %v", err)
- }
-
- time.Sleep(15 * time.Second)
-
- // Create and authenticate client
- env.client = NewMCPSSEClient("http://localhost:8080")
- if err := env.client.Authenticate(); err != nil {
- t.Fatalf("Authentication failed: %v", err)
- }
-
- return env
-}
-
-// Cleanup stops all test environment components
-func (env *TestEnvironment) Cleanup() {
- if env.mcpCmd != nil && env.mcpCmd.Process != nil {
- _ = env.mcpCmd.Process.Kill()
- }
-
- if env.fakeGCP != nil {
- _ = env.fakeGCP.Stop()
- }
-
- if env.dbCmd != nil {
- downCmd := execDockerCompose("-f", "config/docker-compose.test.yml", "down", "-v")
- _ = downCmd.Run()
- }
-}
-
-// TestConfig holds all timeout configurations for integration tests
-type TestConfig struct {
- SessionTimeout string
- CleanupInterval string
- CleanupWaitTime string
- TimerResetWaitTime string
- MultiUserWaitTime string
-}
-
-// GetTestConfig returns test configuration from environment variables or defaults
-func GetTestConfig() TestConfig {
- c := TestConfig{
- SessionTimeout: "10s",
- CleanupInterval: "2s",
- CleanupWaitTime: "15s",
- TimerResetWaitTime: "12s",
- MultiUserWaitTime: "15s",
- }
-
- // Override from environment if set
- if v := os.Getenv("SESSION_TIMEOUT"); v != "" {
- c.SessionTimeout = v
- }
- if v := os.Getenv("SESSION_CLEANUP_INTERVAL"); v != "" {
- c.CleanupInterval = v
- }
- if v := os.Getenv("TEST_CLEANUP_WAIT_TIME"); v != "" {
- c.CleanupWaitTime = v
- }
- if v := os.Getenv("TEST_TIMER_RESET_WAIT_TIME"); v != "" {
- c.TimerResetWaitTime = v
- }
- if v := os.Getenv("TEST_MULTI_USER_WAIT_TIME"); v != "" {
- c.MultiUserWaitTime = v
- }
-
- return c
-}
-
-func waitForDB(t *testing.T) {
- waitForSec := 5
- for range waitForSec {
- // Check if container is running
- psCmd := exec.Command("docker", "compose", "ps", "-q", "test-postgres")
- if output, err := psCmd.Output(); err != nil || len(output) == 0 {
- time.Sleep(1 * time.Second)
- continue
- }
-
- // Check if database is ready
- checkCmd := exec.Command("docker", "compose", "exec", "-T", "test-postgres", "pg_isready", "-U", "testuser", "-d", "testdb")
- if err := checkCmd.Run(); err == nil {
- return
- }
- time.Sleep(1 * time.Second)
- }
-
- t.Fatalf("Database failed to become ready after %d seconds", waitForSec)
-}
-
-// trace logs a message if TRACE environment variable is set
-func trace(t *testing.T, format string, args ...any) {
- if os.Getenv("TRACE") == "1" {
- t.Logf("TRACE: "+format, args...)
- }
-}
-
-// tracef logs a formatted message to stdout if TRACE is set (for use outside tests)
-func tracef(format string, args ...any) {
- if os.Getenv("TRACE") == "1" {
- fmt.Printf("TRACE: "+format+"\n", args...)
- }
-}
-
-// startMCPFront starts the mcp-front server with the given config
-func startMCPFront(t *testing.T, configPath string, extraEnv ...string) {
- mcpCmd := exec.Command("../cmd/mcp-front/mcp-front", "-config", configPath)
-
- // Get test config for session timeouts
- testConfig := GetTestConfig()
-
- // Build default environment with test timeouts
- defaultEnv := []string{
- "SESSION_TIMEOUT=" + testConfig.SessionTimeout,
- "SESSION_CLEANUP_INTERVAL=" + testConfig.CleanupInterval,
- }
-
- // Start with system environment
- mcpCmd.Env = os.Environ()
-
- // Apply defaults first
- mcpCmd.Env = append(mcpCmd.Env, defaultEnv...)
-
- // Apply extra env (can override defaults)
- mcpCmd.Env = append(mcpCmd.Env, extraEnv...)
-
- // Pass through LOG_LEVEL and LOG_FORMAT if set
- if logLevel := os.Getenv("LOG_LEVEL"); logLevel != "" {
- mcpCmd.Env = append(mcpCmd.Env, "LOG_LEVEL="+logLevel)
- }
- if logFormat := os.Getenv("LOG_FORMAT"); logFormat != "" {
- mcpCmd.Env = append(mcpCmd.Env, "LOG_FORMAT="+logFormat)
- }
-
- // Capture output to log file if MCP_LOG_FILE is set
- if logFile := os.Getenv("MCP_LOG_FILE"); logFile != "" {
- f, err := os.OpenFile(logFile, os.O_CREATE|os.O_WRONLY|os.O_APPEND, 0644)
- if err == nil {
- mcpCmd.Stderr = f
- mcpCmd.Stdout = f
- t.Cleanup(func() { f.Close() })
- }
- }
-
- if err := mcpCmd.Start(); err != nil {
- t.Fatalf("Failed to start mcp-front: %v", err)
- }
-
- // Register cleanup that runs even if test is killed
- t.Cleanup(func() {
- stopMCPFront(mcpCmd)
- })
-}
-
-// stopMCPFront stops the mcp-front server gracefully
-func stopMCPFront(cmd *exec.Cmd) {
- if cmd == nil || cmd.Process == nil {
- return
- }
-
- // Try graceful shutdown first (SIGINT)
- if err := cmd.Process.Signal(syscall.SIGINT); err != nil {
- // If SIGINT fails, force kill immediately
- _ = cmd.Process.Kill()
- _ = cmd.Wait()
- return
- }
-
- // Wait up to 5 seconds for graceful shutdown
- done := make(chan error, 1)
- go func() {
- done <- cmd.Wait()
- }()
-
- select {
- case <-done:
- // Graceful shutdown completed
- return
- case <-time.After(5 * time.Second):
- // Timeout, force kill
- _ = cmd.Process.Kill()
- _ = cmd.Wait()
- }
-}
-
-// waitForMCPFront waits for the mcp-front server to be ready
-func waitForMCPFront(t *testing.T) {
- t.Helper()
- for range 10 {
- resp, err := http.Get("http://localhost:8080/health")
- if err == nil && resp.StatusCode == 200 {
- resp.Body.Close()
- return
- }
- if resp != nil {
- resp.Body.Close()
- }
- time.Sleep(1 * time.Second)
- }
- t.Fatal("mcp-front failed to become ready after 10 seconds")
-}
-
-// getMCPContainers returns a list of running mcp/postgres container IDs
-func getMCPContainers() []string {
- cmd := exec.Command("docker", "ps", "--format", "{{.ID}}", "--filter", "ancestor=mcp/postgres")
- output, err := cmd.Output()
- if err != nil {
- return nil
- }
-
- var containers []string
- for line := range strings.SplitSeq(strings.TrimSpace(string(output)), "\n") {
- if line != "" {
- containers = append(containers, line)
- }
- }
- return containers
-}
-
-// cleanupContainers forces cleanup of containers that weren't in the initial set
-func cleanupContainers(t *testing.T, initialContainers []string) {
- time.Sleep(2 * time.Second)
- containers := getMCPContainers()
- for _, container := range containers {
- isInitial := slices.Contains(initialContainers, container)
- if !isInitial {
- t.Logf("Force stopping container: %s...", container)
- if err := exec.Command("docker", "stop", container).Run(); err != nil {
- t.Logf("Failed to stop container %s: %v", container, err)
- } else {
- t.Logf("Stopped container: %s", container)
- }
- }
- }
-}
-
-// TestQuickSmoke provides a fast validation test
-func TestQuickSmoke(t *testing.T) {
- t.Log("Running quick smoke test...")
-
- // Just verify the test infrastructure works
- client := NewMCPSSEClient("http://localhost:8080")
- if client == nil {
- t.Fatal("Failed to create client")
- }
-
- if err := client.Authenticate(); err != nil {
- t.Fatal("Failed to set up authentication")
- }
-
- t.Log("Quick smoke test passed - test infrastructure is working")
-}
diff --git a/internal/adminauth/admin.go b/internal/adminauth/admin.go
index 5f35c18..977c1a5 100644
--- a/internal/adminauth/admin.go
+++ b/internal/adminauth/admin.go
@@ -22,8 +22,16 @@ func IsAdmin(ctx context.Context, email string, adminConfig *config.AdminConfig,
return true
}
- // Check if user is a promoted admin in storage
+ // Check if user is a promoted admin in storage.
+ // Try exact match first (O(1)), fall back to case-insensitive scan
+ // since emails may be stored with different casing.
if store != nil {
+ user, err := store.GetUser(ctx, normalizedEmail)
+ if err == nil {
+ return user.IsAdmin
+ }
+
+ // Fallback: case-insensitive scan for emails stored with different casing
users, err := store.GetAllUsers(ctx)
if err == nil {
for _, user := range users {
diff --git a/internal/browserauth/session.go b/internal/browserauth/session.go
deleted file mode 100644
index a1876d9..0000000
--- a/internal/browserauth/session.go
+++ /dev/null
@@ -1,15 +0,0 @@
-package browserauth
-
-import "time"
-
-// SessionCookie represents the data stored in encrypted browser session cookies
-type SessionCookie struct {
- Email string `json:"email"`
- Expires time.Time `json:"expires"`
-}
-
-// AuthorizationState represents the OAuth authorization code flow state parameter
-type AuthorizationState struct {
- Nonce string `json:"nonce"`
- ReturnURL string `json:"return_url"`
-}
diff --git a/internal/envutil/envutil.go b/internal/config/env.go
similarity index 94%
rename from internal/envutil/envutil.go
rename to internal/config/env.go
index ef5acfc..6297707 100644
--- a/internal/envutil/envutil.go
+++ b/internal/config/env.go
@@ -1,4 +1,4 @@
-package envutil
+package config
import (
"os"
diff --git a/internal/config/load.go b/internal/config/load.go
index 326cec1..768ebf0 100644
--- a/internal/config/load.go
+++ b/internal/config/load.go
@@ -58,11 +58,11 @@ func validateRawConfig(rawConfig map[string]any) error {
if proxy, ok := rawConfig["proxy"].(map[string]any); ok {
if auth, ok := proxy["auth"].(map[string]any); ok {
if kind, ok := auth["kind"].(string); ok && kind == "oauth" {
+ // Validate top-level auth secrets
secrets := []struct {
name string
required bool
}{
- {"googleClientSecret", true},
{"jwtSecret", true},
{"encryptionKey", true}, // Always required for OAuth
}
@@ -88,6 +88,20 @@ func validateRawConfig(rawConfig map[string]any) error {
}
}
}
+
+ // Validate IDP secret (clientSecret must be env ref)
+ if idp, ok := auth["idp"].(map[string]any); ok {
+ if clientSecret, exists := idp["clientSecret"]; exists {
+ if _, isString := clientSecret.(string); isString {
+ return fmt.Errorf("idp.clientSecret must use environment variable reference for security")
+ }
+ if refMap, isMap := clientSecret.(map[string]any); isMap {
+ if _, hasEnv := refMap["$env"]; !hasEnv {
+ return fmt.Errorf("idp.clientSecret must use {\"$env\": \"VAR_NAME\"} format")
+ }
+ }
+ }
+ }
}
}
}
@@ -148,24 +162,54 @@ func validateOAuthConfig(oauth *OAuthAuthConfig) error {
if oauth.Issuer == "" {
return fmt.Errorf("issuer is required")
}
- if oauth.GoogleClientID == "" {
- return fmt.Errorf("googleClientId is required")
+
+ // Validate IDP configuration
+ if oauth.IDP.Provider == "" {
+ return fmt.Errorf("idp.provider is required (google, azure, github, or oidc)")
+ }
+ if oauth.IDP.ClientID == "" {
+ return fmt.Errorf("idp.clientId is required")
}
- if oauth.GoogleClientSecret == "" {
- return fmt.Errorf("googleClientSecret is required")
+ if oauth.IDP.ClientSecret == "" {
+ return fmt.Errorf("idp.clientSecret is required")
}
- if oauth.GoogleRedirectURI == "" {
- return fmt.Errorf("googleRedirectUri is required")
+ if oauth.IDP.RedirectURI == "" {
+ return fmt.Errorf("idp.redirectUri is required")
}
+
+ // Provider-specific validation
+ switch oauth.IDP.Provider {
+ case "google":
+ // No additional validation needed
+ case "azure":
+ if oauth.IDP.TenantID == "" {
+ return fmt.Errorf("idp.tenantId is required for Azure AD")
+ }
+ case "github":
+ // No additional validation needed
+ case "oidc":
+ // Either discoveryUrl or all manual endpoints required
+ if oauth.IDP.DiscoveryURL == "" {
+ if oauth.IDP.AuthorizationURL == "" || oauth.IDP.TokenURL == "" || oauth.IDP.UserInfoURL == "" {
+ return fmt.Errorf("idp.discoveryUrl or all of (authorizationUrl, tokenUrl, userInfoUrl) required for OIDC")
+ }
+ }
+ default:
+ return fmt.Errorf("unsupported idp.provider: %s (must be google, azure, github, or oidc)", oauth.IDP.Provider)
+ }
+
if len(oauth.JWTSecret) < 32 {
return fmt.Errorf("jwtSecret must be at least 32 characters (got %d). Generate with: openssl rand -base64 32", len(oauth.JWTSecret))
}
if len(oauth.EncryptionKey) != 32 {
return fmt.Errorf("encryptionKey must be exactly 32 characters (got %d). Generate with: openssl rand -base64 32 | head -c 32", len(oauth.EncryptionKey))
}
- if len(oauth.AllowedDomains) == 0 {
- return fmt.Errorf("at least one allowed domain is required")
+
+ // Domain or org validation - at least one access control mechanism required
+ if len(oauth.AllowedDomains) == 0 && len(oauth.IDP.AllowedOrgs) == 0 {
+ return fmt.Errorf("at least one of allowedDomains or idp.allowedOrgs is required")
}
+
if oauth.Storage == "firestore" {
if oauth.GCPProject == "" {
return fmt.Errorf("gcpProject is required when using firestore storage")
diff --git a/internal/config/load_test.go b/internal/config/load_test.go
index a79b8e3..4fb5659 100644
--- a/internal/config/load_test.go
+++ b/internal/config/load_test.go
@@ -41,17 +41,20 @@ func TestValidateConfig_UserTokensRequireOAuth(t *testing.T) {
BaseURL: "https://test.example.com",
Addr: ":8080",
Auth: &OAuthAuthConfig{
- Kind: "oauth",
- Issuer: "https://auth.example.com",
- GoogleClientID: "test-client",
- GoogleClientSecret: "test-secret",
- GoogleRedirectURI: "https://test.example.com/callback",
- JWTSecret: "test-jwt-secret-must-be-32-bytes-long",
- EncryptionKey: "test-encryption-key-32-bytes-ok!",
- AllowedDomains: []string{"example.com"},
- AllowedOrigins: []string{"https://test.example.com"},
- TokenTTL: time.Hour,
- RefreshTokenTTL: 30 * 24 * time.Hour,
+ Kind: "oauth",
+ Issuer: "https://auth.example.com",
+ IDP: IDPConfig{
+ Provider: "google",
+ ClientID: "test-client",
+ ClientSecret: "test-secret",
+ RedirectURI: "https://test.example.com/callback",
+ },
+ JWTSecret: "test-jwt-secret-must-be-32-bytes-long",
+ EncryptionKey: "test-encryption-key-32-bytes-ok!",
+ AllowedDomains: []string{"example.com"},
+ AllowedOrigins: []string{"https://test.example.com"},
+ TokenTTL: time.Hour,
+ RefreshTokenTTL: 30 * 24 * time.Hour,
},
},
MCPServers: map[string]*MCPClientConfig{
@@ -98,17 +101,20 @@ func TestValidateConfig_SessionConfig(t *testing.T) {
BaseURL: "https://test.example.com",
Addr: ":8080",
Auth: &OAuthAuthConfig{
- Kind: "oauth",
- Issuer: "https://auth.example.com",
- GoogleClientID: "test-client",
- GoogleClientSecret: "test-secret",
- GoogleRedirectURI: "https://test.example.com/callback",
- JWTSecret: "test-jwt-secret-must-be-32-bytes-long",
- EncryptionKey: "test-encryption-key-32-bytes-ok!",
- AllowedDomains: []string{"example.com"},
- AllowedOrigins: []string{"https://test.example.com"},
- TokenTTL: time.Hour,
- RefreshTokenTTL: 30 * 24 * time.Hour,
+ Kind: "oauth",
+ Issuer: "https://auth.example.com",
+ IDP: IDPConfig{
+ Provider: "google",
+ ClientID: "test-client",
+ ClientSecret: "test-secret",
+ RedirectURI: "https://test.example.com/callback",
+ },
+ JWTSecret: "test-jwt-secret-must-be-32-bytes-long",
+ EncryptionKey: "test-encryption-key-32-bytes-ok!",
+ AllowedDomains: []string{"example.com"},
+ AllowedOrigins: []string{"https://test.example.com"},
+ TokenTTL: time.Hour,
+ RefreshTokenTTL: 30 * 24 * time.Hour,
},
Sessions: &SessionConfig{
Timeout: 10 * time.Minute,
@@ -128,17 +134,20 @@ func TestValidateConfig_SessionConfig(t *testing.T) {
BaseURL: "https://test.example.com",
Addr: ":8080",
Auth: &OAuthAuthConfig{
- Kind: "oauth",
- Issuer: "https://auth.example.com",
- GoogleClientID: "test-client",
- GoogleClientSecret: "test-secret",
- GoogleRedirectURI: "https://test.example.com/callback",
- JWTSecret: "test-jwt-secret-must-be-32-bytes-long",
- EncryptionKey: "test-encryption-key-32-bytes-ok!",
- AllowedDomains: []string{"example.com"},
- AllowedOrigins: []string{"https://test.example.com"},
- TokenTTL: time.Hour,
- RefreshTokenTTL: 30 * 24 * time.Hour,
+ Kind: "oauth",
+ Issuer: "https://auth.example.com",
+ IDP: IDPConfig{
+ Provider: "google",
+ ClientID: "test-client",
+ ClientSecret: "test-secret",
+ RedirectURI: "https://test.example.com/callback",
+ },
+ JWTSecret: "test-jwt-secret-must-be-32-bytes-long",
+ EncryptionKey: "test-encryption-key-32-bytes-ok!",
+ AllowedDomains: []string{"example.com"},
+ AllowedOrigins: []string{"https://test.example.com"},
+ TokenTTL: time.Hour,
+ RefreshTokenTTL: 30 * 24 * time.Hour,
},
Sessions: &SessionConfig{
Timeout: -1 * time.Minute,
@@ -156,17 +165,20 @@ func TestValidateConfig_SessionConfig(t *testing.T) {
BaseURL: "https://test.example.com",
Addr: ":8080",
Auth: &OAuthAuthConfig{
- Kind: "oauth",
- Issuer: "https://auth.example.com",
- GoogleClientID: "test-client",
- GoogleClientSecret: "test-secret",
- GoogleRedirectURI: "https://test.example.com/callback",
- JWTSecret: "test-jwt-secret-must-be-32-bytes-long",
- EncryptionKey: "test-encryption-key-32-bytes-ok!",
- AllowedDomains: []string{"example.com"},
- AllowedOrigins: []string{"https://test.example.com"},
- TokenTTL: time.Hour,
- RefreshTokenTTL: 30 * 24 * time.Hour,
+ Kind: "oauth",
+ Issuer: "https://auth.example.com",
+ IDP: IDPConfig{
+ Provider: "google",
+ ClientID: "test-client",
+ ClientSecret: "test-secret",
+ RedirectURI: "https://test.example.com/callback",
+ },
+ JWTSecret: "test-jwt-secret-must-be-32-bytes-long",
+ EncryptionKey: "test-encryption-key-32-bytes-ok!",
+ AllowedDomains: []string{"example.com"},
+ AllowedOrigins: []string{"https://test.example.com"},
+ TokenTTL: time.Hour,
+ RefreshTokenTTL: 30 * 24 * time.Hour,
},
Sessions: &SessionConfig{
Timeout: 10 * time.Minute,
@@ -184,17 +196,20 @@ func TestValidateConfig_SessionConfig(t *testing.T) {
BaseURL: "https://test.example.com",
Addr: ":8080",
Auth: &OAuthAuthConfig{
- Kind: "oauth",
- Issuer: "https://auth.example.com",
- GoogleClientID: "test-client",
- GoogleClientSecret: "test-secret",
- GoogleRedirectURI: "https://test.example.com/callback",
- JWTSecret: "test-jwt-secret-must-be-32-bytes-long",
- EncryptionKey: "test-encryption-key-32-bytes-ok!",
- AllowedDomains: []string{"example.com"},
- AllowedOrigins: []string{"https://test.example.com"},
- TokenTTL: time.Hour,
- RefreshTokenTTL: 30 * 24 * time.Hour,
+ Kind: "oauth",
+ Issuer: "https://auth.example.com",
+ IDP: IDPConfig{
+ Provider: "google",
+ ClientID: "test-client",
+ ClientSecret: "test-secret",
+ RedirectURI: "https://test.example.com/callback",
+ },
+ JWTSecret: "test-jwt-secret-must-be-32-bytes-long",
+ EncryptionKey: "test-encryption-key-32-bytes-ok!",
+ AllowedDomains: []string{"example.com"},
+ AllowedOrigins: []string{"https://test.example.com"},
+ TokenTTL: time.Hour,
+ RefreshTokenTTL: 30 * 24 * time.Hour,
},
Sessions: &SessionConfig{
Timeout: 0,
diff --git a/internal/config/types.go b/internal/config/types.go
index 515f202..87d714e 100644
--- a/internal/config/types.go
+++ b/internal/config/types.go
@@ -187,6 +187,11 @@ type MCPClientConfig struct {
InlineConfig json.RawMessage `json:"inline,omitempty"`
}
+// IsStdio returns true if this is a stdio-based MCP server
+func (c *MCPClientConfig) IsStdio() bool {
+ return c.TransportType == MCPClientTypeStdio
+}
+
// SessionConfig represents session management configuration
type SessionConfig struct {
Timeout time.Duration
@@ -200,12 +205,39 @@ type AdminConfig struct {
AdminEmails []string `json:"adminEmails"`
}
+// IDPConfig represents identity provider configuration.
+type IDPConfig struct {
+ // Provider type: "google", "azure", "github", or "oidc"
+ Provider string `json:"provider"`
+
+ // OAuth client configuration
+ ClientID string `json:"clientId"`
+ ClientSecret Secret `json:"clientSecret"`
+ RedirectURI string `json:"redirectUri"`
+
+ // For OIDC: discovery URL or manual endpoint configuration
+ DiscoveryURL string `json:"discoveryUrl,omitempty"`
+ AuthorizationURL string `json:"authorizationUrl,omitempty"`
+ TokenURL string `json:"tokenUrl,omitempty"`
+ UserInfoURL string `json:"userInfoUrl,omitempty"`
+
+ // Custom scopes (optional, defaults per provider)
+ Scopes []string `json:"scopes,omitempty"`
+
+ // Azure-specific: tenant ID
+ TenantID string `json:"tenantId,omitempty"`
+
+ // GitHub-specific: allowed organizations
+ AllowedOrgs []string `json:"allowedOrgs,omitempty"`
+}
+
// OAuthAuthConfig represents OAuth 2.0 configuration with resolved values
type OAuthAuthConfig struct {
Kind AuthKind `json:"kind"`
Issuer string `json:"issuer"`
GCPProject string `json:"gcpProject"`
- AllowedDomains []string `json:"allowedDomains"` // For Google OAuth email validation
+ IDP IDPConfig `json:"idp"`
+ AllowedDomains []string `json:"allowedDomains"` // For domain-based access control
AllowedOrigins []string `json:"allowedOrigins"` // For CORS validation
TokenTTL time.Duration `json:"tokenTtl"`
RefreshTokenTTL time.Duration `json:"refreshTokenTtl"`
@@ -213,9 +245,6 @@ type OAuthAuthConfig struct {
Storage string `json:"storage"` // "memory" or "firestore"
FirestoreDatabase string `json:"firestoreDatabase,omitempty"` // Optional: Firestore database name
FirestoreCollection string `json:"firestoreCollection,omitempty"` // Optional: Firestore collection name
- GoogleClientID string `json:"googleClientId"`
- GoogleClientSecret Secret `json:"googleClientSecret"`
- GoogleRedirectURI string `json:"googleRedirectUri"`
JWTSecret Secret `json:"jwtSecret"`
EncryptionKey Secret `json:"encryptionKey"`
// DangerouslyAcceptIssuerAudience allows tokens with just the base issuer as audience
diff --git a/internal/config/unmarshal.go b/internal/config/unmarshal.go
index eb20321..0a03878 100644
--- a/internal/config/unmarshal.go
+++ b/internal/config/unmarshal.go
@@ -118,6 +118,7 @@ func (o *OAuthAuthConfig) UnmarshalJSON(data []byte) error {
Kind AuthKind `json:"kind"`
Issuer json.RawMessage `json:"issuer"`
GCPProject json.RawMessage `json:"gcpProject"`
+ IDP json.RawMessage `json:"idp"`
AllowedDomains []string `json:"allowedDomains"`
AllowedOrigins []string `json:"allowedOrigins"`
TokenTTL string `json:"tokenTtl"`
@@ -126,9 +127,6 @@ func (o *OAuthAuthConfig) UnmarshalJSON(data []byte) error {
Storage string `json:"storage"`
FirestoreDatabase string `json:"firestoreDatabase,omitempty"`
FirestoreCollection string `json:"firestoreCollection,omitempty"`
- GoogleClientID json.RawMessage `json:"googleClientId"`
- GoogleClientSecret json.RawMessage `json:"googleClientSecret"`
- GoogleRedirectURI json.RawMessage `json:"googleRedirectUri"`
JWTSecret json.RawMessage `json:"jwtSecret"`
EncryptionKey json.RawMessage `json:"encryptionKey,omitempty"`
DangerouslyAcceptIssuerAudience bool `json:"dangerouslyAcceptIssuerAudience,omitempty"`
@@ -202,38 +200,13 @@ func (o *OAuthAuthConfig) UnmarshalJSON(data []byte) error {
o.GCPProject = parsed.value
}
- if raw.GoogleClientID != nil {
- parsed, err := ParseConfigValue(raw.GoogleClientID)
+ // Parse IDP config
+ if raw.IDP != nil {
+ idp, err := parseIDPConfig(raw.IDP)
if err != nil {
- return fmt.Errorf("parsing googleClientId: %w", err)
+ return fmt.Errorf("parsing idp: %w", err)
}
- if parsed.needsUserToken {
- return fmt.Errorf("googleClientId cannot be a user token reference")
- }
- o.GoogleClientID = parsed.value
- }
-
- if raw.GoogleRedirectURI != nil {
- parsed, err := ParseConfigValue(raw.GoogleRedirectURI)
- if err != nil {
- return fmt.Errorf("parsing googleRedirectUri: %w", err)
- }
- if parsed.needsUserToken {
- return fmt.Errorf("googleRedirectUri cannot be a user token reference")
- }
- o.GoogleRedirectURI = parsed.value
- }
-
- // Parse secret fields
- if raw.GoogleClientSecret != nil {
- parsed, err := ParseConfigValue(raw.GoogleClientSecret)
- if err != nil {
- return fmt.Errorf("parsing googleClientSecret: %w", err)
- }
- if parsed.needsUserToken {
- return fmt.Errorf("googleClientSecret cannot be a user token reference")
- }
- o.GoogleClientSecret = Secret(parsed.value)
+ o.IDP = *idp
}
if raw.JWTSecret != nil {
@@ -274,6 +247,87 @@ func (o *OAuthAuthConfig) UnmarshalJSON(data []byte) error {
return nil
}
+// parseIDPConfig parses the IDP configuration with env var references
+func parseIDPConfig(data json.RawMessage) (*IDPConfig, error) {
+ type rawIDP struct {
+ Provider string `json:"provider"`
+ ClientID json.RawMessage `json:"clientId"`
+ ClientSecret json.RawMessage `json:"clientSecret"`
+ RedirectURI json.RawMessage `json:"redirectUri"`
+ DiscoveryURL json.RawMessage `json:"discoveryUrl,omitempty"`
+ AuthorizationURL json.RawMessage `json:"authorizationUrl,omitempty"`
+ TokenURL json.RawMessage `json:"tokenUrl,omitempty"`
+ UserInfoURL json.RawMessage `json:"userInfoUrl,omitempty"`
+ Scopes []string `json:"scopes,omitempty"`
+ TenantID json.RawMessage `json:"tenantId,omitempty"`
+ AllowedOrgs []string `json:"allowedOrgs,omitempty"`
+ }
+
+ var raw rawIDP
+ if err := json.Unmarshal(data, &raw); err != nil {
+ return nil, err
+ }
+
+ idp := &IDPConfig{
+ Provider: raw.Provider,
+ Scopes: raw.Scopes,
+ AllowedOrgs: raw.AllowedOrgs,
+ }
+
+ // Helper to parse a field that cannot be a user token reference
+ parseField := func(data json.RawMessage, fieldName string) (string, error) {
+ if data == nil {
+ return "", nil
+ }
+ parsed, err := ParseConfigValue(data)
+ if err != nil {
+ return "", fmt.Errorf("parsing %s: %w", fieldName, err)
+ }
+ if parsed.needsUserToken {
+ return "", fmt.Errorf("%s cannot be a user token reference", fieldName)
+ }
+ return parsed.value, nil
+ }
+
+ var err error
+
+ if idp.ClientID, err = parseField(raw.ClientID, "clientId"); err != nil {
+ return nil, err
+ }
+
+ var clientSecret string
+ if clientSecret, err = parseField(raw.ClientSecret, "clientSecret"); err != nil {
+ return nil, err
+ }
+ idp.ClientSecret = Secret(clientSecret)
+
+ if idp.RedirectURI, err = parseField(raw.RedirectURI, "redirectUri"); err != nil {
+ return nil, err
+ }
+
+ if idp.DiscoveryURL, err = parseField(raw.DiscoveryURL, "discoveryUrl"); err != nil {
+ return nil, err
+ }
+
+ if idp.AuthorizationURL, err = parseField(raw.AuthorizationURL, "authorizationUrl"); err != nil {
+ return nil, err
+ }
+
+ if idp.TokenURL, err = parseField(raw.TokenURL, "tokenUrl"); err != nil {
+ return nil, err
+ }
+
+ if idp.UserInfoURL, err = parseField(raw.UserInfoURL, "userInfoUrl"); err != nil {
+ return nil, err
+ }
+
+ if idp.TenantID, err = parseField(raw.TenantID, "tenantId"); err != nil {
+ return nil, err
+ }
+
+ return idp, nil
+}
+
// UnmarshalJSON implements custom unmarshaling for ProxyConfig
func (p *ProxyConfig) UnmarshalJSON(data []byte) error {
// Use a raw type to parse references
diff --git a/internal/config/unmarshal_test.go b/internal/config/unmarshal_test.go
index dc0682a..bc1db55 100644
--- a/internal/config/unmarshal_test.go
+++ b/internal/config/unmarshal_test.go
@@ -255,9 +255,12 @@ func TestOAuthAuthConfig_UnmarshalJSON(t *testing.T) {
"allowedOrigins": ["https://claude.ai", "https://example.com"],
"tokenTtl": "1h",
"storage": "firestore",
- "googleClientId": "test-client-id",
- "googleClientSecret": {"$env": "CLIENT_SECRET"},
- "googleRedirectUri": "https://example.com/callback",
+ "idp": {
+ "provider": "google",
+ "clientId": "test-client-id",
+ "clientSecret": {"$env": "CLIENT_SECRET"},
+ "redirectUri": "https://example.com/callback"
+ },
"jwtSecret": {"$env": "JWT_SECRET"},
"encryptionKey": {"$env": "ENCRYPTION_KEY"}
}`
@@ -271,8 +274,10 @@ func TestOAuthAuthConfig_UnmarshalJSON(t *testing.T) {
assert.Equal(t, "test-project", config.GCPProject)
assert.Equal(t, []string{"example.com"}, config.AllowedDomains)
assert.Equal(t, []string{"https://claude.ai", "https://example.com"}, config.AllowedOrigins)
- assert.Equal(t, "test-client-id", config.GoogleClientID)
- assert.Equal(t, Secret("test-secret-value"), config.GoogleClientSecret)
+ assert.Equal(t, "google", config.IDP.Provider)
+ assert.Equal(t, "test-client-id", config.IDP.ClientID)
+ assert.Equal(t, Secret("test-secret-value"), config.IDP.ClientSecret)
+ assert.Equal(t, "https://example.com/callback", config.IDP.RedirectURI)
assert.Equal(t, Secret("this-is-a-very-long-jwt-secret-key"), config.JWTSecret)
assert.Equal(t, Secret("exactly-32-bytes-long-encryptkey"), config.EncryptionKey)
}
@@ -406,9 +411,12 @@ func TestProxyConfig_SessionConfigIntegration(t *testing.T) {
"auth": {
"kind": "oauth",
"issuer": "https://auth.example.com",
- "googleClientId": "test-client",
- "googleClientSecret": "test-secret",
- "googleRedirectUri": "https://test.example.com/callback",
+ "idp": {
+ "provider": "google",
+ "clientId": "test-client",
+ "clientSecret": "test-secret",
+ "redirectUri": "https://test.example.com/callback"
+ },
"jwtSecret": "test-jwt-secret-must-be-32-bytes-long",
"encryptionKey": "test-encryption-key-32-bytes-ok!",
"allowedDomains": ["example.com"],
diff --git a/internal/config/validation.go b/internal/config/validation.go
index 40612d1..490a123 100644
--- a/internal/config/validation.go
+++ b/internal/config/validation.go
@@ -146,9 +146,6 @@ func validateAuthStructure(auth map[string]any, result *ValidationResult) {
hint string
}{
{"issuer", ""},
- {"googleClientId", ""},
- {"googleClientSecret", ""},
- {"googleRedirectUri", ""},
{"jwtSecret", "Hint: Must be at least 32 bytes long for HMAC-SHA256"},
{"encryptionKey", "Hint: Must be exactly 32 bytes for AES-256-GCM encryption"},
}
@@ -164,12 +161,18 @@ func validateAuthStructure(auth map[string]any, result *ValidationResult) {
})
}
}
- if domains, ok := auth["allowedDomains"].([]any); !ok || len(domains) == 0 {
+
+ // Validate IDP configuration
+ idp, hasIDP := auth["idp"].(map[string]any)
+ if !hasIDP {
result.Errors = append(result.Errors, ValidationError{
- Path: "proxy.auth.allowedDomains",
- Message: "at least one allowed domain is required for OAuth",
+ Path: "proxy.auth.idp",
+ Message: "idp configuration is required for OAuth",
})
+ } else {
+ validateIDPStructure(idp, result)
}
+
if origins, ok := auth["allowedOrigins"].([]any); !ok || len(origins) == 0 {
result.Errors = append(result.Errors, ValidationError{
Path: "proxy.auth.allowedOrigins",
@@ -184,6 +187,74 @@ func validateAuthStructure(auth map[string]any, result *ValidationResult) {
}
}
+// validateIDPStructure checks identity provider configuration
+func validateIDPStructure(idp map[string]any, result *ValidationResult) {
+ provider, ok := idp["provider"].(string)
+ if !ok {
+ result.Errors = append(result.Errors, ValidationError{
+ Path: "proxy.auth.idp.provider",
+ Message: "provider is required. Options: google, azure, github, oidc",
+ })
+ return
+ }
+
+ // Check required fields for all providers
+ if _, ok := idp["clientId"]; !ok {
+ result.Errors = append(result.Errors, ValidationError{
+ Path: "proxy.auth.idp.clientId",
+ Message: "clientId is required for IDP configuration",
+ })
+ }
+ if _, ok := idp["clientSecret"]; !ok {
+ result.Errors = append(result.Errors, ValidationError{
+ Path: "proxy.auth.idp.clientSecret",
+ Message: "clientSecret is required for IDP configuration",
+ })
+ }
+ if _, ok := idp["redirectUri"]; !ok {
+ result.Errors = append(result.Errors, ValidationError{
+ Path: "proxy.auth.idp.redirectUri",
+ Message: "redirectUri is required for IDP configuration",
+ })
+ }
+
+ // Provider-specific validation
+ switch provider {
+ case "google", "github":
+ // No additional required fields
+ case "azure":
+ if _, ok := idp["tenantId"]; !ok {
+ result.Errors = append(result.Errors, ValidationError{
+ Path: "proxy.auth.idp.tenantId",
+ Message: "tenantId is required for Azure AD provider",
+ })
+ }
+ case "oidc":
+ // Either discoveryUrl or manual endpoints required
+ hasDiscovery := false
+ if _, ok := idp["discoveryUrl"]; ok {
+ hasDiscovery = true
+ }
+ if !hasDiscovery {
+ // Check for manual endpoints
+ requiredEndpoints := []string{"authorizationUrl", "tokenUrl", "userInfoUrl"}
+ for _, endpoint := range requiredEndpoints {
+ if _, ok := idp[endpoint]; !ok {
+ result.Errors = append(result.Errors, ValidationError{
+ Path: "proxy.auth.idp." + endpoint,
+ Message: fmt.Sprintf("%s is required for OIDC provider when discoveryUrl is not provided", endpoint),
+ })
+ }
+ }
+ }
+ default:
+ result.Errors = append(result.Errors, ValidationError{
+ Path: "proxy.auth.idp.provider",
+ Message: fmt.Sprintf("unknown provider '%s' - supported providers: google, azure, github, oidc", provider),
+ })
+ }
+}
+
// validateAdminStructure checks admin configuration structure
func validateAdminStructure(admin map[string]any, result *ValidationResult) {
enabled, ok := admin["enabled"].(bool)
diff --git a/internal/config/validation_test.go b/internal/config/validation_test.go
index 910bb00..a8b5694 100644
--- a/internal/config/validation_test.go
+++ b/internal/config/validation_test.go
@@ -51,9 +51,12 @@ func TestValidateFile(t *testing.T) {
"auth": {
"kind": "oauth",
"issuer": "https://example.com",
- "googleClientId": {"$env": "CLIENT_ID"},
- "googleClientSecret": {"$env": "CLIENT_SECRET"},
- "googleRedirectUri": "https://example.com/callback",
+ "idp": {
+ "provider": "google",
+ "clientId": {"$env": "CLIENT_ID"},
+ "clientSecret": {"$env": "CLIENT_SECRET"},
+ "redirectUri": "https://example.com/callback"
+ },
"jwtSecret": {"$env": "JWT_SECRET"},
"encryptionKey": {"$env": "ENCRYPTION_KEY"},
"allowedDomains": ["example.com"],
@@ -225,9 +228,12 @@ func TestValidateFile(t *testing.T) {
"auth": {
"kind": "oauth",
"issuer": "https://example.com",
- "googleClientId": "id",
- "googleClientSecret": "secret",
- "googleRedirectUri": "https://example.com/callback",
+ "idp": {
+ "provider": "google",
+ "clientId": "id",
+ "clientSecret": "secret",
+ "redirectUri": "https://example.com/callback"
+ },
"jwtSecret": "secret",
"encryptionKey": "key",
"allowedDomains": ["example.com"],
@@ -260,15 +266,12 @@ func TestValidateFile(t *testing.T) {
}`,
wantErrors: []string{
"issuer is required for OAuth",
- "googleClientId is required for OAuth",
- "googleClientSecret is required for OAuth",
- "googleRedirectUri is required for OAuth",
"jwtSecret is required for OAuth. Hint: Must be at least 32 bytes long for HMAC-SHA256",
"encryptionKey is required for OAuth. Hint: Must be exactly 32 bytes for AES-256-GCM encryption",
- "at least one allowed domain is required for OAuth",
+ "idp configuration is required for OAuth",
"at least one allowed origin is required for OAuth (CORS configuration)",
},
- wantErrCount: 8,
+ wantErrCount: 5,
},
{
name: "valid_manual_user_authentication",
@@ -280,9 +283,12 @@ func TestValidateFile(t *testing.T) {
"auth": {
"kind": "oauth",
"issuer": "https://example.com",
- "googleClientId": "id",
- "googleClientSecret": "secret",
- "googleRedirectUri": "https://example.com/callback",
+ "idp": {
+ "provider": "google",
+ "clientId": "id",
+ "clientSecret": "secret",
+ "redirectUri": "https://example.com/callback"
+ },
"jwtSecret": "secret123456789012345678901234567890",
"encryptionKey": "key12345678901234567890123456789",
"allowedDomains": ["example.com"],
@@ -315,9 +321,12 @@ func TestValidateFile(t *testing.T) {
"auth": {
"kind": "oauth",
"issuer": "https://example.com",
- "googleClientId": "id",
- "googleClientSecret": "secret",
- "googleRedirectUri": "https://example.com/callback",
+ "idp": {
+ "provider": "google",
+ "clientId": "id",
+ "clientSecret": "secret",
+ "redirectUri": "https://example.com/callback"
+ },
"jwtSecret": "secret123456789012345678901234567890",
"encryptionKey": "key12345678901234567890123456789",
"allowedDomains": ["example.com"],
@@ -354,9 +363,12 @@ func TestValidateFile(t *testing.T) {
"auth": {
"kind": "oauth",
"issuer": "https://example.com",
- "googleClientId": "id",
- "googleClientSecret": "secret",
- "googleRedirectUri": "https://example.com/callback",
+ "idp": {
+ "provider": "google",
+ "clientId": "id",
+ "clientSecret": "secret",
+ "redirectUri": "https://example.com/callback"
+ },
"jwtSecret": "secret123456789012345678901234567890",
"encryptionKey": "key12345678901234567890123456789",
"allowedDomains": ["example.com"],
@@ -515,9 +527,12 @@ func TestValidateFile_ImprovedErrorMessages(t *testing.T) {
"auth": {
"kind": "oauth",
"issuer": "https://example.com",
- "googleClientId": "id",
- "googleClientSecret": "secret",
- "googleRedirectUri": "https://example.com/callback",
+ "idp": {
+ "provider": "google",
+ "clientId": "id",
+ "clientSecret": "secret",
+ "redirectUri": "https://example.com/callback"
+ },
"jwtSecret": "secret123456789012345678901234567890",
"encryptionKey": "key12345678901234567890123456789",
"allowedDomains": ["example.com"],
diff --git a/internal/cookie/cookie.go b/internal/cookie/cookie.go
index 3be28d4..0c38ea8 100644
--- a/internal/cookie/cookie.go
+++ b/internal/cookie/cookie.go
@@ -4,7 +4,7 @@ import (
"net/http"
"time"
- "github.com/dgellow/mcp-front/internal/envutil"
+ "github.com/dgellow/mcp-front/internal/config"
"github.com/dgellow/mcp-front/internal/log"
)
@@ -16,7 +16,7 @@ const (
// SetSession sets a session cookie with appropriate security settings
func SetSession(w http.ResponseWriter, value string, maxAge time.Duration) {
- secure := !envutil.IsDev()
+ secure := !config.IsDev()
http.SetCookie(w, &http.Cookie{
Name: SessionCookie,
Value: value,
@@ -41,7 +41,7 @@ func SetCSRF(w http.ResponseWriter, value string) {
Value: value,
Path: "/",
HttpOnly: false, // CSRF tokens need to be readable by JavaScript
- Secure: !envutil.IsDev(),
+ Secure: !config.IsDev(),
SameSite: http.SameSiteStrictMode,
MaxAge: int((24 * time.Hour).Seconds()), // 24 hours
})
diff --git a/internal/crypto/csrf.go b/internal/crypto/csrf.go
new file mode 100644
index 0000000..53a9dbf
--- /dev/null
+++ b/internal/crypto/csrf.go
@@ -0,0 +1,61 @@
+package crypto
+
+import (
+ "fmt"
+ "strconv"
+ "strings"
+ "time"
+)
+
+// CSRFProtection provides stateless HMAC-based CSRF token generation and validation.
+// Tokens are self-contained: nonce:timestamp:signature, with configurable expiry.
+type CSRFProtection struct {
+ signingKey []byte
+ ttl time.Duration
+}
+
+// NewCSRFProtection creates a new CSRF protection instance
+func NewCSRFProtection(signingKey []byte, ttl time.Duration) CSRFProtection {
+ return CSRFProtection{
+ signingKey: signingKey,
+ ttl: ttl,
+ }
+}
+
+// Generate creates a new CSRF token
+func (c *CSRFProtection) Generate() (string, error) {
+ nonce := GenerateSecureToken()
+ if nonce == "" {
+ return "", fmt.Errorf("failed to generate nonce")
+ }
+
+ timestamp := strconv.FormatInt(time.Now().Unix(), 10)
+ data := nonce + ":" + timestamp
+ signature := SignData(data, c.signingKey)
+
+ return fmt.Sprintf("%s:%s:%s", nonce, timestamp, signature), nil
+}
+
+// Validate checks if a CSRF token is valid and not expired
+func (c *CSRFProtection) Validate(token string) bool {
+ parts := strings.SplitN(token, ":", 3)
+ if len(parts) != 3 {
+ return false
+ }
+
+ nonce := parts[0]
+ timestampStr := parts[1]
+ signature := parts[2]
+
+ timestamp, err := strconv.ParseInt(timestampStr, 10, 64)
+ if err != nil {
+ return false
+ }
+
+ if time.Since(time.Unix(timestamp, 0)) > c.ttl {
+ return false
+ }
+
+ data := nonce + ":" + timestampStr
+ return ValidateSignedData(data, signature, c.signingKey)
+}
diff --git a/internal/googleauth/google.go b/internal/googleauth/google.go
deleted file mode 100644
index 1667a49..0000000
--- a/internal/googleauth/google.go
+++ /dev/null
@@ -1,121 +0,0 @@
-package googleauth
-
-import (
- "context"
- "encoding/json"
- "fmt"
- "net/http"
- "os"
- "slices"
- "strings"
-
- "github.com/dgellow/mcp-front/internal/config"
- emailutil "github.com/dgellow/mcp-front/internal/emailutil"
- "golang.org/x/oauth2"
- "golang.org/x/oauth2/google"
-)
-
-// UserInfo represents Google user information
-type UserInfo struct {
- Email string `json:"email"`
- HostedDomain string `json:"hd"`
- Name string `json:"name"`
- Picture string `json:"picture"`
- VerifiedEmail bool `json:"verified_email"`
-}
-
-// GoogleAuthURL generates a Google OAuth authorization URL
-func GoogleAuthURL(oauthConfig config.OAuthAuthConfig, state string) string {
- googleOAuth := newGoogleOAuth2Config(oauthConfig)
- return googleOAuth.AuthCodeURL(state,
- oauth2.AccessTypeOffline,
- oauth2.ApprovalForce,
- )
-}
-
-// ExchangeCodeForToken exchanges the authorization code for a token
-func ExchangeCodeForToken(ctx context.Context, oauthConfig config.OAuthAuthConfig, code string) (*oauth2.Token, error) {
- googleOAuth := newGoogleOAuth2Config(oauthConfig)
- return googleOAuth.Exchange(ctx, code)
-}
-
-// ValidateUser validates the Google OAuth token and checks domain membership
-func ValidateUser(ctx context.Context, oauthConfig config.OAuthAuthConfig, token *oauth2.Token) (UserInfo, error) {
- googleOAuth := newGoogleOAuth2Config(oauthConfig)
- client := googleOAuth.Client(ctx, token)
- userInfoURL := "https://www.googleapis.com/oauth2/v2/userinfo"
- if customURL := os.Getenv("GOOGLE_USERINFO_URL"); customURL != "" {
- userInfoURL = customURL
- }
- resp, err := client.Get(userInfoURL)
- if err != nil {
- return UserInfo{}, fmt.Errorf("failed to get user info: %w", err)
- }
- defer resp.Body.Close()
-
- if resp.StatusCode != http.StatusOK {
- return UserInfo{}, fmt.Errorf("failed to get user info: status %d", resp.StatusCode)
- }
-
- var userInfo UserInfo
- if err := json.NewDecoder(resp.Body).Decode(&userInfo); err != nil {
- return UserInfo{}, fmt.Errorf("failed to decode user info: %w", err)
- }
-
- // Validate domain if configured
- if len(oauthConfig.AllowedDomains) > 0 {
- userDomain := emailutil.ExtractDomain(userInfo.Email)
- if !slices.Contains(oauthConfig.AllowedDomains, userDomain) {
- return UserInfo{}, fmt.Errorf("domain '%s' is not allowed. Contact your administrator", userDomain)
- }
- }
-
- return userInfo, nil
-}
-
-// ParseClientRequest parses MCP client registration metadata
-func ParseClientRequest(metadata map[string]any) ([]string, []string, error) {
- // Extract redirect URIs
- redirectURIs := []string{}
- if uris, ok := metadata["redirect_uris"].([]any); ok {
- for _, uri := range uris {
- if uriStr, ok := uri.(string); ok {
- redirectURIs = append(redirectURIs, uriStr)
- }
- }
- }
-
- if len(redirectURIs) == 0 {
- return nil, nil, fmt.Errorf("no valid redirect URIs provided")
- }
-
- // Extract scopes, default to read/write if not provided
- scopes := []string{"read", "write"} // Default MCP scopes
- if clientScopes, ok := metadata["scope"].(string); ok {
- if strings.TrimSpace(clientScopes) != "" {
- scopes = strings.Fields(clientScopes)
- }
- }
-
- return redirectURIs, scopes, nil
-}
-
-// newGoogleOAuth2Config creates the OAuth2 config from our Config
-func newGoogleOAuth2Config(oauthConfig config.OAuthAuthConfig) oauth2.Config {
- // Use custom OAuth endpoints if provided (for testing)
- endpoint := google.Endpoint
- if authURL := os.Getenv("GOOGLE_OAUTH_AUTH_URL"); authURL != "" {
- endpoint.AuthURL = authURL
- }
- if tokenURL := os.Getenv("GOOGLE_OAUTH_TOKEN_URL"); tokenURL != "" {
- endpoint.TokenURL = tokenURL
- }
-
- return oauth2.Config{
- ClientID: oauthConfig.GoogleClientID,
- ClientSecret: string(oauthConfig.GoogleClientSecret),
- RedirectURL: oauthConfig.GoogleRedirectURI,
- Scopes: []string{"openid", "profile", "email"},
- Endpoint: endpoint,
- }
-}
diff --git a/internal/googleauth/google_test.go b/internal/googleauth/google_test.go
deleted file mode 100644
index 3f72d4a..0000000
--- a/internal/googleauth/google_test.go
+++ /dev/null
@@ -1,240 +0,0 @@
-package googleauth
-
-import (
- "context"
- "encoding/json"
- "net/http"
- "net/http/httptest"
- "testing"
- "time"
-
- "github.com/dgellow/mcp-front/internal/config"
- "github.com/stretchr/testify/assert"
- "github.com/stretchr/testify/require"
- "golang.org/x/oauth2"
-)
-
-func TestGoogleAuthURL(t *testing.T) {
- oauthConfig := config.OAuthAuthConfig{
- GoogleClientID: "test-client-id",
- GoogleClientSecret: config.Secret("test-client-secret"),
- GoogleRedirectURI: "https://test.example.com/oauth/callback",
- }
-
- state := "test-state-parameter"
- authURL := GoogleAuthURL(oauthConfig, state)
-
- // Verify URL structure
- assert.Contains(t, authURL, "https://accounts.google.com/o/oauth2/auth")
- assert.Contains(t, authURL, "client_id=test-client-id")
- assert.Contains(t, authURL, "redirect_uri=https%3A%2F%2Ftest.example.com%2Foauth%2Fcallback")
- assert.Contains(t, authURL, "state=test-state-parameter")
- assert.Contains(t, authURL, "access_type=offline")
- assert.Contains(t, authURL, "prompt=consent")
- assert.Contains(t, authURL, "scope=openid+profile+email")
-}
-
-func TestExchangeCodeForToken(t *testing.T) {
- // Create a mock OAuth token endpoint
- tokenServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
- assert.Equal(t, "POST", r.Method)
- assert.Equal(t, "/token", r.URL.Path)
-
- // Parse form data
- err := r.ParseForm()
- require.NoError(t, err)
-
- assert.Equal(t, "test-code", r.FormValue("code"))
- assert.Equal(t, "test-client-id", r.FormValue("client_id"))
- assert.Equal(t, "test-client-secret", r.FormValue("client_secret"))
- assert.Equal(t, "https://test.example.com/oauth/callback", r.FormValue("redirect_uri"))
- assert.Equal(t, "authorization_code", r.FormValue("grant_type"))
-
- // Return mock token response
- response := map[string]any{
- "access_token": "mock-access-token",
- "refresh_token": "mock-refresh-token",
- "token_type": "Bearer",
- "expires_in": 3600,
- }
- w.Header().Set("Content-Type", "application/json")
- if err := json.NewEncoder(w).Encode(response); err != nil {
- t.Errorf("failed to encode response: %v", err)
- }
- }))
- defer tokenServer.Close()
-
- // Set environment variable for custom token URL
- t.Setenv("GOOGLE_OAUTH_TOKEN_URL", tokenServer.URL+"/token")
-
- oauthConfig := config.OAuthAuthConfig{
- GoogleClientID: "test-client-id",
- GoogleClientSecret: config.Secret("test-client-secret"),
- GoogleRedirectURI: "https://test.example.com/oauth/callback",
- }
-
- token, err := ExchangeCodeForToken(context.Background(), oauthConfig, "test-code")
- require.NoError(t, err)
- require.NotNil(t, token)
-
- assert.Equal(t, "mock-access-token", token.AccessToken)
- assert.Equal(t, "mock-refresh-token", token.RefreshToken)
- assert.Equal(t, "Bearer", token.TokenType)
- assert.WithinDuration(t, time.Now().Add(3600*time.Second), token.Expiry, 5*time.Second)
-}
-
-func TestValidateUser(t *testing.T) {
- t.Run("valid user in allowed domain", func(t *testing.T) {
- // Create mock user info endpoint
- userInfoServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
- assert.Equal(t, "GET", r.Method)
- assert.Contains(t, r.Header.Get("Authorization"), "Bearer mock-token")
-
- response := UserInfo{
- Email: "user@example.com",
- HostedDomain: "example.com",
- Name: "Test User",
- Picture: "https://example.com/pic.jpg",
- VerifiedEmail: true,
- }
- w.Header().Set("Content-Type", "application/json")
- if err := json.NewEncoder(w).Encode(response); err != nil {
- t.Errorf("failed to encode response: %v", err)
- }
- }))
- defer userInfoServer.Close()
-
- t.Setenv("GOOGLE_USERINFO_URL", userInfoServer.URL)
-
- oauthConfig := config.OAuthAuthConfig{
- AllowedDomains: []string{"example.com", "test.com"},
- }
-
- token := &oauth2.Token{AccessToken: "mock-token"}
- userInfo, err := ValidateUser(context.Background(), oauthConfig, token)
-
- require.NoError(t, err)
- assert.Equal(t, "user@example.com", userInfo.Email)
- assert.Equal(t, "example.com", userInfo.HostedDomain)
- assert.Equal(t, "Test User", userInfo.Name)
- assert.True(t, userInfo.VerifiedEmail)
- })
-
- t.Run("user from disallowed domain", func(t *testing.T) {
- userInfoServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
- response := UserInfo{
- Email: "user@unauthorized.com",
- HostedDomain: "unauthorized.com",
- VerifiedEmail: true,
- }
- w.Header().Set("Content-Type", "application/json")
- if err := json.NewEncoder(w).Encode(response); err != nil {
- t.Errorf("failed to encode response: %v", err)
- }
- }))
- defer userInfoServer.Close()
-
- t.Setenv("GOOGLE_USERINFO_URL", userInfoServer.URL)
-
- oauthConfig := config.OAuthAuthConfig{
- AllowedDomains: []string{"example.com", "test.com"},
- }
-
- token := &oauth2.Token{AccessToken: "mock-token"}
- _, err := ValidateUser(context.Background(), oauthConfig, token)
-
- require.Error(t, err)
- assert.Contains(t, err.Error(), "domain 'unauthorized.com' is not allowed")
- })
-
- t.Run("no domain restrictions", func(t *testing.T) {
- userInfoServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
- response := UserInfo{
- Email: "user@anydomain.com",
- VerifiedEmail: true,
- }
- w.Header().Set("Content-Type", "application/json")
- if err := json.NewEncoder(w).Encode(response); err != nil {
- t.Errorf("failed to encode response: %v", err)
- }
- }))
- defer userInfoServer.Close()
-
- t.Setenv("GOOGLE_USERINFO_URL", userInfoServer.URL)
-
- oauthConfig := config.OAuthAuthConfig{
- AllowedDomains: []string{}, // Empty means allow all
- }
-
- token := &oauth2.Token{AccessToken: "mock-token"}
- userInfo, err := ValidateUser(context.Background(), oauthConfig, token)
-
- require.NoError(t, err)
- assert.Equal(t, "user@anydomain.com", userInfo.Email)
- })
-}
-
-func TestParseClientRequest(t *testing.T) {
- t.Run("valid request with redirect URIs and scopes", func(t *testing.T) {
- metadata := map[string]any{
- "redirect_uris": []any{
- "https://example.com/callback1",
- "https://example.com/callback2",
- },
- "scope": "read write admin",
- }
-
- redirectURIs, scopes, err := ParseClientRequest(metadata)
- require.NoError(t, err)
-
- assert.Equal(t, []string{
- "https://example.com/callback1",
- "https://example.com/callback2",
- }, redirectURIs)
- assert.Equal(t, []string{"read", "write", "admin"}, scopes)
- })
-
- t.Run("default scopes when not provided", func(t *testing.T) {
- metadata := map[string]any{
- "redirect_uris": []any{"https://example.com/callback"},
- }
-
- redirectURIs, scopes, err := ParseClientRequest(metadata)
- require.NoError(t, err)
-
- assert.Equal(t, []string{"https://example.com/callback"}, redirectURIs)
- assert.Equal(t, []string{"read", "write"}, scopes, "Should default to read/write")
- })
-
- t.Run("empty scope string uses default", func(t *testing.T) {
- metadata := map[string]any{
- "redirect_uris": []any{"https://example.com/callback"},
- "scope": " ", // Whitespace only
- }
-
- _, scopes, err := ParseClientRequest(metadata)
- require.NoError(t, err)
-
- assert.Equal(t, []string{"read", "write"}, scopes)
- })
-
- t.Run("missing redirect URIs", func(t *testing.T) {
- metadata := map[string]any{
- "scope": "read write",
- }
-
- _, _, err := ParseClientRequest(metadata)
- assert.Error(t, err)
- assert.Contains(t, err.Error(), "no valid redirect URIs")
- })
-
- t.Run("empty redirect URIs array", func(t *testing.T) {
- metadata := map[string]any{
- "redirect_uris": []any{},
- }
-
- _, _, err := ParseClientRequest(metadata)
- assert.Error(t, err)
- assert.Contains(t, err.Error(), "no valid redirect URIs")
- })
-}
diff --git a/internal/idp/azure.go b/internal/idp/azure.go
new file mode 100644
index 0000000..6694740
--- /dev/null
+++ b/internal/idp/azure.go
@@ -0,0 +1,34 @@
+package idp
+
+import "fmt"
+
+// NewAzureProvider creates an Azure AD provider using OIDC discovery.
+// Azure AD is OIDC-compliant, so we use the generic OIDC provider with Azure's tenant-specific discovery URL.
+// Optional direct endpoint overrides (authorizationURL, tokenURL, userInfoURL) skip discovery
+// when all three are provided — useful for testing.
+func NewAzureProvider(tenantID, clientID, clientSecret, redirectURI, authorizationURL, tokenURL, userInfoURL string) (*OIDCProvider, error) {
+ if tenantID == "" {
+ return nil, fmt.Errorf("tenantId is required for Azure AD")
+ }
+
+ cfg := OIDCConfig{
+ ProviderType: "azure",
+ ClientID: clientID,
+ ClientSecret: clientSecret,
+ RedirectURI: redirectURI,
+ Scopes: []string{"openid", "email", "profile"},
+ }
+
+ if authorizationURL != "" && tokenURL != "" && userInfoURL != "" {
+ cfg.AuthorizationURL = authorizationURL
+ cfg.TokenURL = tokenURL
+ cfg.UserInfoURL = userInfoURL
+ } else {
+ cfg.DiscoveryURL = fmt.Sprintf(
+ "https://login.microsoftonline.com/%s/v2.0/.well-known/openid-configuration",
+ tenantID,
+ )
+ }
+
+ return NewOIDCProvider(cfg)
+}
diff --git a/internal/idp/azure_test.go b/internal/idp/azure_test.go
new file mode 100644
index 0000000..453a23a
--- /dev/null
+++ b/internal/idp/azure_test.go
@@ -0,0 +1,15 @@
+package idp
+
+import (
+ "testing"
+
+ "github.com/stretchr/testify/assert"
+ "github.com/stretchr/testify/require"
+)
+
+func TestNewAzureProvider_MissingTenantID(t *testing.T) {
+ _, err := NewAzureProvider("", "client-id", "client-secret", "https://example.com/callback", "", "", "")
+
+ require.Error(t, err)
+ assert.Contains(t, err.Error(), "tenantId is required")
+}
diff --git a/internal/idp/factory.go b/internal/idp/factory.go
new file mode 100644
index 0000000..20e7e4c
--- /dev/null
+++ b/internal/idp/factory.go
@@ -0,0 +1,59 @@
+package idp
+
+import (
+ "fmt"
+
+ "github.com/dgellow/mcp-front/internal/config"
+)
+
+// NewProvider creates a Provider based on the IDPConfig.
+func NewProvider(cfg config.IDPConfig) (Provider, error) {
+ switch cfg.Provider {
+ case "google":
+ return NewGoogleProvider(
+ cfg.ClientID,
+ string(cfg.ClientSecret),
+ cfg.RedirectURI,
+ cfg.AuthorizationURL,
+ cfg.TokenURL,
+ cfg.UserInfoURL,
+ ), nil
+
+ case "azure":
+ return NewAzureProvider(
+ cfg.TenantID,
+ cfg.ClientID,
+ string(cfg.ClientSecret),
+ cfg.RedirectURI,
+ cfg.AuthorizationURL,
+ cfg.TokenURL,
+ cfg.UserInfoURL,
+ )
+
+ case "github":
+ return NewGitHubProvider(
+ cfg.ClientID,
+ string(cfg.ClientSecret),
+ cfg.RedirectURI,
+ cfg.AuthorizationURL,
+ cfg.TokenURL,
+ cfg.UserInfoURL,
+ ), nil
+
+ case "oidc":
+ return NewOIDCProvider(OIDCConfig{
+ ProviderType: "oidc",
+ DiscoveryURL: cfg.DiscoveryURL,
+ AuthorizationURL: cfg.AuthorizationURL,
+ TokenURL: cfg.TokenURL,
+ UserInfoURL: cfg.UserInfoURL,
+ ClientID: cfg.ClientID,
+ ClientSecret: string(cfg.ClientSecret),
+ RedirectURI: cfg.RedirectURI,
+ Scopes: cfg.Scopes,
+ })
+
+ default:
+ return nil, fmt.Errorf("unknown provider type: %s", cfg.Provider)
+ }
+}
diff --git a/internal/idp/factory_test.go b/internal/idp/factory_test.go
new file mode 100644
index 0000000..9b38681
--- /dev/null
+++ b/internal/idp/factory_test.go
@@ -0,0 +1,105 @@
+package idp
+
+import (
+ "testing"
+
+ "github.com/dgellow/mcp-front/internal/config"
+ "github.com/stretchr/testify/assert"
+ "github.com/stretchr/testify/require"
+)
+
+func TestNewProvider(t *testing.T) {
+ tests := []struct {
+ name string
+ cfg config.IDPConfig
+ wantType string
+ wantErr bool
+ errContains string
+ skipCreation bool
+ }{
+ {
+ name: "google_provider",
+ cfg: config.IDPConfig{
+ Provider: "google",
+ ClientID: "test-client-id",
+ ClientSecret: config.Secret("test-client-secret"),
+ RedirectURI: "https://example.com/callback",
+ },
+ wantType: "google",
+ wantErr: false,
+ },
+ {
+ name: "github_provider",
+ cfg: config.IDPConfig{
+ Provider: "github",
+ ClientID: "test-client-id",
+ ClientSecret: config.Secret("test-client-secret"),
+ RedirectURI: "https://example.com/callback",
+ },
+ wantType: "github",
+ wantErr: false,
+ },
+ {
+ name: "azure_provider_missing_tenant",
+ cfg: config.IDPConfig{
+ Provider: "azure",
+ ClientID: "test-client-id",
+ ClientSecret: config.Secret("test-client-secret"),
+ RedirectURI: "https://example.com/callback",
+ },
+ wantErr: true,
+ errContains: "tenantId is required",
+ },
+ {
+ name: "oidc_provider_missing_endpoints",
+ cfg: config.IDPConfig{
+ Provider: "oidc",
+ ClientID: "test-client-id",
+ ClientSecret: config.Secret("test-client-secret"),
+ RedirectURI: "https://example.com/callback",
+ },
+ wantErr: true,
+ errContains: "discoveryUrl or all endpoints",
+ },
+ {
+ name: "oidc_provider_with_direct_endpoints",
+ cfg: config.IDPConfig{
+ Provider: "oidc",
+ ClientID: "test-client-id",
+ ClientSecret: config.Secret("test-client-secret"),
+ RedirectURI: "https://example.com/callback",
+ AuthorizationURL: "https://idp.example.com/authorize",
+ TokenURL: "https://idp.example.com/token",
+ UserInfoURL: "https://idp.example.com/userinfo",
+ },
+ wantType: "oidc",
+ wantErr: false,
+ },
+ {
+ name: "unknown_provider",
+ cfg: config.IDPConfig{
+ Provider: "unknown",
+ },
+ wantErr: true,
+ errContains: "unknown provider type",
+ },
+ }
+
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ provider, err := NewProvider(tt.cfg)
+
+ if tt.wantErr {
+ require.Error(t, err)
+ if tt.errContains != "" {
+ assert.Contains(t, err.Error(), tt.errContains)
+ }
+ return
+ }
+
+ require.NoError(t, err)
+ require.NotNil(t, provider)
+ assert.Equal(t, tt.wantType, provider.Type())
+ })
+ }
+}
diff --git a/internal/idp/github.go b/internal/idp/github.go
new file mode 100644
index 0000000..b13ec04
--- /dev/null
+++ b/internal/idp/github.go
@@ -0,0 +1,201 @@
+package idp
+
+import (
+ "context"
+ "encoding/json"
+ "fmt"
+ "net/http"
+
+ emailutil "github.com/dgellow/mcp-front/internal/emailutil"
+ "golang.org/x/oauth2"
+ "golang.org/x/oauth2/github"
+)
+
+// GitHubProvider implements the Provider interface for GitHub OAuth.
+// GitHub uses OAuth 2.0 (not OIDC) and has its own API for user info and org membership.
+type GitHubProvider struct {
+ config oauth2.Config
+ apiBaseURL string // defaults to https://api.github.com, can be overridden for testing
+}
+
+// githubUserResponse represents GitHub's user API response.
+type githubUserResponse struct {
+ ID int64 `json:"id"`
+ Login string `json:"login"`
+ Email string `json:"email"`
+ Name string `json:"name"`
+ AvatarURL string `json:"avatar_url"`
+}
+
+// githubEmailResponse represents an email from GitHub's emails API.
+type githubEmailResponse struct {
+ Email string `json:"email"`
+ Primary bool `json:"primary"`
+ Verified bool `json:"verified"`
+}
+
+// githubOrgResponse represents an org from GitHub's orgs API.
+type githubOrgResponse struct {
+ Login string `json:"login"`
+}
+
+// NewGitHubProvider creates a new GitHub OAuth provider.
+// Optional endpoint overrides (authorizationURL, tokenURL, apiBaseURL) allow
+// pointing at a non-GitHub server for testing.
+func NewGitHubProvider(clientID, clientSecret, redirectURI, authorizationURL, tokenURL, apiBaseURL string) *GitHubProvider {
+ endpoint := github.Endpoint
+ if authorizationURL != "" {
+ endpoint.AuthURL = authorizationURL
+ }
+ if tokenURL != "" {
+ endpoint.TokenURL = tokenURL
+ }
+
+ apiBase := "https://api.github.com"
+ if apiBaseURL != "" {
+ apiBase = apiBaseURL
+ }
+
+ return &GitHubProvider{
+ config: oauth2.Config{
+ ClientID: clientID,
+ ClientSecret: clientSecret,
+ RedirectURL: redirectURI,
+ Scopes: []string{"user:email", "read:org"},
+ Endpoint: endpoint,
+ },
+ apiBaseURL: apiBase,
+ }
+}
+
+// Type returns the provider type.
+func (p *GitHubProvider) Type() string {
+ return "github"
+}
+
+// AuthURL generates the authorization URL.
+func (p *GitHubProvider) AuthURL(state string) string {
+ return p.config.AuthCodeURL(state)
+}
+
+// ExchangeCode exchanges an authorization code for tokens.
+func (p *GitHubProvider) ExchangeCode(ctx context.Context, code string) (*oauth2.Token, error) {
+ return p.config.Exchange(ctx, code)
+}
+
+// UserInfo fetches user identity from GitHub's API.
+// Always fetches organizations so the authorization layer can check membership.
+func (p *GitHubProvider) UserInfo(ctx context.Context, token *oauth2.Token) (*Identity, error) {
+ client := p.config.Client(ctx, token)
+
+ user, err := p.fetchUser(client)
+ if err != nil {
+ return nil, err
+ }
+
+ // Fetch primary email if not in profile
+ // GitHub only shows verified emails in user profile, so if email is present it's verified
+ email := user.Email
+ emailVerified := email != ""
+ if email == "" {
+ primaryEmail, verified, err := p.fetchPrimaryEmail(client)
+ if err != nil {
+ return nil, fmt.Errorf("failed to get user email: %w", err)
+ }
+ email = primaryEmail
+ emailVerified = verified
+ }
+
+ domain := emailutil.ExtractDomain(email)
+
+ orgs, err := p.fetchOrganizations(client)
+ if err != nil {
+ return nil, fmt.Errorf("failed to get user organizations: %w", err)
+ }
+
+ return &Identity{
+ ProviderType: "github",
+ Subject: fmt.Sprintf("%d", user.ID),
+ Email: email,
+ EmailVerified: emailVerified,
+ Name: user.Name,
+ Picture: user.AvatarURL,
+ Domain: domain,
+ Organizations: orgs,
+ }, nil
+}
+
+func (p *GitHubProvider) fetchUser(client *http.Client) (*githubUserResponse, error) {
+ resp, err := client.Get(p.apiBaseURL + "/user")
+ if err != nil {
+ return nil, fmt.Errorf("failed to get user: %w", err)
+ }
+ defer resp.Body.Close()
+
+ if resp.StatusCode != http.StatusOK {
+ return nil, fmt.Errorf("failed to get user: status %d", resp.StatusCode)
+ }
+
+ var user githubUserResponse
+ if err := json.NewDecoder(resp.Body).Decode(&user); err != nil {
+ return nil, fmt.Errorf("failed to decode user: %w", err)
+ }
+
+ return &user, nil
+}
+
+func (p *GitHubProvider) fetchPrimaryEmail(client *http.Client) (string, bool, error) {
+ resp, err := client.Get(p.apiBaseURL + "/user/emails")
+ if err != nil {
+ return "", false, fmt.Errorf("failed to get emails: %w", err)
+ }
+ defer resp.Body.Close()
+
+ if resp.StatusCode != http.StatusOK {
+ return "", false, fmt.Errorf("failed to get emails: status %d", resp.StatusCode)
+ }
+
+ var emails []githubEmailResponse
+ if err := json.NewDecoder(resp.Body).Decode(&emails); err != nil {
+ return "", false, fmt.Errorf("failed to decode emails: %w", err)
+ }
+
+ for _, email := range emails {
+ if email.Primary && email.Verified {
+ return email.Email, true, nil
+ }
+ }
+
+ // Fallback to first verified email
+ for _, email := range emails {
+ if email.Verified {
+ return email.Email, true, nil
+ }
+ }
+
+ return "", false, fmt.Errorf("no verified email found")
+}
+
+func (p *GitHubProvider) fetchOrganizations(client *http.Client) ([]string, error) {
+ resp, err := client.Get(p.apiBaseURL + "/user/orgs")
+ if err != nil {
+ return nil, fmt.Errorf("failed to get organizations: %w", err)
+ }
+ defer resp.Body.Close()
+
+ if resp.StatusCode != http.StatusOK {
+ return nil, fmt.Errorf("failed to get organizations: status %d", resp.StatusCode)
+ }
+
+ var orgs []githubOrgResponse
+ if err := json.NewDecoder(resp.Body).Decode(&orgs); err != nil {
+ return nil, fmt.Errorf("failed to decode organizations: %w", err)
+ }
+
+ orgNames := make([]string, len(orgs))
+ for i, org := range orgs {
+ orgNames[i] = org.Login
+ }
+
+ return orgNames, nil
+}
diff --git a/internal/idp/github_test.go b/internal/idp/github_test.go
new file mode 100644
index 0000000..ccf0c87
--- /dev/null
+++ b/internal/idp/github_test.go
@@ -0,0 +1,224 @@
+package idp
+
+import (
+ "context"
+ "encoding/json"
+ "net/http"
+ "net/http/httptest"
+ "testing"
+
+ "github.com/stretchr/testify/assert"
+ "github.com/stretchr/testify/require"
+ "golang.org/x/oauth2"
+)
+
+func TestGitHubProvider_Type(t *testing.T) {
+ provider := NewGitHubProvider("client-id", "client-secret", "https://example.com/callback", "", "", "")
+ assert.Equal(t, "github", provider.Type())
+}
+
+func TestGitHubProvider_AuthURL(t *testing.T) {
+ provider := NewGitHubProvider("client-id", "client-secret", "https://example.com/callback", "", "", "")
+
+ authURL := provider.AuthURL("test-state")
+
+ assert.Contains(t, authURL, "github.com")
+ assert.Contains(t, authURL, "state=test-state")
+ assert.Contains(t, authURL, "client_id=client-id")
+}
+
+func TestGitHubProvider_UserInfo(t *testing.T) {
+ tests := []struct {
+ name string
+ userResp githubUserResponse
+ emailsResp []githubEmailResponse
+ orgsResp []githubOrgResponse
+ expectedEmail string
+ expectedEmailVerified bool
+ expectedDomain string
+ expectedOrgs []string
+ }{
+ {
+ name: "user_with_public_email",
+ userResp: githubUserResponse{
+ ID: 12345,
+ Login: "testuser",
+ Email: "user@company.com",
+ Name: "Test User",
+ AvatarURL: "https://github.com/avatar.jpg",
+ },
+ orgsResp: []githubOrgResponse{{Login: "my-org"}},
+ expectedEmail: "user@company.com",
+ expectedEmailVerified: true,
+ expectedDomain: "company.com",
+ expectedOrgs: []string{"my-org"},
+ },
+ {
+ name: "user_without_public_email_fetches_from_api",
+ userResp: githubUserResponse{
+ ID: 12345,
+ Login: "testuser",
+ Name: "Test User",
+ },
+ emailsResp: []githubEmailResponse{
+ {Email: "secondary@other.com", Primary: false, Verified: true},
+ {Email: "primary@company.com", Primary: true, Verified: true},
+ },
+ orgsResp: []githubOrgResponse{},
+ expectedEmail: "primary@company.com",
+ expectedEmailVerified: true,
+ expectedDomain: "company.com",
+ expectedOrgs: []string{},
+ },
+ {
+ name: "user_with_unverified_primary_falls_back_to_verified",
+ userResp: githubUserResponse{
+ ID: 12345,
+ Login: "testuser",
+ },
+ emailsResp: []githubEmailResponse{
+ {Email: "primary@company.com", Primary: true, Verified: false},
+ {Email: "verified@company.com", Primary: false, Verified: true},
+ },
+ orgsResp: []githubOrgResponse{},
+ expectedEmail: "verified@company.com",
+ expectedEmailVerified: true,
+ expectedDomain: "company.com",
+ expectedOrgs: []string{},
+ },
+ {
+ name: "orgs_always_populated",
+ userResp: githubUserResponse{
+ ID: 12345,
+ Login: "testuser",
+ Email: "user@gmail.com",
+ },
+ orgsResp: []githubOrgResponse{{Login: "org-a"}, {Login: "org-b"}},
+ expectedEmail: "user@gmail.com",
+ expectedEmailVerified: true,
+ expectedDomain: "gmail.com",
+ expectedOrgs: []string{"org-a", "org-b"},
+ },
+ }
+
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+ w.Header().Set("Content-Type", "application/json")
+
+ switch r.URL.Path {
+ case "/user":
+ err := json.NewEncoder(w).Encode(tt.userResp)
+ require.NoError(t, err)
+ case "/user/emails":
+ err := json.NewEncoder(w).Encode(tt.emailsResp)
+ require.NoError(t, err)
+ case "/user/orgs":
+ err := json.NewEncoder(w).Encode(tt.orgsResp)
+ require.NoError(t, err)
+ default:
+ w.WriteHeader(http.StatusNotFound)
+ }
+ }))
+ defer server.Close()
+
+ provider := &GitHubProvider{
+ config: oauth2.Config{
+ ClientID: "test-client",
+ ClientSecret: "test-secret",
+ RedirectURL: "https://example.com/callback",
+ Scopes: []string{"user:email", "read:org"},
+ Endpoint: oauth2.Endpoint{
+ AuthURL: server.URL + "/authorize",
+ TokenURL: server.URL + "/token",
+ },
+ },
+ apiBaseURL: server.URL,
+ }
+
+ token := &oauth2.Token{AccessToken: "test-token"}
+ identity, err := provider.UserInfo(context.Background(), token)
+
+ require.NoError(t, err)
+ require.NotNil(t, identity)
+ assert.Equal(t, "github", identity.ProviderType)
+ assert.Equal(t, tt.expectedEmail, identity.Email)
+ assert.Equal(t, tt.expectedEmailVerified, identity.EmailVerified)
+ assert.Equal(t, tt.expectedDomain, identity.Domain)
+ assert.Equal(t, tt.expectedOrgs, identity.Organizations)
+ })
+ }
+}
+
+func TestGitHubProvider_UserInfo_APIErrors(t *testing.T) {
+ tests := []struct {
+ name string
+ userStatus int
+ errContains string
+ }{
+ {
+ name: "user_api_error",
+ userStatus: http.StatusInternalServerError,
+ errContains: "status 500",
+ },
+ {
+ name: "user_unauthorized",
+ userStatus: http.StatusUnauthorized,
+ errContains: "status 401",
+ },
+ }
+
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+ w.WriteHeader(tt.userStatus)
+ }))
+ defer server.Close()
+
+ provider := &GitHubProvider{
+ config: oauth2.Config{
+ ClientID: "test-client",
+ ClientSecret: "test-secret",
+ },
+ apiBaseURL: server.URL,
+ }
+
+ token := &oauth2.Token{AccessToken: "test-token"}
+ _, err := provider.UserInfo(context.Background(), token)
+
+ require.Error(t, err)
+ assert.Contains(t, err.Error(), tt.errContains)
+ })
+ }
+}
+
+func TestGitHubProvider_UserInfo_NoVerifiedEmail(t *testing.T) {
+ server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+ w.Header().Set("Content-Type", "application/json")
+
+ switch r.URL.Path {
+ case "/user":
+ err := json.NewEncoder(w).Encode(githubUserResponse{ID: 123, Login: "test"})
+ require.NoError(t, err)
+ case "/user/emails":
+ err := json.NewEncoder(w).Encode([]githubEmailResponse{
+ {Email: "unverified@example.com", Primary: true, Verified: false},
+ })
+ require.NoError(t, err)
+ default:
+ w.WriteHeader(http.StatusNotFound)
+ }
+ }))
+ defer server.Close()
+
+ provider := &GitHubProvider{
+ config: oauth2.Config{ClientID: "test"},
+ apiBaseURL: server.URL,
+ }
+
+ token := &oauth2.Token{AccessToken: "test-token"}
+ _, err := provider.UserInfo(context.Background(), token)
+
+ require.Error(t, err)
+ assert.Contains(t, err.Error(), "no verified email")
+}
diff --git a/internal/idp/google.go b/internal/idp/google.go
new file mode 100644
index 0000000..82e86e2
--- /dev/null
+++ b/internal/idp/google.go
@@ -0,0 +1,113 @@
+package idp
+
+import (
+ "context"
+ "encoding/json"
+ "fmt"
+ "net/http"
+
+ emailutil "github.com/dgellow/mcp-front/internal/emailutil"
+ "golang.org/x/oauth2"
+ "golang.org/x/oauth2/google"
+)
+
+// GoogleProvider implements the Provider interface for Google OAuth.
+// Google has specific quirks like `hd` for hosted domain and `verified_email` field.
+type GoogleProvider struct {
+ config oauth2.Config
+ userInfoURL string
+}
+
+// googleUserInfoResponse represents Google's userinfo response.
+// Note: Google uses `hd` for hosted domain and `verified_email` instead of OIDC standard `email_verified`.
+type googleUserInfoResponse struct {
+ Sub string `json:"sub"`
+ Email string `json:"email"`
+ VerifiedEmail bool `json:"verified_email"`
+ Name string `json:"name"`
+ Picture string `json:"picture"`
+ HostedDomain string `json:"hd"`
+}
+
+// NewGoogleProvider creates a new Google OAuth provider.
+// Optional endpoint overrides (authorizationURL, tokenURL, userInfoURL) allow
+// pointing at a non-Google server — useful for testing or corporate proxies.
+func NewGoogleProvider(clientID, clientSecret, redirectURI, authorizationURL, tokenURL, userInfoURL string) *GoogleProvider {
+ endpoint := google.Endpoint
+ if authorizationURL != "" {
+ endpoint.AuthURL = authorizationURL
+ }
+ if tokenURL != "" {
+ endpoint.TokenURL = tokenURL
+ }
+
+ uiURL := "https://www.googleapis.com/oauth2/v2/userinfo"
+ if userInfoURL != "" {
+ uiURL = userInfoURL
+ }
+
+ return &GoogleProvider{
+ config: oauth2.Config{
+ ClientID: clientID,
+ ClientSecret: clientSecret,
+ RedirectURL: redirectURI,
+ Scopes: []string{"openid", "profile", "email"},
+ Endpoint: endpoint,
+ },
+ userInfoURL: uiURL,
+ }
+}
+
+// Type returns the provider type.
+func (p *GoogleProvider) Type() string {
+ return "google"
+}
+
+// AuthURL generates the authorization URL.
+func (p *GoogleProvider) AuthURL(state string) string {
+ return p.config.AuthCodeURL(state,
+ oauth2.AccessTypeOffline,
+ oauth2.ApprovalForce,
+ )
+}
+
+// ExchangeCode exchanges an authorization code for tokens.
+func (p *GoogleProvider) ExchangeCode(ctx context.Context, code string) (*oauth2.Token, error) {
+ return p.config.Exchange(ctx, code)
+}
+
+// UserInfo fetches user identity from Google's userinfo endpoint.
+func (p *GoogleProvider) UserInfo(ctx context.Context, token *oauth2.Token) (*Identity, error) {
+ client := p.config.Client(ctx, token)
+
+ resp, err := client.Get(p.userInfoURL)
+ if err != nil {
+ return nil, fmt.Errorf("failed to get user info: %w", err)
+ }
+ defer resp.Body.Close()
+
+ if resp.StatusCode != http.StatusOK {
+ return nil, fmt.Errorf("failed to get user info: status %d", resp.StatusCode)
+ }
+
+ var googleUser googleUserInfoResponse
+ if err := json.NewDecoder(resp.Body).Decode(&googleUser); err != nil {
+ return nil, fmt.Errorf("failed to decode user info: %w", err)
+ }
+
+ // Use Google's hosted domain if available, otherwise derive from email
+ domain := googleUser.HostedDomain
+ if domain == "" {
+ domain = emailutil.ExtractDomain(googleUser.Email)
+ }
+
+ return &Identity{
+ ProviderType: "google",
+ Subject: googleUser.Sub,
+ Email: googleUser.Email,
+ EmailVerified: googleUser.VerifiedEmail,
+ Name: googleUser.Name,
+ Picture: googleUser.Picture,
+ Domain: domain,
+ }, nil
+}
diff --git a/internal/idp/google_test.go b/internal/idp/google_test.go
new file mode 100644
index 0000000..662c194
--- /dev/null
+++ b/internal/idp/google_test.go
@@ -0,0 +1,121 @@
+package idp
+
+import (
+ "context"
+ "encoding/json"
+ "net/http"
+ "net/http/httptest"
+ "testing"
+
+ "github.com/stretchr/testify/assert"
+ "github.com/stretchr/testify/require"
+ "golang.org/x/oauth2"
+)
+
+func TestGoogleProvider_Type(t *testing.T) {
+ provider := NewGoogleProvider("client-id", "client-secret", "https://example.com/callback", "", "", "")
+ assert.Equal(t, "google", provider.Type())
+}
+
+func TestGoogleProvider_AuthURL(t *testing.T) {
+ provider := NewGoogleProvider("client-id", "client-secret", "https://example.com/callback", "", "", "")
+
+ authURL := provider.AuthURL("test-state")
+
+ assert.Contains(t, authURL, "accounts.google.com")
+ assert.Contains(t, authURL, "state=test-state")
+ assert.Contains(t, authURL, "client_id=client-id")
+ assert.Contains(t, authURL, "redirect_uri=")
+ assert.Contains(t, authURL, "access_type=offline")
+}
+
+func TestGoogleProvider_UserInfo(t *testing.T) {
+ tests := []struct {
+ name string
+ userInfoResp googleUserInfoResponse
+ expectedDomain string
+ expectedSubject string
+ }{
+ {
+ name: "user_with_hosted_domain",
+ userInfoResp: googleUserInfoResponse{
+ Sub: "12345",
+ Email: "user@company.com",
+ VerifiedEmail: true,
+ Name: "Test User",
+ Picture: "https://example.com/photo.jpg",
+ HostedDomain: "company.com",
+ },
+ expectedDomain: "company.com",
+ expectedSubject: "12345",
+ },
+ {
+ name: "user_without_hosted_domain_derives_from_email",
+ userInfoResp: googleUserInfoResponse{
+ Sub: "12345",
+ Email: "user@gmail.com",
+ VerifiedEmail: true,
+ Name: "Test User",
+ },
+ expectedDomain: "gmail.com",
+ expectedSubject: "12345",
+ },
+ }
+
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+ w.Header().Set("Content-Type", "application/json")
+ err := json.NewEncoder(w).Encode(tt.userInfoResp)
+ require.NoError(t, err)
+ }))
+ defer server.Close()
+
+ provider := &GoogleProvider{
+ config: oauth2.Config{
+ ClientID: "test-client",
+ ClientSecret: "test-secret",
+ RedirectURL: "https://example.com/callback",
+ Scopes: []string{"openid", "profile", "email"},
+ Endpoint: oauth2.Endpoint{
+ AuthURL: server.URL + "/authorize",
+ TokenURL: server.URL + "/token",
+ },
+ },
+ userInfoURL: server.URL,
+ }
+ token := &oauth2.Token{AccessToken: "test-token"}
+
+ identity, err := provider.UserInfo(context.Background(), token)
+
+ require.NoError(t, err)
+ require.NotNil(t, identity)
+ assert.Equal(t, "google", identity.ProviderType)
+ assert.Equal(t, tt.expectedSubject, identity.Subject)
+ assert.Equal(t, tt.expectedDomain, identity.Domain)
+ assert.Equal(t, tt.userInfoResp.Email, identity.Email)
+ assert.Equal(t, tt.userInfoResp.VerifiedEmail, identity.EmailVerified)
+ })
+ }
+}
+
+func TestGoogleProvider_UserInfo_ServerError(t *testing.T) {
+ server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+ w.WriteHeader(http.StatusInternalServerError)
+ }))
+ defer server.Close()
+
+ provider := &GoogleProvider{
+ config: oauth2.Config{
+ ClientID: "test-client",
+ ClientSecret: "test-secret",
+ },
+ userInfoURL: server.URL,
+ }
+ token := &oauth2.Token{AccessToken: "test-token"}
+
+ _, err := provider.UserInfo(context.Background(), token)
+
+ require.Error(t, err)
+ assert.Contains(t, err.Error(), "status 500")
+}
diff --git a/internal/idp/oidc.go b/internal/idp/oidc.go
new file mode 100644
index 0000000..807ac05
--- /dev/null
+++ b/internal/idp/oidc.go
@@ -0,0 +1,178 @@
+package idp
+
+import (
+ "context"
+ "encoding/json"
+ "fmt"
+ "net/http"
+ "time"
+
+ emailutil "github.com/dgellow/mcp-front/internal/emailutil"
+ "golang.org/x/oauth2"
+)
+
+// OIDCConfig configures a generic OIDC provider.
+type OIDCConfig struct {
+ // ProviderType identifies this provider (e.g., "oidc", "google", "azure").
+ ProviderType string
+
+ // Discovery URL for OIDC discovery (optional if endpoints are provided directly).
+ DiscoveryURL string
+
+ // Direct endpoint configuration (used if DiscoveryURL is not set).
+ AuthorizationURL string
+ TokenURL string
+ UserInfoURL string
+
+ // OAuth client configuration.
+ ClientID string
+ ClientSecret string
+ RedirectURI string
+ Scopes []string
+}
+
+// OIDCProvider implements the Provider interface for OIDC-compliant identity providers.
+type OIDCProvider struct {
+ providerType string
+ config oauth2.Config
+ userInfoURL string
+}
+
+// oidcDiscoveryDocument represents the OIDC discovery document.
+type oidcDiscoveryDocument struct {
+ AuthorizationEndpoint string `json:"authorization_endpoint"`
+ TokenEndpoint string `json:"token_endpoint"`
+ UserInfoEndpoint string `json:"userinfo_endpoint"`
+ Issuer string `json:"issuer"`
+}
+
+// oidcUserInfoResponse represents the standard OIDC userinfo response.
+type oidcUserInfoResponse struct {
+ Sub string `json:"sub"`
+ Email string `json:"email"`
+ EmailVerified bool `json:"email_verified"`
+ Name string `json:"name"`
+ Picture string `json:"picture"`
+}
+
+// NewOIDCProvider creates a new OIDC provider.
+// TODO: Add OIDC discovery caching to avoid repeated network calls.
+func NewOIDCProvider(cfg OIDCConfig) (*OIDCProvider, error) {
+ var authURL, tokenURL, userInfoURL string
+
+ if cfg.DiscoveryURL != "" {
+ discovery, err := fetchOIDCDiscovery(cfg.DiscoveryURL)
+ if err != nil {
+ return nil, fmt.Errorf("failed to fetch OIDC discovery: %w", err)
+ }
+ authURL = discovery.AuthorizationEndpoint
+ tokenURL = discovery.TokenEndpoint
+ userInfoURL = discovery.UserInfoEndpoint
+ } else {
+ if cfg.AuthorizationURL == "" || cfg.TokenURL == "" || cfg.UserInfoURL == "" {
+ return nil, fmt.Errorf("either discoveryUrl or all endpoints (authorizationUrl, tokenUrl, userInfoUrl) must be provided")
+ }
+ authURL = cfg.AuthorizationURL
+ tokenURL = cfg.TokenURL
+ userInfoURL = cfg.UserInfoURL
+ }
+
+ scopes := cfg.Scopes
+ if len(scopes) == 0 {
+ scopes = []string{"openid", "email", "profile"}
+ }
+
+ providerType := cfg.ProviderType
+ if providerType == "" {
+ providerType = "oidc"
+ }
+
+ return &OIDCProvider{
+ providerType: providerType,
+ config: oauth2.Config{
+ ClientID: cfg.ClientID,
+ ClientSecret: cfg.ClientSecret,
+ RedirectURL: cfg.RedirectURI,
+ Scopes: scopes,
+ Endpoint: oauth2.Endpoint{
+ AuthURL: authURL,
+ TokenURL: tokenURL,
+ },
+ },
+ userInfoURL: userInfoURL,
+ }, nil
+}
+
+func fetchOIDCDiscovery(discoveryURL string) (*oidcDiscoveryDocument, error) {
+ client := &http.Client{Timeout: 10 * time.Second}
+ resp, err := client.Get(discoveryURL)
+ if err != nil {
+ return nil, fmt.Errorf("failed to fetch discovery document: %w", err)
+ }
+ defer resp.Body.Close()
+
+ if resp.StatusCode != http.StatusOK {
+ return nil, fmt.Errorf("discovery endpoint returned status %d", resp.StatusCode)
+ }
+
+ var discovery oidcDiscoveryDocument
+ if err := json.NewDecoder(resp.Body).Decode(&discovery); err != nil {
+ return nil, fmt.Errorf("failed to decode discovery document: %w", err)
+ }
+
+ if discovery.AuthorizationEndpoint == "" || discovery.TokenEndpoint == "" || discovery.UserInfoEndpoint == "" {
+ return nil, fmt.Errorf("discovery document missing required endpoints")
+ }
+
+ return &discovery, nil
+}
+
+// Type returns the provider type.
+func (p *OIDCProvider) Type() string {
+ return p.providerType
+}
+
+// AuthURL generates the authorization URL.
+func (p *OIDCProvider) AuthURL(state string) string {
+ return p.config.AuthCodeURL(state,
+ oauth2.AccessTypeOffline,
+ oauth2.ApprovalForce,
+ )
+}
+
+// ExchangeCode exchanges an authorization code for tokens.
+func (p *OIDCProvider) ExchangeCode(ctx context.Context, code string) (*oauth2.Token, error) {
+ return p.config.Exchange(ctx, code)
+}
+
+// UserInfo fetches user identity from the OIDC userinfo endpoint.
+// TODO: Add ID token validation as optimization (avoids network call).
+func (p *OIDCProvider) UserInfo(ctx context.Context, token *oauth2.Token) (*Identity, error) {
+ client := p.config.Client(ctx, token)
+ resp, err := client.Get(p.userInfoURL)
+ if err != nil {
+ return nil, fmt.Errorf("failed to get user info: %w", err)
+ }
+ defer resp.Body.Close()
+
+ if resp.StatusCode != http.StatusOK {
+ return nil, fmt.Errorf("failed to get user info: status %d", resp.StatusCode)
+ }
+
+ var userInfoResp oidcUserInfoResponse
+ if err := json.NewDecoder(resp.Body).Decode(&userInfoResp); err != nil {
+ return nil, fmt.Errorf("failed to decode user info: %w", err)
+ }
+
+ domain := emailutil.ExtractDomain(userInfoResp.Email)
+
+ return &Identity{
+ ProviderType: p.providerType,
+ Subject: userInfoResp.Sub,
+ Email: userInfoResp.Email,
+ EmailVerified: userInfoResp.EmailVerified,
+ Name: userInfoResp.Name,
+ Picture: userInfoResp.Picture,
+ Domain: domain,
+ }, nil
+}
diff --git a/internal/idp/oidc_test.go b/internal/idp/oidc_test.go
new file mode 100644
index 0000000..ced5b9f
--- /dev/null
+++ b/internal/idp/oidc_test.go
@@ -0,0 +1,189 @@
+package idp
+
+import (
+ "context"
+ "encoding/json"
+ "net/http"
+ "net/http/httptest"
+ "testing"
+
+ "github.com/stretchr/testify/assert"
+ "github.com/stretchr/testify/require"
+ "golang.org/x/oauth2"
+)
+
+func TestNewOIDCProvider_WithDirectEndpoints(t *testing.T) {
+ provider, err := NewOIDCProvider(OIDCConfig{
+ ProviderType: "custom",
+ AuthorizationURL: "https://idp.example.com/authorize",
+ TokenURL: "https://idp.example.com/token",
+ UserInfoURL: "https://idp.example.com/userinfo",
+ ClientID: "client-id",
+ ClientSecret: "client-secret",
+ RedirectURI: "https://example.com/callback",
+ })
+
+ require.NoError(t, err)
+ require.NotNil(t, provider)
+ assert.Equal(t, "custom", provider.Type())
+}
+
+func TestNewOIDCProvider_WithDiscovery(t *testing.T) {
+ discovery := oidcDiscoveryDocument{
+ Issuer: "https://idp.example.com",
+ AuthorizationEndpoint: "https://idp.example.com/authorize",
+ TokenEndpoint: "https://idp.example.com/token",
+ UserInfoEndpoint: "https://idp.example.com/userinfo",
+ }
+
+ server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+ w.Header().Set("Content-Type", "application/json")
+ err := json.NewEncoder(w).Encode(discovery)
+ require.NoError(t, err)
+ }))
+ defer server.Close()
+
+ provider, err := NewOIDCProvider(OIDCConfig{
+ DiscoveryURL: server.URL,
+ ClientID: "client-id",
+ ClientSecret: "client-secret",
+ RedirectURI: "https://example.com/callback",
+ })
+
+ require.NoError(t, err)
+ require.NotNil(t, provider)
+ assert.Equal(t, "oidc", provider.Type())
+}
+
+func TestNewOIDCProvider_MissingEndpoints(t *testing.T) {
+ _, err := NewOIDCProvider(OIDCConfig{
+ ClientID: "client-id",
+ ClientSecret: "client-secret",
+ RedirectURI: "https://example.com/callback",
+ })
+
+ require.Error(t, err)
+ assert.Contains(t, err.Error(), "discoveryUrl or all endpoints")
+}
+
+func TestNewOIDCProvider_PartialEndpoints(t *testing.T) {
+ _, err := NewOIDCProvider(OIDCConfig{
+ AuthorizationURL: "https://idp.example.com/authorize",
+ ClientID: "client-id",
+ ClientSecret: "client-secret",
+ RedirectURI: "https://example.com/callback",
+ })
+
+ require.Error(t, err)
+ assert.Contains(t, err.Error(), "discoveryUrl or all endpoints")
+}
+
+func TestNewOIDCProvider_DiscoveryMissingEndpoints(t *testing.T) {
+ discovery := oidcDiscoveryDocument{
+ Issuer: "https://idp.example.com",
+ }
+
+ server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+ w.Header().Set("Content-Type", "application/json")
+ err := json.NewEncoder(w).Encode(discovery)
+ require.NoError(t, err)
+ }))
+ defer server.Close()
+
+ _, err := NewOIDCProvider(OIDCConfig{
+ DiscoveryURL: server.URL,
+ ClientID: "client-id",
+ ClientSecret: "client-secret",
+ RedirectURI: "https://example.com/callback",
+ })
+
+ require.Error(t, err)
+ assert.Contains(t, err.Error(), "missing required endpoints")
+}
+
+func TestOIDCProvider_AuthURL(t *testing.T) {
+ provider, err := NewOIDCProvider(OIDCConfig{
+ AuthorizationURL: "https://idp.example.com/authorize",
+ TokenURL: "https://idp.example.com/token",
+ UserInfoURL: "https://idp.example.com/userinfo",
+ ClientID: "client-id",
+ ClientSecret: "client-secret",
+ RedirectURI: "https://example.com/callback",
+ })
+ require.NoError(t, err)
+
+ authURL := provider.AuthURL("test-state")
+
+ assert.Contains(t, authURL, "https://idp.example.com/authorize")
+ assert.Contains(t, authURL, "state=test-state")
+ assert.Contains(t, authURL, "client_id=client-id")
+}
+
+func TestOIDCProvider_UserInfo(t *testing.T) {
+ userInfoServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+ w.Header().Set("Content-Type", "application/json")
+ resp := oidcUserInfoResponse{
+ Sub: "12345",
+ Email: "user@example.com",
+ EmailVerified: true,
+ Name: "Test User",
+ Picture: "https://example.com/photo.jpg",
+ }
+ err := json.NewEncoder(w).Encode(resp)
+ require.NoError(t, err)
+ }))
+ defer userInfoServer.Close()
+
+ provider, err := NewOIDCProvider(OIDCConfig{
+ AuthorizationURL: "https://idp.example.com/authorize",
+ TokenURL: "https://idp.example.com/token",
+ UserInfoURL: userInfoServer.URL,
+ ClientID: "client-id",
+ ClientSecret: "client-secret",
+ RedirectURI: "https://example.com/callback",
+ })
+ require.NoError(t, err)
+
+ token := &oauth2.Token{AccessToken: "test-token"}
+ identity, err := provider.UserInfo(context.Background(), token)
+
+ require.NoError(t, err)
+ require.NotNil(t, identity)
+ assert.Equal(t, "oidc", identity.ProviderType)
+ assert.Equal(t, "12345", identity.Subject)
+ assert.Equal(t, "user@example.com", identity.Email)
+ assert.Equal(t, "example.com", identity.Domain)
+ assert.True(t, identity.EmailVerified)
+}
+
+func TestOIDCProvider_DefaultScopes(t *testing.T) {
+ provider, err := NewOIDCProvider(OIDCConfig{
+ AuthorizationURL: "https://idp.example.com/authorize",
+ TokenURL: "https://idp.example.com/token",
+ UserInfoURL: "https://idp.example.com/userinfo",
+ ClientID: "client-id",
+ ClientSecret: "client-secret",
+ RedirectURI: "https://example.com/callback",
+ })
+ require.NoError(t, err)
+
+ authURL := provider.AuthURL("test-state")
+ assert.Contains(t, authURL, "scope=openid")
+}
+
+func TestOIDCProvider_CustomScopes(t *testing.T) {
+ provider, err := NewOIDCProvider(OIDCConfig{
+ AuthorizationURL: "https://idp.example.com/authorize",
+ TokenURL: "https://idp.example.com/token",
+ UserInfoURL: "https://idp.example.com/userinfo",
+ ClientID: "client-id",
+ ClientSecret: "client-secret",
+ RedirectURI: "https://example.com/callback",
+ Scopes: []string{"openid", "custom_scope"},
+ })
+ require.NoError(t, err)
+
+ authURL := provider.AuthURL("test-state")
+ assert.Contains(t, authURL, "scope=openid")
+ assert.Contains(t, authURL, "custom_scope")
+}
diff --git a/internal/idp/provider.go b/internal/idp/provider.go
new file mode 100644
index 0000000..e700112
--- /dev/null
+++ b/internal/idp/provider.go
@@ -0,0 +1,37 @@
+package idp
+
+import (
+ "context"
+
+ "golang.org/x/oauth2"
+)
+
+// Identity represents user identity as reported by an identity provider.
+// Providers populate this with identity information only — access control
+// (domain, org checks) is handled by the authorization layer.
+type Identity struct {
+ ProviderType string `json:"provider_type"`
+ Subject string `json:"sub"`
+ Email string `json:"email"`
+ EmailVerified bool `json:"email_verified"`
+ Name string `json:"name"`
+ Picture string `json:"picture"`
+ Domain string `json:"domain"`
+ Organizations []string `json:"organizations,omitempty"`
+}
+
+// Provider abstracts identity provider operations.
+type Provider interface {
+ // Type returns the provider type identifier (e.g., "google", "azure", "github", "oidc").
+ Type() string
+
+ // AuthURL generates the authorization URL for the OAuth flow.
+ AuthURL(state string) string
+
+ // ExchangeCode exchanges an authorization code for tokens.
+ ExchangeCode(ctx context.Context, code string) (*oauth2.Token, error)
+
+ // UserInfo fetches user identity from the provider.
+ // Returns identity information only — no access control validation.
+ UserInfo(ctx context.Context, token *oauth2.Token) (*Identity, error)
+}
diff --git a/internal/jsonrpc/errors.go b/internal/jsonrpc/errors.go
index ee3196c..8fcfcf1 100644
--- a/internal/jsonrpc/errors.go
+++ b/internal/jsonrpc/errors.go
@@ -46,25 +46,3 @@ func NewStandardError(code int) *Error {
Message: message,
}
}
-
-// GetErrorName returns a human-readable name for standard error codes
-func GetErrorName(code int) string {
- switch code {
- case ParseError:
- return "PARSE_ERROR"
- case InvalidRequest:
- return "INVALID_REQUEST"
- case MethodNotFound:
- return "METHOD_NOT_FOUND"
- case InvalidParams:
- return "INVALID_PARAMS"
- case InternalError:
- return "INTERNAL_ERROR"
- default:
- // Check if it's an MCP-specific error code
- if code >= -32099 && code <= -32000 {
- return "SERVER_ERROR"
- }
- return "UNKNOWN_ERROR"
- }
-}
diff --git a/internal/mcpfront.go b/internal/mcpfront.go
index e6f5b80..ec70773 100644
--- a/internal/mcpfront.go
+++ b/internal/mcpfront.go
@@ -14,6 +14,7 @@ import (
"github.com/dgellow/mcp-front/internal/client"
"github.com/dgellow/mcp-front/internal/config"
"github.com/dgellow/mcp-front/internal/crypto"
+ "github.com/dgellow/mcp-front/internal/idp"
"github.com/dgellow/mcp-front/internal/inline"
"github.com/dgellow/mcp-front/internal/log"
"github.com/dgellow/mcp-front/internal/oauth"
@@ -52,7 +53,7 @@ func NewMCPFront(ctx context.Context, cfg config.Config) (*MCPFront, error) {
}
// Setup authentication (OAuth components and service client)
- oauthProvider, sessionEncryptor, authConfig, serviceOAuthClient, err := setupAuthentication(ctx, cfg, store)
+ oauthProvider, idpProvider, sessionEncryptor, authConfig, serviceOAuthClient, err := setupAuthentication(ctx, cfg, store)
if err != nil {
return nil, fmt.Errorf("failed to setup authentication: %w", err)
}
@@ -98,6 +99,7 @@ func NewMCPFront(ctx context.Context, cfg config.Config) (*MCPFront, error) {
cfg,
store,
oauthProvider,
+ idpProvider,
sessionEncryptor,
authConfig,
serviceOAuthClient,
@@ -227,32 +229,42 @@ func setupStorage(ctx context.Context, cfg config.Config) (storage.Storage, erro
}
// setupAuthentication creates individual OAuth components using clean constructors
-func setupAuthentication(ctx context.Context, cfg config.Config, store storage.Storage) (fosite.OAuth2Provider, crypto.Encryptor, config.OAuthAuthConfig, *auth.ServiceOAuthClient, error) {
+func setupAuthentication(ctx context.Context, cfg config.Config, store storage.Storage) (fosite.OAuth2Provider, idp.Provider, crypto.Encryptor, config.OAuthAuthConfig, *auth.ServiceOAuthClient, error) {
oauthAuth := cfg.Proxy.Auth
if oauthAuth == nil {
// OAuth not configured
- return nil, nil, config.OAuthAuthConfig{}, nil, nil
+ return nil, nil, nil, config.OAuthAuthConfig{}, nil, nil
}
log.LogDebug("initializing OAuth components")
+ // Create identity provider
+ idpProvider, err := idp.NewProvider(oauthAuth.IDP)
+ if err != nil {
+ return nil, nil, nil, config.OAuthAuthConfig{}, nil, fmt.Errorf("failed to create identity provider: %w", err)
+ }
+
+ log.LogInfoWithFields("mcpfront", "Identity provider configured", map[string]any{
+ "type": idpProvider.Type(),
+ })
+
// Generate or validate JWT secret using clean constructor
jwtSecret, err := oauth.GenerateJWTSecret(string(oauthAuth.JWTSecret))
if err != nil {
- return nil, nil, config.OAuthAuthConfig{}, nil, fmt.Errorf("failed to setup JWT secret: %w", err)
+ return nil, nil, nil, config.OAuthAuthConfig{}, nil, fmt.Errorf("failed to setup JWT secret: %w", err)
}
// Create session encryptor using clean constructor
encryptionKey := []byte(oauthAuth.EncryptionKey)
sessionEncryptor, err := oauth.NewSessionEncryptor(encryptionKey)
if err != nil {
- return nil, nil, config.OAuthAuthConfig{}, nil, fmt.Errorf("failed to create session encryptor: %w", err)
+ return nil, nil, nil, config.OAuthAuthConfig{}, nil, fmt.Errorf("failed to create session encryptor: %w", err)
}
// Create OAuth provider using clean constructor
oauthProvider, err := oauth.NewOAuthProvider(*oauthAuth, store, jwtSecret)
if err != nil {
- return nil, nil, config.OAuthAuthConfig{}, nil, fmt.Errorf("failed to create OAuth provider: %w", err)
+ return nil, nil, nil, config.OAuthAuthConfig{}, nil, fmt.Errorf("failed to create OAuth provider: %w", err)
}
// Create OAuth client for service authentication and token refresh
@@ -279,7 +291,7 @@ func setupAuthentication(ctx context.Context, cfg config.Config, store storage.S
}
}
- return oauthProvider, sessionEncryptor, *oauthAuth, serviceOAuthClient, nil
+ return oauthProvider, idpProvider, sessionEncryptor, *oauthAuth, serviceOAuthClient, nil
}
// buildHTTPHandler creates the complete HTTP handler with all routing and middleware
@@ -287,6 +299,7 @@ func buildHTTPHandler(
cfg config.Config,
storage storage.Storage,
oauthProvider fosite.OAuth2Provider,
+ idpProvider idp.Provider,
sessionEncryptor crypto.Encryptor,
authConfig config.OAuthAuthConfig,
serviceOAuthClient *auth.ServiceOAuthClient,
@@ -337,6 +350,7 @@ func buildHTTPHandler(
authHandlers := server.NewAuthHandlers(
oauthProvider,
authConfig,
+ idpProvider,
storage,
sessionEncryptor,
cfg.MCPServers,
@@ -352,7 +366,7 @@ func buildHTTPHandler(
// Returns 404 by default, or base issuer metadata if dangerouslyAcceptIssuerAudience is enabled
mux.Handle(route("/.well-known/oauth-protected-resource"), server.ChainMiddleware(http.HandlerFunc(authHandlers.ProtectedResourceMetadataHandler), oauthMiddleware...))
mux.Handle(route("/authorize"), server.ChainMiddleware(http.HandlerFunc(authHandlers.AuthorizeHandler), oauthMiddleware...))
- mux.Handle(route("/oauth/callback"), server.ChainMiddleware(http.HandlerFunc(authHandlers.GoogleCallbackHandler), oauthMiddleware...))
+ mux.Handle(route("/oauth/callback"), server.ChainMiddleware(http.HandlerFunc(authHandlers.IDPCallbackHandler), oauthMiddleware...))
mux.Handle(route("/token"), server.ChainMiddleware(http.HandlerFunc(authHandlers.TokenHandler), oauthMiddleware...))
mux.Handle(route("/register"), server.ChainMiddleware(http.HandlerFunc(authHandlers.RegisterHandler), oauthMiddleware...))
mux.Handle(route("/clients/{client_id}"), server.ChainMiddleware(http.HandlerFunc(authHandlers.ClientMetadataHandler), oauthMiddleware...))
@@ -361,12 +375,12 @@ func buildHTTPHandler(
tokenMiddleware := []server.MiddlewareFunc{
corsMiddleware,
tokenLogger,
- server.NewBrowserSSOMiddleware(authConfig, sessionEncryptor, browserStateToken),
+ server.NewBrowserSSOMiddleware(authConfig, idpProvider, sessionEncryptor, browserStateToken),
mcpRecover,
}
// Create token handlers
- tokenHandlers := server.NewTokenHandlers(storage, cfg.MCPServers, true, serviceOAuthClient)
+ tokenHandlers := server.NewTokenHandlers(storage, cfg.MCPServers, true, serviceOAuthClient, []byte(authConfig.EncryptionKey))
// Token management UI endpoints
mux.Handle(route("/my/tokens"), server.ChainMiddleware(http.HandlerFunc(tokenHandlers.ListTokensHandler), tokenMiddleware...))
@@ -407,7 +421,7 @@ func buildHTTPHandler(
}
} else {
// For stdio/SSE servers
- if isStdioServer(serverConfig) {
+ if serverConfig.IsStdio() {
sseServer, mcpServer, err = buildStdioSSEServer(serverName, baseURL, sessionManager)
if err != nil {
return nil, fmt.Errorf("failed to create SSE server for %s: %w", serverName, err)
@@ -475,7 +489,7 @@ func buildHTTPHandler(
// Add browser SSO if OAuth is enabled
if oauthProvider != nil {
// Reuse the same browserStateToken created earlier for consistency
- adminMiddleware = append(adminMiddleware, server.NewBrowserSSOMiddleware(authConfig, sessionEncryptor, browserStateToken))
+ adminMiddleware = append(adminMiddleware, server.NewBrowserSSOMiddleware(authConfig, idpProvider, sessionEncryptor, browserStateToken))
}
// Add admin check middleware
@@ -591,8 +605,3 @@ func buildStdioSSEServer(serverName, baseURL string, sessionManager *client.Stdi
return sseServer, mcpServer, nil
}
-
-// isStdioServer checks if this is a stdio-based server
-func isStdioServer(cfg *config.MCPClientConfig) bool {
- return cfg.TransportType == config.MCPClientTypeStdio
-}
diff --git a/internal/oauth/client_registration.go b/internal/oauth/client_registration.go
new file mode 100644
index 0000000..75510df
--- /dev/null
+++ b/internal/oauth/client_registration.go
@@ -0,0 +1,32 @@
+package oauth
+
+import (
+ "fmt"
+ "strings"
+)
+
+// ParseClientRegistration parses MCP client registration metadata.
+// This is provider-agnostic as it deals with MCP client registration, not IDP.
+func ParseClientRegistration(metadata map[string]any) (redirectURIs []string, scopes []string, err error) {
+ redirectURIs = []string{}
+ if uris, ok := metadata["redirect_uris"].([]any); ok {
+ for _, uri := range uris {
+ if uriStr, ok := uri.(string); ok {
+ redirectURIs = append(redirectURIs, uriStr)
+ }
+ }
+ }
+
+ if len(redirectURIs) == 0 {
+ return nil, nil, fmt.Errorf("no valid redirect URIs provided")
+ }
+
+ scopes = []string{"read", "write"}
+ if clientScopes, ok := metadata["scope"].(string); ok {
+ if strings.TrimSpace(clientScopes) != "" {
+ scopes = strings.Fields(clientScopes)
+ }
+ }
+
+ return redirectURIs, scopes, nil
+}
diff --git a/internal/oauth/client_registration_test.go b/internal/oauth/client_registration_test.go
new file mode 100644
index 0000000..4aacd68
--- /dev/null
+++ b/internal/oauth/client_registration_test.go
@@ -0,0 +1,113 @@
+package oauth
+
+import (
+ "testing"
+
+ "github.com/stretchr/testify/assert"
+ "github.com/stretchr/testify/require"
+)
+
+func TestParseClientRegistration(t *testing.T) {
+ tests := []struct {
+ name string
+ metadata map[string]any
+ wantRedirectURIs []string
+ wantScopes []string
+ wantErr bool
+ errContains string
+ }{
+ {
+ name: "valid_with_single_redirect_uri",
+ metadata: map[string]any{
+ "redirect_uris": []any{"https://example.com/callback"},
+ },
+ wantRedirectURIs: []string{"https://example.com/callback"},
+ wantScopes: []string{"read", "write"},
+ wantErr: false,
+ },
+ {
+ name: "valid_with_multiple_redirect_uris",
+ metadata: map[string]any{
+ "redirect_uris": []any{
+ "https://example.com/callback",
+ "https://example.com/callback2",
+ },
+ },
+ wantRedirectURIs: []string{
+ "https://example.com/callback",
+ "https://example.com/callback2",
+ },
+ wantScopes: []string{"read", "write"},
+ wantErr: false,
+ },
+ {
+ name: "valid_with_custom_scopes",
+ metadata: map[string]any{
+ "redirect_uris": []any{"https://example.com/callback"},
+ "scope": "openid profile email",
+ },
+ wantRedirectURIs: []string{"https://example.com/callback"},
+ wantScopes: []string{"openid", "profile", "email"},
+ wantErr: false,
+ },
+ {
+ name: "valid_with_empty_scope_uses_default",
+ metadata: map[string]any{
+ "redirect_uris": []any{"https://example.com/callback"},
+ "scope": " ",
+ },
+ wantRedirectURIs: []string{"https://example.com/callback"},
+ wantScopes: []string{"read", "write"},
+ wantErr: false,
+ },
+ {
+ name: "missing_redirect_uris",
+ metadata: map[string]any{},
+ wantErr: true,
+ errContains: "no valid redirect URIs",
+ },
+ {
+ name: "empty_redirect_uris",
+ metadata: map[string]any{
+ "redirect_uris": []any{},
+ },
+ wantErr: true,
+ errContains: "no valid redirect URIs",
+ },
+ {
+ name: "redirect_uris_wrong_type",
+ metadata: map[string]any{
+ "redirect_uris": "https://example.com/callback",
+ },
+ wantErr: true,
+ errContains: "no valid redirect URIs",
+ },
+ {
+ name: "redirect_uri_non_string_elements_ignored",
+ metadata: map[string]any{
+ "redirect_uris": []any{123, "https://example.com/callback", nil},
+ },
+ wantRedirectURIs: []string{"https://example.com/callback"},
+ wantScopes: []string{"read", "write"},
+ wantErr: false,
+ },
+ }
+
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ redirectURIs, scopes, err := ParseClientRegistration(tt.metadata)
+
+ if tt.wantErr {
+ require.Error(t, err)
+ if tt.errContains != "" {
+ assert.Contains(t, err.Error(), tt.errContains)
+ }
+ return
+ }
+
+ require.NoError(t, err)
+ assert.Equal(t, tt.wantRedirectURIs, redirectURIs)
+ assert.Equal(t, tt.wantScopes, scopes)
+ })
+ }
+}
diff --git a/internal/oauth/metadata.go b/internal/oauth/metadata.go
index e750415..078906e 100644
--- a/internal/oauth/metadata.go
+++ b/internal/oauth/metadata.go
@@ -53,29 +53,9 @@ func AuthorizationServerMetadata(issuer string) (map[string]any, error) {
}, nil
}
-// ProtectedResourceMetadata builds OAuth 2.0 Protected Resource Metadata per RFC 9728
-// https://datatracker.ietf.org/doc/html/rfc9728
-//
-// Deprecated: Use ServiceProtectedResourceMetadata for per-service metadata endpoints.
-// This function returns the base issuer as the resource, which doesn't support
-// per-service audience validation required by RFC 8707.
-func ProtectedResourceMetadata(issuer string) (map[string]any, error) {
- authzServerURL, err := urlutil.JoinPath(issuer, ".well-known", "oauth-authorization-server")
- if err != nil {
- return nil, err
- }
-
- return map[string]any{
- "resource": issuer,
- "authorization_servers": []string{
- issuer,
- },
- "_links": map[string]any{
- "oauth-authorization-server": map[string]string{
- "href": authzServerURL,
- },
- },
- }, nil
+// AuthorizationServerMetadataURI returns the well-known URI for the authorization server metadata.
+func AuthorizationServerMetadataURI(issuer string) (string, error) {
+ return urlutil.JoinPath(issuer, ".well-known", "oauth-authorization-server")
}
// ServiceProtectedResourceMetadata builds OAuth 2.0 Protected Resource Metadata per RFC 9728
@@ -92,7 +72,7 @@ func ServiceProtectedResourceMetadata(issuer string, serviceName string) (map[st
return nil, err
}
- authzServerURL, err := urlutil.JoinPath(issuer, ".well-known", "oauth-authorization-server")
+ authzServerURL, err := AuthorizationServerMetadataURI(issuer)
if err != nil {
return nil, err
}
diff --git a/internal/oauth/metadata_test.go b/internal/oauth/metadata_test.go
index db9e8df..4367c09 100644
--- a/internal/oauth/metadata_test.go
+++ b/internal/oauth/metadata_test.go
@@ -83,74 +83,6 @@ func TestAuthorizationServerMetadata(t *testing.T) {
}
}
-func TestProtectedResourceMetadata(t *testing.T) {
- tests := []struct {
- name string
- issuer string
- wantErr bool
- }{
- {
- name: "valid issuer",
- issuer: "https://example.com",
- wantErr: false,
- },
- {
- name: "issuer with path",
- issuer: "https://example.com/oauth",
- wantErr: false,
- },
- {
- name: "invalid issuer",
- issuer: "://invalid",
- wantErr: true,
- },
- }
-
- for _, tt := range tests {
- t.Run(tt.name, func(t *testing.T) {
- metadata, err := ProtectedResourceMetadata(tt.issuer)
- if (err != nil) != tt.wantErr {
- t.Errorf("ProtectedResourceMetadata() error = %v, wantErr %v", err, tt.wantErr)
- return
- }
-
- if tt.wantErr {
- return
- }
-
- // Verify resource field
- if metadata["resource"] != tt.issuer {
- t.Errorf("resource = %v, want %v", metadata["resource"], tt.issuer)
- }
-
- // Verify authorization_servers array
- authzServers, ok := metadata["authorization_servers"].([]string)
- if !ok || len(authzServers) == 0 {
- t.Error("authorization_servers is missing or empty")
- }
-
- if authzServers[0] != tt.issuer {
- t.Errorf("authorization_servers[0] = %v, want %v", authzServers[0], tt.issuer)
- }
-
- // Verify _links structure
- links, ok := metadata["_links"].(map[string]any)
- if !ok {
- t.Error("_links is missing or wrong type")
- }
-
- authzServerLink, ok := links["oauth-authorization-server"].(map[string]string)
- if !ok {
- t.Error("oauth-authorization-server link is missing or wrong type")
- }
-
- if authzServerLink["href"] == "" {
- t.Error("oauth-authorization-server href is empty")
- }
- })
- }
-}
-
func TestServiceProtectedResourceMetadata(t *testing.T) {
tests := []struct {
name string
diff --git a/internal/oauth/provider.go b/internal/oauth/provider.go
index 8ceca50..09eb568 100644
--- a/internal/oauth/provider.go
+++ b/internal/oauth/provider.go
@@ -12,10 +12,9 @@ import (
"github.com/dgellow/mcp-front/internal/config"
"github.com/dgellow/mcp-front/internal/crypto"
- "github.com/dgellow/mcp-front/internal/envutil"
jsonwriter "github.com/dgellow/mcp-front/internal/json"
"github.com/dgellow/mcp-front/internal/log"
- "github.com/dgellow/mcp-front/internal/oauthsession"
+ "github.com/dgellow/mcp-front/internal/session"
"github.com/dgellow/mcp-front/internal/storage"
"github.com/dgellow/mcp-front/internal/urlutil"
"github.com/ory/fosite"
@@ -48,8 +47,8 @@ func NewOAuthProvider(oauthConfig config.OAuthAuthConfig, store storage.Storage,
// Determine min parameter entropy based on environment
minEntropy := 8 // Production default - enforce secure state parameters (8+ chars)
- log.Logf("OAuth provider initialization - MCP_FRONT_ENV=%s, isDevelopmentMode=%v", os.Getenv("MCP_FRONT_ENV"), envutil.IsDev())
- if envutil.IsDev() {
+ log.Logf("OAuth provider initialization - MCP_FRONT_ENV=%s, isDevelopmentMode=%v", os.Getenv("MCP_FRONT_ENV"), config.IsDev())
+ if config.IsDev() {
minEntropy = 0 // Development mode - allow empty state parameters
log.LogWarn("Development mode enabled - OAuth security checks relaxed (state parameter entropy: %d)", minEntropy)
}
@@ -158,8 +157,8 @@ func NewValidateTokenMiddleware(provider fosite.OAuth2Provider, issuer string, a
// - This is documented fosite behavior, not a bug
// - The actual session data must be retrieved from the returned AccessRequester
// See: https://github.com/ory/fosite/issues/256
- session := &oauthsession.Session{DefaultSession: &fosite.DefaultSession{}}
- _, accessRequest, err := provider.IntrospectToken(ctx, token, fosite.AccessToken, session)
+ oauthSession := &session.OAuthSession{DefaultSession: &fosite.DefaultSession{}}
+ _, accessRequest, err := provider.IntrospectToken(ctx, token, fosite.AccessToken, oauthSession)
if err != nil {
jsonwriter.WriteUnauthorizedRFC9728(w, "Invalid or expired token", metadataURI)
return
@@ -180,9 +179,9 @@ func NewValidateTokenMiddleware(provider fosite.OAuth2Provider, issuer string, a
// This is the correct way to retrieve session data after token introspection
var userEmail string
if accessRequest != nil {
- if reqSession, ok := accessRequest.GetSession().(*oauthsession.Session); ok {
- if reqSession.UserInfo.Email != "" {
- userEmail = reqSession.UserInfo.Email
+ if reqSession, ok := accessRequest.GetSession().(*session.OAuthSession); ok {
+ if reqSession.Identity.Email != "" {
+ userEmail = reqSession.Identity.Email
}
}
}
diff --git a/internal/oauthsession/session.go b/internal/oauthsession/session.go
deleted file mode 100644
index f8f093f..0000000
--- a/internal/oauthsession/session.go
+++ /dev/null
@@ -1,20 +0,0 @@
-package oauthsession
-
-import (
- "github.com/dgellow/mcp-front/internal/googleauth"
- "github.com/ory/fosite"
-)
-
-// Session extends DefaultSession with user information
-type Session struct {
- *fosite.DefaultSession
- UserInfo googleauth.UserInfo `json:"user_info"`
-}
-
-// Clone implements fosite.Session
-func (s *Session) Clone() fosite.Session {
- return &Session{
- DefaultSession: s.DefaultSession.Clone().(*fosite.DefaultSession),
- UserInfo: s.UserInfo,
- }
-}
diff --git a/internal/server/admin_handlers.go b/internal/server/admin_handlers.go
index b76db11..4c1c20d 100644
--- a/internal/server/admin_handlers.go
+++ b/internal/server/admin_handlers.go
@@ -4,8 +4,6 @@ import (
"fmt"
"net/http"
"net/url"
- "strconv"
- "strings"
"time"
"github.com/dgellow/mcp-front/internal/adminauth"
@@ -23,7 +21,7 @@ type AdminHandlers struct {
storage storage.Storage
config config.Config
sessionManager *client.StdioSessionManager
- encryptionKey []byte // For HMAC-based CSRF tokens
+ csrf crypto.CSRFProtection
}
// NewAdminHandlers creates a new admin handlers instance
@@ -32,67 +30,10 @@ func NewAdminHandlers(storage storage.Storage, config config.Config, sessionMana
storage: storage,
config: config,
sessionManager: sessionManager,
- encryptionKey: []byte(encryptionKey),
+ csrf: crypto.NewCSRFProtection([]byte(encryptionKey), 15*time.Minute),
}
}
-// generateCSRFToken creates a new HMAC-based CSRF token
-func (h *AdminHandlers) generateCSRFToken() (string, error) {
- // Generate random nonce
- nonce := crypto.GenerateSecureToken()
- if nonce == "" {
- return "", fmt.Errorf("failed to generate nonce")
- }
-
- // Add timestamp (Unix seconds)
- timestamp := strconv.FormatInt(time.Now().Unix(), 10)
-
- // Create data to sign: nonce:timestamp
- data := nonce + ":" + timestamp
-
- // Sign with HMAC
- signature := crypto.SignData(data, h.encryptionKey)
-
- // Return format: nonce:timestamp:signature
- return fmt.Sprintf("%s:%s:%s", nonce, timestamp, signature), nil
-}
-
-// validateCSRFToken checks if a CSRF token is valid
-func (h *AdminHandlers) validateCSRFToken(token string) bool {
- // Parse token format: nonce:timestamp:signature
- parts := strings.SplitN(token, ":", 3)
- if len(parts) != 3 {
- log.LogDebug("Invalid CSRF token format")
- return false
- }
-
- nonce := parts[0]
- timestampStr := parts[1]
- signature := parts[2]
-
- // Verify timestamp (15 minute expiry)
- timestamp, err := strconv.ParseInt(timestampStr, 10, 64)
- if err != nil {
- log.LogDebug("Invalid CSRF token timestamp: %v", err)
- return false
- }
-
- now := time.Now().Unix()
- if now-timestamp > 900 { // 15 minutes
- log.LogDebug("CSRF token expired")
- return false
- }
-
- // Verify HMAC signature
- data := nonce + ":" + timestampStr
- if !crypto.ValidateSignedData(data, signature, h.encryptionKey) {
- log.LogDebug("Invalid CSRF token signature")
- return false
- }
-
- return true
-}
-
// DashboardHandler shows the admin dashboard
func (h *AdminHandlers) DashboardHandler(w http.ResponseWriter, r *http.Request) {
// Only accept GET
@@ -152,7 +93,7 @@ func (h *AdminHandlers) DashboardHandler(w http.ResponseWriter, r *http.Request)
currentLogLevel := log.GetLogLevel()
// Generate CSRF token
- csrfToken, err := h.generateCSRFToken()
+ csrfToken, err := h.csrf.Generate()
if err != nil {
log.LogErrorWithFields("admin", "Failed to generate CSRF token", map[string]any{
"error": err.Error(),
@@ -209,7 +150,7 @@ func (h *AdminHandlers) UserActionHandler(w http.ResponseWriter, r *http.Request
}
// Validate CSRF
- if !h.validateCSRFToken(r.FormValue("csrf_token")) {
+ if !h.csrf.Validate(r.FormValue("csrf_token")) {
jsonwriter.WriteForbidden(w, "Invalid CSRF token")
return
}
@@ -379,7 +320,7 @@ func (h *AdminHandlers) SessionActionHandler(w http.ResponseWriter, r *http.Requ
}
// Validate CSRF
- if !h.validateCSRFToken(r.FormValue("csrf_token")) {
+ if !h.csrf.Validate(r.FormValue("csrf_token")) {
jsonwriter.WriteForbidden(w, "Invalid CSRF token")
return
}
@@ -472,7 +413,7 @@ func (h *AdminHandlers) LoggingActionHandler(w http.ResponseWriter, r *http.Requ
}
// Validate CSRF
- if !h.validateCSRFToken(r.FormValue("csrf_token")) {
+ if !h.csrf.Validate(r.FormValue("csrf_token")) {
jsonwriter.WriteForbidden(w, "Invalid CSRF token")
return
}
diff --git a/internal/server/admin_handlers_test.go b/internal/server/admin_handlers_test.go
index b3587fa..a2aa8ee 100644
--- a/internal/server/admin_handlers_test.go
+++ b/internal/server/admin_handlers_test.go
@@ -27,7 +27,7 @@ func TestAdminHandlers_CSRF(t *testing.T) {
t.Run("generate and validate CSRF token", func(t *testing.T) {
// Generate token
- token, err := handlers.generateCSRFToken()
+ token, err := handlers.csrf.Generate()
if err != nil {
t.Fatalf("Failed to generate CSRF token: %v", err)
}
@@ -38,13 +38,13 @@ func TestAdminHandlers_CSRF(t *testing.T) {
}
// Token should be valid immediately
- if !handlers.validateCSRFToken(token) {
+ if !handlers.csrf.Validate(token) {
t.Error("Token should be valid immediately after generation")
}
// Token should not be valid twice (though with HMAC it actually can be)
// With HMAC-based tokens, they can be validated multiple times
- if !handlers.validateCSRFToken(token) {
+ if !handlers.csrf.Validate(token) {
t.Error("HMAC token should be valid on second validation")
}
})
@@ -58,7 +58,7 @@ func TestAdminHandlers_CSRF(t *testing.T) {
}
for _, token := range invalidTokens {
- if handlers.validateCSRFToken(token) {
+ if handlers.csrf.Validate(token) {
t.Errorf("Token '%s' should be invalid", token)
}
}
@@ -69,7 +69,7 @@ func TestAdminHandlers_CSRF(t *testing.T) {
// An expired token would have a timestamp from > 15 minutes ago
expiredToken := "test-nonce:0:invalid-signature"
- if handlers.validateCSRFToken(expiredToken) {
+ if handlers.csrf.Validate(expiredToken) {
t.Error("Expired token should be invalid")
}
})
@@ -79,18 +79,18 @@ func TestAdminHandlers_CSRF(t *testing.T) {
handlers2 := NewAdminHandlers(storage, cfg, sessionManager, "different-encryption-key-32bytes")
// Generate token with first handler
- token1, err := handlers.generateCSRFToken()
+ token1, err := handlers.csrf.Generate()
if err != nil {
t.Fatalf("Failed to generate token: %v", err)
}
// Token from handler1 should not validate with handler2
- if handlers2.validateCSRFToken(token1) {
+ if handlers2.csrf.Validate(token1) {
t.Error("Token should not validate with different encryption key")
}
// But should still validate with original handler
- if !handlers.validateCSRFToken(token1) {
+ if !handlers.csrf.Validate(token1) {
t.Error("Token should validate with original handler")
}
})
diff --git a/internal/server/auth_handlers.go b/internal/server/auth_handlers.go
index 23a98d5..cdb94e8 100644
--- a/internal/server/auth_handlers.go
+++ b/internal/server/auth_handlers.go
@@ -7,19 +7,19 @@ import (
"fmt"
"net/http"
"net/url"
+ "slices"
"strings"
"time"
"github.com/dgellow/mcp-front/internal/auth"
- "github.com/dgellow/mcp-front/internal/browserauth"
"github.com/dgellow/mcp-front/internal/config"
+ "github.com/dgellow/mcp-front/internal/cookie"
"github.com/dgellow/mcp-front/internal/crypto"
- "github.com/dgellow/mcp-front/internal/envutil"
- "github.com/dgellow/mcp-front/internal/googleauth"
+ "github.com/dgellow/mcp-front/internal/idp"
jsonwriter "github.com/dgellow/mcp-front/internal/json"
"github.com/dgellow/mcp-front/internal/log"
"github.com/dgellow/mcp-front/internal/oauth"
- "github.com/dgellow/mcp-front/internal/oauthsession"
+ "github.com/dgellow/mcp-front/internal/session"
"github.com/dgellow/mcp-front/internal/storage"
"github.com/ory/fosite"
)
@@ -28,6 +28,7 @@ import (
type AuthHandlers struct {
oauthProvider fosite.OAuth2Provider
authConfig config.OAuthAuthConfig
+ idpProvider idp.Provider
storage storage.Storage
sessionEncryptor crypto.Encryptor
mcpServers map[string]*config.MCPClientConfig
@@ -37,18 +38,19 @@ type AuthHandlers struct {
// UpstreamOAuthState stores OAuth state during upstream authentication flow (MCP host → mcp-front)
type UpstreamOAuthState struct {
- UserInfo googleauth.UserInfo `json:"user_info"`
- ClientID string `json:"client_id"`
- RedirectURI string `json:"redirect_uri"`
- Scopes []string `json:"scopes"`
- State string `json:"state"`
- ResponseType string `json:"response_type"`
+ Identity idp.Identity `json:"identity"`
+ ClientID string `json:"client_id"`
+ RedirectURI string `json:"redirect_uri"`
+ Scopes []string `json:"scopes"`
+ State string `json:"state"`
+ ResponseType string `json:"response_type"`
}
// NewAuthHandlers creates new auth handlers with dependency injection
func NewAuthHandlers(
oauthProvider fosite.OAuth2Provider,
authConfig config.OAuthAuthConfig,
+ idpProvider idp.Provider,
storage storage.Storage,
sessionEncryptor crypto.Encryptor,
mcpServers map[string]*config.MCPClientConfig,
@@ -57,6 +59,7 @@ func NewAuthHandlers(
return &AuthHandlers{
oauthProvider: oauthProvider,
authConfig: authConfig,
+ idpProvider: idpProvider,
storage: storage,
sessionEncryptor: sessionEncryptor,
mcpServers: mcpServers,
@@ -98,18 +101,31 @@ func (h *AuthHandlers) ProtectedResourceMetadataHandler(w http.ResponseWriter, r
return
}
- // Workaround mode: return base issuer metadata for broken clients
+ // Workaround mode: return base issuer metadata for broken clients.
+ // This intentionally uses the issuer as the resource (no per-service scoping)
+ // because broken clients don't implement RFC 8707 resource indicators.
+ issuer := h.authConfig.Issuer
log.LogWarnWithFields("oauth", "Serving base protected resource metadata (dangerouslyAcceptIssuerAudience enabled)", map[string]any{
- "issuer": h.authConfig.Issuer,
+ "issuer": issuer,
})
- metadata, err := oauth.ProtectedResourceMetadata(h.authConfig.Issuer)
+ authzServerURL, err := oauth.AuthorizationServerMetadataURI(issuer)
if err != nil {
log.LogError("Failed to build protected resource metadata: %v", err)
jsonwriter.WriteInternalServerError(w, "Internal server error")
return
}
+ metadata := map[string]any{
+ "resource": issuer,
+ "authorization_servers": []string{issuer},
+ "_links": map[string]any{
+ "oauth-authorization-server": map[string]string{
+ "href": authzServerURL,
+ },
+ },
+ }
+
w.Header().Set("Content-Type", "application/json")
if err := json.NewEncoder(w).Encode(metadata); err != nil {
log.LogError("Failed to encode protected resource metadata: %v", err)
@@ -207,7 +223,7 @@ func (h *AuthHandlers) AuthorizeHandler(w http.ResponseWriter, r *http.Request)
// In development mode, generate a secure state parameter if missing
// This works around bugs in OAuth clients that don't send state
stateParam := r.URL.Query().Get("state")
- if envutil.IsDev() && len(stateParam) == 0 {
+ if config.IsDev() && len(stateParam) == 0 {
generatedState := crypto.GenerateSecureToken()
log.LogWarn("Development mode: generating state parameter '%s' for buggy client", generatedState)
q := r.URL.Query()
@@ -259,14 +275,19 @@ func (h *AuthHandlers) AuthorizeHandler(w http.ResponseWriter, r *http.Request)
}
state := ar.GetState()
- h.storage.StoreAuthorizeRequest(state, ar)
+ if err := h.storage.StoreAuthorizeRequest(state, ar); err != nil {
+ log.LogError("Failed to store authorize request: %v", err)
+ h.oauthProvider.WriteAuthorizeError(ctx, w, ar, fosite.ErrServerError.WithHint("Failed to store authorization request"))
+ return
+ }
- authURL := googleauth.GoogleAuthURL(h.authConfig, state)
+ authURL := h.idpProvider.AuthURL(state)
http.Redirect(w, r, authURL, http.StatusFound)
}
-// GoogleCallbackHandler handles the callback from Google OAuth
-func (h *AuthHandlers) GoogleCallbackHandler(w http.ResponseWriter, r *http.Request) {
+// IDPCallbackHandler handles the callback from the identity provider.
+// It dispatches to handleBrowserCallback or handleOAuthClientCallback based on the flow type.
+func (h *AuthHandlers) IDPCallbackHandler(w http.ResponseWriter, r *http.Request) {
ctx := r.Context()
state := r.URL.Query().Get("state")
@@ -274,7 +295,7 @@ func (h *AuthHandlers) GoogleCallbackHandler(w http.ResponseWriter, r *http.Requ
if errMsg := r.URL.Query().Get("error"); errMsg != "" {
errDesc := r.URL.Query().Get("error_description")
- log.LogError("Google OAuth error: %s - %s", errMsg, errDesc)
+ log.LogError("OAuth error: %s - %s", errMsg, errDesc)
jsonwriter.WriteBadRequest(w, fmt.Sprintf("Authentication failed: %s", errMsg))
return
}
@@ -293,7 +314,7 @@ func (h *AuthHandlers) GoogleCallbackHandler(w http.ResponseWriter, r *http.Requ
isBrowserFlow = true
stateToken := strings.TrimPrefix(state, "browser:")
- var browserState browserauth.AuthorizationState
+ var browserState session.AuthorizationState
if err := h.oauthStateToken.Verify(stateToken, &browserState); err != nil {
log.LogError("Invalid browser state: %v", err)
jsonwriter.WriteBadRequest(w, "Invalid state parameter")
@@ -301,7 +322,6 @@ func (h *AuthHandlers) GoogleCallbackHandler(w http.ResponseWriter, r *http.Requ
}
returnURL = browserState.ReturnURL
} else {
- // OAuth client flow - retrieve stored authorize request
var found bool
ar, found = h.storage.GetAuthorizeRequest(state)
if !found {
@@ -315,7 +335,7 @@ func (h *AuthHandlers) GoogleCallbackHandler(w http.ResponseWriter, r *http.Requ
ctx, cancel := context.WithTimeout(ctx, 30*time.Second)
defer cancel()
- token, err := googleauth.ExchangeCodeForToken(ctx, h.authConfig, code)
+ token, err := h.idpProvider.ExchangeCode(ctx, code)
if err != nil {
log.LogError("Failed to exchange code: %v", err)
if !isBrowserFlow && ar != nil {
@@ -326,10 +346,19 @@ func (h *AuthHandlers) GoogleCallbackHandler(w http.ResponseWriter, r *http.Requ
return
}
- // Validate user
- userInfo, err := googleauth.ValidateUser(ctx, h.authConfig, token)
+ identity, err := h.idpProvider.UserInfo(ctx, token)
if err != nil {
- log.LogError("User validation failed: %v", err)
+ log.LogError("Failed to fetch user identity: %v", err)
+ if !isBrowserFlow && ar != nil {
+ h.oauthProvider.WriteAuthorizeError(ctx, w, ar, fosite.ErrServerError.WithHint("Failed to fetch user identity"))
+ } else {
+ jsonwriter.WriteInternalServerError(w, "Authentication failed")
+ }
+ return
+ }
+
+ if err := h.validateAccess(identity); err != nil {
+ log.LogError("Access denied: %v", err)
if !isBrowserFlow && ar != nil {
h.oauthProvider.WriteAuthorizeError(ctx, w, ar, fosite.ErrAccessDenied.WithHint(err.Error()))
} else {
@@ -338,65 +367,61 @@ func (h *AuthHandlers) GoogleCallbackHandler(w http.ResponseWriter, r *http.Requ
return
}
- log.Logf("User authenticated: %s", userInfo.Email)
+ log.Logf("User authenticated: %s", identity.Email)
- // Store user in database
- if err := h.storage.UpsertUser(ctx, userInfo.Email); err != nil {
+ if err := h.storage.UpsertUser(ctx, identity.Email); err != nil {
log.LogWarnWithFields("auth", "Failed to track user", map[string]any{
- "email": userInfo.Email,
+ "email": identity.Email,
"error": err.Error(),
})
}
if isBrowserFlow {
- // Browser SSO flow - set encrypted session cookie
- // Browser sessions should last longer than API tokens for better UX
- sessionDuration := 24 * time.Hour
-
- sessionData := browserauth.SessionCookie{
- Email: userInfo.Email,
- Expires: time.Now().Add(sessionDuration),
- }
-
- // Marshal session data to JSON
- jsonData, err := json.Marshal(sessionData)
- if err != nil {
- log.LogError("Failed to marshal session data: %v", err)
- jsonwriter.WriteInternalServerError(w, "Failed to create session")
- return
- }
+ h.handleBrowserCallback(w, r, identity, returnURL)
+ } else {
+ h.handleOAuthClientCallback(ctx, w, r, ar, identity)
+ }
+}
- // Encrypt session data
- encryptedData, err := h.sessionEncryptor.Encrypt(string(jsonData))
- if err != nil {
- log.LogError("Failed to encrypt session: %v", err)
- jsonwriter.WriteInternalServerError(w, "Failed to create session")
- return
- }
+// handleBrowserCallback handles the browser SSO callback flow: creates an encrypted
+// session cookie and redirects to the return URL.
+func (h *AuthHandlers) handleBrowserCallback(w http.ResponseWriter, r *http.Request, identity *idp.Identity, returnURL string) {
+ sessionDuration := 24 * time.Hour
- // Set secure session cookie
- http.SetCookie(w, &http.Cookie{
- Name: "mcp_session",
- Value: encryptedData,
- Path: "/",
- HttpOnly: true,
- Secure: !envutil.IsDev(),
- SameSite: http.SameSiteStrictMode,
- MaxAge: int(sessionDuration.Seconds()),
- })
+ sessionData := session.BrowserCookie{
+ Email: identity.Email,
+ Provider: identity.ProviderType,
+ Expires: time.Now().Add(sessionDuration),
+ }
- log.LogInfoWithFields("auth", "Browser SSO session created", map[string]any{
- "user": userInfo.Email,
- "duration": sessionDuration,
- "returnURL": returnURL,
- })
+ jsonData, err := json.Marshal(sessionData)
+ if err != nil {
+ log.LogError("Failed to marshal session data: %v", err)
+ jsonwriter.WriteInternalServerError(w, "Failed to create session")
+ return
+ }
- // Redirect to return URL
- http.Redirect(w, r, returnURL, http.StatusFound)
+ encryptedData, err := h.sessionEncryptor.Encrypt(string(jsonData))
+ if err != nil {
+ log.LogError("Failed to encrypt session: %v", err)
+ jsonwriter.WriteInternalServerError(w, "Failed to create session")
return
}
- // OAuth client flow - check if any services need OAuth
+ cookie.SetSession(w, encryptedData, sessionDuration)
+
+ log.LogInfoWithFields("auth", "Browser SSO session created", map[string]any{
+ "user": identity.Email,
+ "duration": sessionDuration,
+ "returnURL": returnURL,
+ })
+
+ http.Redirect(w, r, returnURL, http.StatusFound)
+}
+
+// handleOAuthClientCallback handles the OAuth client callback flow: checks for
+// service auth needs, creates a fosite session, and issues the authorize response.
+func (h *AuthHandlers) handleOAuthClientCallback(ctx context.Context, w http.ResponseWriter, r *http.Request, ar fosite.AuthorizeRequester, identity *idp.Identity) {
needsServiceAuth := false
for _, serverConfig := range h.mcpServers {
if serverConfig.RequiresUserToken &&
@@ -408,7 +433,7 @@ func (h *AuthHandlers) GoogleCallbackHandler(w http.ResponseWriter, r *http.Requ
}
if needsServiceAuth {
- stateData, err := h.signUpstreamOAuthState(ar, userInfo)
+ stateData, err := h.signUpstreamOAuthState(ar, *identity)
if err != nil {
log.LogError("Failed to sign OAuth state: %v", err)
h.oauthProvider.WriteAuthorizeError(ctx, w, ar, fosite.ErrServerError.WithHint("Failed to prepare service authentication"))
@@ -419,20 +444,16 @@ func (h *AuthHandlers) GoogleCallbackHandler(w http.ResponseWriter, r *http.Requ
return
}
- // Create session for token issuance
- // Note: Audience claims are stored in the authorize request (ar.GetGrantedAudience())
- // and will be automatically propagated to access tokens by fosite
- session := &oauthsession.Session{
+ session := &session.OAuthSession{
DefaultSession: &fosite.DefaultSession{
ExpiresAt: map[fosite.TokenType]time.Time{
fosite.AccessToken: time.Now().Add(h.authConfig.TokenTTL),
fosite.RefreshToken: time.Now().Add(h.authConfig.RefreshTokenTTL),
},
},
- UserInfo: userInfo,
+ Identity: *identity,
}
- // Accept the authorization request
response, err := h.oauthProvider.NewAuthorizeResponse(ctx, ar, session)
if err != nil {
log.LogError("Authorize response error: %v", err)
@@ -451,7 +472,7 @@ func (h *AuthHandlers) TokenHandler(w http.ResponseWriter, r *http.Request) {
// Create session for the token exchange
// Note: We create our custom Session type here, and fosite will populate it
// with the session data from the authorization code during NewAccessRequest
- session := &oauthsession.Session{DefaultSession: &fosite.DefaultSession{}}
+ session := &session.OAuthSession{DefaultSession: &fosite.DefaultSession{}}
// Handle token request - this retrieves the session from the authorization code
accessRequest, err := h.oauthProvider.NewAccessRequest(ctx, r, session)
@@ -511,7 +532,7 @@ func (h *AuthHandlers) RegisterHandler(w http.ResponseWriter, r *http.Request) {
}
// Parse client request
- redirectURIs, scopes, err := googleauth.ParseClientRequest(metadata)
+ redirectURIs, scopes, err := oauth.ParseClientRegistration(metadata)
if err != nil {
log.LogError("Client request parsing error: %v", err)
jsonwriter.WriteBadRequest(w, err.Error())
@@ -560,9 +581,9 @@ func (h *AuthHandlers) RegisterHandler(w http.ResponseWriter, r *http.Request) {
}
// signUpstreamOAuthState signs upstream OAuth state for secure storage
-func (h *AuthHandlers) signUpstreamOAuthState(ar fosite.AuthorizeRequester, userInfo googleauth.UserInfo) (string, error) {
+func (h *AuthHandlers) signUpstreamOAuthState(ar fosite.AuthorizeRequester, identity idp.Identity) (string, error) {
state := UpstreamOAuthState{
- UserInfo: userInfo,
+ Identity: identity,
ClientID: ar.GetClient().GetID(),
RedirectURI: ar.GetRedirectURI().String(),
Scopes: ar.GetRequestedScopes(),
@@ -582,6 +603,28 @@ func (h *AuthHandlers) verifyUpstreamOAuthState(signedState string) (*UpstreamOA
return &state, nil
}
+// validateAccess checks whether an authenticated identity meets the configured access policy.
+func (h *AuthHandlers) validateAccess(identity *idp.Identity) error {
+ if len(h.authConfig.AllowedDomains) > 0 &&
+ !slices.Contains(h.authConfig.AllowedDomains, identity.Domain) {
+ return fmt.Errorf("domain '%s' is not allowed. Contact your administrator", identity.Domain)
+ }
+ if len(h.authConfig.IDP.AllowedOrgs) > 0 &&
+ !hasOverlap(identity.Organizations, h.authConfig.IDP.AllowedOrgs) {
+ return fmt.Errorf("user is not a member of any allowed organization. Contact your administrator")
+ }
+ return nil
+}
+
+func hasOverlap(a, b []string) bool {
+ for _, x := range a {
+ if slices.Contains(b, x) {
+ return true
+ }
+ }
+ return false
+}
+
// ServiceSelectionHandler shows the interstitial page for selecting services to connect
func (h *AuthHandlers) ServiceSelectionHandler(w http.ResponseWriter, r *http.Request) {
if r.Method != http.MethodGet {
@@ -602,7 +645,7 @@ func (h *AuthHandlers) ServiceSelectionHandler(w http.ResponseWriter, r *http.Re
return
}
- userEmail := upstreamOAuthState.UserInfo.Email
+ userEmail := upstreamOAuthState.Identity.Email
// Prepare template data
returnURL := fmt.Sprintf("/oauth/services?state=%s", url.QueryEscape(signedState))
@@ -724,7 +767,7 @@ func (h *AuthHandlers) CompleteOAuthHandler(w http.ResponseWriter, r *http.Reque
Client: client,
RequestedScope: upstreamOAuthState.Scopes,
GrantedScope: upstreamOAuthState.Scopes,
- Session: &oauthsession.Session{DefaultSession: &fosite.DefaultSession{}},
+ Session: &session.OAuthSession{DefaultSession: &fosite.DefaultSession{}},
},
}
@@ -737,14 +780,14 @@ func (h *AuthHandlers) CompleteOAuthHandler(w http.ResponseWriter, r *http.Reque
ar.RedirectURI = redirectURI
// Create session with user info
- session := &oauthsession.Session{
+ session := &session.OAuthSession{
DefaultSession: &fosite.DefaultSession{
ExpiresAt: map[fosite.TokenType]time.Time{
fosite.AccessToken: time.Now().Add(h.authConfig.TokenTTL),
fosite.RefreshToken: time.Now().Add(h.authConfig.RefreshTokenTTL),
},
},
- UserInfo: upstreamOAuthState.UserInfo,
+ Identity: upstreamOAuthState.Identity,
}
ar.SetSession(session)
diff --git a/internal/server/auth_handlers_test.go b/internal/server/auth_handlers_test.go
index 6e2e98f..a8cd625 100644
--- a/internal/server/auth_handlers_test.go
+++ b/internal/server/auth_handlers_test.go
@@ -1,6 +1,7 @@
package server
import (
+ "context"
"encoding/json"
"net/http"
"net/http/httptest"
@@ -9,15 +10,43 @@ import (
"time"
"github.com/dgellow/mcp-front/internal/auth"
- "github.com/dgellow/mcp-front/internal/browserauth"
"github.com/dgellow/mcp-front/internal/config"
"github.com/dgellow/mcp-front/internal/crypto"
+ "github.com/dgellow/mcp-front/internal/idp"
"github.com/dgellow/mcp-front/internal/oauth"
+ "github.com/dgellow/mcp-front/internal/session"
"github.com/dgellow/mcp-front/internal/storage"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
+ "golang.org/x/oauth2"
)
+// mockIDPProvider is a mock IDP provider for testing
+type mockIDPProvider struct{}
+
+func (m *mockIDPProvider) Type() string {
+ return "mock"
+}
+
+func (m *mockIDPProvider) AuthURL(state string) string {
+ return "https://auth.example.com/authorize?state=" + state
+}
+
+func (m *mockIDPProvider) ExchangeCode(ctx context.Context, code string) (*oauth2.Token, error) {
+ return &oauth2.Token{AccessToken: "test-token"}, nil
+}
+
+func (m *mockIDPProvider) UserInfo(ctx context.Context, token *oauth2.Token) (*idp.Identity, error) {
+ return &idp.Identity{
+ ProviderType: "mock",
+ Subject: "123",
+ Email: "test@example.com",
+ EmailVerified: true,
+ Name: "Test User",
+ Domain: "example.com",
+ }, nil
+}
+
func TestAuthenticationBoundaries(t *testing.T) {
tests := []struct {
name string
@@ -47,18 +76,21 @@ func TestAuthenticationBoundaries(t *testing.T) {
// Setup test OAuth configuration
oauthConfig := config.OAuthAuthConfig{
- Kind: config.AuthKindOAuth,
- Issuer: "https://test.example.com",
- GoogleClientID: "test-client-id",
- GoogleClientSecret: config.Secret("test-client-secret"),
- GoogleRedirectURI: "https://test.example.com/oauth/callback",
- JWTSecret: config.Secret(strings.Repeat("a", 32)),
- EncryptionKey: config.Secret(strings.Repeat("b", 32)),
- TokenTTL: time.Hour,
- RefreshTokenTTL: 30 * 24 * time.Hour,
- Storage: "memory",
- AllowedDomains: []string{"example.com"},
- AllowedOrigins: []string{"https://test.example.com"},
+ Kind: config.AuthKindOAuth,
+ Issuer: "https://test.example.com",
+ IDP: config.IDPConfig{
+ Provider: "google",
+ ClientID: "test-client-id",
+ ClientSecret: config.Secret("test-client-secret"),
+ RedirectURI: "https://test.example.com/oauth/callback",
+ },
+ JWTSecret: config.Secret(strings.Repeat("a", 32)),
+ EncryptionKey: config.Secret(strings.Repeat("b", 32)),
+ TokenTTL: time.Hour,
+ RefreshTokenTTL: 30 * 24 * time.Hour,
+ Storage: "memory",
+ AllowedDomains: []string{"example.com"},
+ AllowedOrigins: []string{"https://test.example.com"},
}
// Create storage
@@ -77,17 +109,21 @@ func TestAuthenticationBoundaries(t *testing.T) {
// Create service OAuth client
serviceOAuthClient := auth.NewServiceOAuthClient(store, "https://test.example.com", []byte(strings.Repeat("k", 32)))
+ // Create mock IDP provider for testing
+ mockIDP := &mockIDPProvider{}
+
// Create handlers
authHandlers := NewAuthHandlers(
oauthProvider,
oauthConfig,
+ mockIDP,
store,
sessionEncryptor,
map[string]*config.MCPClientConfig{},
serviceOAuthClient,
)
- tokenHandlers := NewTokenHandlers(store, map[string]*config.MCPClientConfig{}, true, serviceOAuthClient)
+ tokenHandlers := NewTokenHandlers(store, map[string]*config.MCPClientConfig{}, true, serviceOAuthClient, []byte(oauthConfig.EncryptionKey))
// Build mux with middlewares
mux := http.NewServeMux()
@@ -103,7 +139,7 @@ func TestAuthenticationBoundaries(t *testing.T) {
// Protected endpoints
tokenMiddleware := []MiddlewareFunc{
corsMiddleware,
- NewBrowserSSOMiddleware(oauthConfig, sessionEncryptor, &browserStateToken),
+ NewBrowserSSOMiddleware(oauthConfig, mockIDP, sessionEncryptor, &browserStateToken),
}
mux.Handle("/my/tokens", ChainMiddleware(
@@ -154,9 +190,10 @@ func TestAuthenticationBoundaries(t *testing.T) {
// Test with valid session cookie (if auth is expected)
if tt.expectAuth {
// Create session data
- sessionData := browserauth.SessionCookie{
- Email: "test@example.com",
- Expires: time.Now().Add(24 * time.Hour),
+ sessionData := session.BrowserCookie{
+ Email: "test@example.com",
+ Provider: "mock",
+ Expires: time.Now().Add(24 * time.Hour),
}
jsonData, err := json.Marshal(sessionData)
require.NoError(t, err)
@@ -186,18 +223,21 @@ func TestAuthenticationBoundaries(t *testing.T) {
func TestOAuthEndpointHandlers(t *testing.T) {
oauthConfig := config.OAuthAuthConfig{
- Kind: config.AuthKindOAuth,
- Issuer: "https://test.example.com",
- GoogleClientID: "test-client-id",
- GoogleClientSecret: config.Secret("test-client-secret"),
- GoogleRedirectURI: "https://test.example.com/oauth/callback",
- JWTSecret: config.Secret(strings.Repeat("a", 32)),
- EncryptionKey: config.Secret(strings.Repeat("b", 32)),
- TokenTTL: time.Hour,
- RefreshTokenTTL: 30 * 24 * time.Hour,
- Storage: "memory",
- AllowedDomains: []string{"example.com"},
- AllowedOrigins: []string{"https://test.example.com"},
+ Kind: config.AuthKindOAuth,
+ Issuer: "https://test.example.com",
+ IDP: config.IDPConfig{
+ Provider: "google",
+ ClientID: "test-client-id",
+ ClientSecret: config.Secret("test-client-secret"),
+ RedirectURI: "https://test.example.com/oauth/callback",
+ },
+ JWTSecret: config.Secret(strings.Repeat("a", 32)),
+ EncryptionKey: config.Secret(strings.Repeat("b", 32)),
+ TokenTTL: time.Hour,
+ RefreshTokenTTL: 30 * 24 * time.Hour,
+ Storage: "memory",
+ AllowedDomains: []string{"example.com"},
+ AllowedOrigins: []string{"https://test.example.com"},
}
store := storage.NewMemoryStorage()
@@ -208,10 +248,12 @@ func TestOAuthEndpointHandlers(t *testing.T) {
sessionEncryptor, err := oauth.NewSessionEncryptor([]byte(oauthConfig.EncryptionKey))
require.NoError(t, err)
serviceOAuthClient := auth.NewServiceOAuthClient(store, "https://test.example.com", []byte(strings.Repeat("k", 32)))
+ mockIDP := &mockIDPProvider{}
authHandlers := NewAuthHandlers(
oauthProvider,
oauthConfig,
+ mockIDP,
store,
sessionEncryptor,
map[string]*config.MCPClientConfig{},
@@ -272,7 +314,7 @@ func TestOAuthEndpointHandlers(t *testing.T) {
req := httptest.NewRequest(http.MethodGet, "/oauth/services", nil)
// Add valid session cookie
- sessionData := browserauth.SessionCookie{
+ sessionData := session.BrowserCookie{
Email: "test@example.com",
Expires: time.Now().Add(24 * time.Hour),
}
@@ -378,3 +420,88 @@ func TestBearerTokenAuth(t *testing.T) {
})
}
}
+
+func TestValidateAccess(t *testing.T) {
+ tests := []struct {
+ name string
+ allowedDomains []string
+ allowedOrgs []string
+ identity *idp.Identity
+ wantErr bool
+ errContains string
+ }{
+ {
+ name: "no_restrictions",
+ allowedDomains: nil,
+ allowedOrgs: nil,
+ identity: &idp.Identity{Domain: "any.com", Organizations: []string{"any-org"}},
+ wantErr: false,
+ },
+ {
+ name: "domain_allowed",
+ allowedDomains: []string{"company.com"},
+ identity: &idp.Identity{Domain: "company.com"},
+ wantErr: false,
+ },
+ {
+ name: "domain_rejected",
+ allowedDomains: []string{"company.com"},
+ identity: &idp.Identity{Domain: "other.com"},
+ wantErr: true,
+ errContains: "domain 'other.com' is not allowed",
+ },
+ {
+ name: "org_allowed",
+ allowedOrgs: []string{"allowed-org"},
+ identity: &idp.Identity{Domain: "any.com", Organizations: []string{"allowed-org", "other-org"}},
+ wantErr: false,
+ },
+ {
+ name: "org_rejected",
+ allowedOrgs: []string{"required-org"},
+ identity: &idp.Identity{Domain: "any.com", Organizations: []string{"other-org"}},
+ wantErr: true,
+ errContains: "not a member of any allowed organization",
+ },
+ {
+ name: "domain_and_org_both_pass",
+ allowedDomains: []string{"company.com"},
+ allowedOrgs: []string{"my-org"},
+ identity: &idp.Identity{Domain: "company.com", Organizations: []string{"my-org"}},
+ wantErr: false,
+ },
+ {
+ name: "domain_fails_before_org_check",
+ allowedDomains: []string{"company.com"},
+ allowedOrgs: []string{"my-org"},
+ identity: &idp.Identity{Domain: "other.com", Organizations: []string{"my-org"}},
+ wantErr: true,
+ errContains: "domain 'other.com' is not allowed",
+ },
+ }
+
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ h := &AuthHandlers{
+ authConfig: config.OAuthAuthConfig{
+ AllowedDomains: tt.allowedDomains,
+ IDP: config.IDPConfig{
+ AllowedOrgs: tt.allowedOrgs,
+ },
+ },
+ }
+
+ err := h.validateAccess(tt.identity)
+
+ if tt.wantErr {
+ require.Error(t, err)
+ if tt.errContains != "" {
+ assert.Contains(t, err.Error(), tt.errContains)
+ }
+ return
+ }
+
+ require.NoError(t, err)
+ })
+ }
+}
diff --git a/internal/server/http.go b/internal/server/http.go
index 3bbe7dc..6a0dc68 100644
--- a/internal/server/http.go
+++ b/internal/server/http.go
@@ -4,86 +4,10 @@ import (
"context"
"errors"
"net/http"
- "strings"
- "time"
- "github.com/dgellow/mcp-front/internal/auth"
- "github.com/dgellow/mcp-front/internal/client"
- "github.com/dgellow/mcp-front/internal/config"
"github.com/dgellow/mcp-front/internal/log"
- "github.com/dgellow/mcp-front/internal/storage"
- mcpserver "github.com/mark3labs/mcp-go/server"
)
-// UserTokenService handles user token retrieval and OAuth refresh
-type UserTokenService struct {
- storage storage.Storage
- serviceOAuthClient *auth.ServiceOAuthClient
-}
-
-// NewUserTokenService creates a new user token service
-func NewUserTokenService(storage storage.Storage, serviceOAuthClient *auth.ServiceOAuthClient) *UserTokenService {
- return &UserTokenService{
- storage: storage,
- serviceOAuthClient: serviceOAuthClient,
- }
-}
-
-// GetUserToken retrieves and formats a user token for a service, handling OAuth refresh.
-//
-// Token refresh strategy: Optimistic continuation on failure.
-// If refresh fails, we log a warning and continue with the current token. The external
-// service will reject the expired token with 401, giving the user a clear error.
-// This is acceptable because: (1) refresh failures are rare (network issues, revoked
-// tokens), and (2) forcing users to re-auth is better than silently hiding auth issues.
-func (uts *UserTokenService) GetUserToken(ctx context.Context, userEmail, serviceName string, serviceConfig *config.MCPClientConfig) (string, error) {
- storedToken, err := uts.storage.GetUserToken(ctx, userEmail, serviceName)
- if err != nil {
- return "", err
- }
-
- switch storedToken.Type {
- case storage.TokenTypeManual:
- // Token is already in storedToken.Value, formatUserToken will handle it
- break
- case storage.TokenTypeOAuth:
- if storedToken.OAuthData != nil && uts.serviceOAuthClient != nil {
- if err := uts.serviceOAuthClient.RefreshToken(ctx, userEmail, serviceName, serviceConfig); err != nil {
- log.LogWarnWithFields("user_token", "Failed to refresh OAuth token", map[string]any{
- "service": serviceName,
- "user": userEmail,
- "error": err.Error(),
- })
- // Continue with current token - the service will handle auth failure
- } else {
- // Re-fetch the updated token after refresh
- refreshedToken, err := uts.storage.GetUserToken(ctx, userEmail, serviceName)
- if err != nil {
- log.LogErrorWithFields("user_token", "Failed to fetch token after successful refresh", map[string]any{
- "service": serviceName,
- "user": userEmail,
- "error": err.Error(),
- })
- // Continue with original token - the service will handle auth failure
- } else {
- storedToken = refreshedToken
- var expiresAt time.Time
- if refreshedToken.OAuthData != nil {
- expiresAt = refreshedToken.OAuthData.ExpiresAt
- }
- log.LogInfoWithFields("user_token", "OAuth token refreshed and updated", map[string]any{
- "service": serviceName,
- "user": userEmail,
- "expiresAt": expiresAt,
- })
- }
- }
- }
- }
-
- return formatUserToken(storedToken, serviceConfig.UserAuthentication), nil
-}
-
// HTTPServer manages the HTTP server lifecycle
type HTTPServer struct {
server *http.Server
@@ -99,8 +23,6 @@ func NewHTTPServer(handler http.Handler, addr string) *HTTPServer {
}
}
-// Handler builders and mux assembly
-
// HealthHandler handles health check requests
type HealthHandler struct{}
@@ -143,184 +65,3 @@ func (h *HTTPServer) Stop(ctx context.Context) error {
})
return nil
}
-
-// isStdioServer checks if this is a stdio-based server
-func isStdioServer(config *config.MCPClientConfig) bool {
- return config.Command != ""
-}
-
-// formatUserToken formats a stored token according to the user authentication configuration
-func formatUserToken(storedToken *storage.StoredToken, auth *config.UserAuthentication) string {
- if storedToken == nil {
- return ""
- }
-
- if storedToken.Type == storage.TokenTypeOAuth && storedToken.OAuthData != nil {
- token := storedToken.OAuthData.AccessToken
- if auth.TokenFormat != "" && auth.TokenFormat != "{{token}}" {
- return strings.ReplaceAll(auth.TokenFormat, "{{token}}", token)
- }
- return token
- }
-
- token := storedToken.Value
- if auth != nil && auth.TokenFormat != "" && auth.TokenFormat != "{{token}}" {
- return strings.ReplaceAll(auth.TokenFormat, "{{token}}", token)
- }
- return token
-}
-
-// SessionHandlerKey is the context key for session handlers
-type SessionHandlerKey struct{}
-
-// SessionRequestHandler handles session-specific logic for a request
-type SessionRequestHandler struct {
- h *MCPHandler
- userEmail string
- config *config.MCPClientConfig
- mcpServer *mcpserver.MCPServer // The shared MCP server
-}
-
-// NewSessionRequestHandler creates a new session request handler with all dependencies
-func NewSessionRequestHandler(h *MCPHandler, userEmail string, config *config.MCPClientConfig, mcpServer *mcpserver.MCPServer) *SessionRequestHandler {
- return &SessionRequestHandler{
- h: h,
- userEmail: userEmail,
- config: config,
- mcpServer: mcpServer,
- }
-}
-
-// GetUserEmail returns the user email for this session
-func (s *SessionRequestHandler) GetUserEmail() string {
- return s.userEmail
-}
-
-// GetServerName returns the server name for this session
-func (s *SessionRequestHandler) GetServerName() string {
- return s.h.serverName
-}
-
-// GetStorage returns the storage interface
-func (s *SessionRequestHandler) GetStorage() storage.Storage {
- return s.h.storage
-}
-
-// HandleSessionRegistration handles the registration of a new session
-func HandleSessionRegistration(
- sessionCtx context.Context,
- session mcpserver.ClientSession,
- handler *SessionRequestHandler,
- sessionManager *client.StdioSessionManager,
-) {
- // Create stdio process for this session
- key := client.SessionKey{
- UserEmail: handler.userEmail,
- ServerName: handler.h.serverName,
- SessionID: session.SessionID(),
- }
-
- log.LogDebugWithFields("server", "Registering session", map[string]any{
- "sessionID": session.SessionID(),
- "server": handler.h.serverName,
- "user": handler.userEmail,
- })
-
- log.LogTraceWithFields("server", "Session registration started", map[string]any{
- "sessionID": session.SessionID(),
- "server": handler.h.serverName,
- "user": handler.userEmail,
- "requiresUserToken": handler.config.RequiresUserToken,
- "transportType": handler.config.TransportType,
- "command": handler.config.Command,
- })
-
- var userToken string
- if handler.config.RequiresUserToken && handler.userEmail != "" && handler.h.storage != nil {
- storedToken, err := handler.h.storage.GetUserToken(sessionCtx, handler.userEmail, handler.h.serverName)
- if err != nil {
- log.LogDebugWithFields("server", "No user token found", map[string]any{
- "server": handler.h.serverName,
- "user": handler.userEmail,
- })
- } else if storedToken != nil {
- if handler.config.UserAuthentication != nil {
- userToken = formatUserToken(storedToken, handler.config.UserAuthentication)
- } else {
- userToken = storedToken.Value
- }
- }
- }
-
- stdioSession, err := sessionManager.GetOrCreateSession(
- sessionCtx,
- key,
- handler.config,
- handler.h.info,
- handler.h.setupBaseURL,
- userToken,
- )
- if err != nil {
- log.LogErrorWithFields("server", "Failed to create stdio session", map[string]any{
- "error": err.Error(),
- "sessionID": session.SessionID(),
- "server": handler.h.serverName,
- "user": handler.userEmail,
- })
- return
- }
-
- // Discover and register capabilities from the stdio process
- if err := stdioSession.DiscoverAndRegisterCapabilities(
- sessionCtx,
- handler.mcpServer,
- handler.userEmail,
- handler.config.RequiresUserToken,
- handler.h.storage,
- handler.h.serverName,
- handler.h.setupBaseURL,
- handler.config.UserAuthentication,
- session,
- ); err != nil {
- log.LogErrorWithFields("server", "Failed to discover and register capabilities", map[string]any{
- "error": err.Error(),
- "sessionID": session.SessionID(),
- "server": handler.h.serverName,
- "user": handler.userEmail,
- })
- if err := sessionManager.RemoveSession(key); err != nil {
- log.LogErrorWithFields("server", "Failed to remove session on capability failure", map[string]any{
- "sessionID": session.SessionID(),
- "server": handler.h.serverName,
- "user": handler.userEmail,
- "error": err.Error(),
- })
- }
- return
- }
-
- if handler.userEmail != "" {
- if handler.h.storage != nil {
- activeSession := storage.ActiveSession{
- SessionID: session.SessionID(),
- UserEmail: handler.userEmail,
- ServerName: handler.h.serverName,
- Created: time.Now(),
- LastActive: time.Now(),
- }
- if err := handler.h.storage.TrackSession(sessionCtx, activeSession); err != nil {
- log.LogWarnWithFields("server", "Failed to track session", map[string]any{
- "error": err.Error(),
- "sessionID": session.SessionID(),
- "user": handler.userEmail,
- })
- }
- }
- }
-
- log.LogInfoWithFields("server", "Session successfully created and connected", map[string]any{
- "sessionID": session.SessionID(),
- "server": handler.h.serverName,
- "user": handler.userEmail,
- })
-}
diff --git a/internal/server/http_test.go b/internal/server/http_test.go
index 822b13b..955c9da 100644
--- a/internal/server/http_test.go
+++ b/internal/server/http_test.go
@@ -37,17 +37,20 @@ func TestHealthEndpoint(t *testing.T) {
func TestOAuthEndpointsCORS(t *testing.T) {
// Setup OAuth config
oauthConfig := config.OAuthAuthConfig{
- Kind: config.AuthKindOAuth,
- Issuer: "https://test.example.com",
- GoogleClientID: "test-client-id",
- GoogleClientSecret: config.Secret("test-client-secret"),
- GoogleRedirectURI: "https://test.example.com/oauth/callback",
- JWTSecret: config.Secret(strings.Repeat("a", 32)),
- EncryptionKey: config.Secret(strings.Repeat("b", 32)),
- TokenTTL: time.Hour,
- RefreshTokenTTL: 30 * 24 * time.Hour,
- Storage: "memory",
- AllowedOrigins: []string{"http://localhost:6274"},
+ Kind: config.AuthKindOAuth,
+ Issuer: "https://test.example.com",
+ IDP: config.IDPConfig{
+ Provider: "google",
+ ClientID: "test-client-id",
+ ClientSecret: config.Secret("test-client-secret"),
+ RedirectURI: "https://test.example.com/oauth/callback",
+ },
+ JWTSecret: config.Secret(strings.Repeat("a", 32)),
+ EncryptionKey: config.Secret(strings.Repeat("b", 32)),
+ TokenTTL: time.Hour,
+ RefreshTokenTTL: 30 * 24 * time.Hour,
+ Storage: "memory",
+ AllowedOrigins: []string{"http://localhost:6274"},
}
store := storage.NewMemoryStorage()
@@ -58,10 +61,12 @@ func TestOAuthEndpointsCORS(t *testing.T) {
sessionEncryptor, err := oauth.NewSessionEncryptor([]byte(oauthConfig.EncryptionKey))
require.NoError(t, err)
serviceOAuthClient := auth.NewServiceOAuthClient(store, "https://test.example.com", []byte(strings.Repeat("k", 32)))
+ mockIDP := &mockIDPProvider{}
authHandlers := NewAuthHandlers(
oauthProvider,
oauthConfig,
+ mockIDP,
store,
sessionEncryptor,
map[string]*config.MCPClientConfig{},
diff --git a/internal/server/mcp_handler.go b/internal/server/mcp_handler.go
index 1f517ec..c0bc2ee 100644
--- a/internal/server/mcp_handler.go
+++ b/internal/server/mcp_handler.go
@@ -122,7 +122,7 @@ func (h *MCPHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
log.LogInfoWithFields("mcp", "Handling message request", map[string]any{
"path": r.URL.Path,
"server": h.serverName,
- "isStdio": isStdioServer(serverConfig),
+ "isStdio": serverConfig.IsStdio(),
"user": userEmail,
"remoteAddr": r.RemoteAddr,
"contentLength": r.ContentLength,
@@ -133,7 +133,7 @@ func (h *MCPHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
log.LogInfoWithFields("mcp", "Handling SSE request", map[string]any{
"path": r.URL.Path,
"server": h.serverName,
- "isStdio": isStdioServer(serverConfig),
+ "isStdio": serverConfig.IsStdio(),
"user": userEmail,
"remoteAddr": r.RemoteAddr,
"userAgent": r.UserAgent(),
@@ -168,7 +168,7 @@ func (h *MCPHandler) trackUserAccess(ctx context.Context, userEmail string) {
func (h *MCPHandler) handleSSERequest(ctx context.Context, w http.ResponseWriter, r *http.Request, userEmail string, config *config.MCPClientConfig) {
h.trackUserAccess(ctx, userEmail)
- if !isStdioServer(config) {
+ if !config.IsStdio() {
// For non-stdio servers, handle normally
h.handleNonStdioSSERequest(ctx, w, r, userEmail, config)
return
@@ -206,7 +206,7 @@ func (h *MCPHandler) handleSSERequest(ctx context.Context, w http.ResponseWriter
func (h *MCPHandler) handleMessageRequest(ctx context.Context, w http.ResponseWriter, r *http.Request, userEmail string, config *config.MCPClientConfig) {
h.trackUserAccess(ctx, userEmail)
- if isStdioServer(config) {
+ if config.IsStdio() {
sessionID := r.URL.Query().Get("sessionId")
if sessionID == "" {
jsonrpc.WriteError(w, nil, jsonrpc.InvalidParams, "missing sessionId")
diff --git a/internal/server/middleware.go b/internal/server/middleware.go
index bb04163..c207e44 100644
--- a/internal/server/middleware.go
+++ b/internal/server/middleware.go
@@ -10,15 +10,15 @@ import (
"time"
"github.com/dgellow/mcp-front/internal/adminauth"
- "github.com/dgellow/mcp-front/internal/browserauth"
"github.com/dgellow/mcp-front/internal/config"
"github.com/dgellow/mcp-front/internal/cookie"
"github.com/dgellow/mcp-front/internal/crypto"
- "github.com/dgellow/mcp-front/internal/googleauth"
+ "github.com/dgellow/mcp-front/internal/idp"
jsonwriter "github.com/dgellow/mcp-front/internal/json"
"github.com/dgellow/mcp-front/internal/log"
"github.com/dgellow/mcp-front/internal/oauth"
"github.com/dgellow/mcp-front/internal/servicecontext"
+ "github.com/dgellow/mcp-front/internal/session"
"github.com/dgellow/mcp-front/internal/storage"
"golang.org/x/crypto/bcrypt"
)
@@ -289,7 +289,7 @@ func adminMiddleware(adminConfig *config.AdminConfig, store storage.Storage) Mid
}
// NewBrowserSSOMiddleware creates middleware for browser-based SSO authentication
-func NewBrowserSSOMiddleware(authConfig config.OAuthAuthConfig, sessionEncryptor crypto.Encryptor, browserStateToken *crypto.TokenSigner) MiddlewareFunc {
+func NewBrowserSSOMiddleware(authConfig config.OAuthAuthConfig, idpProvider idp.Provider, sessionEncryptor crypto.Encryptor, browserStateToken *crypto.TokenSigner) MiddlewareFunc {
return func(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
// Check for session cookie
@@ -301,8 +301,8 @@ func NewBrowserSSOMiddleware(authConfig config.OAuthAuthConfig, sessionEncryptor
jsonwriter.WriteInternalServerError(w, "Failed to generate authentication state")
return
}
- googleURL := googleauth.GoogleAuthURL(authConfig, state)
- http.Redirect(w, r, googleURL, http.StatusFound)
+ authURL := idpProvider.AuthURL(state)
+ http.Redirect(w, r, authURL, http.StatusFound)
return
}
@@ -313,13 +313,13 @@ func NewBrowserSSOMiddleware(authConfig config.OAuthAuthConfig, sessionEncryptor
log.LogDebug("Invalid session cookie: %v", err)
cookie.ClearSession(w) // Clear bad cookie
state := generateBrowserState(browserStateToken, r.URL.String())
- googleURL := googleauth.GoogleAuthURL(authConfig, state)
- http.Redirect(w, r, googleURL, http.StatusFound)
+ authURL := idpProvider.AuthURL(state)
+ http.Redirect(w, r, authURL, http.StatusFound)
return
}
// Parse session data
- var sessionData browserauth.SessionCookie
+ var sessionData session.BrowserCookie
if err := json.NewDecoder(strings.NewReader(decrypted)).Decode(&sessionData); err != nil {
// Invalid format
cookie.ClearSession(w)
@@ -331,14 +331,14 @@ func NewBrowserSSOMiddleware(authConfig config.OAuthAuthConfig, sessionEncryptor
if time.Now().After(sessionData.Expires) {
log.LogDebug("Session expired for user %s", sessionData.Email)
cookie.ClearSession(w)
- // Redirect directly to Google OAuth
+ // Redirect directly to OAuth
state := generateBrowserState(browserStateToken, r.URL.String())
if state == "" {
jsonwriter.WriteInternalServerError(w, "Failed to generate authentication state")
return
}
- googleURL := googleauth.GoogleAuthURL(authConfig, state)
- http.Redirect(w, r, googleURL, http.StatusFound)
+ authURL := idpProvider.AuthURL(state)
+ http.Redirect(w, r, authURL, http.StatusFound)
return
}
@@ -353,7 +353,7 @@ func NewBrowserSSOMiddleware(authConfig config.OAuthAuthConfig, sessionEncryptor
// generateBrowserState creates a secure state parameter for browser SSO
func generateBrowserState(browserStateToken *crypto.TokenSigner, returnURL string) string {
- state := browserauth.AuthorizationState{
+ state := session.AuthorizationState{
Nonce: crypto.GenerateSecureToken(),
ReturnURL: returnURL,
}
diff --git a/internal/server/session_handler.go b/internal/server/session_handler.go
new file mode 100644
index 0000000..58334f6
--- /dev/null
+++ b/internal/server/session_handler.go
@@ -0,0 +1,167 @@
+package server
+
+import (
+ "context"
+ "time"
+
+ "github.com/dgellow/mcp-front/internal/client"
+ "github.com/dgellow/mcp-front/internal/config"
+ "github.com/dgellow/mcp-front/internal/log"
+ "github.com/dgellow/mcp-front/internal/storage"
+ mcpserver "github.com/mark3labs/mcp-go/server"
+)
+
+// SessionHandlerKey is the context key for session handlers
+type SessionHandlerKey struct{}
+
+// SessionRequestHandler handles session-specific logic for a request
+type SessionRequestHandler struct {
+ h *MCPHandler
+ userEmail string
+ config *config.MCPClientConfig
+ mcpServer *mcpserver.MCPServer // The shared MCP server
+}
+
+// NewSessionRequestHandler creates a new session request handler with all dependencies
+func NewSessionRequestHandler(h *MCPHandler, userEmail string, config *config.MCPClientConfig, mcpServer *mcpserver.MCPServer) *SessionRequestHandler {
+ return &SessionRequestHandler{
+ h: h,
+ userEmail: userEmail,
+ config: config,
+ mcpServer: mcpServer,
+ }
+}
+
+// GetUserEmail returns the user email for this session
+func (s *SessionRequestHandler) GetUserEmail() string {
+ return s.userEmail
+}
+
+// GetServerName returns the server name for this session
+func (s *SessionRequestHandler) GetServerName() string {
+ return s.h.serverName
+}
+
+// GetStorage returns the storage interface
+func (s *SessionRequestHandler) GetStorage() storage.Storage {
+ return s.h.storage
+}
+
+// HandleSessionRegistration handles the registration of a new session
+func HandleSessionRegistration(
+ sessionCtx context.Context,
+ session mcpserver.ClientSession,
+ handler *SessionRequestHandler,
+ sessionManager *client.StdioSessionManager,
+) {
+ // Create stdio process for this session
+ key := client.SessionKey{
+ UserEmail: handler.userEmail,
+ ServerName: handler.h.serverName,
+ SessionID: session.SessionID(),
+ }
+
+ log.LogDebugWithFields("server", "Registering session", map[string]any{
+ "sessionID": session.SessionID(),
+ "server": handler.h.serverName,
+ "user": handler.userEmail,
+ })
+
+ log.LogTraceWithFields("server", "Session registration started", map[string]any{
+ "sessionID": session.SessionID(),
+ "server": handler.h.serverName,
+ "user": handler.userEmail,
+ "requiresUserToken": handler.config.RequiresUserToken,
+ "transportType": handler.config.TransportType,
+ "command": handler.config.Command,
+ })
+
+ var userToken string
+ if handler.config.RequiresUserToken && handler.userEmail != "" && handler.h.storage != nil {
+ storedToken, err := handler.h.storage.GetUserToken(sessionCtx, handler.userEmail, handler.h.serverName)
+ if err != nil {
+ log.LogDebugWithFields("server", "No user token found", map[string]any{
+ "server": handler.h.serverName,
+ "user": handler.userEmail,
+ })
+ } else if storedToken != nil {
+ if handler.config.UserAuthentication != nil {
+ userToken = formatUserToken(storedToken, handler.config.UserAuthentication)
+ } else {
+ userToken = storedToken.Value
+ }
+ }
+ }
+
+ stdioSession, err := sessionManager.GetOrCreateSession(
+ sessionCtx,
+ key,
+ handler.config,
+ handler.h.info,
+ handler.h.setupBaseURL,
+ userToken,
+ )
+ if err != nil {
+ log.LogErrorWithFields("server", "Failed to create stdio session", map[string]any{
+ "error": err.Error(),
+ "sessionID": session.SessionID(),
+ "server": handler.h.serverName,
+ "user": handler.userEmail,
+ })
+ return
+ }
+
+ // Discover and register capabilities from the stdio process
+ if err := stdioSession.DiscoverAndRegisterCapabilities(
+ sessionCtx,
+ handler.mcpServer,
+ handler.userEmail,
+ handler.config.RequiresUserToken,
+ handler.h.storage,
+ handler.h.serverName,
+ handler.h.setupBaseURL,
+ handler.config.UserAuthentication,
+ session,
+ ); err != nil {
+ log.LogErrorWithFields("server", "Failed to discover and register capabilities", map[string]any{
+ "error": err.Error(),
+ "sessionID": session.SessionID(),
+ "server": handler.h.serverName,
+ "user": handler.userEmail,
+ })
+ if err := sessionManager.RemoveSession(key); err != nil {
+ log.LogErrorWithFields("server", "Failed to remove session on capability failure", map[string]any{
+ "sessionID": session.SessionID(),
+ "server": handler.h.serverName,
+ "user": handler.userEmail,
+ "error": err.Error(),
+ })
+ }
+ return
+ }
+
+ if handler.userEmail != "" {
+ if handler.h.storage != nil {
+ activeSession := storage.ActiveSession{
+ SessionID: session.SessionID(),
+ UserEmail: handler.userEmail,
+ ServerName: handler.h.serverName,
+ Created: time.Now(),
+ LastActive: time.Now(),
+ }
+ if err := handler.h.storage.TrackSession(sessionCtx, activeSession); err != nil {
+ log.LogWarnWithFields("server", "Failed to track session", map[string]any{
+ "error": err.Error(),
+ "sessionID": session.SessionID(),
+ "user": handler.userEmail,
+ })
+ }
+ }
+ }
+
+ log.LogInfoWithFields("server", "Session successfully created and connected", map[string]any{
+ "sessionID": session.SessionID(),
+ "server": handler.h.serverName,
+ "user": handler.userEmail,
+ })
+}
diff --git a/internal/server/token_handlers.go b/internal/server/token_handlers.go
index b2cbdaa..cf33a00 100644
--- a/internal/server/token_handlers.go
+++ b/internal/server/token_handlers.go
@@ -4,7 +4,6 @@ import (
"fmt"
"net/http"
"strings"
- "sync"
"time"
"github.com/dgellow/mcp-front/internal/auth"
@@ -20,40 +19,22 @@ import (
type TokenHandlers struct {
tokenStore storage.UserTokenStore
mcpServers map[string]*config.MCPClientConfig
- csrfTokens sync.Map // Thread-safe CSRF token storage
+ csrf crypto.CSRFProtection
oauthEnabled bool
serviceOAuthClient *auth.ServiceOAuthClient
}
// NewTokenHandlers creates a new token handlers instance
-func NewTokenHandlers(tokenStore storage.UserTokenStore, mcpServers map[string]*config.MCPClientConfig, oauthEnabled bool, serviceOAuthClient *auth.ServiceOAuthClient) *TokenHandlers {
+func NewTokenHandlers(tokenStore storage.UserTokenStore, mcpServers map[string]*config.MCPClientConfig, oauthEnabled bool, serviceOAuthClient *auth.ServiceOAuthClient, csrfKey []byte) *TokenHandlers {
return &TokenHandlers{
tokenStore: tokenStore,
mcpServers: mcpServers,
oauthEnabled: oauthEnabled,
serviceOAuthClient: serviceOAuthClient,
+ csrf: crypto.NewCSRFProtection(csrfKey, 15*time.Minute),
}
}
-// generateCSRFToken creates a new CSRF token
-func (h *TokenHandlers) generateCSRFToken() (string, error) {
- token := crypto.GenerateSecureToken()
- if token == "" {
- return "", fmt.Errorf("failed to generate CSRF token")
- }
- h.csrfTokens.Store(token, true)
- return token, nil
-}
-
-// validateCSRFToken checks if a CSRF token is valid
-func (h *TokenHandlers) validateCSRFToken(token string) bool {
- if _, exists := h.csrfTokens.LoadAndDelete(token); exists {
- // One-time use via LoadAndDelete
- return true
- }
- return false
-}
-
// ListTokensHandler shows the token management page
func (h *TokenHandlers) ListTokensHandler(w http.ResponseWriter, r *http.Request) {
// Only accept GET
@@ -137,7 +118,7 @@ func (h *TokenHandlers) ListTokensHandler(w http.ResponseWriter, r *http.Request
}
// Generate CSRF token
- csrfToken, err := h.generateCSRFToken()
+ csrfToken, err := h.csrf.Generate()
if err != nil {
log.LogErrorWithFields("token", "Failed to generate CSRF token", map[string]any{
"error": err.Error(),
@@ -188,7 +169,7 @@ func (h *TokenHandlers) SetTokenHandler(w http.ResponseWriter, r *http.Request)
// Validate CSRF token
csrfToken := r.FormValue("csrf_token")
- if !h.validateCSRFToken(csrfToken) {
+ if !h.csrf.Validate(csrfToken) {
jsonwriter.WriteForbidden(w, "Invalid CSRF token")
return
}
@@ -304,7 +285,7 @@ func (h *TokenHandlers) DeleteTokenHandler(w http.ResponseWriter, r *http.Reques
// Validate CSRF token
csrfToken := r.FormValue("csrf_token")
- if !h.validateCSRFToken(csrfToken) {
+ if !h.csrf.Validate(csrfToken) {
jsonwriter.WriteForbidden(w, "Invalid CSRF token")
return
}
diff --git a/internal/server/user_token_service.go b/internal/server/user_token_service.go
new file mode 100644
index 0000000..7e14d57
--- /dev/null
+++ b/internal/server/user_token_service.go
@@ -0,0 +1,102 @@
+package server
+
+import (
+ "context"
+ "strings"
+ "time"
+
+ "github.com/dgellow/mcp-front/internal/auth"
+ "github.com/dgellow/mcp-front/internal/config"
+ "github.com/dgellow/mcp-front/internal/log"
+ "github.com/dgellow/mcp-front/internal/storage"
+)
+
+// UserTokenService handles user token retrieval and OAuth refresh
+type UserTokenService struct {
+ storage storage.Storage
+ serviceOAuthClient *auth.ServiceOAuthClient
+}
+
+// NewUserTokenService creates a new user token service
+func NewUserTokenService(storage storage.Storage, serviceOAuthClient *auth.ServiceOAuthClient) *UserTokenService {
+ return &UserTokenService{
+ storage: storage,
+ serviceOAuthClient: serviceOAuthClient,
+ }
+}
+
+// GetUserToken retrieves and formats a user token for a service, handling OAuth refresh.
+//
+// Token refresh strategy: Optimistic continuation on failure.
+// If refresh fails, we log a warning and continue with the current token. The external
+// service will reject the expired token with 401, giving the user a clear error.
+// This is acceptable because: (1) refresh failures are rare (network issues, revoked
+// tokens), and (2) forcing users to re-auth is better than silently hiding auth issues.
+func (uts *UserTokenService) GetUserToken(ctx context.Context, userEmail, serviceName string, serviceConfig *config.MCPClientConfig) (string, error) {
+ storedToken, err := uts.storage.GetUserToken(ctx, userEmail, serviceName)
+ if err != nil {
+ return "", err
+ }
+
+ switch storedToken.Type {
+ case storage.TokenTypeManual:
+ // Token is already in storedToken.Value, formatUserToken will handle it
+ break
+ case storage.TokenTypeOAuth:
+ if storedToken.OAuthData != nil && uts.serviceOAuthClient != nil {
+ if err := uts.serviceOAuthClient.RefreshToken(ctx, userEmail, serviceName, serviceConfig); err != nil {
+ log.LogWarnWithFields("user_token", "Failed to refresh OAuth token", map[string]any{
+ "service": serviceName,
+ "user": userEmail,
+ "error": err.Error(),
+ })
+ // Continue with current token - the service will handle auth failure
+ } else {
+ // Re-fetch the updated token after refresh
+ refreshedToken, err := uts.storage.GetUserToken(ctx, userEmail, serviceName)
+ if err != nil {
+ log.LogErrorWithFields("user_token", "Failed to fetch token after successful refresh", map[string]any{
+ "service": serviceName,
+ "user": userEmail,
+ "error": err.Error(),
+ })
+ // Continue with original token - the service will handle auth failure
+ } else {
+ storedToken = refreshedToken
+ var expiresAt time.Time
+ if refreshedToken.OAuthData != nil {
+ expiresAt = refreshedToken.OAuthData.ExpiresAt
+ }
+ log.LogInfoWithFields("user_token", "OAuth token refreshed and updated", map[string]any{
+ "service": serviceName,
+ "user": userEmail,
+ "expiresAt": expiresAt,
+ })
+ }
+ }
+ }
+ }
+
+ return formatUserToken(storedToken, serviceConfig.UserAuthentication), nil
+}
+
+// formatUserToken formats a stored token according to the user authentication configuration
+func formatUserToken(storedToken *storage.StoredToken, auth *config.UserAuthentication) string {
+ if storedToken == nil {
+ return ""
+ }
+
+ if storedToken.Type == storage.TokenTypeOAuth && storedToken.OAuthData != nil {
+ token := storedToken.OAuthData.AccessToken
+ if auth.TokenFormat != "" && auth.TokenFormat != "{{token}}" {
+ return strings.ReplaceAll(auth.TokenFormat, "{{token}}", token)
+ }
+ return token
+ }
+
+ token := storedToken.Value
+ if auth != nil && auth.TokenFormat != "" && auth.TokenFormat != "{{token}}" {
+ return strings.ReplaceAll(auth.TokenFormat, "{{token}}", token)
+ }
+ return token
+}
diff --git a/internal/session/session.go b/internal/session/session.go
new file mode 100644
index 0000000..226b160
--- /dev/null
+++ b/internal/session/session.go
@@ -0,0 +1,35 @@
+package session
+
+import (
+ "time"
+
+ "github.com/dgellow/mcp-front/internal/idp"
+ "github.com/ory/fosite"
+)
+
+// OAuthSession extends DefaultSession with user information for the OAuth flow
+type OAuthSession struct {
+ *fosite.DefaultSession
+ Identity idp.Identity `json:"identity"`
+}
+
+// Clone implements fosite.Session
+func (s *OAuthSession) Clone() fosite.Session {
+ return &OAuthSession{
+ DefaultSession: s.DefaultSession.Clone().(*fosite.DefaultSession),
+ Identity: s.Identity,
+ }
+}
+
+// BrowserCookie represents the data stored in encrypted browser session cookies
+type BrowserCookie struct {
+ Email string `json:"email"`
+ Provider string `json:"provider"` // IDP that authenticated this user (e.g., "google", "azure", "github")
+ Expires time.Time `json:"expires"`
+}
+
+// AuthorizationState represents the OAuth authorization code flow state parameter
+type AuthorizationState struct {
+ Nonce string `json:"nonce"`
+ ReturnURL string `json:"return_url"`
+}
diff --git a/internal/browserauth/session_test.go b/internal/session/session_test.go
similarity index 56%
rename from internal/browserauth/session_test.go
rename to internal/session/session_test.go
index f7e15de..0294040 100644
--- a/internal/browserauth/session_test.go
+++ b/internal/session/session_test.go
@@ -1,4 +1,4 @@
-package browserauth
+package session
import (
"encoding/json"
@@ -9,43 +9,42 @@ import (
"github.com/stretchr/testify/require"
)
-func TestSessionCookie_MarshalUnmarshal(t *testing.T) {
- original := SessionCookie{
- Email: "user@example.com",
- Expires: time.Now().Add(24 * time.Hour).Truncate(time.Second),
+func TestBrowserCookie_MarshalUnmarshal(t *testing.T) {
+ original := BrowserCookie{
+ Email: "user@example.com",
+ Provider: "google",
+ Expires: time.Now().Add(24 * time.Hour).Truncate(time.Second),
}
- // Marshal to JSON
data, err := json.Marshal(original)
require.NoError(t, err)
- // Unmarshal back
- var unmarshaled SessionCookie
+ var unmarshaled BrowserCookie
err = json.Unmarshal(data, &unmarshaled)
require.NoError(t, err)
- // Truncate for comparison (JSON time serialization)
assert.Equal(t, original.Email, unmarshaled.Email)
+ assert.Equal(t, original.Provider, unmarshaled.Provider)
assert.WithinDuration(t, original.Expires, unmarshaled.Expires, time.Second)
}
-func TestSessionCookie_Expiry(t *testing.T) {
+func TestBrowserCookie_Expiry(t *testing.T) {
t.Run("not expired", func(t *testing.T) {
- session := SessionCookie{
- Email: "user@example.com",
- Expires: time.Now().Add(1 * time.Hour),
+ s := BrowserCookie{
+ Email: "user@example.com",
+ Provider: "google",
+ Expires: time.Now().Add(1 * time.Hour),
}
-
- assert.True(t, session.Expires.After(time.Now()))
+ assert.True(t, s.Expires.After(time.Now()))
})
t.Run("expired", func(t *testing.T) {
- session := SessionCookie{
- Email: "user@example.com",
- Expires: time.Now().Add(-1 * time.Hour),
+ s := BrowserCookie{
+ Email: "user@example.com",
+ Provider: "google",
+ Expires: time.Now().Add(-1 * time.Hour),
}
-
- assert.True(t, session.Expires.Before(time.Now()))
+ assert.True(t, s.Expires.Before(time.Now()))
})
}
@@ -55,11 +54,9 @@ func TestAuthorizationState_MarshalUnmarshal(t *testing.T) {
ReturnURL: "/my/tokens",
}
- // Marshal to JSON
data, err := json.Marshal(original)
require.NoError(t, err)
- // Unmarshal back
var unmarshaled AuthorizationState
err = json.Unmarshal(data, &unmarshaled)
require.NoError(t, err)
diff --git a/internal/storage/firestore.go b/internal/storage/firestore.go
index 1311ee6..0b49554 100644
--- a/internal/storage/firestore.go
+++ b/internal/storage/firestore.go
@@ -16,6 +16,11 @@ import (
"google.golang.org/grpc/status"
)
+const (
+ usersCollection = "mcp_front_users"
+ sessionsCollection = "mcp_front_sessions"
+)
+
// FirestoreStorage implements OAuth client storage using Google Cloud Firestore.
//
// Error handling strategy:
@@ -195,8 +200,9 @@ func (s *FirestoreStorage) loadClientsFromFirestore(ctx context.Context) error {
}
// StoreAuthorizeRequest stores an authorize request with state (in memory only - short-lived)
-func (s *FirestoreStorage) StoreAuthorizeRequest(state string, req fosite.AuthorizeRequester) {
+func (s *FirestoreStorage) StoreAuthorizeRequest(state string, req fosite.AuthorizeRequester) error {
s.stateCache.Store(state, req)
+ return nil
}
// GetAuthorizeRequest retrieves an authorize request by state (one-time use)
@@ -533,7 +539,7 @@ func (s *FirestoreStorage) UpsertUser(ctx context.Context, email string) error {
}
// Try to get existing user first
- doc, err := s.client.Collection("mcp_front_users").Doc(email).Get(ctx)
+ doc, err := s.client.Collection(usersCollection).Doc(email).Get(ctx)
if err == nil {
// User exists, update LastSeen
_, err = doc.Ref.Update(ctx, []firestore.Update{
@@ -547,16 +553,35 @@ func (s *FirestoreStorage) UpsertUser(ctx context.Context, email string) error {
userDoc.FirstSeen = time.Now()
userDoc.Enabled = true
userDoc.IsAdmin = false
- _, err = s.client.Collection("mcp_front_users").Doc(email).Set(ctx, userDoc)
+ _, err = s.client.Collection(usersCollection).Doc(email).Set(ctx, userDoc)
return err
}
return err
}
+// GetUser returns a single user by email
+func (s *FirestoreStorage) GetUser(ctx context.Context, email string) (*UserInfo, error) {
+ doc, err := s.client.Collection(usersCollection).Doc(email).Get(ctx)
+ if err != nil {
+ if status.Code(err) == codes.NotFound {
+ return nil, ErrUserNotFound
+ }
+ return nil, fmt.Errorf("failed to get user from Firestore: %w", err)
+ }
+
+ var userDoc UserDoc
+ if err := doc.DataTo(&userDoc); err != nil {
+ return nil, fmt.Errorf("failed to unmarshal user: %w", err)
+ }
+
+ user := UserInfo(userDoc)
+ return &user, nil
+}
+
// GetAllUsers returns all users
func (s *FirestoreStorage) GetAllUsers(ctx context.Context) ([]UserInfo, error) {
- iter := s.client.Collection("mcp_front_users").Documents(ctx)
+ iter := s.client.Collection(usersCollection).Documents(ctx)
defer iter.Stop()
var users []UserInfo
@@ -583,7 +608,7 @@ func (s *FirestoreStorage) GetAllUsers(ctx context.Context) ([]UserInfo, error)
// UpdateUserStatus updates a user's enabled status
func (s *FirestoreStorage) UpdateUserStatus(ctx context.Context, email string, enabled bool) error {
- _, err := s.client.Collection("mcp_front_users").Doc(email).Update(ctx, []firestore.Update{
+ _, err := s.client.Collection(usersCollection).Doc(email).Update(ctx, []firestore.Update{
{Path: "enabled", Value: enabled},
})
if status.Code(err) == codes.NotFound {
@@ -595,7 +620,7 @@ func (s *FirestoreStorage) UpdateUserStatus(ctx context.Context, email string, e
// DeleteUser removes a user from storage
func (s *FirestoreStorage) DeleteUser(ctx context.Context, email string) error {
// Delete user document
- _, err := s.client.Collection("mcp_front_users").Doc(email).Delete(ctx)
+ _, err := s.client.Collection(usersCollection).Doc(email).Delete(ctx)
if err != nil && status.Code(err) != codes.NotFound {
return err
}
@@ -625,7 +650,7 @@ func (s *FirestoreStorage) DeleteUser(ctx context.Context, email string) error {
// SetUserAdmin updates a user's admin status
func (s *FirestoreStorage) SetUserAdmin(ctx context.Context, email string, isAdmin bool) error {
- _, err := s.client.Collection("mcp_front_users").Doc(email).Update(ctx, []firestore.Update{
+ _, err := s.client.Collection(usersCollection).Doc(email).Update(ctx, []firestore.Update{
{Path: "is_admin", Value: isAdmin},
})
if status.Code(err) == codes.NotFound {
@@ -645,7 +670,7 @@ func (s *FirestoreStorage) TrackSession(ctx context.Context, session ActiveSessi
}
// Check if session exists
- doc, err := s.client.Collection("mcp_front_sessions").Doc(session.SessionID).Get(ctx)
+ doc, err := s.client.Collection(sessionsCollection).Doc(session.SessionID).Get(ctx)
if err == nil {
// Session exists, update LastActive
_, err = doc.Ref.Update(ctx, []firestore.Update{
@@ -657,7 +682,7 @@ func (s *FirestoreStorage) TrackSession(ctx context.Context, session ActiveSessi
// Session doesn't exist, create new
if status.Code(err) == codes.NotFound {
sessionDoc.Created = time.Now()
- _, err = s.client.Collection("mcp_front_sessions").Doc(session.SessionID).Set(ctx, sessionDoc)
+ _, err = s.client.Collection(sessionsCollection).Doc(session.SessionID).Set(ctx, sessionDoc)
return err
}
@@ -666,7 +691,7 @@ func (s *FirestoreStorage) TrackSession(ctx context.Context, session ActiveSessi
// GetActiveSessions returns all active sessions
func (s *FirestoreStorage) GetActiveSessions(ctx context.Context) ([]ActiveSession, error) {
- iter := s.client.Collection("mcp_front_sessions").Documents(ctx)
+ iter := s.client.Collection(sessionsCollection).Documents(ctx)
defer iter.Stop()
var sessions []ActiveSession
@@ -693,7 +718,7 @@ func (s *FirestoreStorage) GetActiveSessions(ctx context.Context) ([]ActiveSessi
// RevokeSession removes a session
func (s *FirestoreStorage) RevokeSession(ctx context.Context, sessionID string) error {
- _, err := s.client.Collection("mcp_front_sessions").Doc(sessionID).Delete(ctx)
+ _, err := s.client.Collection(sessionsCollection).Doc(sessionID).Delete(ctx)
if err != nil && status.Code(err) != codes.NotFound {
return err
}
diff --git a/internal/storage/memory.go b/internal/storage/memory.go
index 6a1dc79..63c2c54 100644
--- a/internal/storage/memory.go
+++ b/internal/storage/memory.go
@@ -43,8 +43,9 @@ func NewMemoryStorage() *MemoryStorage {
}
// StoreAuthorizeRequest stores an authorize request with state
-func (s *MemoryStorage) StoreAuthorizeRequest(state string, req fosite.AuthorizeRequester) {
+func (s *MemoryStorage) StoreAuthorizeRequest(state string, req fosite.AuthorizeRequester) error {
s.stateCache.Store(state, req)
+ return nil
}
// GetAuthorizeRequest retrieves an authorize request by state (one-time use)
@@ -204,6 +205,19 @@ func (s *MemoryStorage) UpsertUser(ctx context.Context, email string) error {
return nil
}
+// GetUser returns a single user by email
+func (s *MemoryStorage) GetUser(ctx context.Context, email string) (*UserInfo, error) {
+ s.usersMutex.RLock()
+ defer s.usersMutex.RUnlock()
+
+ user, exists := s.users[email]
+ if !exists {
+ return nil, ErrUserNotFound
+ }
+ userCopy := *user
+ return &userCopy, nil
+}
+
// GetAllUsers returns all users
func (s *MemoryStorage) GetAllUsers(ctx context.Context) ([]UserInfo, error) {
s.usersMutex.RLock()
diff --git a/internal/storage/storage.go b/internal/storage/storage.go
index 1f9a780..e44cb02 100644
--- a/internal/storage/storage.go
+++ b/internal/storage/storage.go
@@ -76,7 +76,7 @@ type Storage interface {
fosite.Storage
// OAuth state management
- StoreAuthorizeRequest(state string, req fosite.AuthorizeRequester)
+ StoreAuthorizeRequest(state string, req fosite.AuthorizeRequester) error
GetAuthorizeRequest(state string) (fosite.AuthorizeRequester, bool)
// OAuth client management
@@ -89,6 +89,7 @@ type Storage interface {
// User tracking (upserted when users access MCP endpoints)
UpsertUser(ctx context.Context, email string) error
+ GetUser(ctx context.Context, email string) (*UserInfo, error)
GetAllUsers(ctx context.Context) ([]UserInfo, error)
UpdateUserStatus(ctx context.Context, email string, enabled bool) error
DeleteUser(ctx context.Context, email string) error