diff --git a/src/strands/tools/mcp/mcp_client.py b/src/strands/tools/mcp/mcp_client.py index dec8ec313..4689096a0 100644 --- a/src/strands/tools/mcp/mcp_client.py +++ b/src/strands/tools/mcp/mcp_client.py @@ -26,7 +26,7 @@ from mcp.types import TextContent as MCPTextContent from ...types import PaginatedList -from ...types.exceptions import MCPClientInitializationError +from ...types.exceptions import MCPClientInitializationError, MCPConnectionError from ...types.media import ImageFormat from ...types.tools import ToolResultContent, ToolResultStatus from .mcp_agent_tool import MCPAgentTool @@ -86,6 +86,8 @@ def __init__(self, transport_callable: Callable[[], MCPTransport], *, startup_ti self._background_thread: threading.Thread | None = None self._background_thread_session: ClientSession | None = None self._background_thread_event_loop: AbstractEventLoop | None = None + self._last_runtime_exception: Exception | None = None + self._pending_futures: set[futures.Future] = set() def __enter__(self) -> "MCPClient": """Context manager entry point which initializes the MCP server connection. @@ -177,6 +179,11 @@ async def _set_close_event() -> None: self._log_debug_with_thread("waiting for background thread to join") self._background_thread.join() + + # Fail any remaining pending futures + if self._pending_futures: + self._fail_pending_futures(MCPConnectionError("MCP client was stopped while operations were pending")) + self._log_debug_with_thread("background thread is closed, MCPClient context exited") # Reset fields to allow instance reuse @@ -186,6 +193,7 @@ async def _set_close_event() -> None: self._background_thread_session = None self._background_thread_event_loop = None self._session_id = uuid.uuid4() + self._pending_futures.clear() def list_tools_sync(self, pagination_token: Optional[str] = None) -> PaginatedList[MCPAgentTool]: """Synchronously retrieves the list of available tools from the MCP server. @@ -293,6 +301,9 @@ async def _call_tool_async() -> MCPCallToolResult: try: call_tool_result: MCPCallToolResult = self._invoke_on_background_thread(_call_tool_async()).result() return self._handle_tool_result(tool_use_id, call_tool_result) + except MCPConnectionError as e: + logger.exception("MCP background thread failure during tool call: %s", e) + return self._handle_tool_execution_error(tool_use_id, e) except Exception as e: logger.exception("tool execution failed") return self._handle_tool_execution_error(tool_use_id, e) @@ -331,6 +342,9 @@ async def _call_tool_async() -> MCPCallToolResult: future = self._invoke_on_background_thread(_call_tool_async()) call_tool_result: MCPCallToolResult = await asyncio.wrap_future(future) return self._handle_tool_result(tool_use_id, call_tool_result) + except MCPConnectionError as e: + logger.exception("MCP background thread failure during async tool call: %s", e) + return self._handle_tool_execution_error(tool_use_id, e) except Exception as e: logger.exception("tool execution failed") return self._handle_tool_execution_error(tool_use_id, e) @@ -419,9 +433,13 @@ async def _async_background_thread(self) -> None: if not self._init_future.done(): self._init_future.set_exception(e) else: + # Store the exception for potential recovery handling + self._last_runtime_exception = e self._log_debug_with_thread( "encountered exception on background thread after initialization %s", str(e) ) + # Fail all pending futures so callers don't hang forever + self._fail_pending_futures(e) def _background_task(self) -> None: """Sets up and runs the event loop in the background thread. @@ -476,7 +494,35 @@ def _log_debug_with_thread(self, msg: str, *args: Any, **kwargs: Any) -> None: def _invoke_on_background_thread(self, coro: Coroutine[Any, Any, T]) -> futures.Future[T]: if self._background_thread_session is None or self._background_thread_event_loop is None: raise MCPClientInitializationError("the client session was not initialized") - return asyncio.run_coroutine_threadsafe(coro=coro, loop=self._background_thread_event_loop) + future = asyncio.run_coroutine_threadsafe(coro=coro, loop=self._background_thread_event_loop) + + self._pending_futures.add(future) + + # Remove Future tracking when it completes normally + def cleanup_future(f: futures.Future) -> None: + self._pending_futures.discard(f) + + future.add_done_callback(cleanup_future) + return future def _is_session_active(self) -> bool: return self._background_thread is not None and self._background_thread.is_alive() + + def _fail_pending_futures(self, exception: Exception) -> None: + """Fail all pending futures with the given exception. + + This is called when the background thread encounters a fatal error, + ensuring that any threads waiting on futures don't hang forever. + """ + self._log_debug_with_thread( + "Failing %d pending futures due to background thread exception", len(self._pending_futures) + ) + + for future in list(self._pending_futures): + if not future.done(): + try: + future.set_exception(MCPConnectionError(f"MCP background thread died: {exception}")) + except Exception as e: + self._log_debug_with_thread("Failed to set exception on future: %s", str(e)) + + self._pending_futures.clear() diff --git a/src/strands/types/exceptions.py b/src/strands/types/exceptions.py index 90f2b8d7f..a91ab4409 100644 --- a/src/strands/types/exceptions.py +++ b/src/strands/types/exceptions.py @@ -49,6 +49,15 @@ class ContextWindowOverflowException(Exception): class MCPClientInitializationError(Exception): """Raised when the MCP server fails to initialize properly.""" + +class MCPConnectionError(Exception): + """Raised when the MCP connection fails during runtime. + + This exception indicates that the MCP background thread has died or + the connection has been lost after successful initialization. This is + different from MCPClientInitializationError which occurs during startup. + """ + pass