Skip to content

Commit 7cb3809

Browse files
committed
fix unit tests
1 parent 9ae03db commit 7cb3809

File tree

2 files changed

+42
-224
lines changed

2 files changed

+42
-224
lines changed

tests/unit/test_initialize_client.py

Lines changed: 0 additions & 174 deletions
This file was deleted.

tests/unit/test_server.py

Lines changed: 42 additions & 50 deletions
Original file line numberDiff line numberDiff line change
@@ -30,9 +30,9 @@
3030
class TestServer:
3131
"""Tests for the server module."""
3232

33-
@patch('mcp_proxy_for_aws.server.ProxyClient')
33+
@patch('mcp_proxy_for_aws.server.AWSMCPProxyClientFactory')
3434
@patch('mcp_proxy_for_aws.server.create_transport_with_sigv4')
35-
@patch('mcp_proxy_for_aws.server.FastMCPProxy')
35+
@patch('mcp_proxy_for_aws.server.AWSMCPProxy')
3636
@patch('mcp_proxy_for_aws.server.determine_aws_region')
3737
@patch('mcp_proxy_for_aws.server.determine_service_name')
3838
@patch('mcp_proxy_for_aws.server.add_tool_filtering_middleware')
@@ -43,9 +43,9 @@ async def test_setup_mcp_mode(
4343
mock_add_filtering,
4444
mock_determine_service,
4545
mock_determine_region,
46-
mock_fastmcp_proxy,
46+
mock_aws_proxy,
4747
mock_create_transport,
48-
mock_client_class,
48+
mock_client_factory_class,
4949
):
5050
"""Test that MCP mode is set up correctly."""
5151
# Arrange
@@ -68,20 +68,18 @@ async def test_setup_mcp_mode(
6868
mock_determine_service.return_value = 'test-service'
6969
mock_determine_region.return_value = 'us-east-1'
7070

71-
# Mock the transport and client
71+
# Mock the transport and client factory
7272
mock_transport = Mock(spec=ClientTransport)
7373
mock_create_transport.return_value = mock_transport
7474

75-
mock_client = Mock()
76-
mock_client.initialize_result = None
77-
mock_client.is_connected = Mock(return_value=True)
78-
mock_client.__aenter__ = AsyncMock(return_value=mock_client)
79-
mock_client.__aexit__ = AsyncMock(return_value=None)
80-
mock_client_class.return_value = mock_client
75+
mock_client_factory = Mock()
76+
mock_client_factory.disconnect_all = AsyncMock()
77+
mock_client_factory_class.return_value = mock_client_factory
8178

8279
mock_proxy = Mock()
8380
mock_proxy.run_async = AsyncMock()
84-
mock_fastmcp_proxy.return_value = mock_proxy
81+
mock_proxy.add_middleware = Mock()
82+
mock_aws_proxy.return_value = mock_proxy
8583

8684
# Act
8785
await run_proxy(mock_args)
@@ -98,17 +96,17 @@ async def test_setup_mcp_mode(
9896
assert call_args[0][3] == {'AWS_REGION': 'us-east-1'} # metadata
9997
# call_args[0][4] is the Timeout object
10098
assert call_args[0][5] is None # profile
101-
mock_client_class.assert_called_once_with(mock_transport)
102-
mock_fastmcp_proxy.assert_called_once()
99+
mock_client_factory_class.assert_called_once_with(mock_transport)
100+
mock_aws_proxy.assert_called_once()
103101
mock_add_filtering.assert_called_once_with(mock_proxy, True)
104102
mock_add_retry.assert_called_once_with(mock_proxy, 1)
105103
mock_proxy.run_async.assert_called_once_with(
106104
transport='stdio', show_banner=False, log_level='INFO'
107105
)
108106

109-
@patch('mcp_proxy_for_aws.server.ProxyClient')
107+
@patch('mcp_proxy_for_aws.server.AWSMCPProxyClientFactory')
110108
@patch('mcp_proxy_for_aws.server.create_transport_with_sigv4')
111-
@patch('mcp_proxy_for_aws.server.FastMCPProxy')
109+
@patch('mcp_proxy_for_aws.server.AWSMCPProxy')
112110
@patch('mcp_proxy_for_aws.server.determine_aws_region')
113111
@patch('mcp_proxy_for_aws.server.determine_service_name')
114112
@patch('mcp_proxy_for_aws.server.add_tool_filtering_middleware')
@@ -117,9 +115,9 @@ async def test_setup_mcp_mode_no_retries(
117115
mock_add_filtering,
118116
mock_determine_service,
119117
mock_determine_region,
120-
mock_fastmcp_proxy,
118+
mock_aws_proxy,
121119
mock_create_transport,
122-
mock_client_class,
120+
mock_client_factory_class,
123121
):
124122
"""Test that MCP mode setup without retries doesn't add retry middleware."""
125123
# Arrange
@@ -142,20 +140,18 @@ async def test_setup_mcp_mode_no_retries(
142140
mock_determine_service.return_value = 'test-service'
143141
mock_determine_region.return_value = 'us-east-1'
144142

145-
# Mock the transport and client
143+
# Mock the transport and client factory
146144
mock_transport = Mock(spec=ClientTransport)
147145
mock_create_transport.return_value = mock_transport
148146

149-
mock_client = Mock()
150-
mock_client.initialize_result = None
151-
mock_client.is_connected = Mock(return_value=True)
152-
mock_client.__aenter__ = AsyncMock(return_value=mock_client)
153-
mock_client.__aexit__ = AsyncMock(return_value=None)
154-
mock_client_class.return_value = mock_client
147+
mock_client_factory = Mock()
148+
mock_client_factory.disconnect_all = AsyncMock()
149+
mock_client_factory_class.return_value = mock_client_factory
155150

156151
mock_proxy = Mock()
157152
mock_proxy.run_async = AsyncMock()
158-
mock_fastmcp_proxy.return_value = mock_proxy
153+
mock_proxy.add_middleware = Mock()
154+
mock_aws_proxy.return_value = mock_proxy
159155

160156
# Act
161157
await run_proxy(mock_args)
@@ -175,16 +171,16 @@ async def test_setup_mcp_mode_no_retries(
175171
} # metadata
176172
# call_args[0][4] is the Timeout object
177173
assert call_args[0][5] == 'test-profile' # profile
178-
mock_client_class.assert_called_once_with(mock_transport)
179-
mock_fastmcp_proxy.assert_called_once()
174+
mock_client_factory_class.assert_called_once_with(mock_transport)
175+
mock_aws_proxy.assert_called_once()
180176
mock_add_filtering.assert_called_once_with(mock_proxy, False)
181177
mock_proxy.run_async.assert_called_once_with(
182178
transport='stdio', show_banner=False, log_level='INFO'
183179
)
184180

185-
@patch('mcp_proxy_for_aws.server.ProxyClient')
181+
@patch('mcp_proxy_for_aws.server.AWSMCPProxyClientFactory')
186182
@patch('mcp_proxy_for_aws.server.create_transport_with_sigv4')
187-
@patch('mcp_proxy_for_aws.server.FastMCPProxy')
183+
@patch('mcp_proxy_for_aws.server.AWSMCPProxy')
188184
@patch('mcp_proxy_for_aws.server.determine_aws_region')
189185
@patch('mcp_proxy_for_aws.server.determine_service_name')
190186
@patch('mcp_proxy_for_aws.server.add_tool_filtering_middleware')
@@ -193,9 +189,9 @@ async def test_setup_mcp_mode_no_metadata_injects_aws_region(
193189
mock_add_filtering,
194190
mock_determine_service,
195191
mock_determine_region,
196-
mock_fastmcp_proxy,
192+
mock_aws_proxy,
197193
mock_create_transport,
198-
mock_client_class,
194+
mock_client_factory_class,
199195
):
200196
"""Test that AWS_REGION is automatically injected when no metadata is provided."""
201197
# Arrange
@@ -219,16 +215,14 @@ async def test_setup_mcp_mode_no_metadata_injects_aws_region(
219215
mock_transport = Mock(spec=ClientTransport)
220216
mock_create_transport.return_value = mock_transport
221217

222-
mock_client = Mock()
223-
mock_client.initialize_result = None
224-
mock_client.is_connected = Mock(return_value=True)
225-
mock_client.__aenter__ = AsyncMock(return_value=mock_client)
226-
mock_client.__aexit__ = AsyncMock(return_value=None)
227-
mock_client_class.return_value = mock_client
218+
mock_client_factory = Mock()
219+
mock_client_factory.disconnect_all = AsyncMock()
220+
mock_client_factory_class.return_value = mock_client_factory
228221

229222
mock_proxy = Mock()
230223
mock_proxy.run_async = AsyncMock()
231-
mock_fastmcp_proxy.return_value = mock_proxy
224+
mock_proxy.add_middleware = Mock()
225+
mock_aws_proxy.return_value = mock_proxy
232226

233227
# Act
234228
await run_proxy(mock_args)
@@ -239,9 +233,9 @@ async def test_setup_mcp_mode_no_metadata_injects_aws_region(
239233
metadata = call_args[0][3]
240234
assert metadata == {'AWS_REGION': 'ap-southeast-1'}
241235

242-
@patch('mcp_proxy_for_aws.server.ProxyClient')
236+
@patch('mcp_proxy_for_aws.server.AWSMCPProxyClientFactory')
243237
@patch('mcp_proxy_for_aws.server.create_transport_with_sigv4')
244-
@patch('mcp_proxy_for_aws.server.FastMCPProxy')
238+
@patch('mcp_proxy_for_aws.server.AWSMCPProxy')
245239
@patch('mcp_proxy_for_aws.server.determine_aws_region')
246240
@patch('mcp_proxy_for_aws.server.determine_service_name')
247241
@patch('mcp_proxy_for_aws.server.add_tool_filtering_middleware')
@@ -250,9 +244,9 @@ async def test_setup_mcp_mode_metadata_without_aws_region_injects_it(
250244
mock_add_filtering,
251245
mock_determine_service,
252246
mock_determine_region,
253-
mock_fastmcp_proxy,
247+
mock_aws_proxy,
254248
mock_create_transport,
255-
mock_client_class,
249+
mock_client_factory_class,
256250
):
257251
"""Test that AWS_REGION is injected even when other metadata is provided."""
258252
# Arrange
@@ -276,16 +270,14 @@ async def test_setup_mcp_mode_metadata_without_aws_region_injects_it(
276270
mock_transport = Mock(spec=ClientTransport)
277271
mock_create_transport.return_value = mock_transport
278272

279-
mock_client = Mock()
280-
mock_client.initialize_result = None
281-
mock_client.is_connected = Mock(return_value=True)
282-
mock_client.__aenter__ = AsyncMock(return_value=mock_client)
283-
mock_client.__aexit__ = AsyncMock(return_value=None)
284-
mock_client_class.return_value = mock_client
273+
mock_client_factory = Mock()
274+
mock_client_factory.disconnect_all = AsyncMock()
275+
mock_client_factory_class.return_value = mock_client_factory
285276

286277
mock_proxy = Mock()
287278
mock_proxy.run_async = AsyncMock()
288-
mock_fastmcp_proxy.return_value = mock_proxy
279+
mock_proxy.add_middleware = Mock()
280+
mock_aws_proxy.return_value = mock_proxy
289281

290282
# Act
291283
await run_proxy(mock_args)

0 commit comments

Comments
 (0)