Skip to content

Commit d6a04b5

Browse files
committed
fix: do not use multiple clients
1 parent 08e3ef5 commit d6a04b5

File tree

4 files changed

+31
-44
lines changed

4 files changed

+31
-44
lines changed

mcp_proxy_for_aws/proxy.py

Lines changed: 8 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -140,8 +140,7 @@ class AWSMCPProxyClientFactory:
140140
def __init__(self, transport: ClientTransport) -> None:
141141
"""Initialize a client factory with transport."""
142142
self._transport = transport
143-
self._client = AWSMCPProxyClient(transport)
144-
self._clients: list[AWSMCPProxyClient] = []
143+
self._client: AWSMCPProxyClient | None = None
145144
self._initialize_request: InitializeRequest | None = None
146145

147146
def set_init_params(self, initialize_request: InitializeRequest):
@@ -150,7 +149,7 @@ def set_init_params(self, initialize_request: InitializeRequest):
150149

151150
async def get_client(self) -> Client:
152151
"""Get client."""
153-
if not self._client.is_connected():
152+
if self._client is None:
154153
self._client = AWSMCPProxyClient(self._transport)
155154

156155
return self._client
@@ -159,10 +158,10 @@ async def __call__(self) -> Client:
159158
"""Implement the callable factory interface."""
160159
return await self.get_client()
161160

162-
async def disconnect_all(self):
161+
async def disconnect(self):
163162
"""Disconnect all the clients (no throw)."""
164-
for client in reversed(self._clients):
165-
try:
166-
await client._disconnect(force=True)
167-
except Exception:
168-
logger.exception('Failed to disconnect client.')
163+
try:
164+
if self._client:
165+
await self._client._disconnect(force=True)
166+
except Exception:
167+
logger.exception('Failed to disconnect client.')

mcp_proxy_for_aws/server.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -105,7 +105,7 @@ async def run_proxy(args) -> None:
105105
logger.error('Cannot start proxy server: %s', e)
106106
raise e
107107
finally:
108-
await client_factory.disconnect_all()
108+
await client_factory.disconnect()
109109

110110

111111
def add_tool_filtering_middleware(mcp: FastMCP, read_only: bool = False) -> None:

tests/unit/test_proxy.py

Lines changed: 18 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -150,7 +150,7 @@ def test_client_factory_initialization():
150150
factory = AWSMCPProxyClientFactory(mock_transport)
151151

152152
assert factory._transport == mock_transport
153-
assert isinstance(factory._client, AWSMCPProxyClient)
153+
assert factory._client is None
154154
assert factory._initialize_request is None
155155

156156

@@ -171,10 +171,11 @@ async def test_client_factory_get_client_when_connected():
171171
mock_transport = Mock(spec=ClientTransport)
172172
factory = AWSMCPProxyClientFactory(mock_transport)
173173

174-
factory._client.is_connected = Mock(return_value=True)
174+
mock_client = Mock(spec=AWSMCPProxyClient)
175+
factory._client = mock_client
175176

176177
client = await factory.get_client()
177-
assert client == factory._client
178+
assert client == mock_client
178179

179180

180181
@pytest.mark.asyncio
@@ -183,12 +184,9 @@ async def test_client_factory_get_client_when_disconnected():
183184
mock_transport = Mock(spec=ClientTransport)
184185
factory = AWSMCPProxyClientFactory(mock_transport)
185186

186-
old_client = factory._client
187-
factory._client.is_connected = Mock(return_value=False)
188-
189187
client = await factory.get_client()
190-
assert client != old_client
191188
assert isinstance(client, AWSMCPProxyClient)
189+
assert factory._client == client
192190

193191

194192
@pytest.mark.asyncio
@@ -197,45 +195,35 @@ async def test_client_factory_callable_interface():
197195
mock_transport = Mock(spec=ClientTransport)
198196
factory = AWSMCPProxyClientFactory(mock_transport)
199197

200-
factory._client.is_connected = Mock(return_value=True)
201-
202198
client = await factory()
203-
assert client == factory._client
199+
assert isinstance(client, AWSMCPProxyClient)
204200

