Skip to content

Commit 275b91a

Browse files
authored
feat(auth): introduce require-oauth flag to comply with OAuth in MCP specification (170)
Introduce require-oauth flag When this flag is enabled, authorization middleware will be turned on. When this flag is enabled, Derived which is generated based on the client token will not be used. --- Wire Authorization middleware to http mux This commit adds authorization middleware. Additionally, this commit rejects the requests if the bearer token is absent in Authorization header of the request. --- Add offline token validation for expiration and audience Per Model Context Protocol specification, MCP Servers must check the audience field of the token to ensure that they are generated specifically for them. This commits parses the JWT token and asserts that audience is correct and token is not expired. --- Add online token verification via TokenReview request to API Server This commit sends online token verification by sending request to TokenReview endpoint of API Server with the token and expected audience. If API Server returns the status as authenticated, that means this token can be used to generate a new ad hoc token for MCP Server. If API Server returns the status as not authenticated, that means this token is invalid and MCP Server returns 401 to force the client to initiate OAuth flow. --- Serve oauth protected resource metadata endpoint --- Introduce server-url to be represented in protected resource metadata --- Add error return type in Derived function --- Return error if error occurs in Derived, when require-oauth --- Add test cases for authorization-url and server-url --- Wire server-url to audience, if it is set --- Remove redundant ssebaseurl parameter from http
1 parent 114726f commit 275b91a

File tree

17 files changed

+827
-34
lines changed

17 files changed

+827
-34
lines changed

pkg/config/config.go

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,9 @@ type StaticConfig struct {
2222
DisableDestructive bool `toml:"disable_destructive,omitempty"`
2323
EnabledTools []string `toml:"enabled_tools,omitempty"`
2424
DisabledTools []string `toml:"disabled_tools,omitempty"`
25+
RequireOAuth bool `toml:"require_oauth,omitempty"`
26+
AuthorizationURL string `toml:"authorization_url,omitempty"`
27+
ServerURL string `toml:"server_url,omitempty"`
2528
}
2629

2730
type GroupVersionKind struct {

pkg/http/authorization.go

Lines changed: 120 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,120 @@
1+
package http
2+
3+
import (
4+
"encoding/base64"
5+
"encoding/json"
6+
"fmt"
7+
"net/http"
8+
"slices"
9+
"strings"
10+
"time"
11+
12+
"k8s.io/klog/v2"
13+
14+
"github.com/manusa/kubernetes-mcp-server/pkg/mcp"
15+
)
16+
17+
const (
18+
Audience = "kubernetes-mcp-server"
19+
)
20+
21+
// AuthorizationMiddleware validates the OAuth flow using Kubernetes TokenReview API
22+
func AuthorizationMiddleware(requireOAuth bool, serverURL string, mcpServer *mcp.Server) func(http.Handler) http.Handler {
23+
return func(next http.Handler) http.Handler {
24+
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
25+
if r.URL.Path == "/healthz" || r.URL.Path == "/.well-known/oauth-protected-resource" {
26+
next.ServeHTTP(w, r)
27+
return
28+
}
29+
if !requireOAuth {
30+
next.ServeHTTP(w, r)
31+
return
32+
}
33+
34+
authHeader := r.Header.Get("Authorization")
35+
if authHeader == "" || !strings.HasPrefix(authHeader, "Bearer ") {
36+
klog.V(1).Infof("Authentication failed - missing or invalid bearer token: %s %s from %s", r.Method, r.URL.Path, r.RemoteAddr)
37+
38+
w.Header().Set("WWW-Authenticate", fmt.Sprintf(`Bearer realm="Kubernetes MCP Server", audience=%s, error="invalid_token"`, Audience))
39+
http.Error(w, "Unauthorized: Bearer token required", http.StatusUnauthorized)
40+
return
41+
}
42+
43+
token := strings.TrimPrefix(authHeader, "Bearer ")
44+
45+
audience := Audience
46+
if serverURL != "" {
47+
audience = serverURL
48+
}
49+
50+
err := validateJWTToken(token, audience)
51+
if err != nil {
52+
klog.V(1).Infof("Authentication failed - JWT validation error: %s %s from %s, error: %v", r.Method, r.URL.Path, r.RemoteAddr, err)
53+
54+
w.Header().Set("WWW-Authenticate", fmt.Sprintf(`Bearer realm="Kubernetes MCP Server", audience=%s, error="invalid_token"`, Audience))
55+
http.Error(w, "Unauthorized: Invalid token", http.StatusUnauthorized)
56+
return
57+
}
58+
59+
// Validate token using Kubernetes TokenReview API
60+
_, _, err = mcpServer.VerifyToken(r.Context(), token, Audience)
61+
if err != nil {
62+
klog.V(1).Infof("Authentication failed - token validation error: %s %s from %s, error: %v", r.Method, r.URL.Path, r.RemoteAddr, err)
63+
64+
w.Header().Set("WWW-Authenticate", fmt.Sprintf(`Bearer realm="Kubernetes MCP Server", audience=%s, error="invalid_token"`, Audience))
65+
http.Error(w, "Unauthorized: Invalid token", http.StatusUnauthorized)
66+
return
67+
}
68+
69+
next.ServeHTTP(w, r)
70+
})
71+
}
72+
}
73+
74+
type JWTClaims struct {
75+
Issuer string `json:"iss"`
76+
Audience []string `json:"aud"`
77+
ExpiresAt int64 `json:"exp"`
78+
}
79+
80+
// validateJWTToken validates basic JWT claims without signature verification
81+
func validateJWTToken(token, audience string) error {
82+
parts := strings.Split(token, ".")
83+
if len(parts) != 3 {
84+
return fmt.Errorf("invalid JWT token format")
85+
}
86+
87+
claims, err := parseJWTClaims(parts[1])
88+
if err != nil {
89+
return fmt.Errorf("failed to parse JWT claims: %v", err)
90+
}
91+
92+
if claims.ExpiresAt > 0 && time.Now().Unix() > claims.ExpiresAt {
93+
return fmt.Errorf("token expired")
94+
}
95+
96+
if !slices.Contains(claims.Audience, audience) {
97+
return fmt.Errorf("token audience mismatch: %v", claims.Audience)
98+
}
99+
100+
return nil
101+
}
102+
103+
func parseJWTClaims(payload string) (*JWTClaims, error) {
104+
// Add padding if needed
105+
if len(payload)%4 != 0 {
106+
payload += strings.Repeat("=", 4-len(payload)%4)
107+
}
108+
109+
decoded, err := base64.URLEncoding.DecodeString(payload)
110+
if err != nil {
111+
return nil, fmt.Errorf("failed to decode JWT payload: %v", err)
112+
}
113+
114+
var claims JWTClaims
115+
if err := json.Unmarshal(decoded, &claims); err != nil {
116+
return nil, fmt.Errorf("failed to unmarshal JWT claims: %v", err)
117+
}
118+
119+
return &claims, nil
120+
}

0 commit comments

Comments
 (0)