Skip to content

Commit 9ae03db

Browse files
committed
fix: show initialization errors for all the MCP clients
1 parent e8b675f commit 9ae03db

File tree

3 files changed

+214
-95
lines changed

3 files changed

+214
-95
lines changed
Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,31 @@
1+
import logging
2+
import mcp.types as mt
3+
from fastmcp.server.middleware import CallNext, Middleware, MiddlewareContext
4+
from mcp_proxy_for_aws.proxy import AWSMCPProxyClientFactory
5+
from typing_extensions import override
6+
7+
8+
logger = logging.getLogger(__name__)
9+
10+
11+
class InitializeMiddleware(Middleware):
12+
"""Intecept MCP initialize request and initialize the proxy client."""
13+
14+
def __init__(self, client_factory: AWSMCPProxyClientFactory) -> None:
15+
"""Create a middleware with client factory."""
16+
super().__init__()
17+
self._client_factory = client_factory
18+
19+
@override
20+
async def on_initialize(
21+
self,
22+
context: MiddlewareContext[mt.InitializeRequest],
23+
call_next: CallNext[mt.InitializeRequest, None],
24+
) -> None:
25+
try:
26+
logger.debug('Received initiqlize request %s.', context.message)
27+
self._client_factory.set_init_params(context.message)
28+
return await call_next(context)
29+
except Exception:
30+
logger.exception('Initialize failed in middleware.')
31+
raise

mcp_proxy_for_aws/proxy.py

Lines changed: 158 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,158 @@
1+
import httpx
2+
import logging
3+
from fastmcp import Client
4+
from fastmcp.client.transports import ClientTransport
5+
from fastmcp.exceptions import NotFoundError
6+
from fastmcp.server.proxy import ClientFactoryT
7+
from fastmcp.server.proxy import FastMCPProxy as _FastMCPProxy
8+
from fastmcp.server.proxy import ProxyClient as _ProxyClient
9+
from fastmcp.server.proxy import ProxyToolManager as _ProxyToolManager
10+
from fastmcp.tools import Tool
11+
from mcp import McpError
12+
from mcp.types import InitializeRequest, JSONRPCError, JSONRPCMessage
13+
from typing import Any
14+
from typing_extensions import override
15+
16+
17+
logger = logging.getLogger(__name__)
18+
19+
20+
class AWSProxyToolManager(_ProxyToolManager):
21+
"""Customized proxy tool manager that better suites our needs."""
22+
23+
def __init__(self, client_factory: ClientFactoryT, **kwargs: Any):
24+
"""Initialize a proxy tool manager.
25+
26+
Cached tools are set to None.
27+
"""
28+
super().__init__(client_factory, **kwargs)
29+
self._cached_tools: dict[str, Tool] | None = None
30+
31+
@override
32+
async def get_tool(self, key: str) -> Tool:
33+
"""Return the tool from cached tools.
34+
35+
This method is invoked when the client tries to call a tool.
36+
37+
tool = self.get_tool(key)
38+
tool.invoke(...)
39+
40+
The parent class implementation always make a mcp call to list the tools.
41+
Since the client already knows the name of the tools, list_tool is not necessary.
42+
We are wasting a network call just to get the tools which were already listed.
43+
44+
In case the server supports notifications/tools/listChanged, the `get_tools` method
45+
will be called explicity , hence, we are not missing the change to the tool list.
46+
"""
47+
if self._cached_tools is None:
48+
logger.debug('cached_tools not found, calling get_tools')
49+
self._cached_tools = await self.get_tools()
50+
if key in self._cached_tools:
51+
return self._cached_tools[key]
52+
raise NotFoundError(f'Tool {key!r} not found')
53+
54+
@override
55+
async def get_tools(self) -> dict[str, Tool]:
56+
"""Return list tools."""
57+
self._cached_tools = await super(AWSProxyToolManager, self).get_tools()
58+
return self._cached_tools
59+
60+
61+
class AWSMCPProxy(_FastMCPProxy):
62+
"""Customized MCP Proxy to better suite our needs."""
63+
64+
def __init__(
65+
self,
66+
*,
67+
client_factory: ClientFactoryT | None = None,
68+
**kwargs,
69+
):
70+
"""Initialize a client."""
71+
super().__init__(client_factory=client_factory, **kwargs)
72+
self._tool_manager = AWSProxyToolManager(
73+
client_factory=self.client_factory,
74+
transformations=self._tool_manager.transformations,
75+
)
76+
77+
78+
class AWSMCPProxyClient(_ProxyClient):
79+
"""Proxy client that handles HTTP errors when connection fails."""
80+
81+
def __init__(self, transport: ClientTransport, **kwargs):
82+
"""Constructor of AutoRefreshProxyCilent."""
83+
super().__init__(transport, **kwargs)
84+
85+
@override
86+
async def _connect(self):
87+
"""Enter as normal && initialize only once."""
88+
logger.debug('Connecting %s', self)
89+
try:
90+
result = await super(AWSMCPProxyClient, self)._connect()
91+
logger.debug('Connected %s', self)
92+
return result
93+
except httpx.HTTPStatusError as http_error:
94+
logger.exception('Connection failed')
95+
response = http_error.response
96+
try:
97+
body = await response.aread()
98+
jsonrpc_msg = JSONRPCMessage.model_validate_json(body).root
99+
except Exception:
100+
logger.debug('HTTP error is not a valid MCP message.')
101+
raise http_error
102+
103+
if isinstance(jsonrpc_msg, JSONRPCError):
104+
logger.debug('Converting HTTP error to MCP error %s', http_error)
105+
# raising McpError so that the sdk can handle the exception properly
106+
raise McpError(error=jsonrpc_msg.error) from http_error
107+
else:
108+
raise http_error
109+
110+
async def __aexit__(self, exc_type, exc_val, exc_tb):
111+
"""The MCP Proxy for AWS project is a proxy from stdio to http (sigv4).
112+
113+
We want the client to remain connected in the until the stdio connection is closed.
114+
115+
https://modelcontextprotocol.io/specification/2024-11-05/basic/transports#stdio
116+
117+
1. close stdin
118+
2. terminate subprocess
119+
120+
There is no equivalent of the streamble-http DELETE concept in stdio to terminate a session.
121+
Hence the connection will be terminated only at program exit.
122+
"""
123+
# return await super().__aexit__(exc_type, exc_val, exc_tb)
124+
pass
125+
126+
127+
class AWSMCPProxyClientFactory:
128+
"""Client factory that returns a connected client."""
129+
130+
def __init__(self, transport: ClientTransport) -> None:
131+
"""Initialize a client factory with transport."""
132+
self._transport = transport
133+
self._client = AWSMCPProxyClient(transport)
134+
self._clients: list[AWSMCPProxyClient] = []
135+
self._initialize_request: InitializeRequest | None = None
136+
137+
def set_init_params(self, initialize_request: InitializeRequest):
138+
"""Set client init parameters."""
139+
self._initialize_request = initialize_request
140+
141+
async def get_client(self) -> Client:
142+
"""Get client."""
143+
if not self._client.is_connected():
144+
self._client = AWSMCPProxyClient(self._transport)
145+
146+
return self._client
147+
148+
async def __call__(self) -> Client:
149+
"""Implement the callable factory interface."""
150+
return await self.get_client()
151+
152+
async def disconnect_all(self):
153+
"""Disconnect all the clients (no throw)."""
154+
for client in reversed(self._clients):
155+
try:
156+
await client._disconnect(force=True)
157+
except Exception:
158+
logger.exception('Failed to disconnect client.')

