Skip to content

Commit e0d405d

Browse files
committed
enhance test suite and fix errors
1 parent 9bc99f5 commit e0d405d

File tree

8 files changed

+660
-124
lines changed

8 files changed

+660
-124
lines changed

fastapi_mcp/server.py

Lines changed: 82 additions & 65 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
11
import json
22
import httpx
3-
from contextlib import asynccontextmanager
4-
from typing import Dict, Optional, Any, List, Union, AsyncIterator
3+
from typing import Dict, Optional, Any, List, Union
54

65
from fastapi import FastAPI, Request, APIRouter
76
from fastapi.openapi.utils import get_openapi
@@ -10,6 +9,7 @@
109

1110
from fastapi_mcp.openapi.convert import convert_openapi_to_mcp_tools
1211
from fastapi_mcp.transport.sse import FastApiSseTransport
12+
from fastapi_mcp.types import AsyncClientProtocol
1313

1414
from logging import getLogger
1515

@@ -26,7 +26,24 @@ def __init__(
2626
base_url: Optional[str] = None,
2727
describe_all_responses: bool = False,
2828
describe_full_response_schema: bool = False,
29+
http_client: Optional[AsyncClientProtocol] = None,
2930
):
31+
"""
32+
Create an MCP server from a FastAPI app.
33+
34+
Args:
35+
fastapi: The FastAPI application
36+
name: Name for the MCP server (defaults to app.title)
37+
description: Description for the MCP server (defaults to app.description)
38+
base_url: Base URL for API requests. If not provided, the base URL will be determined from the
39+
FastAPI app's root path. Although optional, it is highly recommended to provide a base URL,
40+
as the root path would be different when the app is deployed.
41+
describe_all_responses: Whether to include all possible response schemas in tool descriptions
42+
describe_full_response_schema: Whether to include full json schema for responses in tool descriptions
43+
http_client: Optional HTTP client to use for API calls. If not provided, a new httpx.AsyncClient will be created.
44+
This is primarily for testing purposes.
45+
"""
46+
3047
self.operation_map: Dict[str, Dict[str, Any]]
3148
self.tools: List[types.Tool]
3249

