Skip to content

Commit 7f543d9

Browse files
committed
fastapi-native routes
1 parent 97ffd8e commit 7f543d9

File tree

5 files changed

+105
-25
lines changed

5 files changed

+105
-25
lines changed

examples/full_schema_description_example.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,6 @@
1010
# Add MCP server to the FastAPI app
1111
mcp = FastApiMCP(
1212
items.app,
13-
mount_path="/mcp",
1413
name="Item API MCP",
1514
description="MCP server for the Item API",
1615
base_url="http://localhost:8000",

examples/simple_example.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,6 @@
1010
# Add MCP server to the FastAPI app
1111
mcp = FastApiMCP(
1212
items.app,
13-
mount_path="/mcp",
1413
name="Item API MCP",
1514
description="MCP server for the Item API",
1615
base_url="http://localhost:8000",

fastapi_mcp/server.py

Lines changed: 33 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -1,21 +1,25 @@
11
from contextlib import asynccontextmanager
22
from typing import Dict, Optional, Any, List, Union, AsyncIterator
33

4-
from fastapi import FastAPI, Request
4+
from fastapi import FastAPI, Request, APIRouter
55
from fastapi.openapi.utils import get_openapi
66
from mcp.server.lowlevel.server import Server
7-
from mcp.server.sse import SseServerTransport
87
import mcp.types as types
98

109
from fastapi_mcp.openapi.convert import convert_openapi_to_mcp_tools
1110
from fastapi_mcp.execute import execute_api_tool
11+
from fastapi_mcp.transport.sse import FastApiSseTransport
12+
13+
from logging import getLogger
14+
15+
16+
logger = getLogger(__name__)
1217

1318

1419
class FastApiMCP:
1520
def __init__(
1621
self,
1722
fastapi: FastAPI,
18-
mount_path: str = "/mcp",
1923
name: Optional[str] = None,
2024
description: Optional[str] = None,
2125
base_url: Optional[str] = None,
@@ -29,7 +33,6 @@ def __init__(
2933
self.name = name
3034
self.description = description
3135

32-
self._mount_path = mount_path
3336
self._base_url = base_url
3437
self._describe_all_responses = describe_all_responses
3538
self._describe_full_response_schema = describe_full_response_schema
@@ -41,10 +44,12 @@ def create_server(self) -> Server:
4144
Create an MCP server from the FastAPI app.
4245
4346
Args:
44-
app: The FastAPI application
47+
fastapi: The FastAPI application
4548
name: Name for the MCP server (defaults to app.title)
4649
description: Description for the MCP server (defaults to app.description)
47-
base_url: Base URL for API requests (defaults to http://localhost:$PORT)
50+
base_url: Base URL for API requests. If not provided, the base URL will be determined from the
51+
FastAPI app's root path. Although optional, it is highly recommended to provide a base URL,
52+
as the root path would be different when the app is deployed.
4853
describe_all_responses: Whether to include all possible response schemas in tool descriptions
4954
describe_full_response_schema: Whether to include full json schema for responses in tool descriptions
5055
@@ -126,37 +131,42 @@ async def handle_call_tool(
126131

127132
return mcp_server
128133

129-
def mount(self) -> None:
134+
def mount(self, router: Optional[FastAPI | APIRouter] = None, mount_path: str = "/mcp") -> None:
130135
"""
131136
Mount the MCP server to the FastAPI app.
132137
133138
Args:
134-
app: The FastAPI application
135-
mcp_server: The MCP server to mount
136-
operation_map: A mapping of operation IDs to operation details
139+
router: The FastAPI app or APIRouter to mount the MCP server to. If not provided, the MCP
140+
server will be mounted to the FastAPI app.
137141
mount_path: Path where the MCP server will be mounted
138-
base_url: Base URL for API requests
139142
"""
140143
# Normalize mount path
141-
if not self._mount_path.startswith("/"):
142-
self._mount_path = f"/{self._mount_path}"
143-
if self._mount_path.endswith("/"):
144-
self._mount_path = self._mount_path[:-1]
144+
if not mount_path.startswith("/"):
145+
mount_path = f"/{mount_path}"
146+
if mount_path.endswith("/"):
147+
mount_path = mount_path[:-1]
148+
149+
if not router:
150+
router = self.fastapi
145151

146152
# Create SSE transport for MCP messages
147-
sse_transport = SseServerTransport(f"{self._mount_path}/messages/")
153+
sse_transport = FastApiSseTransport(f"{mount_path}/messages/")
148154

149-
# Define MCP connection handler
155+
# Route for MCP connection
156+
@router.get(mount_path)
150157
async def handle_mcp_connection(request: Request):
151-
async with sse_transport.connect_sse(request.scope, request.receive, request._send) as streams:
158+
async with sse_transport.connect_sse(request.scope, request.receive, request._send) as (reader, writer):
152159
await self.mcp_server.run(
153-
streams[0],
154-
streams[1],
160+
reader,
161+
writer,
155162
self.mcp_server.create_initialization_options(
156163
notification_options=None, experimental_capabilities={}
157164
),
158165
)
159166

160-
# Mount the MCP connection handler
161-
self.fastapi.get(self._mount_path)(handle_mcp_connection)
162-
self.fastapi.mount(f"{self._mount_path}/messages/", app=sse_transport.handle_post_message)
167+
# Route for MCP messages
168+
@router.post(f"{mount_path}/messages/")
169+
async def handle_post_message(request: Request):
170+
await sse_transport.handle_fastapi_post_message(request)
171+
172+
logger.info(f"MCP server listening at {mount_path}")

fastapi_mcp/transport/__init__.py

Whitespace-only changes.

fastapi_mcp/transport/sse.py

Lines changed: 72 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,72 @@
1+
from uuid import UUID
2+
from logging import getLogger
3+
4+
from fastapi import Request, Response
5+
from pydantic import ValidationError
6+
from mcp.server.sse import SseServerTransport
7+
from mcp.types import JSONRPCMessage
8+
9+
10+
logger = getLogger(__name__)
11+
12+
13+
class FastApiSseTransport(SseServerTransport):
14+
async def handle_fastapi_post_message(self, request: Request) -> None:
15+
"""
16+
A reimplementation of the handle_post_message method of SseServerTransport
17+
that integrates better with FastAPI.
18+
19+
A few good reasons for doing this:
20+
1. Avoid mounting a whole Starlette app and instead use a more FastAPI-native
21+
approach. Mounting has some known issues and limitations.
22+
2. Avoid re-constructing the scope, receive, and send from the request, as done
23+
in the original implementation.
24+
25+
The combination of mounting a whole Starlette app and reconstructing the scope
26+
and send from the request proved to be especially error-prone for us when using
27+
tracing tools like Sentry, which had destructive effects on the request object
28+
when using the original implementation.
29+
"""
30+
31+
logger.debug("Handling POST message")
32+
scope = request.scope
33+
receive = request.receive
34+
send = request._send
35+
36+
session_id_param = request.query_params.get("session_id")
37+
if session_id_param is None:
38+
logger.warning("Received request without session_id")
39+
response = Response("session_id is required", status_code=400)
40+
return await response(scope, receive, send)
41+
42+
try:
43+
session_id = UUID(hex=session_id_param)
44+
logger.debug(f"Parsed session ID: {session_id}")
45+
except ValueError:
46+
logger.warning(f"Received invalid session ID: {session_id_param}")
47+
response = Response("Invalid session ID", status_code=400)
48+
return await response(scope, receive, send)
49+
50+
writer = self._read_stream_writers.get(session_id)
51+
if not writer:
52+
logger.warning(f"Could not find session for ID: {session_id}")
53+
response = Response("Could not find session", status_code=404)
54+
return await response(scope, receive, send)
55+
56+
body = await request.body()
57+
logger.debug(f"Received JSON: {body.decode()}")
58+
59+
try:
60+
message = JSONRPCMessage.model_validate_json(body)
61+
logger.debug(f"Validated client message: {message}")
62+
except ValidationError as err:
63+
logger.error(f"Failed to parse message: {err}")
64+
response = Response("Could not parse message", status_code=400)
65+
await response(scope, receive, send)
66+
await writer.send(err)
67+
return
68+
69+
logger.debug(f"Sending message to writer: {message}")
70+
response = Response("Accepted", status_code=202)
71+
await response(scope, receive, send)
72+
await writer.send(message)

0 commit comments

Comments
 (0)