|
| 1 | +import pytest |
| 2 | +import uuid |
| 3 | +from uuid import UUID |
| 4 | +from unittest.mock import AsyncMock, MagicMock, patch |
| 5 | +from fastapi import HTTPException, Request |
| 6 | +from pydantic import ValidationError |
| 7 | +from anyio.streams.memory import MemoryObjectSendStream |
| 8 | + |
| 9 | +from fastapi_mcp.transport.sse import FastApiSseTransport |
| 10 | +from mcp.types import JSONRPCMessage, JSONRPCError |
| 11 | + |
| 12 | + |
| 13 | +@pytest.fixture |
| 14 | +def mock_transport() -> FastApiSseTransport: |
| 15 | + # Initialize transport with a mock endpoint |
| 16 | + transport = FastApiSseTransport("/messages") |
| 17 | + transport._read_stream_writers = {} |
| 18 | + return transport |
| 19 | + |
| 20 | + |
| 21 | +@pytest.fixture |
| 22 | +def valid_session_id(): |
| 23 | + session_id = uuid.uuid4() |
| 24 | + return session_id |
| 25 | + |
| 26 | + |
| 27 | +@pytest.fixture |
| 28 | +def mock_writer(): |
| 29 | + return AsyncMock(spec=MemoryObjectSendStream) |
| 30 | + |
| 31 | + |
| 32 | +@pytest.mark.anyio |
| 33 | +async def test_handle_post_message_missing_session_id(mock_transport: FastApiSseTransport) -> None: |
| 34 | + """Test handling a request with a missing session_id.""" |
| 35 | + # Create a mock request with no session_id |
| 36 | + mock_request = MagicMock(spec=Request) |
| 37 | + mock_request.query_params = {} |
| 38 | + |
| 39 | + # Check that the function raises HTTPException with the correct status code |
| 40 | + with pytest.raises(HTTPException) as excinfo: |
| 41 | + await mock_transport.handle_fastapi_post_message(mock_request) |
| 42 | + |
| 43 | + assert excinfo.value.status_code == 400 |
| 44 | + assert "session_id is required" in excinfo.value.detail |
| 45 | + |
| 46 | + |
| 47 | +@pytest.mark.anyio |
| 48 | +async def test_handle_post_message_invalid_session_id(mock_transport: FastApiSseTransport) -> None: |
| 49 | + """Test handling a request with an invalid session_id.""" |
| 50 | + # Create a mock request with an invalid session_id |
| 51 | + mock_request = MagicMock(spec=Request) |
| 52 | + mock_request.query_params = {"session_id": "not-a-valid-uuid"} |
| 53 | + |
| 54 | + # Check that the function raises HTTPException with the correct status code |
| 55 | + with pytest.raises(HTTPException) as excinfo: |
| 56 | + await mock_transport.handle_fastapi_post_message(mock_request) |
| 57 | + |
| 58 | + assert excinfo.value.status_code == 400 |
| 59 | + assert "Invalid session ID" in excinfo.value.detail |
| 60 | + |
| 61 | + |
| 62 | +@pytest.mark.anyio |
| 63 | +async def test_handle_post_message_session_not_found( |
| 64 | + mock_transport: FastApiSseTransport, valid_session_id: UUID |
| 65 | +) -> None: |
| 66 | + """Test handling a request with a valid session_id that doesn't exist.""" |
| 67 | + # Create a mock request with a valid session_id |
| 68 | + mock_request = MagicMock(spec=Request) |
| 69 | + mock_request.query_params = {"session_id": valid_session_id.hex} |
| 70 | + |
| 71 | + # The session_id is valid but not in the transport's writers |
| 72 | + with pytest.raises(HTTPException) as excinfo: |
| 73 | + await mock_transport.handle_fastapi_post_message(mock_request) |
| 74 | + |
| 75 | + assert excinfo.value.status_code == 404 |
| 76 | + assert "Could not find session" in excinfo.value.detail |
| 77 | + |
| 78 | + |
| 79 | +@pytest.mark.anyio |
| 80 | +async def test_handle_post_message_validation_error( |
| 81 | + mock_transport: FastApiSseTransport, valid_session_id: UUID, mock_writer: AsyncMock |
| 82 | +) -> None: |
| 83 | + """Test handling a request with invalid JSON that causes a ValidationError.""" |
| 84 | + # Set up the mock transport with a valid session |
| 85 | + mock_transport._read_stream_writers[valid_session_id] = mock_writer |
| 86 | + |
| 87 | + # Create a mock request with valid session_id but invalid body |
| 88 | + mock_request = MagicMock(spec=Request) |
| 89 | + mock_request.query_params = {"session_id": valid_session_id.hex} |
| 90 | + mock_request.body = AsyncMock(return_value=b'{"invalid": "json"}') |
| 91 | + |
| 92 | + # Mock BackgroundTasks |
| 93 | + with patch("fastapi_mcp.transport.sse.BackgroundTasks") as MockBackgroundTasks: |
| 94 | + mock_background_tasks = MockBackgroundTasks.return_value |
| 95 | + |
| 96 | + # Call the function |
| 97 | + response = await mock_transport.handle_fastapi_post_message(mock_request) |
| 98 | + |
| 99 | + # Verify response and background task setup |
| 100 | + assert response.status_code == 400 |
| 101 | + assert "error" in response.body.decode() if isinstance(response.body, bytes) else False |
| 102 | + assert mock_background_tasks.add_task.called |
| 103 | + assert response.background == mock_background_tasks |
| 104 | + |
| 105 | + |
| 106 | +@pytest.mark.anyio |
| 107 | +async def test_handle_post_message_general_exception( |
| 108 | + mock_transport: FastApiSseTransport, valid_session_id: UUID, mock_writer: AsyncMock |
| 109 | +) -> None: |
| 110 | + """Test handling a request that causes a general exception during body processing.""" |
| 111 | + # Set up the mock transport with a valid session |
| 112 | + mock_transport._read_stream_writers[valid_session_id] = mock_writer |
| 113 | + |
| 114 | + # Create a mock request that raises an exception when body is accessed |
| 115 | + mock_request = MagicMock(spec=Request) |
| 116 | + mock_request.query_params = {"session_id": valid_session_id.hex} |
| 117 | + |
| 118 | + # Instead of mocking the body method to raise an exception, |
| 119 | + # we'll patch the body method to return a normal value and then |
| 120 | + # patch JSONRPCMessage.model_validate_json to raise the exception |
| 121 | + mock_request.body = AsyncMock(return_value=b'{"jsonrpc": "2.0", "method": "test", "id": "1"}') |
| 122 | + |
| 123 | + # Mock the model_validate_json method to raise an Exception |
| 124 | + with patch("mcp.types.JSONRPCMessage.model_validate_json", side_effect=Exception("Test exception")): |
| 125 | + # Check that the function raises HTTPException with the correct status code |
| 126 | + with pytest.raises(HTTPException) as excinfo: |
| 127 | + await mock_transport.handle_fastapi_post_message(mock_request) |
| 128 | + |
| 129 | + assert excinfo.value.status_code == 400 |
| 130 | + assert "Invalid request body" in excinfo.value.detail |
| 131 | + |
| 132 | + |
| 133 | +@pytest.mark.anyio |
| 134 | +async def test_send_message_safely_with_validation_error( |
| 135 | + mock_transport: FastApiSseTransport, mock_writer: AsyncMock |
| 136 | +) -> None: |
| 137 | + """Test sending a ValidationError message safely.""" |
| 138 | + # Create a minimal validation error manually instead of using from_exception_data |
| 139 | + mock_validation_error = MagicMock(spec=ValidationError) |
| 140 | + mock_validation_error.__str__.return_value = "Mock validation error" # type: ignore |
| 141 | + |
| 142 | + # Call the function |
| 143 | + await mock_transport._send_message_safely(mock_writer, mock_validation_error) |
| 144 | + |
| 145 | + # Verify that the writer.send was called with a JSONRPCError |
| 146 | + assert mock_writer.send.called |
| 147 | + sent_message = mock_writer.send.call_args[0][0] |
| 148 | + assert isinstance(sent_message, JSONRPCMessage) |
| 149 | + assert isinstance(sent_message.root, JSONRPCError) |
| 150 | + assert sent_message.root.error.code == -32700 # Parse error code |
| 151 | + |
| 152 | + |
| 153 | +@pytest.mark.anyio |
| 154 | +async def test_send_message_safely_with_jsonrpc_message( |
| 155 | + mock_transport: FastApiSseTransport, mock_writer: AsyncMock |
| 156 | +) -> None: |
| 157 | + """Test sending a JSONRPCMessage safely.""" |
| 158 | + # Create a JSONRPCMessage |
| 159 | + message = JSONRPCMessage.model_validate({"jsonrpc": "2.0", "id": "123", "method": "test_method", "params": {}}) |
| 160 | + |
| 161 | + # Call the function |
| 162 | + await mock_transport._send_message_safely(mock_writer, message) |
| 163 | + |
| 164 | + # Verify that the writer.send was called with the message |
| 165 | + assert mock_writer.send.called |
| 166 | + sent_message = mock_writer.send.call_args[0][0] |
| 167 | + assert sent_message == message |
| 168 | + |
| 169 | + |
| 170 | +@pytest.mark.anyio |
| 171 | +async def test_send_message_safely_exception_handling( |
| 172 | + mock_transport: FastApiSseTransport, mock_writer: AsyncMock |
| 173 | +) -> None: |
| 174 | + """Test exception handling when sending a message.""" |
| 175 | + # Set up the writer to raise an exception |
| 176 | + mock_writer.send.side_effect = Exception("Test exception") |
| 177 | + |
| 178 | + # Create a message |
| 179 | + message = JSONRPCMessage.model_validate({"jsonrpc": "2.0", "id": "123", "method": "test_method", "params": {}}) |
| 180 | + |
| 181 | + # Call the function - it should not raise an exception |
| 182 | + await mock_transport._send_message_safely(mock_writer, message) |
| 183 | + |
| 184 | + # Verify that the writer.send was called |
| 185 | + assert mock_writer.send.called |
0 commit comments