Skip to content

Commit a2ff5ce

Browse files
committed
Fix ASGI errors
1 parent 1f325f2 commit a2ff5ce

File tree

2 files changed

+59
-23
lines changed

2 files changed

+59
-23
lines changed

fastapi_mcp/server.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -167,6 +167,6 @@ async def handle_mcp_connection(request: Request):
167167
# Route for MCP messages
168168
@router.post(f"{mount_path}/messages/")
169169
async def handle_post_message(request: Request):
170-
await sse_transport.handle_fastapi_post_message(request)
170+
return await sse_transport.handle_fastapi_post_message(request)
171171

172172
logger.info(f"MCP server listening at {mount_path}")

fastapi_mcp/transport/sse.py

Lines changed: 58 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -1,17 +1,20 @@
11
from uuid import UUID
22
from logging import getLogger
3+
from typing import Union
34

4-
from fastapi import Request, Response
5+
from anyio.streams.memory import MemoryObjectSendStream
6+
from fastapi import Request, Response, BackgroundTasks, HTTPException
7+
from fastapi.responses import JSONResponse
58
from pydantic import ValidationError
69
from mcp.server.sse import SseServerTransport
7-
from mcp.types import JSONRPCMessage
10+
from mcp.types import JSONRPCMessage, JSONRPCError, ErrorData
811

912

1013
logger = getLogger(__name__)
1114

1215

1316
class FastApiSseTransport(SseServerTransport):
14-
async def handle_fastapi_post_message(self, request: Request) -> None:
17+
async def handle_fastapi_post_message(self, request: Request) -> Response:
1518
"""
1619
A reimplementation of the handle_post_message method of SseServerTransport
1720
that integrates better with FastAPI.
@@ -21,37 +24,33 @@ async def handle_fastapi_post_message(self, request: Request) -> None:
2124
approach. Mounting has some known issues and limitations.
2225
2. Avoid re-constructing the scope, receive, and send from the request, as done
2326
in the original implementation.
27+
3. Use FastAPI's native response handling mechanisms and exception patterns to
28+
avoid unexpected rabbit holes.
2429
2530
The combination of mounting a whole Starlette app and reconstructing the scope
2631
and send from the request proved to be especially error-prone for us when using
2732
tracing tools like Sentry, which had destructive effects on the request object
2833
when using the original implementation.
2934
"""
3035

31-
logger.debug("Handling POST message")
32-
scope = request.scope
33-
receive = request.receive
34-
send = request._send
36+
logger.debug("Handling POST message with FastAPI patterns")
3537

3638
session_id_param = request.query_params.get("session_id")
3739
if session_id_param is None:
3840
logger.warning("Received request without session_id")
39-
response = Response("session_id is required", status_code=400)
40-
return await response(scope, receive, send)
41+
raise HTTPException(status_code=400, detail="session_id is required")
4142

4243
try:
4344
session_id = UUID(hex=session_id_param)
4445
logger.debug(f"Parsed session ID: {session_id}")
4546
except ValueError:
4647
logger.warning(f"Received invalid session ID: {session_id_param}")
47-
response = Response("Invalid session ID", status_code=400)
48-
return await response(scope, receive, send)
48+
raise HTTPException(status_code=400, detail="Invalid session ID")
4949

5050
writer = self._read_stream_writers.get(session_id)
5151
if not writer:
5252
logger.warning(f"Could not find session for ID: {session_id}")
53-
response = Response("Could not find session", status_code=404)
54-
return await response(scope, receive, send)
53+
raise HTTPException(status_code=404, detail="Could not find session")
5554

5655
body = await request.body()
5756
logger.debug(f"Received JSON: {body.decode()}")
@@ -61,12 +60,49 @@ async def handle_fastapi_post_message(self, request: Request) -> None:
6160
logger.debug(f"Validated client message: {message}")
6261
except ValidationError as err:
6362
logger.error(f"Failed to parse message: {err}")
64-
response = Response("Could not parse message", status_code=400)
65-
await response(scope, receive, send)
66-
await writer.send(err)
67-
return
68-
69-
logger.debug(f"Sending message to writer: {message}")
70-
response = Response("Accepted", status_code=202)
71-
await response(scope, receive, send)
72-
await writer.send(message)
63+
# Create background task to send error
64+
background_tasks = BackgroundTasks()
65+
background_tasks.add_task(self._send_message_safely, writer, err)
66+
response = JSONResponse(content={"error": "Could not parse message"}, status_code=400)
67+
response.background = background_tasks
68+
return response
69+
except Exception as e:
70+
logger.error(f"Error processing request body: {e}")
71+
raise HTTPException(status_code=400, detail="Invalid request body")
72+
73+
# Create background task to send message
74+
background_tasks = BackgroundTasks()
75+
background_tasks.add_task(self._send_message_safely, writer, message)
76+
logger.debug("Accepting message, will send in background")
77+
78+
# Return response with background task
79+
response = JSONResponse(content={"message": "Accepted"}, status_code=202)
80+
response.background = background_tasks
81+
return response
82+
83+
async def _send_message_safely(
84+
self, writer: MemoryObjectSendStream[JSONRPCMessage], message: Union[JSONRPCMessage, ValidationError]
85+
):
86+
"""Send a message to the writer, avoiding ASGI race conditions"""
87+
88+
try:
89+
logger.debug(f"Sending message to writer from background task: {message}")
90+
91+
if isinstance(message, ValidationError):
92+
# Convert ValidationError to JSONRPCError
93+
error_data = ErrorData(
94+
code=-32700, # Parse error code in JSON-RPC
95+
message="Parse error",
96+
data={"validation_error": str(message)},
97+
)
98+
json_rpc_error = JSONRPCError(
99+
jsonrpc="2.0",
100+
id="unknown", # We don't know the ID from the invalid request
101+
error=error_data,
102+
)
103+
error_message = JSONRPCMessage(root=json_rpc_error)
104+
await writer.send(error_message)
105+
else:
106+
await writer.send(message)
107+
except Exception as e:
108+
logger.error(f"Error sending message to writer: {e}")

0 commit comments

Comments
 (0)