Skip to content

Commit 5f08290

Browse files
committed
remove the hacky approach to injecting http request context, now that MCP sdk supports passing http context directly to tool handlers
1 parent 37109a3 commit 5f08290

File tree

2 files changed

+30
-59
lines changed

2 files changed

+30
-59
lines changed

fastapi_mcp/server.py

Lines changed: 30 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
import json
22
import httpx
3-
from typing import Dict, Optional, Any, List, Union, Callable, Awaitable, Iterable, Literal, Sequence
3+
from typing import Dict, Optional, Any, List, Union, Literal, Sequence
44
from typing_extensions import Annotated, Doc
55

66
from fastapi import FastAPI, Request, APIRouter, params
@@ -19,45 +19,6 @@
1919
logger = logging.getLogger(__name__)
2020

2121

22-
class LowlevelMCPServer(Server):
23-
def call_tool(self):
24-
"""
25-
A near-direct copy of `mcp.server.lowlevel.server.Server.call_tool()`, except that it looks for
26-
the original HTTP request info in the MCP message, and passes it to the tool call handler.
27-
"""
28-
29-
def decorator(
30-
func: Callable[
31-
...,
32-
Awaitable[Iterable[types.TextContent | types.ImageContent | types.EmbeddedResource]],
33-
],
34-
):
35-
logger.debug("Registering handler for CallToolRequest")
36-
37-
async def handler(req: types.CallToolRequest):
38-
try:
39-
# HACK: Pull the original HTTP request info from the MCP message. It was injected in
40-
# `FastApiSseTransport.handle_fastapi_post_message()`
41-
if hasattr(req.params, "_http_request_info") and req.params._http_request_info is not None:
42-
http_request_info = HTTPRequestInfo.model_validate(req.params._http_request_info)
43-
results = await func(req.params.name, (req.params.arguments or {}), http_request_info)
44-
else:
45-
results = await func(req.params.name, (req.params.arguments or {}))
46-
return types.ServerResult(types.CallToolResult(content=list(results), isError=False))
47-
except Exception as e:
48-
return types.ServerResult(
49-
types.CallToolResult(
50-
content=[types.TextContent(type="text", text=str(e))],
51-
isError=True,
52-
)
53-
)
54-
55-
self.request_handlers[types.CallToolRequest] = handler
56-
return func
57-
58-
return decorator
59-
60-
6122
class FastApiMCP:
6223
"""
6324
Create an MCP server from a FastAPI app.
@@ -115,14 +76,14 @@ def __init__(
11576
Doc("Configuration for MCP authentication"),
11677
] = None,
11778
headers: Annotated[
118-
Optional[List[str]],
79+
List[str],
11980
Doc(
12081
"""
12182
List of HTTP header names to forward from the incoming MCP request into each tool invocation.
12283
Only headers in this allowlist will be forwarded. Defaults to ['authorization'].
12384
"""
12485
),
125-
] = None,
86+
] = ["authorization"],
12687
):
12788
# Validate operation and tag filtering options
12889
if include_operations is not None and exclude_operations is not None:
@@ -157,7 +118,7 @@ def __init__(
157118
timeout=10.0,
158119
)
159120

160-
self._forward_headers = {h.lower() for h in (headers or ["Authorization"])}
121+
self._forward_headers = {h.lower() for h in headers}
161122

162123
self.setup_server()
163124

@@ -179,16 +140,40 @@ def setup_server(self) -> None:
179140
# Filter tools based on operation IDs and tags
180141
self.tools = self._filter_tools(all_tools, openapi_schema)
181142

182-
mcp_server: LowlevelMCPServer = LowlevelMCPServer(self.name, self.description)
143+
mcp_server: Server = Server(self.name, self.description)
183144

184145
@mcp_server.list_tools()
185146
async def handle_list_tools() -> List[types.Tool]:
186147
return self.tools
187148

188149
@mcp_server.call_tool()
189150
async def handle_call_tool(
190-
name: str, arguments: Dict[str, Any], http_request_info: Optional[HTTPRequestInfo] = None
151+
name: str, arguments: Dict[str, Any]
191152
) -> List[Union[types.TextContent, types.ImageContent, types.EmbeddedResource]]:
153+
# Extract HTTP request info from MCP context
154+
http_request_info = None
155+
try:
156+
# Access the MCP server's request context to get the original HTTP Request
157+
request_context = mcp_server.request_context
158+
159+
if request_context and hasattr(request_context, "request"):
160+
http_request = request_context.request
161+
162+
if http_request and hasattr(http_request, "method"):
163+
http_request_info = HTTPRequestInfo(
164+
method=http_request.method,
165+
path=http_request.url.path,
166+
headers=dict(http_request.headers),
167+
cookies=http_request.cookies,
168+
query_params=dict(http_request.query_params),
169+
body=None,
170+
)
171+
logger.debug(
172+
f"Extracted HTTP request info from context: {http_request_info.method} {http_request_info.path}"
173+
)
174+
except (LookupError, AttributeError) as e:
175+
logger.error(f"Could not extract HTTP request info from context: {e}")
176+
192177
return await self._execute_api_tool(
193178
client=self._http_client,
194179
tool_name=name,

fastapi_mcp/transport/sse.py

Lines changed: 0 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,6 @@
99
from pydantic import ValidationError
1010
from mcp.server.sse import SseServerTransport
1111
from mcp.types import JSONRPCMessage, JSONRPCError, ErrorData
12-
from fastapi_mcp.types import HTTPRequestInfo
1312

1413

1514
logger = logging.getLogger(__name__)
@@ -60,19 +59,6 @@ async def handle_fastapi_post_message(self, request: Request) -> Response:
6059
try:
6160
message = JSONRPCMessage.model_validate_json(body)
6261

63-
# HACK to inject the HTTP request info into the MCP message,
64-
# so we can use it for auth.
65-
# It is then used in our custom `LowlevelMCPServer.call_tool()` decorator.
66-
if hasattr(message.root, "params") and message.root.params is not None:
67-
message.root.params["_http_request_info"] = HTTPRequestInfo(
68-
method=request.method,
69-
path=request.url.path,
70-
headers=dict(request.headers),
71-
cookies=request.cookies,
72-
query_params=dict(request.query_params),
73-
body=body.decode(),
74-
).model_dump(mode="json")
75-
7662
logger.debug(f"Validated client message: {message}")
7763
except ValidationError as err:
7864
logger.error(f"Failed to parse message: {err}")

0 commit comments

Comments
 (0)