|
1 | 1 | """Cancellation interactions against the low-level Server, driven through the public Client API. |
2 | 2 |
|
3 | | -There is no client-side cancellation API: cancelling means sending a CancelledNotification |
4 | | -carrying the request id, which only the server-side handler can observe (`ctx.request_id`), so |
5 | | -these tests capture the id from inside the blocked handler before cancelling. The handler blocks |
6 | | -on an Event rather than a sleep, and every wait is bounded by `anyio.fail_after`. |
| 3 | +Client-side cancellation is cancelling the caller's scope around an in-flight call; the |
| 4 | +dispatcher then sends the courtesy notifications/cancelled. The receiving-side tests instead |
| 5 | +drive the wire act directly -- sending a CancelledNotification carrying the request id, which |
| 6 | +only the server-side handler can observe (`ctx.request_id`) -- so they capture the id from |
| 7 | +inside the blocked handler before cancelling. Handlers block on an Event rather than a sleep, |
| 8 | +and every wait is bounded by `anyio.fail_after`. |
7 | 9 | """ |
8 | 10 |
|
9 | 11 | import anyio |
|
27 | 29 |
|
28 | 30 | from mcp import MCPError |
29 | 31 | from mcp.client import ClientRequestContext, ClientSession |
| 32 | +from mcp.client._memory import InMemoryTransport |
| 33 | +from mcp.client.client import Client |
30 | 34 | from mcp.server import Server, ServerRequestContext |
31 | 35 | from mcp.shared.memory import MessageStream, create_client_server_memory_streams |
32 | 36 | from mcp.shared.message import SessionMessage |
33 | 37 | from tests.interaction._connect import Connect |
34 | | -from tests.interaction._helpers import IncomingMessage |
| 38 | +from tests.interaction._helpers import IncomingMessage, RecordingTransport |
35 | 39 | from tests.interaction._requirements import requirement |
36 | 40 |
|
37 | 41 | pytestmark = pytest.mark.anyio |
38 | 42 |
|
39 | 43 |
|
| 44 | +@requirement("protocol:cancel:abort-signal") |
| 45 | +async def test_cancelling_the_callers_scope_sends_cancelled_and_abandons_the_call() -> None: |
| 46 | + """Cancelling the scope around an in-flight call sends notifications/cancelled and the call never returns. |
| 47 | +
|
| 48 | + Spec-mandated (cancellation flow): the sender of a cancelled request issues |
| 49 | + notifications/cancelled referencing its id. Legacy-era act: at 2026-07-28 the wire act splits |
| 50 | + by transport (see the manifest entry's note). The wire is observed at the recording-transport |
| 51 | + seam; the reason string is the SDK's own deliberate output. |
| 52 | + """ |
| 53 | + handler_started = anyio.Event() |
| 54 | + |
| 55 | + async def call_tool(ctx: ServerRequestContext, params: types.CallToolRequestParams) -> CallToolResult: |
| 56 | + assert params.name == "block" |
| 57 | + handler_started.set() |
| 58 | + await anyio.Event().wait() # blocks until the courtesy cancellation interrupts it |
| 59 | + raise NotImplementedError # unreachable: the wait above never completes normally |
| 60 | + |
| 61 | + server = Server("blocker", on_call_tool=call_tool) |
| 62 | + recording = RecordingTransport(InMemoryTransport(server)) |
| 63 | + |
| 64 | + async with Client(recording, mode="legacy") as client: |
| 65 | + with anyio.fail_after(5): |
| 66 | + async with anyio.create_task_group() as task_group: # pragma: no branch |
| 67 | + |
| 68 | + async def call() -> None: |
| 69 | + await client.call_tool("block", {}) |
| 70 | + raise NotImplementedError # unreachable: the surrounding scope is cancelled mid-flight |
| 71 | + |
| 72 | + task_group.start_soon(call) |
| 73 | + await handler_started.wait() |
| 74 | + task_group.cancel_scope.cancel() |
| 75 | + |
| 76 | + (call_request,) = [ |
| 77 | + item.message |
| 78 | + for item in recording.sent |
| 79 | + if isinstance(item.message, JSONRPCRequest) and item.message.method == "tools/call" |
| 80 | + ] |
| 81 | + (cancellation,) = [ |
| 82 | + item.message |
| 83 | + for item in recording.sent |
| 84 | + if isinstance(item.message, JSONRPCNotification) and item.message.method == "notifications/cancelled" |
| 85 | + ] |
| 86 | + assert cancellation.params == snapshot({"requestId": 2, "reason": "caller cancelled"}) |
| 87 | + assert cancellation.params is not None and cancellation.params["requestId"] == call_request.id |
| 88 | + |
| 89 | + |
40 | 90 | @requirement("protocol:cancel:in-flight") |
41 | 91 | @requirement("protocol:cancel:handler-abort-propagates") |
42 | 92 | async def test_cancellation_stops_in_flight_handler(connect: Connect) -> None: |
@@ -87,6 +137,77 @@ async def call_and_capture_error() -> None: |
87 | 137 | assert errors == snapshot([ErrorData(code=0, message="Request cancelled")]) |
88 | 138 |
|
89 | 139 |
|
| 140 | +@requirement("protocol:cancel:in-flight") |
| 141 | +async def test_client_answers_a_cancelled_server_initiated_request_with_the_code_zero_error(connect: Connect) -> None: |
| 142 | + """Cancelling a server-initiated request interrupts the client's callback, and the client |
| 143 | + answers with the code-0 error -- the client half of the divergence on this requirement (the |
| 144 | + spec says the receiver should not respond at all). The server cancels its own sampling |
| 145 | + request while still awaiting it, so the client's answer is observed as the awaited call's |
| 146 | + failure; the whole exchange sits under one fail_after, so a silent client fails the test |
| 147 | + instead of hanging it. |
| 148 | + """ |
| 149 | + callback_started = anyio.Event() |
| 150 | + callback_cancelled = anyio.Event() |
| 151 | + client_request_ids: list[types.RequestId] = [] |
| 152 | + errors: list[ErrorData] = [] |
| 153 | + |
| 154 | + async def sampling_callback( |
| 155 | + context: ClientRequestContext, params: types.CreateMessageRequestParams |
| 156 | + ) -> types.CreateMessageResult: |
| 157 | + client_request_ids.append(context.request_id) |
| 158 | + callback_started.set() |
| 159 | + try: |
| 160 | + await anyio.Event().wait() # blocks until the cancellation interrupts it |
| 161 | + except anyio.get_cancelled_exc_class(): |
| 162 | + callback_cancelled.set() |
| 163 | + raise |
| 164 | + raise NotImplementedError # unreachable |
| 165 | + |
| 166 | + async def list_tools( |
| 167 | + ctx: ServerRequestContext, params: types.PaginatedRequestParams | None |
| 168 | + ) -> types.ListToolsResult: |
| 169 | + return types.ListToolsResult(tools=[types.Tool(name="canceller", input_schema={"type": "object"})]) |
| 170 | + |
| 171 | + async def call_tool(ctx: ServerRequestContext, params: types.CallToolRequestParams) -> CallToolResult: |
| 172 | + assert params.name == "canceller" |
| 173 | + request = types.CreateMessageRequest( |
| 174 | + params=types.CreateMessageRequestParams( |
| 175 | + messages=[types.SamplingMessage(role="user", content=TextContent(text="Say hello."))], |
| 176 | + max_tokens=8, |
| 177 | + ) |
| 178 | + ) |
| 179 | + with anyio.fail_after(5): |
| 180 | + async with anyio.create_task_group() as task_group: |
| 181 | + |
| 182 | + async def sample_and_capture_error() -> None: |
| 183 | + with pytest.raises(MCPError) as exc_info: |
| 184 | + await ctx.session.send_request(request, types.CreateMessageResult) |
| 185 | + errors.append(exc_info.value.error) |
| 186 | + |
| 187 | + task_group.start_soon(sample_and_capture_error) |
| 188 | + await callback_started.wait() |
| 189 | + await ctx.session.send_notification( |
| 190 | + types.CancelledNotification( |
| 191 | + params=types.CancelledNotificationParams( |
| 192 | + request_id=client_request_ids[0], reason="user aborted" |
| 193 | + ) |
| 194 | + ), |
| 195 | + related_request_id=ctx.request_id, |
| 196 | + ) |
| 197 | + # The join above completes only when the client's answer arrives; the enclosing |
| 198 | + # fail_after turns a silent client into a TimeoutError -- a failed test, not a hang. |
| 199 | + await callback_cancelled.wait() |
| 200 | + return CallToolResult(content=[TextContent(text="cancelled")]) |
| 201 | + |
| 202 | + server = Server("canceller", on_list_tools=list_tools, on_call_tool=call_tool) |
| 203 | + |
| 204 | + async with connect(server, sampling_callback=sampling_callback) as client: |
| 205 | + result = await client.call_tool("canceller", {}) |
| 206 | + |
| 207 | + assert result == snapshot(CallToolResult(content=[TextContent(text="cancelled")])) |
| 208 | + assert errors == snapshot([ErrorData(code=0, message="Request cancelled")]) |
| 209 | + |
| 210 | + |
90 | 211 | @requirement("protocol:cancel:no-further-notifications") |
91 | 212 | async def test_no_notifications_for_a_request_arrive_after_its_cancellation(connect: Connect) -> None: |
92 | 213 | """After a request is cancelled, no further notifications for it reach the wire (spec-mandated). |
|
0 commit comments