|
1 | 1 | from contextlib import asynccontextmanager
|
2 |
| -from typing import Dict, Optional, Any, Tuple, List, Union, AsyncIterator |
| 2 | +from typing import Dict, Optional, Any, List, Union, AsyncIterator |
3 | 3 |
|
4 | 4 | from fastapi import FastAPI, Request
|
5 | 5 | from fastapi.openapi.utils import get_openapi
|
|
11 | 11 | from fastapi_mcp.execute import execute_api_tool
|
12 | 12 |
|
13 | 13 |
|
14 |
| -def create_mcp_server( |
15 |
| - app: FastAPI, |
16 |
| - name: Optional[str] = None, |
17 |
| - description: Optional[str] = None, |
18 |
| - base_url: Optional[str] = None, |
19 |
| - describe_all_responses: bool = False, |
20 |
| - describe_full_response_schema: bool = False, |
21 |
| -) -> Tuple[Server, Dict[str, Dict[str, Any]]]: |
22 |
| - """ |
23 |
| - Create an MCP server from a FastAPI app. |
24 |
| -
|
25 |
| - Args: |
26 |
| - app: The FastAPI application |
27 |
| - name: Name for the MCP server (defaults to app.title) |
28 |
| - description: Description for the MCP server (defaults to app.description) |
29 |
| - base_url: Base URL for API requests (defaults to http://localhost:$PORT) |
30 |
| - describe_all_responses: Whether to include all possible response schemas in tool descriptions |
31 |
| - describe_full_response_schema: Whether to include full json schema for responses in tool descriptions |
32 |
| -
|
33 |
| - Returns: |
34 |
| - A tuple containing: |
35 |
| - - The created MCP Server instance (NOT mounted to the app) |
36 |
| - - A mapping of operation IDs to operation details for HTTP execution |
37 |
| - """ |
38 |
| - # Get OpenAPI schema from FastAPI app |
39 |
| - openapi_schema = get_openapi( |
40 |
| - title=app.title, |
41 |
| - version=app.version, |
42 |
| - openapi_version=app.openapi_version, |
43 |
| - description=app.description, |
44 |
| - routes=app.routes, |
45 |
| - ) |
46 |
| - |
47 |
| - # Get server name and description from app if not provided |
48 |
| - server_name = name or app.title or "FastAPI MCP" |
49 |
| - server_description = description or app.description |
50 |
| - |
51 |
| - # Convert OpenAPI schema to MCP tools |
52 |
| - tools, operation_map = convert_openapi_to_mcp_tools( |
53 |
| - openapi_schema, |
54 |
| - describe_all_responses=describe_all_responses, |
55 |
| - describe_full_response_schema=describe_full_response_schema, |
56 |
| - ) |
57 |
| - |
58 |
| - # Determine base URL if not provided |
59 |
| - if not base_url: |
60 |
| - # Try to determine the base URL from FastAPI config |
61 |
| - if hasattr(app, "root_path") and app.root_path: |
62 |
| - base_url = app.root_path |
63 |
| - else: |
64 |
| - # Default to localhost with FastAPI default port |
65 |
| - port = 8000 |
66 |
| - for route in app.routes: |
67 |
| - if hasattr(route, "app") and hasattr(route.app, "port"): |
68 |
| - port = route.app.port |
69 |
| - break |
70 |
| - base_url = f"http://localhost:{port}" |
71 |
| - |
72 |
| - # Normalize base URL |
73 |
| - if base_url.endswith("/"): |
74 |
| - base_url = base_url[:-1] |
75 |
| - |
76 |
| - # Create the MCP server |
77 |
| - mcp_server: Server = Server(server_name, server_description) |
78 |
| - |
79 |
| - # Create a lifespan context manager to store the base_url and operation_map |
80 |
| - @asynccontextmanager |
81 |
| - async def server_lifespan(server) -> AsyncIterator[Dict[str, Any]]: |
82 |
| - # Store context data that will be available to all server handlers |
83 |
| - context = {"base_url": base_url, "operation_map": operation_map} |
84 |
| - yield context |
85 |
| - |
86 |
| - # Use our custom lifespan |
87 |
| - mcp_server.lifespan = server_lifespan |
88 |
| - |
89 |
| - # Register handlers for tools |
90 |
| - @mcp_server.list_tools() |
91 |
| - async def handle_list_tools() -> List[types.Tool]: |
92 |
| - """Handler for the tools/list request""" |
93 |
| - return tools |
94 |
| - |
95 |
| - # Register the tool call handler |
96 |
| - @mcp_server.call_tool() |
97 |
| - async def handle_call_tool( |
98 |
| - name: str, arguments: Dict[str, Any] |
99 |
| - ) -> List[Union[types.TextContent, types.ImageContent, types.EmbeddedResource]]: |
100 |
| - """Handler for the tools/call request""" |
101 |
| - # Get context from server lifespan |
102 |
| - ctx = mcp_server.request_context |
103 |
| - base_url = ctx.lifespan_context["base_url"] |
104 |
| - operation_map = ctx.lifespan_context["operation_map"] |
105 |
| - |
106 |
| - # Execute the tool |
107 |
| - return await execute_api_tool(base_url, name, arguments, operation_map) |
108 |
| - |
109 |
| - return mcp_server, operation_map |
110 |
| - |
111 |
| - |
112 |
| -def mount_mcp_server( |
113 |
| - app: FastAPI, |
114 |
| - mcp_server: Server, |
115 |
| - operation_map: Dict[str, Dict[str, Any]], |
116 |
| - mount_path: str = "/mcp", |
117 |
| - base_url: Optional[str] = None, |
118 |
| -) -> None: |
119 |
| - """ |
120 |
| - Mount an MCP server to a FastAPI app. |
121 |
| -
|
122 |
| - Args: |
123 |
| - app: The FastAPI application |
124 |
| - mcp_server: The MCP server to mount |
125 |
| - operation_map: A mapping of operation IDs to operation details |
126 |
| - mount_path: Path where the MCP server will be mounted |
127 |
| - base_url: Base URL for API requests |
128 |
| - """ |
129 |
| - # Normalize mount path |
130 |
| - if not mount_path.startswith("/"): |
131 |
| - mount_path = f"/{mount_path}" |
132 |
| - if mount_path.endswith("/"): |
133 |
| - mount_path = mount_path[:-1] |
134 |
| - |
135 |
| - # Create SSE transport for MCP messages |
136 |
| - sse_transport = SseServerTransport(f"{mount_path}/messages/") |
137 |
| - |
138 |
| - # Define MCP connection handler |
139 |
| - async def handle_mcp_connection(request: Request): |
140 |
| - async with sse_transport.connect_sse(request.scope, request.receive, request._send) as streams: |
141 |
| - await mcp_server.run( |
142 |
| - streams[0], |
143 |
| - streams[1], |
144 |
| - mcp_server.create_initialization_options(notification_options=None, experimental_capabilities={}), |
145 |
| - ) |
146 |
| - |
147 |
| - # Mount the MCP connection handler |
148 |
| - app.get(mount_path)(handle_mcp_connection) |
149 |
| - app.mount(f"{mount_path}/messages/", app=sse_transport.handle_post_message) |
150 |
| - |
151 |
| - |
152 |
| -def add_mcp_server( |
153 |
| - app: FastAPI, |
154 |
| - mount_path: str = "/mcp", |
155 |
| - name: Optional[str] = None, |
156 |
| - description: Optional[str] = None, |
157 |
| - base_url: Optional[str] = None, |
158 |
| - describe_all_responses: bool = False, |
159 |
| - describe_full_response_schema: bool = False, |
160 |
| -) -> Server: |
161 |
| - """ |
162 |
| - Add an MCP server to a FastAPI app. |
163 |
| -
|
164 |
| - Args: |
165 |
| - app: The FastAPI application |
166 |
| - mount_path: Path where the MCP server will be mounted |
167 |
| - name: Name for the MCP server (defaults to app.title) |
168 |
| - description: Description for the MCP server (defaults to app.description) |
169 |
| - base_url: Base URL for API requests (defaults to http://localhost:$PORT) |
170 |
| - describe_all_responses: Whether to include all possible response schemas in tool descriptions |
171 |
| - describe_full_response_schema: Whether to include full json schema for responses in tool descriptions |
172 |
| -
|
173 |
| - Returns: |
174 |
| - The MCP server instance that was created and mounted |
175 |
| - """ |
176 |
| - # Create MCP server |
177 |
| - mcp_server, operation_map = create_mcp_server( |
178 |
| - app, |
179 |
| - name, |
180 |
| - description, |
181 |
| - base_url, |
182 |
| - describe_all_responses=describe_all_responses, |
183 |
| - describe_full_response_schema=describe_full_response_schema, |
184 |
| - ) |
185 |
| - |
186 |
| - # Mount MCP server |
187 |
| - mount_mcp_server(app, mcp_server, operation_map, mount_path, base_url) |
188 |
| - |
189 |
| - return mcp_server |
| 14 | +class FastApiMCP: |
| 15 | + def __init__( |
| 16 | + self, |
| 17 | + fastapi: FastAPI, |
| 18 | + mount_path: str = "/mcp", |
| 19 | + name: Optional[str] = None, |
| 20 | + description: Optional[str] = None, |
| 21 | + base_url: Optional[str] = None, |
| 22 | + describe_all_responses: bool = False, |
| 23 | + describe_full_response_schema: bool = False, |
| 24 | + ): |
| 25 | + self.operation_map: Dict[str, Dict[str, Any]] |
| 26 | + self.tools: List[types.Tool] |
| 27 | + |
| 28 | + self.fastapi = fastapi |
| 29 | + self.name = name |
| 30 | + self.description = description |
| 31 | + |
| 32 | + self._mount_path = mount_path |
| 33 | + self._base_url = base_url |
| 34 | + self._describe_all_responses = describe_all_responses |
| 35 | + self._describe_full_response_schema = describe_full_response_schema |
| 36 | + |
| 37 | + self.mcp_server = self.create_server() |
| 38 | + |
| 39 | + def create_server(self) -> Server: |
| 40 | + """ |
| 41 | + Create an MCP server from the FastAPI app. |
| 42 | +
|
| 43 | + Args: |
| 44 | + app: The FastAPI application |
| 45 | + name: Name for the MCP server (defaults to app.title) |
| 46 | + description: Description for the MCP server (defaults to app.description) |
| 47 | + base_url: Base URL for API requests (defaults to http://localhost:$PORT) |
| 48 | + describe_all_responses: Whether to include all possible response schemas in tool descriptions |
| 49 | + describe_full_response_schema: Whether to include full json schema for responses in tool descriptions |
| 50 | +
|
| 51 | + Returns: |
| 52 | + A tuple containing: |
| 53 | + - The created MCP Server instance (NOT mounted to the app) |
| 54 | + - A mapping of operation IDs to operation details for HTTP execution |
| 55 | + """ |
| 56 | + # Get OpenAPI schema from FastAPI app |
| 57 | + openapi_schema = get_openapi( |
| 58 | + title=self.fastapi.title, |
| 59 | + version=self.fastapi.version, |
| 60 | + openapi_version=self.fastapi.openapi_version, |
| 61 | + description=self.fastapi.description, |
| 62 | + routes=self.fastapi.routes, |
| 63 | + ) |
| 64 | + |
| 65 | + # Get server name and description from app if not provided |
| 66 | + server_name = self.name or self.fastapi.title or "FastAPI MCP" |
| 67 | + server_description = self.description or self.fastapi.description |
| 68 | + |
| 69 | + # Convert OpenAPI schema to MCP tools |
| 70 | + self.tools, self.operation_map = convert_openapi_to_mcp_tools( |
| 71 | + openapi_schema, |
| 72 | + describe_all_responses=self._describe_all_responses, |
| 73 | + describe_full_response_schema=self._describe_full_response_schema, |
| 74 | + ) |
| 75 | + |
| 76 | + # Determine base URL if not provided |
| 77 | + if not self._base_url: |
| 78 | + # Try to determine the base URL from FastAPI config |
| 79 | + if hasattr(self.fastapi, "root_path") and self.fastapi.root_path: |
| 80 | + self._base_url = self.fastapi.root_path |
| 81 | + else: |
| 82 | + # Default to localhost with FastAPI default port |
| 83 | + port = 8000 |
| 84 | + for route in self.fastapi.routes: |
| 85 | + if hasattr(route, "app") and hasattr(route.app, "port"): |
| 86 | + port = route.app.port |
| 87 | + break |
| 88 | + self._base_url = f"http://localhost:{port}" |
| 89 | + |
| 90 | + # Normalize base URL |
| 91 | + if self._base_url.endswith("/"): |
| 92 | + self._base_url = self._base_url[:-1] |
| 93 | + |
| 94 | + # Create the MCP server |
| 95 | + mcp_server: Server = Server(server_name, server_description) |
| 96 | + |
| 97 | + # Create a lifespan context manager to store the base_url and operation_map |
| 98 | + @asynccontextmanager |
| 99 | + async def server_lifespan(server) -> AsyncIterator[Dict[str, Any]]: |
| 100 | + # Store context data that will be available to all server handlers |
| 101 | + context = {"base_url": self._base_url, "operation_map": self.operation_map} |
| 102 | + yield context |
| 103 | + |
| 104 | + # Use our custom lifespan |
| 105 | + mcp_server.lifespan = server_lifespan |
| 106 | + |
| 107 | + # Register handlers for tools |
| 108 | + @mcp_server.list_tools() |
| 109 | + async def handle_list_tools() -> List[types.Tool]: |
| 110 | + """Handler for the tools/list request""" |
| 111 | + return self.tools |
| 112 | + |
| 113 | + # Register the tool call handler |
| 114 | + @mcp_server.call_tool() |
| 115 | + async def handle_call_tool( |
| 116 | + name: str, arguments: Dict[str, Any] |
| 117 | + ) -> List[Union[types.TextContent, types.ImageContent, types.EmbeddedResource]]: |
| 118 | + """Handler for the tools/call request""" |
| 119 | + # Get context from server lifespan |
| 120 | + ctx = mcp_server.request_context |
| 121 | + base_url = ctx.lifespan_context["base_url"] |
| 122 | + operation_map = ctx.lifespan_context["operation_map"] |
| 123 | + |
| 124 | + # Execute the tool |
| 125 | + return await execute_api_tool(base_url, name, arguments, operation_map) |
| 126 | + |
| 127 | + return mcp_server |
| 128 | + |
| 129 | + def mount(self) -> None: |
| 130 | + """ |
| 131 | + Mount the MCP server to the FastAPI app. |
| 132 | +
|
| 133 | + Args: |
| 134 | + app: The FastAPI application |
| 135 | + mcp_server: The MCP server to mount |
| 136 | + operation_map: A mapping of operation IDs to operation details |
| 137 | + mount_path: Path where the MCP server will be mounted |
| 138 | + base_url: Base URL for API requests |
| 139 | + """ |
| 140 | + # Normalize mount path |
| 141 | + if not self._mount_path.startswith("/"): |
| 142 | + self._mount_path = f"/{self._mount_path}" |
| 143 | + if self._mount_path.endswith("/"): |
| 144 | + self._mount_path = self._mount_path[:-1] |
| 145 | + |
| 146 | + # Create SSE transport for MCP messages |
| 147 | + sse_transport = SseServerTransport(f"{self._mount_path}/messages/") |
| 148 | + |
| 149 | + # Define MCP connection handler |
| 150 | + async def handle_mcp_connection(request: Request): |
| 151 | + async with sse_transport.connect_sse(request.scope, request.receive, request._send) as streams: |
| 152 | + await self.mcp_server.run( |
| 153 | + streams[0], |
| 154 | + streams[1], |
| 155 | + self.mcp_server.create_initialization_options( |
| 156 | + notification_options=None, experimental_capabilities={} |
| 157 | + ), |
| 158 | + ) |
| 159 | + |
| 160 | + # Mount the MCP connection handler |
| 161 | + self.fastapi.get(self._mount_path)(handle_mcp_connection) |
| 162 | + self.fastapi.mount(f"{self._mount_path}/messages/", app=sse_transport.handle_post_message) |
0 commit comments