1
1
import json
2
2
import httpx
3
- from typing import Dict , Optional , Any , List , Union , Callable , Awaitable , Iterable , Literal , Sequence
3
+ from typing import Dict , Optional , Any , List , Union , Literal , Sequence
4
4
from typing_extensions import Annotated , Doc
5
5
6
6
from fastapi import FastAPI , Request , APIRouter , params
19
19
logger = logging .getLogger (__name__ )
20
20
21
21
22
- class LowlevelMCPServer (Server ):
23
- def call_tool (self ):
24
- """
25
- A near-direct copy of `mcp.server.lowlevel.server.Server.call_tool()`, except that it looks for
26
- the original HTTP request info in the MCP message, and passes it to the tool call handler.
27
- """
28
-
29
- def decorator (
30
- func : Callable [
31
- ...,
32
- Awaitable [Iterable [types .TextContent | types .ImageContent | types .EmbeddedResource ]],
33
- ],
34
- ):
35
- logger .debug ("Registering handler for CallToolRequest" )
36
-
37
- async def handler (req : types .CallToolRequest ):
38
- try :
39
- # HACK: Pull the original HTTP request info from the MCP message. It was injected in
40
- # `FastApiSseTransport.handle_fastapi_post_message()`
41
- if hasattr (req .params , "_http_request_info" ) and req .params ._http_request_info is not None :
42
- http_request_info = HTTPRequestInfo .model_validate (req .params ._http_request_info )
43
- results = await func (req .params .name , (req .params .arguments or {}), http_request_info )
44
- else :
45
- results = await func (req .params .name , (req .params .arguments or {}))
46
- return types .ServerResult (types .CallToolResult (content = list (results ), isError = False ))
47
- except Exception as e :
48
- return types .ServerResult (
49
- types .CallToolResult (
50
- content = [types .TextContent (type = "text" , text = str (e ))],
51
- isError = True ,
52
- )
53
- )
54
-
55
- self .request_handlers [types .CallToolRequest ] = handler
56
- return func
57
-
58
- return decorator
59
-
60
-
61
22
class FastApiMCP :
62
23
"""
63
24
Create an MCP server from a FastAPI app.
@@ -115,14 +76,14 @@ def __init__(
115
76
Doc ("Configuration for MCP authentication" ),
116
77
] = None ,
117
78
headers : Annotated [
118
- Optional [ List [str ] ],
79
+ List [str ],
119
80
Doc (
120
81
"""
121
82
List of HTTP header names to forward from the incoming MCP request into each tool invocation.
122
83
Only headers in this allowlist will be forwarded. Defaults to ['authorization'].
123
84
"""
124
85
),
125
- ] = None ,
86
+ ] = [ "authorization" ] ,
126
87
):
127
88
# Validate operation and tag filtering options
128
89
if include_operations is not None and exclude_operations is not None :
@@ -157,7 +118,7 @@ def __init__(
157
118
timeout = 10.0 ,
158
119
)
159
120
160
- self ._forward_headers = {h .lower () for h in ( headers or [ "Authorization" ]) }
121
+ self ._forward_headers = {h .lower () for h in headers }
161
122
162
123
self .setup_server ()
163
124
@@ -179,16 +140,40 @@ def setup_server(self) -> None:
179
140
# Filter tools based on operation IDs and tags
180
141
self .tools = self ._filter_tools (all_tools , openapi_schema )
181
142
182
- mcp_server : LowlevelMCPServer = LowlevelMCPServer (self .name , self .description )
143
+ mcp_server : Server = Server (self .name , self .description )
183
144
184
145
@mcp_server .list_tools ()
185
146
async def handle_list_tools () -> List [types .Tool ]:
186
147
return self .tools
187
148
188
149
@mcp_server .call_tool ()
189
150
async def handle_call_tool (
190
- name : str , arguments : Dict [str , Any ], http_request_info : Optional [ HTTPRequestInfo ] = None
151
+ name : str , arguments : Dict [str , Any ]
191
152
) -> List [Union [types .TextContent , types .ImageContent , types .EmbeddedResource ]]:
153
+ # Extract HTTP request info from MCP context
154
+ http_request_info = None
155
+ try :
156
+ # Access the MCP server's request context to get the original HTTP Request
157
+ request_context = mcp_server .request_context
158
+
159
+ if request_context and hasattr (request_context , "request" ):
160
+ http_request = request_context .request
161
+
162
+ if http_request and hasattr (http_request , "method" ):
163
+ http_request_info = HTTPRequestInfo (
164
+ method = http_request .method ,
165
+ path = http_request .url .path ,
166
+ headers = dict (http_request .headers ),
167
+ cookies = http_request .cookies ,
168
+ query_params = dict (http_request .query_params ),
169
+ body = None ,
170
+ )
171
+ logger .debug (
172
+ f"Extracted HTTP request info from context: { http_request_info .method } { http_request_info .path } "
173
+ )
174
+ except (LookupError , AttributeError ) as e :
175
+ logger .error (f"Could not extract HTTP request info from context: { e } " )
176
+
192
177
return await self ._execute_api_tool (
193
178
client = self ._http_client ,
194
179
tool_name = name ,
0 commit comments