1
1
from uuid import UUID
2
2
from logging import getLogger
3
+ from typing import Union
3
4
4
- from fastapi import Request , Response
5
+ from anyio .streams .memory import MemoryObjectSendStream
6
+ from fastapi import Request , Response , BackgroundTasks , HTTPException
7
+ from fastapi .responses import JSONResponse
5
8
from pydantic import ValidationError
6
9
from mcp .server .sse import SseServerTransport
7
- from mcp .types import JSONRPCMessage
10
+ from mcp .types import JSONRPCMessage , JSONRPCError , ErrorData
8
11
9
12
10
13
logger = getLogger (__name__ )
11
14
12
15
13
16
class FastApiSseTransport (SseServerTransport ):
14
- async def handle_fastapi_post_message (self , request : Request ) -> None :
17
+ async def handle_fastapi_post_message (self , request : Request ) -> Response :
15
18
"""
16
19
A reimplementation of the handle_post_message method of SseServerTransport
17
20
that integrates better with FastAPI.
@@ -21,37 +24,33 @@ async def handle_fastapi_post_message(self, request: Request) -> None:
21
24
approach. Mounting has some known issues and limitations.
22
25
2. Avoid re-constructing the scope, receive, and send from the request, as done
23
26
in the original implementation.
27
+ 3. Use FastAPI's native response handling mechanisms and exception patterns to
28
+ avoid unexpected rabbit holes.
24
29
25
30
The combination of mounting a whole Starlette app and reconstructing the scope
26
31
and send from the request proved to be especially error-prone for us when using
27
32
tracing tools like Sentry, which had destructive effects on the request object
28
33
when using the original implementation.
29
34
"""
30
35
31
- logger .debug ("Handling POST message" )
32
- scope = request .scope
33
- receive = request .receive
34
- send = request ._send
36
+ logger .debug ("Handling POST message with FastAPI patterns" )
35
37
36
38
session_id_param = request .query_params .get ("session_id" )
37
39
if session_id_param is None :
38
40
logger .warning ("Received request without session_id" )
39
- response = Response ("session_id is required" , status_code = 400 )
40
- return await response (scope , receive , send )
41
+ raise HTTPException (status_code = 400 , detail = "session_id is required" )
41
42
42
43
try :
43
44
session_id = UUID (hex = session_id_param )
44
45
logger .debug (f"Parsed session ID: { session_id } " )
45
46
except ValueError :
46
47
logger .warning (f"Received invalid session ID: { session_id_param } " )
47
- response = Response ("Invalid session ID" , status_code = 400 )
48
- return await response (scope , receive , send )
48
+ raise HTTPException (status_code = 400 , detail = "Invalid session ID" )
49
49
50
50
writer = self ._read_stream_writers .get (session_id )
51
51
if not writer :
52
52
logger .warning (f"Could not find session for ID: { session_id } " )
53
- response = Response ("Could not find session" , status_code = 404 )
54
- return await response (scope , receive , send )
53
+ raise HTTPException (status_code = 404 , detail = "Could not find session" )
55
54
56
55
body = await request .body ()
57
56
logger .debug (f"Received JSON: { body .decode ()} " )
@@ -61,12 +60,49 @@ async def handle_fastapi_post_message(self, request: Request) -> None:
61
60
logger .debug (f"Validated client message: { message } " )
62
61
except ValidationError as err :
63
62
logger .error (f"Failed to parse message: { err } " )
64
- response = Response ("Could not parse message" , status_code = 400 )
65
- await response (scope , receive , send )
66
- await writer .send (err )
67
- return
68
-
69
- logger .debug (f"Sending message to writer: { message } " )
70
- response = Response ("Accepted" , status_code = 202 )
71
- await response (scope , receive , send )
72
- await writer .send (message )
63
+ # Create background task to send error
64
+ background_tasks = BackgroundTasks ()
65
+ background_tasks .add_task (self ._send_message_safely , writer , err )
66
+ response = JSONResponse (content = {"error" : "Could not parse message" }, status_code = 400 )
67
+ response .background = background_tasks
68
+ return response
69
+ except Exception as e :
70
+ logger .error (f"Error processing request body: { e } " )
71
+ raise HTTPException (status_code = 400 , detail = "Invalid request body" )
72
+
73
+ # Create background task to send message
74
+ background_tasks = BackgroundTasks ()
75
+ background_tasks .add_task (self ._send_message_safely , writer , message )
76
+ logger .debug ("Accepting message, will send in background" )
77
+
78
+ # Return response with background task
79
+ response = JSONResponse (content = {"message" : "Accepted" }, status_code = 202 )
80
+ response .background = background_tasks
81
+ return response
82
+
83
+ async def _send_message_safely (
84
+ self , writer : MemoryObjectSendStream [JSONRPCMessage ], message : Union [JSONRPCMessage , ValidationError ]
85
+ ):
86
+ """Send a message to the writer, avoiding ASGI race conditions"""
87
+
88
+ try :
89
+ logger .debug (f"Sending message to writer from background task: { message } " )
90
+
91
+ if isinstance (message , ValidationError ):
92
+ # Convert ValidationError to JSONRPCError
93
+ error_data = ErrorData (
94
+ code = - 32700 , # Parse error code in JSON-RPC
95
+ message = "Parse error" ,
96
+ data = {"validation_error" : str (message )},
97
+ )
98
+ json_rpc_error = JSONRPCError (
99
+ jsonrpc = "2.0" ,
100
+ id = "unknown" , # We don't know the ID from the invalid request
101
+ error = error_data ,
102
+ )
103
+ error_message = JSONRPCMessage (root = json_rpc_error )
104
+ await writer .send (error_message )
105
+ else :
106
+ await writer .send (message )
107
+ except Exception as e :
108
+ logger .error (f"Error sending message to writer: { e } " )
0 commit comments