1
1
import json
2
2
import httpx
3
- from contextlib import asynccontextmanager
4
- from typing import Dict , Optional , Any , List , Union , AsyncIterator
3
+ from typing import Dict , Optional , Any , List , Union
5
4
6
5
from fastapi import FastAPI , Request , APIRouter
7
6
from fastapi .openapi .utils import get_openapi
10
9
11
10
from fastapi_mcp .openapi .convert import convert_openapi_to_mcp_tools
12
11
from fastapi_mcp .transport .sse import FastApiSseTransport
12
+ from fastapi_mcp .types import AsyncClientProtocol
13
13
14
14
from logging import getLogger
15
15
@@ -26,7 +26,24 @@ def __init__(
26
26
base_url : Optional [str ] = None ,
27
27
describe_all_responses : bool = False ,
28
28
describe_full_response_schema : bool = False ,
29
+ http_client : Optional [AsyncClientProtocol ] = None ,
29
30
):
31
+ """
32
+ Create an MCP server from a FastAPI app.
33
+
34
+ Args:
35
+ fastapi: The FastAPI application
36
+ name: Name for the MCP server (defaults to app.title)
37
+ description: Description for the MCP server (defaults to app.description)
38
+ base_url: Base URL for API requests. If not provided, the base URL will be determined from the
39
+ FastAPI app's root path. Although optional, it is highly recommended to provide a base URL,
40
+ as the root path would be different when the app is deployed.
41
+ describe_all_responses: Whether to include all possible response schemas in tool descriptions
42
+ describe_full_response_schema: Whether to include full json schema for responses in tool descriptions
43
+ http_client: Optional HTTP client to use for API calls. If not provided, a new httpx.AsyncClient will be created.
44
+ This is primarily for testing purposes.
45
+ """
46
+
30
47
self .operation_map : Dict [str , Dict [str , Any ]]
31
48
self .tools : List [types .Tool ]
32
49
@@ -38,27 +55,11 @@ def __init__(
38
55
self ._describe_all_responses = describe_all_responses
39
56
self ._describe_full_response_schema = describe_full_response_schema
40
57
58
+ self ._http_client = http_client or httpx .AsyncClient ()
59
+
41
60
self .server = self .create_server ()
42
61
43
62
def create_server (self ) -> Server :
44
- """
45
- Create an MCP server from the FastAPI app.
46
-
47
- Args:
48
- fastapi: The FastAPI application
49
- name: Name for the MCP server (defaults to app.title)
50
- description: Description for the MCP server (defaults to app.description)
51
- base_url: Base URL for API requests. If not provided, the base URL will be determined from the
52
- FastAPI app's root path. Although optional, it is highly recommended to provide a base URL,
53
- as the root path would be different when the app is deployed.
54
- describe_all_responses: Whether to include all possible response schemas in tool descriptions
55
- describe_full_response_schema: Whether to include full json schema for responses in tool descriptions
56
-
57
- Returns:
58
- A tuple containing:
59
- - The created MCP Server instance (NOT mounted to the app)
60
- - A mapping of operation IDs to operation details for HTTP execution
61
- """
62
63
# Get OpenAPI schema from FastAPI app
63
64
openapi_schema = get_openapi (
64
65
title = self .fastapi .title ,
@@ -93,38 +94,26 @@ def create_server(self) -> Server:
93
94
if self ._base_url .endswith ("/" ):
94
95
self ._base_url = self ._base_url [:- 1 ]
95
96
96
- # Create the MCP server
97
+ # Create the MCP lowlevel server
97
98
mcp_server : Server = Server (self .name , self .description )
98
99
99
- # Create a lifespan context manager to store the base_url and operation_map
100
- @asynccontextmanager
101
- async def server_lifespan (server ) -> AsyncIterator [Dict [str , Any ]]:
102
- # Store context data that will be available to all server handlers
103
- context = {"base_url" : self ._base_url , "operation_map" : self .operation_map }
104
- yield context
105
-
106
- # Use our custom lifespan
107
- mcp_server .lifespan = server_lifespan
108
-
109
100
# Register handlers for tools
110
101
@mcp_server .list_tools ()
111
102
async def handle_list_tools () -> List [types .Tool ]:
112
- """Handler for the tools/list request"""
113
103
return self .tools
114
104
115
105
# Register the tool call handler
116
106
@mcp_server .call_tool ()
117
107
async def handle_call_tool (
118
108
name : str , arguments : Dict [str , Any ]
119
109
) -> List [Union [types .TextContent , types .ImageContent , types .EmbeddedResource ]]:
120
- """Handler for the tools/call request"""
121
- # Get context from server lifespan
122
- ctx = mcp_server .request_context
123
- base_url = ctx .lifespan_context ["base_url" ]
124
- operation_map = ctx .lifespan_context ["operation_map" ]
125
-
126
- # Execute the tool
127
- return await self .execute_api_tool (base_url , name , arguments , operation_map )
110
+ return await self ._execute_api_tool (
111
+ client = self ._http_client ,
112
+ base_url = self ._base_url or "" ,
113
+ tool_name = name ,
114
+ arguments = arguments ,
115
+ operation_map = self .operation_map ,
116
+ )
128
117
129
118
return mcp_server
130
119
@@ -168,8 +157,13 @@ async def handle_post_message(request: Request):
168
157
169
158
logger .info (f"MCP server listening at { mount_path } " )
170
159
171
- async def execute_api_tool (
172
- self , base_url : str , tool_name : str , arguments : Dict [str , Any ], operation_map : Dict [str , Dict [str , Any ]]
160
+ async def _execute_api_tool (
161
+ self ,
162
+ client : AsyncClientProtocol ,
163
+ base_url : str ,
164
+ tool_name : str ,
165
+ arguments : Dict [str , Any ],
166
+ operation_map : Dict [str , Dict [str , Any ]],
173
167
) -> List [Union [types .TextContent , types .ImageContent , types .EmbeddedResource ]]:
174
168
"""
175
169
Execute an MCP tool by making an HTTP request to the corresponding API endpoint.
@@ -179,20 +173,20 @@ async def execute_api_tool(
179
173
tool_name: The name of the tool to execute
180
174
arguments: The arguments for the tool
181
175
operation_map: A mapping from tool names to operation details
176
+ client: Optional HTTP client to use (primarily for testing)
182
177
183
178
Returns:
184
179
The result as MCP content types
185
180
"""
186
181
if tool_name not in operation_map :
187
- return [ types . TextContent ( type = "text" , text = f"Unknown tool: { tool_name } " )]
182
+ raise Exception ( f"Unknown tool: { tool_name } " )
188
183
189
184
operation = operation_map [tool_name ]
190
185
path : str = operation ["path" ]
191
186
method : str = operation ["method" ]
192
187
parameters : List [Dict [str , Any ]] = operation .get ("parameters" , [])
193
188
arguments = arguments .copy () if arguments else {} # Deep copy arguments to avoid mutating the original
194
189
195
- # Prepare URL with path parameters
196
190
url = f"{ base_url } { path } "
197
191
for param in parameters :
198
192
if param .get ("in" ) == "path" and param .get ("name" ) in arguments :
@@ -201,7 +195,6 @@ async def execute_api_tool(
201
195
raise ValueError (f"Parameter name is None for parameter: { param } " )
202
196
url = url .replace (f"{{{ param_name } }}" , str (arguments .pop (param_name )))
203
197
204
- # Prepare query parameters
205
198
query = {}
206
199
for param in parameters :
207
200
if param .get ("in" ) == "query" and param .get ("name" ) in arguments :
@@ -210,7 +203,6 @@ async def execute_api_tool(
210
203
raise ValueError (f"Parameter name is None for parameter: { param } " )
211
204
query [param_name ] = arguments .pop (param_name )
212
205
213
- # Prepare headers
214
206
headers = {}
215
207
for param in parameters :
216
208
if param .get ("in" ) == "header" and param .get ("name" ) in arguments :
@@ -219,32 +211,57 @@ async def execute_api_tool(
219
211
raise ValueError (f"Parameter name is None for parameter: { param } " )
220
212
headers [param_name ] = arguments .pop (param_name )
221
213
222
- # Prepare request body (remaining kwargs)
223
214
body = arguments if arguments else None
224
215
225
216
try :
226
- # Make request
227
217
logger .debug (f"Making { method .upper ()} request to { url } " )
228
- async with httpx .AsyncClient () as client :
229
- if method .lower () == "get" :
230
- response = await client .get (url , params = query , headers = headers )
231
- elif method .lower () == "post" :
232
- response = await client .post (url , params = query , headers = headers , json = body )
233
- elif method .lower () == "put" :
234
- response = await client .put (url , params = query , headers = headers , json = body )
235
- elif method .lower () == "delete" :
236
- response = await client .delete (url , params = query , headers = headers )
237
- elif method .lower () == "patch" :
238
- response = await client .patch (url , params = query , headers = headers , json = body )
239
- else :
240
- return [types .TextContent (type = "text" , text = f"Unsupported HTTP method: { method } " )]
218
+ response = await self ._request (client , method , url , query , headers , body )
241
219
242
- # Process response
220
+ # TODO: Better typing for the AsyncClientProtocol. It should return a ResponseProtocol that has a json() method that returns a dict/list/etc.
243
221
try :
244
222
result = response .json ()
245
- return [types .TextContent (type = "text" , text = json .dumps (result , indent = 2 ))]
223
+ result_text = json .dumps (result , indent = 2 )
224
+ except json .JSONDecodeError :
225
+ if hasattr (response , "text" ):
226
+ result_text = response .text
227
+ else :
228
+ result_text = response .content
229
+
230
+ # If not raising an exception, the MCP server will return the result as a regular text response, without marking it as an error.
231
+ # TODO: Use a raise_for_status() method on the response (it needs to also be implemented in the AsyncClientProtocol)
232
+ if 400 <= response .status_code < 600 :
233
+ raise Exception (
234
+ f"Error calling { tool_name } . Status code: { response .status_code } . Response: { response .text } "
235
+ )
236
+
237
+ try :
238
+ return [types .TextContent (type = "text" , text = result_text )]
246
239
except ValueError :
247
- return [types .TextContent (type = "text" , text = response . text )]
240
+ return [types .TextContent (type = "text" , text = result_text )]
248
241
249
242
except Exception as e :
250
- return [types .TextContent (type = "text" , text = f"Error calling { tool_name } : { str (e )} " )]
243
+ logger .exception (f"Error calling { tool_name } " )
244
+ raise e
245
+
246
+ async def _request (
247
+ self ,
248
+ client : AsyncClientProtocol ,
249
+ method : str ,
250
+ url : str ,
251
+ query : Dict [str , Any ],
252
+ headers : Dict [str , str ],
253
+ body : Optional [Any ],
254
+ ) -> Any :
255
+ """Helper method to make the actual HTTP request"""
256
+ if method .lower () == "get" :
257
+ return await client .get (url , params = query , headers = headers )
258
+ elif method .lower () == "post" :
259
+ return await client .post (url , params = query , headers = headers , json = body )
260
+ elif method .lower () == "put" :
261
+ return await client .put (url , params = query , headers = headers , json = body )
262
+ elif method .lower () == "delete" :
263
+ return await client .delete (url , params = query , headers = headers )
264
+ elif method .lower () == "patch" :
265
+ return await client .patch (url , params = query , headers = headers , json = body )
266
+ else :
267
+ raise ValueError (f"Unsupported HTTP method: { method } " )
0 commit comments