mcp_proxy_for_aws/server.py

Lines changed: 25 additions & 95 deletions
Original file line numberDiff line numberDiff line change
@@ -23,26 +23,16 @@
2323
"""
2424

2525
import asyncio
26-
import contextlib
2726
import httpx
2827
import logging
29-
import sys
30-
from fastmcp.client import ClientTransport
3128
from fastmcp.server.middleware.error_handling import RetryMiddleware
3229
from fastmcp.server.middleware.logging import LoggingMiddleware
33-
from fastmcp.server.proxy import FastMCPProxy, ProxyClient
3430
from fastmcp.server.server import FastMCP
35-
from mcp import McpError
36-
from mcp.types import (
37-
CONNECTION_CLOSED,
38-
ErrorData,
39-
JSONRPCError,
40-
JSONRPCMessage,
41-
JSONRPCResponse,
42-
)
4331
from mcp_proxy_for_aws.cli import parse_args
4432
from mcp_proxy_for_aws.logging_config import configure_logging
33+
from mcp_proxy_for_aws.middleware.initialize_middleware import InitializeMiddleware
4534
from mcp_proxy_for_aws.middleware.tool_filter import ToolFilteringMiddleware
35+
from mcp_proxy_for_aws.proxy import AWSMCPProxy, AWSMCPProxyClientFactory
4636
from mcp_proxy_for_aws.utils import (
4737
create_transport_with_sigv4,
4838
determine_aws_region,
@@ -53,62 +43,9 @@
5343
logger = logging.getLogger(__name__)
5444

5545

56-
@contextlib.asynccontextmanager
57-
async def _initialize_client(transport: ClientTransport):
58-
"""Handle the exceptions for during client initialize."""
59-
async with contextlib.AsyncExitStack() as stack:
60-
try:
61-
client = await stack.enter_async_context(ProxyClient(transport))
62-
except httpx.HTTPStatusError as http_error:
63-
logger.error('HTTP Error during initialize %s', http_error)
64-
response = http_error.response
65-
try:
66-
body = await response.aread()
67-
jsonrpc_msg = JSONRPCMessage.model_validate_json(body).root
68-
if isinstance(jsonrpc_msg, (JSONRPCError, JSONRPCResponse)):
69-
line = jsonrpc_msg.model_dump_json(
70-
by_alias=True,
71-
exclude_none=True,
72-
)
73-
logger.debug('Writing the unhandled http error to stdout %s', http_error)
74-
print(line, file=sys.stdout)
75-
else:
76-
logger.debug('Ignoring jsonrpc message type=%s', type(jsonrpc_msg))
77-
except Exception as _:
78-
logger.debug('Cannot read HTTP response body')
79-
raise http_error
80-
except Exception as e:
81-
cause = e.__cause__
82-
if isinstance(cause, McpError):
83-
logger.error('MCP Error during initialize %s', cause.error)
84-
jsonrpc_error = JSONRPCError(jsonrpc='2.0', id=0, error=cause.error)
85-
line = jsonrpc_error.model_dump_json(
86-
by_alias=True,
87-
exclude_none=True,
88-
)
89-
else:
90-
logger.error('Error during initialize %s', e)
91-
jsonrpc_error = JSONRPCError(
92-
jsonrpc='2.0',
93-
id=0,
94-
error=ErrorData(
95-
code=CONNECTION_CLOSED,
96-
message=str(e),
97-
),
98-
)
99-
line = jsonrpc_error.model_dump_json(
100-
by_alias=True,
101-
exclude_none=True,
102-
)
103-
print(line, file=sys.stdout)
104-
raise e
105-
logger.debug('Initialized MCP client')
106-
yield client
107-
108-
10946
async def run_proxy(args) -> None:
11047
"""Set up the server in MCP mode."""
111-
logger.info('Setting up server in MCP mode')
48+
logger.info('Setting up mcp proxy server to %s', args.endpoint)
11249

11350
# Validate and determine service
11451
service = determine_service_name(args.endpoint, args.service)
@@ -134,7 +71,6 @@ async def run_proxy(args) -> None:
13471
metadata,
13572
profile,
13673
)
137-
logger.info('Running in MCP mode')
13874

13975
timeout = httpx.Timeout(
14076
args.timeout,
@@ -147,35 +83,29 @@ async def run_proxy(args) -> None:
14783
transport = create_transport_with_sigv4(
14884
args.endpoint, service, region, metadata, timeout, profile
14985
)
86+
client_factory = AWSMCPProxyClientFactory(transport)
15087

151-
async with _initialize_client(transport) as client:
152-
153-
async def client_factory():
154-
nonlocal client
155-
if not client.is_connected():
156-
logger.debug('Reinitialize client')
157-
client = ProxyClient(transport)
158-
await client._connect()
159-
return client
160-
161-
try:
162-
proxy = FastMCPProxy(
163-
client_factory=client_factory,
164-
name='MCP Proxy for AWS',
165-
instructions=(
166-
'MCP Proxy for AWS provides access to SigV4 protected MCP servers through a single interface. '
167-
'This proxy handles authentication and request routing to the appropriate backend services.'
168-
),
169-
)
170-
add_logging_middleware(proxy, args.log_level)
171-
add_tool_filtering_middleware(proxy, args.read_only)
172-
173-
if args.retries:
174-
add_retry_middleware(proxy, args.retries)
175-
await proxy.run_async(transport='stdio', show_banner=False, log_level=args.log_level)
176-
except Exception as e:
177-
logger.error('Cannot start proxy server: %s', e)
178-
raise e
88+
try:
89+
proxy = AWSMCPProxy(
90+
client_factory=client_factory,
91+
name='MCP Proxy for AWS',
92+
instructions=(
93+
'MCP Proxy for AWS provides access to SigV4 protected MCP servers through a single interface. '
94+
'This proxy handles authentication and request routing to the appropriate backend services.'
95+
),
96+
)
97+
proxy.add_middleware(InitializeMiddleware(client_factory))
98+
add_logging_middleware(proxy, args.log_level)
99+
add_tool_filtering_middleware(proxy, args.read_only)
100+
101+
if args.retries:
102+
add_retry_middleware(proxy, args.retries)
103+
await proxy.run_async(transport='stdio', show_banner=False, log_level=args.log_level)
104+
except Exception as e:
105+
logger.error('Cannot start proxy server: %s', e)
106+
raise e
107+
finally:
108+
await client_factory.disconnect_all()
179109

180110

181111
def add_tool_filtering_middleware(mcp: FastMCP, read_only: bool = False) -> None:

0 commit comments

Comments
 (0)