Skip to content

Commit e771672

Browse files
committed
feat: pass client information with proxy version
1 parent 2a7cf94 commit e771672

File tree

2 files changed

+170
-1
lines changed

2 files changed

+170
-1
lines changed

mcp_proxy_for_aws/server.py

Lines changed: 17 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -36,10 +36,13 @@
3636
from mcp.types import (
3737
CONNECTION_CLOSED,
3838
ErrorData,
39+
Implementation,
40+
InitializeRequest,
3941
JSONRPCError,
4042
JSONRPCMessage,
4143
JSONRPCResponse,
4244
)
45+
from mcp_proxy_for_aws import __version__
4346
from mcp_proxy_for_aws.cli import parse_args
4447
from mcp_proxy_for_aws.logging_config import configure_logging
4548
from mcp_proxy_for_aws.middleware.tool_filter import ToolFilteringMiddleware
@@ -48,17 +51,30 @@
4851
determine_aws_region,
4952
determine_service_name,
5053
)
54+
from pydantic import ValidationError
5155

5256

5357
logger = logging.getLogger(__name__)
5458

5559

60+
DEFAULT_CLIENT_INFO = Implementation(name='mcp-proxy-for-aws', version=__version__)
61+
62+
5663
@contextlib.asynccontextmanager
5764
async def _initialize_client(transport: ClientTransport):
5865
"""Handle the exceptions for during client initialize."""
5966
async with contextlib.AsyncExitStack() as stack:
6067
try:
61-
client = await stack.enter_async_context(Client(transport))
68+
client_info: Implementation | None = None
69+
if first_line := sys.stdin.readline():
70+
with contextlib.suppress(ValidationError):
71+
init_request = InitializeRequest.model_validate_json(first_line, by_alias=True)
72+
client_info = init_request.params.clientInfo
73+
client_info.name = f'{client_info.name} via {DEFAULT_CLIENT_INFO.name}@{DEFAULT_CLIENT_INFO.version}'
74+
logger.debug('Using client info %s', client_info)
75+
client = await stack.enter_async_context(
76+
Client(transport, client_info=client_info or DEFAULT_CLIENT_INFO)
77+
)
6278
if client.initialize_result:
6379
print(
6480
client.initialize_result.model_dump_json(

tests/unit/test_client_info.py

Lines changed: 153 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,153 @@
1+
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
"""Tests for client info extraction from stdin."""
16+
17+
import pytest
18+
from io import StringIO
19+
from mcp import InitializeRequest
20+
from mcp.types import Implementation, InitializeRequestParams
21+
from mcp_proxy_for_aws.server import DEFAULT_CLIENT_INFO, _initialize_client
22+
from unittest.mock import AsyncMock, Mock, patch
23+
24+
25+
@pytest.mark.asyncio
26+
async def test_client_info_from_valid_initialize_request():
27+
"""Test extracting client info from valid InitializeRequest in stdin."""
28+
init_request = InitializeRequest(
29+
method='initialize',
30+
params=InitializeRequestParams(
31+
protocolVersion='2024-11-05',
32+
capabilities={},
33+
clientInfo=Implementation(name='test-client', version='1.0.0'),
34+
),
35+
)
36+
37+
mock_stdin = StringIO(init_request.model_dump_json(by_alias=True) + '\n')
38+
mock_transport = Mock()
39+
mock_client = Mock()
40+
mock_client.initialize_result = None
41+
42+
with patch('sys.stdin', mock_stdin):
43+
with patch('mcp_proxy_for_aws.server.Client') as mock_client_class:
44+
mock_client_class.return_value.__aenter__ = AsyncMock(return_value=mock_client)
45+
mock_client_class.return_value.__aexit__ = AsyncMock(return_value=None)
46+
47+
async with _initialize_client(mock_transport):
48+
pass
49+
50+
# Verify Client was called with modified client_info
51+
call_args = mock_client_class.call_args
52+
client_info = call_args.kwargs.get('client_info')
53+
assert client_info is not None
54+
assert 'test-client via mcp-proxy-for-aws@' in client_info.name
55+
56+
57+
@pytest.mark.asyncio
58+
async def test_client_info_with_empty_stdin():
59+
"""Test when stdin is empty."""
60+
mock_stdin = StringIO('')
61+
mock_transport = Mock()
62+
mock_client = Mock()
63+
mock_client.initialize_result = None
64+
65+
with patch('sys.stdin', mock_stdin):
66+
with patch('mcp_proxy_for_aws.server.Client') as mock_client_class:
67+
mock_client_class.return_value.__aenter__ = AsyncMock(return_value=mock_client)
68+
mock_client_class.return_value.__aexit__ = AsyncMock(return_value=None)
69+
70+
async with _initialize_client(mock_transport):
71+
pass
72+
73+
# Verify Client was called with default client_info
74+
call_args = mock_client_class.call_args
75+
client_info = call_args.kwargs.get('client_info')
76+
assert client_info == DEFAULT_CLIENT_INFO
77+
78+
79+
@pytest.mark.asyncio
80+
async def test_client_info_with_invalid_json():
81+
"""Test when stdin contains invalid JSON."""
82+
mock_stdin = StringIO('invalid json\n')
83+
mock_transport = Mock()
84+
mock_client = Mock()
85+
mock_client.initialize_result = None
86+
87+
with patch('sys.stdin', mock_stdin):
88+
with patch('mcp_proxy_for_aws.server.Client') as mock_client_class:
89+
mock_client_class.return_value.__aenter__ = AsyncMock(return_value=mock_client)
90+
mock_client_class.return_value.__aexit__ = AsyncMock(return_value=None)
91+
92+
async with _initialize_client(mock_transport):
93+
pass
94+
95+
# Verify Client was called with default client_info
96+
call_args = mock_client_class.call_args
97+
client_info = call_args.kwargs.get('client_info')
98+
assert client_info == DEFAULT_CLIENT_INFO
99+
100+
101+
@pytest.mark.asyncio
102+
async def test_client_info_with_non_initialize_request():
103+
"""Test when stdin contains valid JSON but not an InitializeRequest."""
104+
mock_stdin = StringIO('{"method": "other", "params": {}}\n')
105+
mock_transport = Mock()
106+
mock_client = Mock()
107+
mock_client.initialize_result = None
108+
109+
with patch('sys.stdin', mock_stdin):
110+
with patch('mcp_proxy_for_aws.server.Client') as mock_client_class:
111+
mock_client_class.return_value.__aenter__ = AsyncMock(return_value=mock_client)
112+
mock_client_class.return_value.__aexit__ = AsyncMock(return_value=None)
113+
114+
async with _initialize_client(mock_transport):
115+
pass
116+
117+
# Verify Client was called with default client_info
118+
call_args = mock_client_class.call_args
119+
client_info = call_args.kwargs.get('client_info')
120+
assert client_info == DEFAULT_CLIENT_INFO
121+
122+
123+
@pytest.mark.asyncio
124+
async def test_client_info_with_malformed_request():
125+
"""Test InitializeRequest with missing required fields."""
126+
# Manually create JSON without clientInfo to test validation error handling
127+
malformed_json = '{"method": "initialize", "params": {"protocolVersion": "2024-11-05", "capabilities": {}}}\n'
128+
129+
mock_stdin = StringIO(malformed_json)
130+
mock_transport = Mock()
131+
mock_client = Mock()
132+
mock_client.initialize_result = None
133+
134+
with patch('sys.stdin', mock_stdin):
135+
with patch('mcp_proxy_for_aws.server.Client') as mock_client_class:
136+
mock_client_class.return_value.__aenter__ = AsyncMock(return_value=mock_client)
137+
mock_client_class.return_value.__aexit__ = AsyncMock(return_value=None)
138+
139+
async with _initialize_client(mock_transport):
140+
pass
141+
142+
# Verify Client was called with default client_info due to validation error
143+
call_args = mock_client_class.call_args
144+
client_info = call_args.kwargs.get('client_info')
145+
assert client_info == DEFAULT_CLIENT_INFO
146+
147+
148+
@pytest.mark.asyncio
149+
async def test_default_client_info_values():
150+
"""Test DEFAULT_CLIENT_INFO has expected values."""
151+
assert DEFAULT_CLIENT_INFO.name == 'mcp-proxy-for-aws'
152+
assert DEFAULT_CLIENT_INFO.version is not None
153+
assert len(DEFAULT_CLIENT_INFO.version) > 0

0 commit comments

Comments
 (0)