@@ -38,27 +55,11 @@ def __init__(
3855
self._describe_all_responses = describe_all_responses
3956
self._describe_full_response_schema = describe_full_response_schema
4057

58+
self._http_client = http_client or httpx.AsyncClient()
59+
4160
self.server = self.create_server()
4261

4362
def create_server(self) -> Server:
44-
"""
45-
Create an MCP server from the FastAPI app.
46-
47-
Args:
48-
fastapi: The FastAPI application
49-
name: Name for the MCP server (defaults to app.title)
50-
description: Description for the MCP server (defaults to app.description)
51-
base_url: Base URL for API requests. If not provided, the base URL will be determined from the
52-
FastAPI app's root path. Although optional, it is highly recommended to provide a base URL,
53-
as the root path would be different when the app is deployed.
54-
describe_all_responses: Whether to include all possible response schemas in tool descriptions
55-
describe_full_response_schema: Whether to include full json schema for responses in tool descriptions
56-
57-
Returns:
58-
A tuple containing:
59-
- The created MCP Server instance (NOT mounted to the app)
60-
- A mapping of operation IDs to operation details for HTTP execution
61-
"""
6263
# Get OpenAPI schema from FastAPI app
6364
openapi_schema = get_openapi(
6465
title=self.fastapi.title,
@@ -93,38 +94,26 @@ def create_server(self) -> Server:
9394
if self._base_url.endswith("/"):
9495
self._base_url = self._base_url[:-1]
9596

96-
# Create the MCP server
97+
# Create the MCP lowlevel server
9798
mcp_server: Server = Server(self.name, self.description)
9899

99-
# Create a lifespan context manager to store the base_url and operation_map
100-
@asynccontextmanager
101-
async def server_lifespan(server) -> AsyncIterator[Dict[str, Any]]:
102-
# Store context data that will be available to all server handlers
103-
context = {"base_url": self._base_url, "operation_map": self.operation_map}
104-
yield context
105-
106-
# Use our custom lifespan
107-
mcp_server.lifespan = server_lifespan
108-
109100
# Register handlers for tools
110101
@mcp_server.list_tools()
111102
async def handle_list_tools() -> List[types.Tool]:
112-
"""Handler for the tools/list request"""
113103
return self.tools
114104

115105
# Register the tool call handler
116106
@mcp_server.call_tool()
117107
async def handle_call_tool(
118108
name: str, arguments: Dict[str, Any]
119109
) -> List[Union[types.TextContent, types.ImageContent, types.EmbeddedResource]]:
120-
"""Handler for the tools/call request"""
121-
# Get context from server lifespan
122-
ctx = mcp_server.request_context
123-
base_url = ctx.lifespan_context["base_url"]
124-
operation_map = ctx.lifespan_context["operation_map"]
125-
126-
# Execute the tool
127-
return await self.execute_api_tool(base_url, name, arguments, operation_map)
110+
return await self._execute_api_tool(
111+
client=self._http_client,
112+
base_url=self._base_url or "",
113+
tool_name=name,
114+
arguments=arguments,
115+
operation_map=self.operation_map,
116+
)
128117

129118
return mcp_server
130119

@@ -168,8 +157,13 @@ async def handle_post_message(request: Request):
168157

169158
logger.info(f"MCP server listening at {mount_path}")
170159

171-
async def execute_api_tool(
172-
self, base_url: str, tool_name: str, arguments: Dict[str, Any], operation_map: Dict[str, Dict[str, Any]]
160+
async def _execute_api_tool(
161+
self,
162+
client: AsyncClientProtocol,
163+
base_url: str,
164+
tool_name: str,
165+
arguments: Dict[str, Any],
166+
operation_map: Dict[str, Dict[str, Any]],
173167
) -> List[Union[types.TextContent, types.ImageContent, types.EmbeddedResource]]:
174168
"""
175169
Execute an MCP tool by making an HTTP request to the corresponding API endpoint.
@@ -179,20 +173,20 @@ async def execute_api_tool(
179173
tool_name: The name of the tool to execute
180174
arguments: The arguments for the tool
181175
operation_map: A mapping from tool names to operation details
176+
client: Optional HTTP client to use (primarily for testing)
182177
183178
Returns:
184179
The result as MCP content types
185180
"""
186181
if tool_name not in operation_map:
187-
return [types.TextContent(type="text", text=f"Unknown tool: {tool_name}")]
182+
raise Exception(f"Unknown tool: {tool_name}")
188183

189184
operation = operation_map[tool_name]
190185
path: str = operation["path"]
191186
method: str = operation["method"]
192187
parameters: List[Dict[str, Any]] = operation.get("parameters", [])
193188
arguments = arguments.copy() if arguments else {} # Deep copy arguments to avoid mutating the original
194189

195-
# Prepare URL with path parameters
196190
url = f"{base_url}{path}"
197191
for param in parameters:
198192
if param.get("in") == "path" and param.get("name") in arguments:
@@ -201,7 +195,6 @@ async def execute_api_tool(
201195
raise ValueError(f"Parameter name is None for parameter: {param}")
202196
url = url.replace(f"{{{param_name}}}", str(arguments.pop(param_name)))
203197

204-
# Prepare query parameters
205198
query = {}
206199
for param in parameters:
207200
if param.get("in") == "query" and param.get("name") in arguments:
@@ -210,7 +203,6 @@ async def execute_api_tool(
210203
raise ValueError(f"Parameter name is None for parameter: {param}")
211204
query[param_name] = arguments.pop(param_name)
212205

213-
# Prepare headers
214206
headers = {}
215207
for param in parameters:
216208
if param.get("in") == "header" and param.get("name") in arguments:
@@ -219,32 +211,57 @@ async def execute_api_tool(
219211
raise ValueError(f"Parameter name is None for parameter: {param}")
220212
headers[param_name] = arguments.pop(param_name)
221213

222-
# Prepare request body (remaining kwargs)
223214
body = arguments if arguments else None
224215

225216
try:
226-
# Make request
227217
logger.debug(f"Making {method.upper()} request to {url}")
228-
async with httpx.AsyncClient() as client:
229-
if method.lower() == "get":
230-
response = await client.get(url, params=query, headers=headers)
231-
elif method.lower() == "post":
232-
response = await client.post(url, params=query, headers=headers, json=body)
233-
elif method.lower() == "put":
234-
response = await client.put(url, params=query, headers=headers, json=body)
235-
elif method.lower() == "delete":
236-
response = await client.delete(url, params=query, headers=headers)
237-
elif method.lower() == "patch":
238-
response = await client.patch(url, params=query, headers=headers, json=body)
239-
else:
240-
return [types.TextContent(type="text", text=f"Unsupported HTTP method: {method}")]
218+
response = await self._request(client, method, url, query, headers, body)
241219

242-
# Process response
220+
# TODO: Better typing for the AsyncClientProtocol. It should return a ResponseProtocol that has a json() method that returns a dict/list/etc.
243221
try:
244222
result = response.json()
245-
return [types.TextContent(type="text", text=json.dumps(result, indent=2))]
223+
result_text = json.dumps(result, indent=2)
224+
except json.JSONDecodeError:
225+
if hasattr(response, "text"):
226+
result_text = response.text
227+
else:
228+
result_text = response.content
229+
230+
# If not raising an exception, the MCP server will return the result as a regular text response, without marking it as an error.
231+
# TODO: Use a raise_for_status() method on the response (it needs to also be implemented in the AsyncClientProtocol)
232+
if 400 <= response.status_code < 600:
233+
raise Exception(
234+
f"Error calling {tool_name}. Status code: {response.status_code}. Response: {response.text}"
235+
)
236+
237+
try:
238+
return [types.TextContent(type="text", text=result_text)]
246239
except ValueError:
247-
return [types.TextContent(type="text", text=response.text)]
240+
return [types.TextContent(type="text", text=result_text)]
248241

249242
except Exception as e:
250-
return [types.TextContent(type="text", text=f"Error calling {tool_name}: {str(e)}")]
243+
logger.exception(f"Error calling {tool_name}")
244+
raise e
245+
246+
async def _request(
247+
self,
248+
client: AsyncClientProtocol,
249+
method: str,
250+
url: str,
251+
query: Dict[str, Any],
252+
headers: Dict[str, str],
253+
body: Optional[Any],
254+
) -> Any:
255+
"""Helper method to make the actual HTTP request"""
256+
if method.lower() == "get":
257+
return await client.get(url, params=query, headers=headers)
258+
elif method.lower() == "post":
259+
return await client.post(url, params=query, headers=headers, json=body)
260+
elif method.lower() == "put":
261+
return await client.put(url, params=query, headers=headers, json=body)
262+
elif method.lower() == "delete":
263+
return await client.delete(url, params=query, headers=headers)
264+
elif method.lower() == "patch":
265+
return await client.patch(url, params=query, headers=headers, json=body)
266+
else:
267+
raise ValueError(f"Unsupported HTTP method: {method}")

fastapi_mcp/types.py

Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,46 @@
1+
from pydantic import BaseModel, ConfigDict
2+
3+
from typing import Any, Protocol, Optional, Dict
4+
5+
6+
class BaseType(BaseModel):
7+
model_config = ConfigDict(extra="forbid")
8+
9+
10+
class AsyncClientProtocol(Protocol):
11+
"""Protocol defining the interface for async HTTP clients."""
12+
13+
async def get(
14+
self, url: str, *, params: Optional[Dict[str, Any]] = None, headers: Optional[Dict[str, str]] = None
15+
) -> Any: ...
16+
17+
async def post(
18+
self,
19+
url: str,
20+
*,
21+
params: Optional[Dict[str, Any]] = None,
22+
headers: Optional[Dict[str, str]] = None,
23+
json: Optional[Any] = None,
24+
) -> Any: ...
25+
26+
async def put(
27+
self,
28+
url: str,
29+
*,
30+
params: Optional[Dict[str, Any]] = None,
31+
headers: Optional[Dict[str, str]] = None,
32+
json: Optional[Any] = None,
33+
) -> Any: ...
34+
35+
async def delete(
36+
self, url: str, *, params: Optional[Dict[str, Any]] = None, headers: Optional[Dict[str, str]] = None
37+
) -> Any: ...
38+
39+
async def patch(
40+
self,
41+
url: str,
42+
*,
43+
params: Optional[Dict[str, Any]] = None,
44+
headers: Optional[Dict[str, str]] = None,
45+
json: Optional[Any] = None,
46+
) -> Any: ...

fastapi_mcp/utils/__init__.py

Whitespace-only changes.

fastapi_mcp/utils/testing.py

Lines changed: 63 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,63 @@
1+
import json
2+
from typing import Any, Dict, Optional
3+
4+
from fastapi import FastAPI
5+
from fastapi.testclient import TestClient
6+
7+
from fastapi_mcp.server import AsyncClientProtocol
8+
9+
10+
class FastAPITestClient(AsyncClientProtocol):
11+
def __init__(self, app: FastAPI):
12+
self.client = TestClient(app, raise_server_exceptions=False)
13+
14+
async def get(
15+
self, url: str, *, params: Optional[Dict[str, Any]] = None, headers: Optional[Dict[str, str]] = None
16+
) -> Any:
17+
response = self.client.get(url, params=params, headers=headers)
18+
return self._wrap_response(response)
19+
20+
async def post(
21+
self,
22+
url: str,
23+
*,
24+
params: Optional[Dict[str, Any]] = None,
25+
headers: Optional[Dict[str, str]] = None,
26+
json: Optional[Any] = None,
27+
) -> Any:
28+
response = self.client.post(url, params=params, headers=headers, json=json)
29+
return self._wrap_response(response)
30+
31+
async def put(
32+
self,
33+
url: str,
34+
*,
35+
params: Optional[Dict[str, Any]] = None,
36+
headers: Optional[Dict[str, str]] = None,
37+
json: Optional[Any] = None,
38+
) -> Any:
39+
response = self.client.put(url, params=params, headers=headers, json=json)
40+
return self._wrap_response(response)
41+
42+
async def delete(
43+
self, url: str, *, params: Optional[Dict[str, Any]] = None, headers: Optional[Dict[str, str]] = None
44+
) -> Any:
45+
response = self.client.delete(url, params=params, headers=headers)
46+
return self._wrap_response(response)
47+
48+
async def patch(
49+
self,
50+
url: str,
51+
*,
52+
params: Optional[Dict[str, Any]] = None,
53+
headers: Optional[Dict[str, str]] = None,
54+
json: Optional[Any] = None,
55+
) -> Any:
56+
response = self.client.patch(url, params=params, headers=headers, json=json)
57+
return self._wrap_response(response)
58+
59+
def _wrap_response(self, response: Any) -> Any:
60+
response.json = (
61+
lambda: json.loads(response.content) if hasattr(response, "content") and response.content else None
62+
)
63+
return response

0 commit comments

Comments
 (0)