205201

206202
@pytest.mark.asyncio
207203
async def test_client_factory_disconnect_all():
208-
"""Test disconnect_all disconnects all clients."""
204+
"""Test disconnect disconnects the client."""
209205
mock_transport = Mock(spec=ClientTransport)
210206
factory = AWSMCPProxyClientFactory(mock_transport)
211207

212-
mock_client1 = Mock()
213-
mock_client1._disconnect = AsyncMock()
214-
mock_client2 = Mock()
215-
mock_client2._disconnect = AsyncMock()
208+
mock_client = Mock()
209+
mock_client._disconnect = AsyncMock()
210+
factory._client = mock_client
216211

217-
factory._clients = [mock_client1, mock_client2]
212+
await factory.disconnect()
218213

219-
await factory.disconnect_all()
220-
221-
mock_client1._disconnect.assert_called_once_with(force=True)
222-
mock_client2._disconnect.assert_called_once_with(force=True)
214+
mock_client._disconnect.assert_called_once_with(force=True)
223215

224216

225217
@pytest.mark.asyncio
226218
async def test_client_factory_disconnect_all_handles_exceptions():
227-
"""Test disconnect_all handles exceptions gracefully."""
219+
"""Test disconnect handles exceptions gracefully."""
228220
mock_transport = Mock(spec=ClientTransport)
229221
factory = AWSMCPProxyClientFactory(mock_transport)
230222

231-
mock_client1 = Mock()
232-
mock_client1._disconnect = AsyncMock(side_effect=Exception('Disconnect failed'))
233-
mock_client2 = Mock()
234-
mock_client2._disconnect = AsyncMock()
235-
236-
factory._clients = [mock_client1, mock_client2]
223+
mock_client = Mock()
224+
mock_client._disconnect = AsyncMock(side_effect=Exception('Disconnect failed'))
225+
factory._client = mock_client
237226

238-
await factory.disconnect_all()
227+
await factory.disconnect()
239228

240-
mock_client1._disconnect.assert_called_once_with(force=True)
241-
mock_client2._disconnect.assert_called_once_with(force=True)
229+
mock_client._disconnect.assert_called_once_with(force=True)

tests/unit/test_server.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -73,7 +73,7 @@ async def test_setup_mcp_mode(
7373
mock_create_transport.return_value = mock_transport
7474

7575
mock_client_factory = Mock()
76-
mock_client_factory.disconnect_all = AsyncMock()
76+
mock_client_factory.disconnect = AsyncMock()
7777
mock_client_factory_class.return_value = mock_client_factory
7878

7979
mock_proxy = Mock()
@@ -145,7 +145,7 @@ async def test_setup_mcp_mode_no_retries(
145145
mock_create_transport.return_value = mock_transport
146146

147147
mock_client_factory = Mock()
148-
mock_client_factory.disconnect_all = AsyncMock()
148+
mock_client_factory.disconnect = AsyncMock()
149149
mock_client_factory_class.return_value = mock_client_factory
150150

151151
mock_proxy = Mock()
@@ -216,7 +216,7 @@ async def test_setup_mcp_mode_no_metadata_injects_aws_region(
216216
mock_create_transport.return_value = mock_transport
217217

218218
mock_client_factory = Mock()
219-
mock_client_factory.disconnect_all = AsyncMock()
219+
mock_client_factory.disconnect = AsyncMock()
220220
mock_client_factory_class.return_value = mock_client_factory
221221

222222
mock_proxy = Mock()
@@ -271,7 +271,7 @@ async def test_setup_mcp_mode_metadata_without_aws_region_injects_it(
271271
mock_create_transport.return_value = mock_transport
272272

273273
mock_client_factory = Mock()
274-
mock_client_factory.disconnect_all = AsyncMock()
274+
mock_client_factory.disconnect = AsyncMock()
275275
mock_client_factory_class.return_value = mock_client_factory
276276

277277
mock_proxy = Mock()

0 commit comments

Comments
 (0)