Skip to content

Commit 0f5f2b4

Browse files
committed
test: add unit tests for proxy
1 parent 7cb3809 commit 0f5f2b4

File tree

1 file changed

+241
-0
lines changed

1 file changed

+241
-0
lines changed

tests/unit/test_proxy.py

Lines changed: 241 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,241 @@
1+
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
"""Tests for proxy module."""
16+
17+
import httpx
18+
import pytest
19+
from fastmcp.client.transports import ClientTransport
20+
from fastmcp.exceptions import NotFoundError
21+
from fastmcp.tools import Tool
22+
from mcp import McpError
23+
from mcp.types import ErrorData, InitializeRequest, JSONRPCError
24+
from mcp_proxy_for_aws.proxy import (
25+
AWSMCPProxy,
26+
AWSMCPProxyClient,
27+
AWSMCPProxyClientFactory,
28+
AWSProxyToolManager,
29+
)
30+
from unittest.mock import AsyncMock, Mock, patch
31+
32+
33+
@pytest.mark.asyncio
34+
async def test_tool_manager_get_tool_with_cache():
35+
"""Test get_tool returns from cache when available."""
36+
mock_factory = Mock()
37+
manager = AWSProxyToolManager(mock_factory)
38+
mock_tool = Mock(spec=Tool)
39+
manager._cached_tools = {'test_tool': mock_tool}
40+
41+
result = await manager.get_tool('test_tool')
42+
assert result == mock_tool
43+
44+
45+
@pytest.mark.asyncio
46+
async def test_tool_manager_get_tool_without_cache():
47+
"""Test get_tool fetches tools when cache is empty."""
48+
mock_factory = Mock()
49+
manager = AWSProxyToolManager(mock_factory)
50+
mock_tool = Mock(spec=Tool)
51+
52+
with patch.object(manager, 'get_tools', return_value={'test_tool': mock_tool}):
53+
result = await manager.get_tool('test_tool')
54+
assert result == mock_tool
55+
assert manager._cached_tools == {'test_tool': mock_tool}
56+
57+
58+
@pytest.mark.asyncio
59+
async def test_tool_manager_get_tool_not_found():
60+
"""Test get_tool raises NotFoundError when tool doesn't exist."""
61+
mock_factory = Mock()
62+
manager = AWSProxyToolManager(mock_factory)
63+
manager._cached_tools = {}
64+
65+
with pytest.raises(NotFoundError, match="Tool 'missing_tool' not found"):
66+
await manager.get_tool('missing_tool')
67+
68+
69+
@pytest.mark.asyncio
70+
async def test_tool_manager_get_tools_updates_cache():
71+
"""Test get_tools updates the cache."""
72+
mock_factory = Mock()
73+
manager = AWSProxyToolManager(mock_factory)
74+
mock_tools = {'tool1': Mock(spec=Tool), 'tool2': Mock(spec=Tool)}
75+
76+
with patch('mcp_proxy_for_aws.proxy._ProxyToolManager.get_tools', return_value=mock_tools):
77+
result = await manager.get_tools()
78+
assert result == mock_tools
79+
assert manager._cached_tools == mock_tools
80+
81+
82+
def test_proxy_initialization():
83+
"""Test AWSMCPProxy initializes with custom tool manager."""
84+
mock_factory = Mock()
85+
proxy = AWSMCPProxy(client_factory=mock_factory, name='test')
86+
assert isinstance(proxy._tool_manager, AWSProxyToolManager)
87+
88+
89+
@pytest.mark.asyncio
90+
async def test_proxy_client_connect_success():
91+
"""Test successful connection."""
92+
mock_transport = Mock(spec=ClientTransport)
93+
client = AWSMCPProxyClient(mock_transport)
94+
95+
with patch('mcp_proxy_for_aws.proxy._ProxyClient._connect', return_value='connected'):
96+
result = await client._connect()
97+
assert result == 'connected'
98+
99+
100+
@pytest.mark.asyncio
101+
async def test_proxy_client_connect_http_error_with_mcp_error():
102+
"""Test connection failure with MCP error response."""
103+
mock_transport = Mock(spec=ClientTransport)
104+
client = AWSMCPProxyClient(mock_transport)
105+
106+
error_data = ErrorData(code=-32600, message='Invalid Request')
107+
jsonrpc_error = JSONRPCError(jsonrpc='2.0', id=1, error=error_data)
108+
109+
mock_response = Mock()
110+
mock_response.aread = AsyncMock(return_value=jsonrpc_error.model_dump_json().encode())
111+
112+
http_error = httpx.HTTPStatusError('error', request=Mock(), response=mock_response)
113+
114+
with patch('mcp_proxy_for_aws.proxy._ProxyClient._connect', side_effect=http_error):
115+
with pytest.raises(McpError) as exc_info:
116+
await client._connect()
117+
assert exc_info.value.error.code == -32600
118+
assert exc_info.value.error.message == 'Invalid Request'
119+
120+
121+
@pytest.mark.asyncio
122+
async def test_proxy_client_connect_http_error_non_mcp():
123+
"""Test connection failure with non-MCP HTTP error."""
124+
mock_transport = Mock(spec=ClientTransport)
125+
client = AWSMCPProxyClient(mock_transport)
126+
127+
mock_response = Mock()
128+
mock_response.aread = AsyncMock(return_value=b'Not a JSON-RPC message')
129+
130+
http_error = httpx.HTTPStatusError('error', request=Mock(), response=mock_response)
131+
132+
with patch('mcp_proxy_for_aws.proxy._ProxyClient._connect', side_effect=http_error):
133+
with pytest.raises(httpx.HTTPStatusError):
134+
await client._connect()
135+
136+
137+
@pytest.mark.asyncio
138+
async def test_proxy_client_aexit_does_not_disconnect():
139+
"""Test __aexit__ does not disconnect the client."""
140+
mock_transport = Mock(spec=ClientTransport)
141+
client = AWSMCPProxyClient(mock_transport)
142+
143+
result = await client.__aexit__(None, None, None)
144+
assert result is None
145+
146+
147+
def test_client_factory_initialization():
148+
"""Test factory initialization."""
149+
mock_transport = Mock(spec=ClientTransport)
150+
factory = AWSMCPProxyClientFactory(mock_transport)
151+
152+
assert factory._transport == mock_transport
153+
assert isinstance(factory._client, AWSMCPProxyClient)
154+
assert factory._initialize_request is None
155+
156+
157+
def test_client_factory_set_init_params():
158+
"""Test setting initialization parameters."""
159+
mock_transport = Mock(spec=ClientTransport)
160+
factory = AWSMCPProxyClientFactory(mock_transport)
161+
162+
mock_request = Mock(spec=InitializeRequest)
163+
factory.set_init_params(mock_request)
164+
165+
assert factory._initialize_request == mock_request
166+
167+
168+
@pytest.mark.asyncio
169+
async def test_client_factory_get_client_when_connected():
170+
"""Test get_client returns existing client when connected."""
171+
mock_transport = Mock(spec=ClientTransport)
172+
factory = AWSMCPProxyClientFactory(mock_transport)
173+
174+
factory._client.is_connected = Mock(return_value=True)
175+
176+
client = await factory.get_client()
177+
assert client == factory._client
178+
179+
180+
@pytest.mark.asyncio
181+
async def test_client_factory_get_client_when_disconnected():
182+
"""Test get_client creates new client when disconnected."""
183+
mock_transport = Mock(spec=ClientTransport)
184+
factory = AWSMCPProxyClientFactory(mock_transport)
185+
186+
old_client = factory._client
187+
factory._client.is_connected = Mock(return_value=False)
188+
189+
client = await factory.get_client()
190+
assert client != old_client
191+
assert isinstance(client, AWSMCPProxyClient)
192+
193+
194+
@pytest.mark.asyncio
195+
async def test_client_factory_callable_interface():
196+
"""Test factory callable interface."""
197+
mock_transport = Mock(spec=ClientTransport)
198+
factory = AWSMCPProxyClientFactory(mock_transport)
199+
200+
factory._client.is_connected = Mock(return_value=True)
201+
202+
client = await factory()
203+
assert client == factory._client
204+
205+
206+
@pytest.mark.asyncio
207+
async def test_client_factory_disconnect_all():
208+
"""Test disconnect_all disconnects all clients."""
209+
mock_transport = Mock(spec=ClientTransport)
210+
factory = AWSMCPProxyClientFactory(mock_transport)
211+
212+
mock_client1 = Mock()
213+
mock_client1._disconnect = AsyncMock()
214+
mock_client2 = Mock()
215+
mock_client2._disconnect = AsyncMock()
216+
217+
factory._clients = [mock_client1, mock_client2]
218+
219+
await factory.disconnect_all()
220+
221+
mock_client1._disconnect.assert_called_once_with(force=True)
222+
mock_client2._disconnect.assert_called_once_with(force=True)
223+
224+
225+
@pytest.mark.asyncio
226+
async def test_client_factory_disconnect_all_handles_exceptions():
227+
"""Test disconnect_all handles exceptions gracefully."""
228+
mock_transport = Mock(spec=ClientTransport)
229+
factory = AWSMCPProxyClientFactory(mock_transport)
230+
231+
mock_client1 = Mock()
232+
mock_client1._disconnect = AsyncMock(side_effect=Exception('Disconnect failed'))
233+
mock_client2 = Mock()
234+
mock_client2._disconnect = AsyncMock()
235+
236+
factory._clients = [mock_client1, mock_client2]
237+
238+
await factory.disconnect_all()
239+
240+
mock_client1._disconnect.assert_called_once_with(force=True)
241+
mock_client2._disconnect.assert_called_once_with(force=True)

0 commit comments

Comments
 (0)