Skip to content

Commit 99ee583

Browse files
committed
fix: avoid infinite recursion
can happen when the network is down and connect keeps failing
1 parent aa601be commit 99ee583

File tree

3 files changed

+91
-12
lines changed

3 files changed

+91
-12
lines changed

mcp_proxy_for_aws/proxy.py

Lines changed: 20 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -78,12 +78,13 @@ def __init__(
7878
class AWSMCPProxyClient(_ProxyClient):
7979
"""Proxy client that handles HTTP errors when connection fails."""
8080

81-
def __init__(self, transport: ClientTransport, **kwargs):
81+
def __init__(self, transport: ClientTransport, max_connect_retry=3, **kwargs):
8282
"""Constructor of AutoRefreshProxyCilent."""
8383
super().__init__(transport, **kwargs)
84+
self._max_connect_retry = max_connect_retry
8485

8586
@override
86-
async def _connect(self):
87+
async def _connect(self, retry=0):
8788
"""Enter as normal && initialize only once."""
8889
logger.debug('Connecting %s', self)
8990
try:
@@ -96,27 +97,36 @@ async def _connect(self):
9697
try:
9798
body = await response.aread()
9899
jsonrpc_msg = JSONRPCMessage.model_validate_json(body).root
99-
except Exception:
100-
logger.debug('HTTP error is not a valid MCP message.')
100+
except Exception as e:
101+
logger.debug('HTTP error is not a valid MCP message.', exc_info=e)
101102
raise http_error
102103

103104
if isinstance(jsonrpc_msg, JSONRPCError):
104-
logger.debug('Converting HTTP error to MCP error %s', http_error)
105+
logger.debug('Converting HTTP error to MCP error', exc_info=http_error)
105106
# raising McpError so that the sdk can handle the exception properly
106107
raise McpError(error=jsonrpc_msg.error) from http_error
107108
else:
108109
raise http_error
109-
except RuntimeError:
110+
except RuntimeError as e:
111+
if isinstance(e.__cause__, McpError):
112+
raise e.__cause__
113+
114+
if retry > self._max_connect_retry:
115+
raise e
116+
110117
try:
111-
logger.warning('encountered runtime error, try force disconnect.')
118+
logger.warning('encountered runtime error, try force disconnect.', exc_info=e)
112119
await self._disconnect(force=True)
113-
except Exception:
120+
except httpx.TimeoutException:
114121
# _disconnect awaits on the session_task,
115122
# which raises the timeout error that caused the client session to be terminated.
116123
# the error is ignored as long as the counter is force set to 0.
117124
# TODO: investigate how timeout error is handled by fastmcp and httpx
118-
logger.exception('encountered another error, ignoring.')
119-
return await self._connect()
125+
logger.exception(
126+
'Session was terminated due to timeout error, ignore and reconnect'
127+
)
128+
129+
return await self._connect(retry + 1)
120130

121131
async def __aexit__(self, exc_type, exc_val, exc_tb):
122132
"""The MCP Proxy for AWS project is a proxy from stdio to http (sigv4).

tests/unit/test_proxy.py

Lines changed: 69 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -227,3 +227,72 @@ async def test_client_factory_disconnect_all_handles_exceptions():
227227
await factory.disconnect()
228228

229229
mock_client._disconnect.assert_called_once_with(force=True)
230+
231+
232+
@pytest.mark.asyncio
233+
async def test_proxy_client_connect_runtime_error_with_mcp_error():
234+
"""Test connection handles RuntimeError wrapping McpError."""
235+
mock_transport = Mock(spec=ClientTransport)
236+
client = AWSMCPProxyClient(mock_transport)
237+
238+
error_data = ErrorData(code=-32600, message='Invalid Request')
239+
mcp_error = McpError(error=error_data)
240+
runtime_error = RuntimeError('Connection failed')
241+
runtime_error.__cause__ = mcp_error
242+
243+
with patch('mcp_proxy_for_aws.proxy._ProxyClient._connect', side_effect=runtime_error):
244+
with pytest.raises(McpError) as exc_info:
245+
await client._connect()
246+
assert exc_info.value.error.code == -32600
247+
248+
249+
@pytest.mark.asyncio
250+
async def test_proxy_client_connect_runtime_error_max_retries():
251+
"""Test connection stops retrying after max_connect_retry."""
252+
mock_transport = Mock(spec=ClientTransport)
253+
client = AWSMCPProxyClient(mock_transport, max_connect_retry=2)
254+
255+
runtime_error = RuntimeError('Connection failed')
256+
257+
with patch('mcp_proxy_for_aws.proxy._ProxyClient._connect', side_effect=runtime_error):
258+
with patch.object(client, '_disconnect', new_callable=AsyncMock) as mock_disconnect:
259+
with pytest.raises(RuntimeError):
260+
await client._connect()
261+
assert mock_disconnect.call_count == 3
262+
263+
264+
@pytest.mark.asyncio
265+
async def test_proxy_client_connect_runtime_error_with_timeout():
266+
"""Test connection handles TimeoutException during disconnect."""
267+
mock_transport = Mock(spec=ClientTransport)
268+
client = AWSMCPProxyClient(mock_transport, max_connect_retry=1)
269+
270+
runtime_error = RuntimeError('Connection failed')
271+
call_count = 0
272+
273+
async def mock_connect_side_effect(*args, **kwargs):
274+
nonlocal call_count
275+
call_count += 1
276+
if call_count <= 2:
277+
raise runtime_error
278+
return 'connected'
279+
280+
with patch(
281+
'mcp_proxy_for_aws.proxy._ProxyClient._connect', side_effect=mock_connect_side_effect
282+
):
283+
with patch.object(
284+
client,
285+
'_disconnect',
286+
new_callable=AsyncMock,
287+
side_effect=httpx.TimeoutException('timeout'),
288+
):
289+
result = await client._connect()
290+
assert result == 'connected'
291+
292+
293+
@pytest.mark.asyncio
294+
async def test_proxy_client_max_connect_retry_default():
295+
"""Test default max_connect_retry is 3."""
296+
mock_transport = Mock(spec=ClientTransport)
297+
client = AWSMCPProxyClient(mock_transport)
298+
assert client._max_connect_retry == 3

uv.lock

Lines changed: 2 additions & 2 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

0 commit comments

Comments
 (0)