Skip to content

Commit 5b53194

Browse files
committed
fix: disconnect client in reverse order
1 parent 0f5f2b4 commit 5b53194

File tree

2 files changed

+51
-14
lines changed

2 files changed

+51
-14
lines changed

mcp_proxy_for_aws/proxy.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -130,7 +130,7 @@ class AWSMCPProxyClientFactory:
130130
def __init__(self, transport: ClientTransport) -> None:
131131
"""Initialize a client factory with transport."""
132132
self._transport = transport
133-
self._client = AWSMCPProxyClient(transport)
133+
self._client: AWSMCPProxyClient | None = None
134134
self._clients: list[AWSMCPProxyClient] = []
135135
self._initialize_request: InitializeRequest | None = None
136136

@@ -140,8 +140,9 @@ def set_init_params(self, initialize_request: InitializeRequest):
140140

141141
async def get_client(self) -> Client:
142142
"""Get client."""
143-
if not self._client.is_connected():
143+
if self._client is None or not self._client.is_connected():
144144
self._client = AWSMCPProxyClient(self._transport)
145+
self._clients.append(self._client)
145146

146147
return self._client
147148

tests/unit/test_proxy.py

Lines changed: 48 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -150,7 +150,8 @@ def test_client_factory_initialization():
150150
factory = AWSMCPProxyClientFactory(mock_transport)
151151

152152
assert factory._transport == mock_transport
153-
assert isinstance(factory._client, AWSMCPProxyClient)
153+
assert factory._client is None
154+
assert factory._clients == []
154155
assert factory._initialize_request is None
155156

156157

@@ -171,10 +172,12 @@ async def test_client_factory_get_client_when_connected():
171172
mock_transport = Mock(spec=ClientTransport)
172173
factory = AWSMCPProxyClientFactory(mock_transport)
173174

174-
factory._client.is_connected = Mock(return_value=True)
175+
mock_client = Mock(spec=AWSMCPProxyClient)
176+
mock_client.is_connected = Mock(return_value=True)
177+
factory._client = mock_client
175178

176179
client = await factory.get_client()
177-
assert client == factory._client
180+
assert client == mock_client
178181

179182

180183
@pytest.mark.asyncio
@@ -183,12 +186,14 @@ async def test_client_factory_get_client_when_disconnected():
183186
mock_transport = Mock(spec=ClientTransport)
184187
factory = AWSMCPProxyClientFactory(mock_transport)
185188

186-
old_client = factory._client
187-
factory._client.is_connected = Mock(return_value=False)
189+
mock_old_client = Mock(spec=AWSMCPProxyClient)
190+
mock_old_client.is_connected = Mock(return_value=False)
191+
factory._client = mock_old_client
188192

189193
client = await factory.get_client()
190-
assert client != old_client
194+
assert client != mock_old_client
191195
assert isinstance(client, AWSMCPProxyClient)
196+
assert client in factory._clients
192197

193198

194199
@pytest.mark.asyncio
@@ -197,10 +202,12 @@ async def test_client_factory_callable_interface():
197202
mock_transport = Mock(spec=ClientTransport)
198203
factory = AWSMCPProxyClientFactory(mock_transport)
199204

200-
factory._client.is_connected = Mock(return_value=True)
205+
mock_client = Mock(spec=AWSMCPProxyClient)
206+
mock_client.is_connected = Mock(return_value=True)
207+
factory._client = mock_client
201208

202209
client = await factory()
203-
assert client == factory._client
210+
assert client == mock_client
204211

205212

206213
@pytest.mark.asyncio
@@ -222,20 +229,49 @@ async def test_client_factory_disconnect_all():
222229
mock_client2._disconnect.assert_called_once_with(force=True)
223230

224231

232+
@pytest.mark.asyncio
233+
async def test_client_factory_disconnect_all_reverse_order():
234+
"""Test disconnect_all disconnects clients in reverse order."""
235+
mock_transport = Mock(spec=ClientTransport)
236+
factory = AWSMCPProxyClientFactory(mock_transport)
237+
238+
disconnect_order = []
239+
240+
mock_client1 = Mock()
241+
mock_client1._disconnect = AsyncMock(side_effect=lambda **kwargs: disconnect_order.append(1))
242+
mock_client2 = Mock()
243+
mock_client2._disconnect = AsyncMock(side_effect=lambda **kwargs: disconnect_order.append(2))
244+
mock_client3 = Mock()
245+
mock_client3._disconnect = AsyncMock(side_effect=lambda **kwargs: disconnect_order.append(3))
246+
247+
factory._clients = [mock_client1, mock_client2, mock_client3]
248+
249+
await factory.disconnect_all()
250+
251+
assert disconnect_order == [3, 2, 1]
252+
253+
225254
@pytest.mark.asyncio
226255
async def test_client_factory_disconnect_all_handles_exceptions():
227-
"""Test disconnect_all handles exceptions gracefully."""
256+
"""Test disconnect_all handles exceptions gracefully and continues in reverse order."""
228257
mock_transport = Mock(spec=ClientTransport)
229258
factory = AWSMCPProxyClientFactory(mock_transport)
230259

260+
disconnect_order = []
261+
231262
mock_client1 = Mock()
232-
mock_client1._disconnect = AsyncMock(side_effect=Exception('Disconnect failed'))
263+
mock_client1._disconnect = AsyncMock(side_effect=lambda **kwargs: disconnect_order.append(1))
233264
mock_client2 = Mock()
234-
mock_client2._disconnect = AsyncMock()
265+
mock_client2._disconnect = AsyncMock(side_effect=Exception('Disconnect failed'))
266+
mock_client3 = Mock()
267+
mock_client3._disconnect = AsyncMock(side_effect=lambda **kwargs: disconnect_order.append(3))
235268

236-
factory._clients = [mock_client1, mock_client2]
269+
factory._clients = [mock_client1, mock_client2, mock_client3]
237270

238271
await factory.disconnect_all()
239272

273+
# Verify client3 and client1 were disconnected despite client2 failing
274+
assert disconnect_order == [3, 1]
240275
mock_client1._disconnect.assert_called_once_with(force=True)
241276
mock_client2._disconnect.assert_called_once_with(force=True)
277+
mock_client3._disconnect.assert_called_once_with(force=True)

0 commit comments

Comments
 (0)