Skip to content

Commit dacee48

Browse files
committed
improve transport test coverage
1 parent 746c993 commit dacee48

File tree

2 files changed

+185
-0
lines changed

2 files changed

+185
-0
lines changed

tests/test_sse_mock_transport.py

Lines changed: 185 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,185 @@
1+
import pytest
2+
import uuid
3+
from uuid import UUID
4+
from unittest.mock import AsyncMock, MagicMock, patch
5+
from fastapi import HTTPException, Request
6+
from pydantic import ValidationError
7+
from anyio.streams.memory import MemoryObjectSendStream
8+
9+
from fastapi_mcp.transport.sse import FastApiSseTransport
10+
from mcp.types import JSONRPCMessage, JSONRPCError
11+
12+
13+
@pytest.fixture
14+
def mock_transport() -> FastApiSseTransport:
15+
# Initialize transport with a mock endpoint
16+
transport = FastApiSseTransport("/messages")
17+
transport._read_stream_writers = {}
18+
return transport
19+
20+
21+
@pytest.fixture
22+
def valid_session_id():
23+
session_id = uuid.uuid4()
24+
return session_id
25+
26+
27+
@pytest.fixture
28+
def mock_writer():
29+
return AsyncMock(spec=MemoryObjectSendStream)
30+
31+
32+
@pytest.mark.anyio
33+
async def test_handle_post_message_missing_session_id(mock_transport: FastApiSseTransport) -> None:
34+
"""Test handling a request with a missing session_id."""
35+
# Create a mock request with no session_id
36+
mock_request = MagicMock(spec=Request)
37+
mock_request.query_params = {}
38+
39+
# Check that the function raises HTTPException with the correct status code
40+
with pytest.raises(HTTPException) as excinfo:
41+
await mock_transport.handle_fastapi_post_message(mock_request)
42+
43+
assert excinfo.value.status_code == 400
44+
assert "session_id is required" in excinfo.value.detail
45+
46+
47+
@pytest.mark.anyio
48+
async def test_handle_post_message_invalid_session_id(mock_transport: FastApiSseTransport) -> None:
49+
"""Test handling a request with an invalid session_id."""
50+
# Create a mock request with an invalid session_id
51+
mock_request = MagicMock(spec=Request)
52+
mock_request.query_params = {"session_id": "not-a-valid-uuid"}
53+
54+
# Check that the function raises HTTPException with the correct status code
55+
with pytest.raises(HTTPException) as excinfo:
56+
await mock_transport.handle_fastapi_post_message(mock_request)
57+
58+
assert excinfo.value.status_code == 400
59+
assert "Invalid session ID" in excinfo.value.detail
60+
61+
62+
@pytest.mark.anyio
63+
async def test_handle_post_message_session_not_found(
64+
mock_transport: FastApiSseTransport, valid_session_id: UUID
65+
) -> None:
66+
"""Test handling a request with a valid session_id that doesn't exist."""
67+
# Create a mock request with a valid session_id
68+
mock_request = MagicMock(spec=Request)
69+
mock_request.query_params = {"session_id": valid_session_id.hex}
70+
71+
# The session_id is valid but not in the transport's writers
72+
with pytest.raises(HTTPException) as excinfo:
73+
await mock_transport.handle_fastapi_post_message(mock_request)
74+
75+
assert excinfo.value.status_code == 404
76+
assert "Could not find session" in excinfo.value.detail
77+
78+
79+
@pytest.mark.anyio
80+
async def test_handle_post_message_validation_error(
81+
mock_transport: FastApiSseTransport, valid_session_id: UUID, mock_writer: AsyncMock
82+
) -> None:
83+
"""Test handling a request with invalid JSON that causes a ValidationError."""
84+
# Set up the mock transport with a valid session
85+
mock_transport._read_stream_writers[valid_session_id] = mock_writer
86+
87+
# Create a mock request with valid session_id but invalid body
88+
mock_request = MagicMock(spec=Request)
89+
mock_request.query_params = {"session_id": valid_session_id.hex}
90+
mock_request.body = AsyncMock(return_value=b'{"invalid": "json"}')
91+
92+
# Mock BackgroundTasks
93+
with patch("fastapi_mcp.transport.sse.BackgroundTasks") as MockBackgroundTasks:
94+
mock_background_tasks = MockBackgroundTasks.return_value
95+
96+
# Call the function
97+
response = await mock_transport.handle_fastapi_post_message(mock_request)
98+
99+
# Verify response and background task setup
100+
assert response.status_code == 400
101+
assert "error" in response.body.decode() if isinstance(response.body, bytes) else False
102+
assert mock_background_tasks.add_task.called
103+
assert response.background == mock_background_tasks
104+
105+
106+
@pytest.mark.anyio
107+
async def test_handle_post_message_general_exception(
108+
mock_transport: FastApiSseTransport, valid_session_id: UUID, mock_writer: AsyncMock
109+
) -> None:
110+
"""Test handling a request that causes a general exception during body processing."""
111+
# Set up the mock transport with a valid session
112+
mock_transport._read_stream_writers[valid_session_id] = mock_writer
113+
114+
# Create a mock request that raises an exception when body is accessed
115+
mock_request = MagicMock(spec=Request)
116+
mock_request.query_params = {"session_id": valid_session_id.hex}
117+
118+
# Instead of mocking the body method to raise an exception,
119+
# we'll patch the body method to return a normal value and then
120+
# patch JSONRPCMessage.model_validate_json to raise the exception
121+
mock_request.body = AsyncMock(return_value=b'{"jsonrpc": "2.0", "method": "test", "id": "1"}')
122+
123+
# Mock the model_validate_json method to raise an Exception
124+
with patch("mcp.types.JSONRPCMessage.model_validate_json", side_effect=Exception("Test exception")):
125+
# Check that the function raises HTTPException with the correct status code
126+
with pytest.raises(HTTPException) as excinfo:
127+
await mock_transport.handle_fastapi_post_message(mock_request)
128+
129+
assert excinfo.value.status_code == 400
130+
assert "Invalid request body" in excinfo.value.detail
131+
132+
133+
@pytest.mark.anyio
134+
async def test_send_message_safely_with_validation_error(
135+
mock_transport: FastApiSseTransport, mock_writer: AsyncMock
136+
) -> None:
137+
"""Test sending a ValidationError message safely."""
138+
# Create a minimal validation error manually instead of using from_exception_data
139+
mock_validation_error = MagicMock(spec=ValidationError)
140+
mock_validation_error.__str__.return_value = "Mock validation error" # type: ignore
141+
142+
# Call the function
143+
await mock_transport._send_message_safely(mock_writer, mock_validation_error)
144+
145+
# Verify that the writer.send was called with a JSONRPCError
146+
assert mock_writer.send.called
147+
sent_message = mock_writer.send.call_args[0][0]
148+
assert isinstance(sent_message, JSONRPCMessage)
149+
assert isinstance(sent_message.root, JSONRPCError)
150+
assert sent_message.root.error.code == -32700 # Parse error code
151+
152+
153+
@pytest.mark.anyio
154+
async def test_send_message_safely_with_jsonrpc_message(
155+
mock_transport: FastApiSseTransport, mock_writer: AsyncMock
156+
) -> None:
157+
"""Test sending a JSONRPCMessage safely."""
158+
# Create a JSONRPCMessage
159+
message = JSONRPCMessage.model_validate({"jsonrpc": "2.0", "id": "123", "method": "test_method", "params": {}})
160+
161+
# Call the function
162+
await mock_transport._send_message_safely(mock_writer, message)
163+
164+
# Verify that the writer.send was called with the message
165+
assert mock_writer.send.called
166+
sent_message = mock_writer.send.call_args[0][0]
167+
assert sent_message == message
168+
169+
170+
@pytest.mark.anyio
171+
async def test_send_message_safely_exception_handling(
172+
mock_transport: FastApiSseTransport, mock_writer: AsyncMock
173+
) -> None:
174+
"""Test exception handling when sending a message."""
175+
# Set up the writer to raise an exception
176+
mock_writer.send.side_effect = Exception("Test exception")
177+
178+
# Create a message
179+
message = JSONRPCMessage.model_validate({"jsonrpc": "2.0", "id": "123", "method": "test_method", "params": {}})
180+
181+
# Call the function - it should not raise an exception
182+
await mock_transport._send_message_safely(mock_writer, message)
183+
184+
# Verify that the writer.send was called
185+
assert mock_writer.send.called
File renamed without changes.

0 commit comments

Comments
 (0)