Skip to content

Commit 9632822

Browse files
authored
fix: write initialize error to stdout (aws#95)
* logging improvement * add uv.lock * fix: proxy must run as mpc stdio server explicitly * fix: on connection failure, write the mcp message to stdout * fix broken unit tests * add unit test for initialize * fix linter
1 parent 308e5a8 commit 9632822

File tree

8 files changed

+3588
-188
lines changed

8 files changed

+3588
-188
lines changed

mcp_proxy_for_aws/logging_config.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -48,3 +48,4 @@ def configure_logging(level: Optional[str] = None) -> None:
4848
# Set httpx logging to WARNING by default to reduce noise
4949
logging.getLogger('httpx').setLevel(logging.WARNING)
5050
logging.getLogger('httpcore').setLevel(logging.WARNING)
51+
logging.getLogger('botocore').setLevel(logging.WARNING)

mcp_proxy_for_aws/server.py

Lines changed: 99 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -23,12 +23,23 @@
2323
"""
2424

2525
import asyncio
26+
import contextlib
2627
import httpx
2728
import logging
29+
import sys
2830
from fastmcp import Client
31+
from fastmcp.client import ClientTransport
2932
from fastmcp.server.middleware.error_handling import RetryMiddleware
3033
from fastmcp.server.middleware.logging import LoggingMiddleware
3134
from fastmcp.server.server import FastMCP
35+
from mcp import McpError
36+
from mcp.types import (
37+
CONNECTION_CLOSED,
38+
ErrorData,
39+
JSONRPCError,
40+
JSONRPCMessage,
41+
JSONRPCResponse,
42+
)
3243
from mcp_proxy_for_aws.cli import parse_args
3344
from mcp_proxy_for_aws.logging_config import configure_logging
3445
from mcp_proxy_for_aws.middleware.tool_filter import ToolFilteringMiddleware
@@ -37,13 +48,75 @@
3748
determine_aws_region,
3849
determine_service_name,
3950
)
40-
from typing import Any
4151

4252

4353
logger = logging.getLogger(__name__)
4454

4555

46-
async def setup_mcp_mode(local_mcp: FastMCP, args) -> None:
56+
@contextlib.asynccontextmanager
57+
async def _initialize_client(transport: ClientTransport):
58+
"""Handle the exceptions for during client initialize."""
59+
# line = sys.stdin.readline()
60+
# logger.debug('First line from kiro %s', line)
61+
async with contextlib.AsyncExitStack() as stack:
62+
try:
63+
client = await stack.enter_async_context(Client(transport))
64+
if client.initialize_result:
65+
print(
66+
client.initialize_result.model_dump_json(
67+
by_alias=True,
68+
exclude_none=True,
69+
),
70+
file=sys.stdout,
71+
)
72+
except httpx.HTTPStatusError as http_error:
73+
logger.error('HTTP Error during initialize %s', http_error)
74+
response = http_error.response
75+
try:
76+
body = await response.aread()
77+
jsonrpc_msg = JSONRPCMessage.model_validate_json(body).root
78+
if isinstance(jsonrpc_msg, (JSONRPCError, JSONRPCResponse)):
79+
line = jsonrpc_msg.model_dump_json(
80+
by_alias=True,
81+
exclude_none=True,
82+
)
83+
logger.debug('Writing the unhandled http error to stdout %s', http_error)
84+
print(line, file=sys.stdout)
85+
else:
86+
logger.debug('Ignoring jsonrpc message type=%s', type(jsonrpc_msg))
87+
except Exception as _:
88+
logger.debug('Cannot read HTTP response body')
89+
raise http_error
90+
except Exception as e:
91+
cause = e.__cause__
92+
if isinstance(cause, McpError):
93+
logger.error('MCP Error during initialize %s', cause.error)
94+
jsonrpc_error = JSONRPCError(jsonrpc='2.0', id=0, error=cause.error)
95+
line = jsonrpc_error.model_dump_json(
96+
by_alias=True,
97+
exclude_none=True,
98+
)
99+
else:
100+
logger.error('Error during initialize %s', e)
101+
jsonrpc_error = JSONRPCError(
102+
jsonrpc='2.0',
103+
id=0,
104+
error=ErrorData(
105+
code=CONNECTION_CLOSED,
106+
message=str(e),
107+
),
108+
)
109+
line = jsonrpc_error.model_dump_json(
110+
by_alias=True,
111+
exclude_none=True,
112+
)
113+
print(line, file=sys.stdout)
114+
raise e
115+
logger.debug('Initialized MCP client')
116+
yield client
117+
118+
119+
async def run_proxy(args) -> None:
47120
"""Set up the server in MCP mode."""
48121
logger.info('Setting up server in MCP mode')
49122

@@ -84,16 +157,25 @@ async def setup_mcp_mode(local_mcp: FastMCP, args) -> None:
84157
transport = create_transport_with_sigv4(
85158
args.endpoint, service, region, metadata, timeout, profile
86159
)
87-
async with Client(transport=transport) as client:
88-
# Create proxy with the transport
89-
proxy = FastMCP.as_proxy(client)
90-
add_logging_middleware(proxy, args.log_level)
91-
add_tool_filtering_middleware(proxy, args.read_only)
92-
93-
if args.retries:
94-
add_retry_middleware(proxy, args.retries)
95-
96-
await proxy.run_async()
160+
async with _initialize_client(transport) as client:
161+
try:
162+
proxy = FastMCP.as_proxy(
163+
client,
164+
name='MCP Proxy for AWS',
165+
instructions=(
166+
'MCP Proxy for AWS provides access to SigV4 protected MCP servers through a single interface. '
167+
'This proxy handles authentication and request routing to the appropriate backend services.'
168+
),
169+
)
170+
add_logging_middleware(proxy, args.log_level)
171+
add_tool_filtering_middleware(proxy, args.read_only)
172+
173+
if args.retries:
174+
add_retry_middleware(proxy, args.retries)
175+
await proxy.run_async(transport='stdio')
176+
except Exception as e:
177+
logger.error('Cannot start proxy server: %s', e)
178+
raise e
97179

98180

99181
def add_tool_filtering_middleware(mcp: FastMCP, read_only: bool = False) -> None:
@@ -146,27 +228,12 @@ def main():
146228
configure_logging(args.log_level)
147229
logger.info('Starting MCP Proxy for AWS Server')
148230

149-
# Create FastMCP instance
150-
mcp = FastMCP[Any](
151-
name='MCP Proxy',
152-
instructions=(
153-
'MCP Proxy for AWS Server that provides access to backend servers through a single interface. '
154-
'This proxy handles authentication and request routing to the appropriate backend services.'
155-
),
156-
)
157-
158-
async def setup_and_run():
159-
try:
160-
await setup_mcp_mode(mcp, args)
161-
162-
logger.info('Server setup complete, starting MCP server')
163-
164-
except Exception as e:
165-
logger.error('Failed to start server: %s', e)
166-
raise
167-
168231
# Run the server
169-
asyncio.run(setup_and_run())
232+
try:
233+
asyncio.run(run_proxy(args))
234+
except Exception:
235+
logger.exception('Error launching MCP proxy for aws')
236+
return 1
170237

171238

172239
if __name__ == '__main__':

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@ description = "MCP Proxy for AWS"
1616
readme = "README.md"
1717
requires-python = ">=3.10,<3.14"
1818
dependencies = [
19-
"fastmcp>=2.13.0.2",
19+
"fastmcp>=2.13.1",
2020
"boto3>=1.34.0",
2121
"botocore>=1.34.0",
2222
]

tests/integ/mcp/simple_mcp_server/mcp_server.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -79,8 +79,10 @@ async def elicit_for_my_name(elicitation_expected: str, ctx: Context):
7979
@mcp.tool
8080
def echo_metadata(ctx: Context):
8181
"""MCP Tool that echoes back the _meta field from the request."""
82-
meta = ctx.request_context.meta
83-
return {'received_meta': meta}
82+
if ctx.request_context:
83+
meta = ctx.request_context.meta
84+
return {'received_meta': meta}
85+
raise RuntimeError('No request context received')
8486

8587

8688
#### Server Setup
Lines changed: 174 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,174 @@
1+
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
"""Tests for _initialize_client error handling."""
16+
17+
import httpx
18+
import pytest
19+
from mcp import McpError
20+
from mcp.types import ErrorData, JSONRPCError, JSONRPCResponse
21+
from mcp_proxy_for_aws.server import _initialize_client
22+
from unittest.mock import AsyncMock, Mock, patch
23+
24+
25+
@pytest.mark.asyncio
26+
async def test_successful_initialization():
27+
"""Test successful client initialization."""
28+
mock_transport = Mock()
29+
mock_client = Mock()
30+
31+
with patch('mcp_proxy_for_aws.server.Client') as mock_client_class:
32+
mock_client_class.return_value.__aenter__ = AsyncMock(return_value=mock_client)
33+
mock_client_class.return_value.__aexit__ = AsyncMock(return_value=None)
34+
35+
async with _initialize_client(mock_transport) as client:
36+
assert client == mock_client
37+
38+
39+
@pytest.mark.asyncio
40+
async def test_http_error_with_jsonrpc_error(capsys):
41+
"""Test HTTPStatusError with JSONRPCError response."""
42+
mock_transport = Mock()
43+
error_data = ErrorData(code=-32600, message='Invalid Request')
44+
jsonrpc_error = JSONRPCError(jsonrpc='2.0', id=1, error=error_data)
45+
46+
mock_response = Mock()
47+
mock_response.aread = AsyncMock(return_value=jsonrpc_error.model_dump_json().encode())
48+
49+
http_error = httpx.HTTPStatusError('error', request=Mock(), response=mock_response)
50+
51+
with patch('mcp_proxy_for_aws.server.Client') as mock_client_class:
52+
mock_client_class.return_value.__aenter__ = AsyncMock(side_effect=http_error)
53+
54+
with pytest.raises(httpx.HTTPStatusError):
55+
async with _initialize_client(mock_transport):
56+
pass
57+
58+
captured = capsys.readouterr()
59+
assert 'Invalid Request' in captured.out
60+
61+
62+
@pytest.mark.asyncio
63+
async def test_http_error_with_jsonrpc_response(capsys):
64+
"""Test HTTPStatusError with JSONRPCResponse."""
65+
mock_transport = Mock()
66+
jsonrpc_response = JSONRPCResponse(jsonrpc='2.0', id=1, result={'status': 'error'})
67+
68+
mock_response = Mock()
69+
mock_response.aread = AsyncMock(return_value=jsonrpc_response.model_dump_json().encode())
70+
71+
http_error = httpx.HTTPStatusError('error', request=Mock(), response=mock_response)
72+
73+
with patch('mcp_proxy_for_aws.server.Client') as mock_client_class:
74+
mock_client_class.return_value.__aenter__ = AsyncMock(side_effect=http_error)
75+
76+
with pytest.raises(httpx.HTTPStatusError):
77+
async with _initialize_client(mock_transport):
78+
pass
79+
80+
captured = capsys.readouterr()
81+
assert '"result":{"status":"error"}' in captured.out
82+
83+
84+
@pytest.mark.asyncio
85+
async def test_http_error_with_invalid_json():
86+
"""Test HTTPStatusError with invalid JSON response."""
87+
mock_transport = Mock()
88+
89+
mock_response = Mock()
90+
mock_response.aread = AsyncMock(return_value=b'invalid json')
91+
92+
http_error = httpx.HTTPStatusError('error', request=Mock(), response=mock_response)
93+
94+
with patch('mcp_proxy_for_aws.server.Client') as mock_client_class:
95+
mock_client_class.return_value.__aenter__ = AsyncMock(side_effect=http_error)
96+
97+
with pytest.raises(httpx.HTTPStatusError):
98+
async with _initialize_client(mock_transport):
99+
pass
100+
101+
102+
@pytest.mark.asyncio
103+
async def test_http_error_with_non_jsonrpc_message():
104+
"""Test HTTPStatusError with non-JSONRPCError/Response message."""
105+
mock_transport = Mock()
106+
107+
mock_response = Mock()
108+
mock_response.aread = AsyncMock(return_value=b'{"jsonrpc":"2.0","method":"test"}')
109+
110+
http_error = httpx.HTTPStatusError('error', request=Mock(), response=mock_response)
111+
112+
with patch('mcp_proxy_for_aws.server.Client') as mock_client_class:
113+
mock_client_class.return_value.__aenter__ = AsyncMock(side_effect=http_error)
114+
115+
with pytest.raises(httpx.HTTPStatusError):
116+
async with _initialize_client(mock_transport):
117+
pass
118+
119+
120+
@pytest.mark.asyncio
121+
async def test_http_error_response_read_failure():
122+
"""Test HTTPStatusError when response.aread() fails."""
123+
mock_transport = Mock()
124+
125+
mock_response = Mock()
126+
mock_response.aread = AsyncMock(side_effect=Exception('Read failed'))
127+
128+
http_error = httpx.HTTPStatusError('error', request=Mock(), response=mock_response)
129+
130+
with patch('mcp_proxy_for_aws.server.Client') as mock_client_class:
131+
mock_client_class.return_value.__aenter__ = AsyncMock(side_effect=http_error)
132+
133+
with pytest.raises(httpx.HTTPStatusError):
134+
async with _initialize_client(mock_transport):
135+
pass
136+
137+
138+
@pytest.mark.asyncio
139+
async def test_generic_error_with_mcp_error_cause(capsys):
140+
"""Test generic exception with McpError as cause."""
141+
mock_transport = Mock()
142+
error_data = ErrorData(code=-32601, message='Method not found')
143+
mcp_error = McpError(error_data)
144+
generic_error = Exception('Wrapper error')
145+
generic_error.__cause__ = mcp_error
146+
147+
with patch('mcp_proxy_for_aws.server.Client') as mock_client_class:
148+
mock_client_class.return_value.__aenter__ = AsyncMock(side_effect=generic_error)
149+
150+
with pytest.raises(Exception):
151+
async with _initialize_client(mock_transport):
152+
pass
153+
154+
captured = capsys.readouterr()
155+
assert 'Method not found' in captured.out
156+
assert '"code":-32601' in captured.out
157+
158+
159+
@pytest.mark.asyncio
160+
async def test_generic_error_without_mcp_error_cause(capsys):
161+
"""Test generic exception without McpError cause."""
162+
mock_transport = Mock()
163+
generic_error = Exception('Generic error')
164+
165+
with patch('mcp_proxy_for_aws.server.Client') as mock_client_class:
166+
mock_client_class.return_value.__aenter__ = AsyncMock(side_effect=generic_error)
167+
168+
with pytest.raises(Exception):
169+
async with _initialize_client(mock_transport):
170+
pass
171+
172+
captured = capsys.readouterr()
173+
assert 'Generic error' in captured.out
174+
assert '"code":-32000' in captured.out

0 commit comments

Comments
 (0)