Skip to content

Commit f161149

Browse files
Implement RFC 7523 JWT flows (modelcontextprotocol#1247)
Co-authored-by: Yann Jouanin <[email protected]>
1 parent db9e451 commit f161149

File tree

10 files changed

+463
-52
lines changed

10 files changed

+463
-52
lines changed

examples/servers/simple-auth/mcp_simple_auth/simple_auth_provider.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -73,6 +73,8 @@ async def get_client(self, client_id: str) -> OAuthClientInformationFull | None:
7373

7474
async def register_client(self, client_info: OAuthClientInformationFull):
7575
"""Register a new OAuth client."""
76+
if not client_info.client_id:
77+
raise ValueError("No client_id provided")
7678
self.clients[client_info.client_id] = client_info
7779

7880
async def authorize(self, client: OAuthClientInformationFull, params: AuthorizationParams) -> str:
@@ -209,6 +211,8 @@ async def exchange_authorization_code(
209211
"""Exchange authorization code for tokens."""
210212
if authorization_code.code not in self.auth_codes:
211213
raise ValueError("Invalid authorization code")
214+
if not client.client_id:
215+
raise ValueError("No client_id provided")
212216

213217
# Generate MCP access token
214218
mcp_token = f"mcp_{secrets.token_hex(32)}"

pyproject.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@ dependencies = [
3333
"uvicorn>=0.31.1; sys_platform != 'emscripten'",
3434
"jsonschema>=4.20.0",
3535
"pywin32>=310; sys_platform == 'win32'",
36+
"pyjwt[crypto]>=2.10.1",
3637
]
3738

3839
[project.optional-dependencies]

src/mcp/client/auth/__init__.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,7 @@
1+
"""
2+
OAuth2 Authentication implementation for HTTPX.
3+
4+
Implements authorization code flow with PKCE and automatic token refresh.
5+
"""
6+
7+
from mcp.client.auth.oauth2 import * # noqa: F403
Lines changed: 148 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,148 @@
1+
import time
2+
from collections.abc import Awaitable, Callable
3+
from typing import Any
4+
from uuid import uuid4
5+
6+
import httpx
7+
import jwt
8+
from pydantic import BaseModel, Field
9+
10+
from mcp.client.auth import OAuthClientProvider, OAuthFlowError, OAuthTokenError, TokenStorage
11+
from mcp.shared.auth import OAuthClientMetadata
12+
13+
14+
class JWTParameters(BaseModel):
15+
"""JWT parameters."""
16+
17+
assertion: str | None = Field(
18+
default=None,
19+
description="JWT assertion for JWT authentication. "
20+
"Will be used instead of generating a new assertion if provided.",
21+
)
22+
23+
issuer: str | None = Field(default=None, description="Issuer for JWT assertions.")
24+
subject: str | None = Field(default=None, description="Subject identifier for JWT assertions.")
25+
audience: str | None = Field(default=None, description="Audience for JWT assertions.")
26+
claims: dict[str, Any] | None = Field(default=None, description="Additional claims for JWT assertions.")
27+
jwt_signing_algorithm: str | None = Field(default="RS256", description="Algorithm for signing JWT assertions.")
28+
jwt_signing_key: str | None = Field(default=None, description="Private key for JWT signing.")
29+
jwt_lifetime_seconds: int = Field(default=300, description="Lifetime of generated JWT in seconds.")
30+
31+
def to_assertion(self, with_audience_fallback: str | None = None) -> str:
32+
if self.assertion is not None:
33+
# Prebuilt JWT (e.g. acquired out-of-band)
34+
assertion = self.assertion
35+
else:
36+
if not self.jwt_signing_key:
37+
raise OAuthFlowError("Missing signing key for JWT bearer grant")
38+
if not self.issuer:
39+
raise OAuthFlowError("Missing issuer for JWT bearer grant")
40+
if not self.subject:
41+
raise OAuthFlowError("Missing subject for JWT bearer grant")
42+
43+
audience = self.audience if self.audience else with_audience_fallback
44+
if not audience:
45+
raise OAuthFlowError("Missing audience for JWT bearer grant")
46+
47+
now = int(time.time())
48+
claims: dict[str, Any] = {
49+
"iss": self.issuer,
50+
"sub": self.subject,
51+
"aud": audience,
52+
"exp": now + self.jwt_lifetime_seconds,
53+
"iat": now,
54+
"jti": str(uuid4()),
55+
}
56+
claims.update(self.claims or {})
57+
58+
assertion = jwt.encode(
59+
claims,
60+
self.jwt_signing_key,
61+
algorithm=self.jwt_signing_algorithm or "RS256",
62+
)
63+
return assertion
64+
65+
66+
class RFC7523OAuthClientProvider(OAuthClientProvider):
67+
"""OAuth client provider for RFC7532 clients."""
68+
69+
jwt_parameters: JWTParameters | None = None
70+
71+
def __init__(
72+
self,
73+
server_url: str,
74+
client_metadata: OAuthClientMetadata,
75+
storage: TokenStorage,
76+
redirect_handler: Callable[[str], Awaitable[None]] | None = None,
77+
callback_handler: Callable[[], Awaitable[tuple[str, str | None]]] | None = None,
78+
timeout: float = 300.0,
79+
jwt_parameters: JWTParameters | None = None,
80+
) -> None:
81+
super().__init__(server_url, client_metadata, storage, redirect_handler, callback_handler, timeout)
82+
self.jwt_parameters = jwt_parameters
83+
84+
async def _exchange_token_authorization_code(
85+
self, auth_code: str, code_verifier: str, *, token_data: dict[str, Any] | None = None
86+
) -> httpx.Request:
87+
"""Build token exchange request for authorization_code flow."""
88+
token_data = token_data or {}
89+
if self.context.client_metadata.token_endpoint_auth_method == "private_key_jwt":
90+
self._add_client_authentication_jwt(token_data=token_data)
91+
return await super()._exchange_token_authorization_code(auth_code, code_verifier, token_data=token_data)
92+
93+
async def _perform_authorization(self) -> httpx.Request:
94+
"""Perform the authorization flow."""
95+
if "urn:ietf:params:oauth:grant-type:jwt-bearer" in self.context.client_metadata.grant_types:
96+
token_request = await self._exchange_token_jwt_bearer()
97+
return token_request
98+
else:
99+
return await super()._perform_authorization()
100+
101+
def _add_client_authentication_jwt(self, *, token_data: dict[str, Any]):
102+
"""Add JWT assertion for client authentication to token endpoint parameters."""
103+
if not self.jwt_parameters:
104+
raise OAuthTokenError("Missing JWT parameters for private_key_jwt flow")
105+
if not self.context.oauth_metadata:
106+
raise OAuthTokenError("Missing OAuth metadata for private_key_jwt flow")
107+
108+
# We need to set the audience to the issuer identifier of the authorization server
109+
# https://datatracker.ietf.org/doc/html/draft-ietf-oauth-rfc7523bis-01#name-updates-to-rfc-7523
110+
issuer = str(self.context.oauth_metadata.issuer)
111+
assertion = self.jwt_parameters.to_assertion(with_audience_fallback=issuer)
112+
113+
# When using private_key_jwt, in a client_credentials flow, we use RFC 7523 Section 2.2
114+
token_data["client_assertion"] = assertion
115+
token_data["client_assertion_type"] = "urn:ietf:params:oauth:client-assertion-type:jwt-bearer"
116+
# We need to set the audience to the resource server, the audience is difference from the one in claims
117+
# it represents the resource server that will validate the token
118+
token_data["audience"] = self.context.get_resource_url()
119+
120+
async def _exchange_token_jwt_bearer(self) -> httpx.Request:
121+
"""Build token exchange request for JWT bearer grant."""
122+
if not self.context.client_info:
123+
raise OAuthFlowError("Missing client info")
124+
if not self.jwt_parameters:
125+
raise OAuthFlowError("Missing JWT parameters")
126+
if not self.context.oauth_metadata:
127+
raise OAuthTokenError("Missing OAuth metadata")
128+
129+
# We need to set the audience to the issuer identifier of the authorization server
130+
# https://datatracker.ietf.org/doc/html/draft-ietf-oauth-rfc7523bis-01#name-updates-to-rfc-7523
131+
issuer = str(self.context.oauth_metadata.issuer)
132+
assertion = self.jwt_parameters.to_assertion(with_audience_fallback=issuer)
133+
134+
token_data = {
135+
"grant_type": "urn:ietf:params:oauth:grant-type:jwt-bearer",
136+
"assertion": assertion,
137+
}
138+
139+
if self.context.should_include_resource_param(self.context.protocol_version):
140+
token_data["resource"] = self.context.get_resource_url()
141+
142+
if self.context.client_metadata.scope:
143+
token_data["scope"] = self.context.client_metadata.scope
144+
145+
token_url = self._get_token_endpoint()
146+
return httpx.Request(
147+
"POST", token_url, data=token_data, headers={"Content-Type": "application/x-www-form-urlencoded"}
148+
)

src/mcp/client/auth.py renamed to src/mcp/client/auth/oauth2.py

Lines changed: 51 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
import time
1414
from collections.abc import AsyncGenerator, Awaitable, Callable
1515
from dataclasses import dataclass, field
16-
from typing import Protocol
16+
from typing import Any, Protocol
1717
from urllib.parse import urlencode, urljoin, urlparse
1818

1919
import anyio
@@ -88,8 +88,8 @@ class OAuthContext:
8888
server_url: str
8989
client_metadata: OAuthClientMetadata
9090
storage: TokenStorage
91-
redirect_handler: Callable[[str], Awaitable[None]]
92-
callback_handler: Callable[[], Awaitable[tuple[str, str | None]]]
91+
redirect_handler: Callable[[str], Awaitable[None]] | None
92+
callback_handler: Callable[[], Awaitable[tuple[str, str | None]]] | None
9393
timeout: float = 300.0
9494

9595
# Discovered metadata
@@ -189,8 +189,8 @@ def __init__(
189189
server_url: str,
190190
client_metadata: OAuthClientMetadata,
191191
storage: TokenStorage,
192-
redirect_handler: Callable[[str], Awaitable[None]],
193-
callback_handler: Callable[[], Awaitable[tuple[str, str | None]]],
192+
redirect_handler: Callable[[str], Awaitable[None]] | None = None,
193+
callback_handler: Callable[[], Awaitable[tuple[str, str | None]]] | None = None,
194194
timeout: float = 300.0,
195195
):
196196
"""Initialize OAuth2 authentication."""
@@ -351,8 +351,21 @@ async def _handle_registration_response(self, response: httpx.Response) -> None:
351351
except ValidationError as e:
352352
raise OAuthRegistrationError(f"Invalid registration response: {e}")
353353

354-
async def _perform_authorization(self) -> tuple[str, str]:
354+
async def _perform_authorization(self) -> httpx.Request:
355+
"""Perform the authorization flow."""
356+
auth_code, code_verifier = await self._perform_authorization_code_grant()
357+
token_request = await self._exchange_token_authorization_code(auth_code, code_verifier)
358+
return token_request
359+
360+
async def _perform_authorization_code_grant(self) -> tuple[str, str]:
355361
"""Perform the authorization redirect and get auth code."""
362+
if self.context.client_metadata.redirect_uris is None:
363+
raise OAuthFlowError("No redirect URIs provided for authorization code grant")
364+
if not self.context.redirect_handler:
365+
raise OAuthFlowError("No redirect handler provided for authorization code grant")
366+
if not self.context.callback_handler:
367+
raise OAuthFlowError("No callback handler provided for authorization code grant")
368+
356369
if self.context.oauth_metadata and self.context.oauth_metadata.authorization_endpoint:
357370
auth_endpoint = str(self.context.oauth_metadata.authorization_endpoint)
358371
else:
@@ -397,24 +410,34 @@ async def _perform_authorization(self) -> tuple[str, str]:
397410
# Return auth code and code verifier for token exchange
398411
return auth_code, pkce_params.code_verifier
399412

400-
async def _exchange_token(self, auth_code: str, code_verifier: str) -> httpx.Request:
401-
"""Build token exchange request."""
402-
if not self.context.client_info:
403-
raise OAuthFlowError("Missing client info")
404-
413+
def _get_token_endpoint(self) -> str:
405414
if self.context.oauth_metadata and self.context.oauth_metadata.token_endpoint:
406415
token_url = str(self.context.oauth_metadata.token_endpoint)
407416
else:
408417
auth_base_url = self.context.get_authorization_base_url(self.context.server_url)
409418
token_url = urljoin(auth_base_url, "/token")
419+
return token_url
420+
421+
async def _exchange_token_authorization_code(
422+
self, auth_code: str, code_verifier: str, *, token_data: dict[str, Any] | None = {}
423+
) -> httpx.Request:
424+
"""Build token exchange request for authorization_code flow."""
425+
if self.context.client_metadata.redirect_uris is None:
426+
raise OAuthFlowError("No redirect URIs provided for authorization code grant")
427+
if not self.context.client_info:
428+
raise OAuthFlowError("Missing client info")
410429

411-
token_data = {
412-
"grant_type": "authorization_code",
413-
"code": auth_code,
414-
"redirect_uri": str(self.context.client_metadata.redirect_uris[0]),
415-
"client_id": self.context.client_info.client_id,
416-
"code_verifier": code_verifier,
417-
}
430+
token_url = self._get_token_endpoint()
431+
token_data = token_data or {}
432+
token_data.update(
433+
{
434+
"grant_type": "authorization_code",
435+
"code": auth_code,
436+
"redirect_uri": str(self.context.client_metadata.redirect_uris[0]),
437+
"client_id": self.context.client_info.client_id,
438+
"code_verifier": code_verifier,
439+
}
440+
)
418441

419442
# Only include resource param if conditions are met
420443
if self.context.should_include_resource_param(self.context.protocol_version):
@@ -430,7 +453,9 @@ async def _exchange_token(self, auth_code: str, code_verifier: str) -> httpx.Req
430453
async def _handle_token_response(self, response: httpx.Response) -> None:
431454
"""Handle token exchange response."""
432455
if response.status_code != 200:
433-
raise OAuthTokenError(f"Token exchange failed: {response.status_code}")
456+
body = await response.aread()
457+
body = body.decode("utf-8")
458+
raise OAuthTokenError(f"Token exchange failed ({response.status_code}): {body}")
434459

435460
try:
436461
content = await response.aread()
@@ -577,12 +602,8 @@ async def async_auth_flow(self, request: httpx.Request) -> AsyncGenerator[httpx.
577602
registration_response = yield registration_request
578603
await self._handle_registration_response(registration_response)
579604

580-
# Step 5: Perform authorization
581-
auth_code, code_verifier = await self._perform_authorization()
582-
583-
# Step 6: Exchange authorization code for tokens
584-
token_request = await self._exchange_token(auth_code, code_verifier)
585-
token_response = yield token_request
605+
# Step 5: Perform authorization and complete token exchange
606+
token_response = yield await self._perform_authorization()
586607
await self._handle_token_response(token_response)
587608
except Exception:
588609
logger.exception("OAuth flow error")
@@ -601,17 +622,13 @@ async def async_auth_flow(self, request: httpx.Request) -> AsyncGenerator[httpx.
601622
# Step 2a: Update the required scopes
602623
self._select_scopes(response)
603624

604-
# Step 2b: Perform (re-)authorization
605-
auth_code, code_verifier = await self._perform_authorization()
606-
607-
# Step 2c: Exchange authorization code for tokens
608-
token_request = await self._exchange_token(auth_code, code_verifier)
609-
token_response = yield token_request
625+
# Step 2b: Perform (re-)authorization and token exchange
626+
token_response = yield await self._perform_authorization()
610627
await self._handle_token_response(token_response)
611628
except Exception:
612629
logger.exception("OAuth flow error")
613630
raise
614631

615-
# Retry with new tokens
616-
self._add_auth_header(request)
617-
yield request
632+
# Retry with new tokens
633+
self._add_auth_header(request)
634+
yield request

src/mcp/shared/auth.py

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -41,13 +41,13 @@ class OAuthClientMetadata(BaseModel):
4141
for the full specification.
4242
"""
4343

44-
redirect_uris: list[AnyUrl] = Field(..., min_length=1)
45-
# token_endpoint_auth_method: this implementation only supports none &
46-
# client_secret_post;
47-
# ie: we do not support client_secret_basic
48-
token_endpoint_auth_method: Literal["none", "client_secret_post"] = "client_secret_post"
49-
# grant_types: this implementation only supports authorization_code & refresh_token
50-
grant_types: list[Literal["authorization_code", "refresh_token"] | str] = [
44+
redirect_uris: list[AnyUrl] | None = Field(..., min_length=1)
45+
# supported auth methods for the token endpoint
46+
token_endpoint_auth_method: Literal["none", "client_secret_post", "private_key_jwt"] = "client_secret_post"
47+
# supported grant_types of this implementation
48+
grant_types: list[
49+
Literal["authorization_code", "refresh_token", "urn:ietf:params:oauth:grant-type:jwt-bearer"] | str
50+
] = [
5151
"authorization_code",
5252
"refresh_token",
5353
]
@@ -82,10 +82,10 @@ def validate_scope(self, requested_scope: str | None) -> list[str] | None:
8282
def validate_redirect_uri(self, redirect_uri: AnyUrl | None) -> AnyUrl:
8383
if redirect_uri is not None:
8484
# Validate redirect_uri against client's registered redirect URIs
85-
if redirect_uri not in self.redirect_uris:
85+
if self.redirect_uris is None or redirect_uri not in self.redirect_uris:
8686
raise InvalidRedirectUriError(f"Redirect URI '{redirect_uri}' not registered for client")
8787
return redirect_uri
88-
elif len(self.redirect_uris) == 1:
88+
elif self.redirect_uris is not None and len(self.redirect_uris) == 1:
8989
return self.redirect_uris[0]
9090
else:
9191
raise InvalidRedirectUriError("redirect_uri must be specified when client has multiple registered URIs")
@@ -97,7 +97,7 @@ class OAuthClientInformationFull(OAuthClientMetadata):
9797
(client information plus metadata).
9898
"""
9999

100-
client_id: str
100+
client_id: str | None = None
101101
client_secret: str | None = None
102102
client_id_issued_at: int | None = None
103103
client_secret_expires_at: int | None = None

0 commit comments

Comments
 (0)