Skip to content

Commit 4f8daf6

Browse files
committed
add sse transport tests
1 parent 3dd13de commit 4f8daf6

File tree

3 files changed

+158
-6
lines changed

3 files changed

+158
-6
lines changed

tests/fixtures/complex_app.py

Lines changed: 14 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -8,20 +8,17 @@
88
Product,
99
Customer,
1010
OrderResponse,
11-
Address,
1211
PaginatedResponse,
1312
ProductCategory,
1413
OrderRequest,
1514
ErrorResponse,
1615
)
1716

1817

19-
@pytest.fixture
20-
def complex_fastapi_app(
18+
def make_complex_fastapi_app(
2119
example_product: Product,
2220
example_customer: Customer,
2321
example_order_response: OrderResponse,
24-
example_address: Address,
2522
) -> FastAPI:
2623
app = FastAPI(
2724
title="Complex E-Commerce API",
@@ -129,3 +126,16 @@ async def get_customer(
129126
return customer_copy
130127

131128
return app
129+
130+
131+
@pytest.fixture
132+
def complex_fastapi_app(
133+
example_product: Product,
134+
example_customer: Customer,
135+
example_order_response: OrderResponse,
136+
) -> FastAPI:
137+
return make_complex_fastapi_app(
138+
example_product=example_product,
139+
example_customer=example_customer,
140+
example_order_response=example_order_response,
141+
)

tests/fixtures/simple_app.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,8 +6,7 @@
66
from .types import Item
77

88

9-
@pytest.fixture
10-
def simple_fastapi_app() -> FastAPI:
9+
def make_simple_fastapi_app() -> FastAPI:
1110
app = FastAPI(
1211
title="Test API",
1312
description="A test API app for unit testing",
@@ -66,3 +65,8 @@ async def raise_error() -> None:
6665
raise Exception("This is a test error")
6766

6867
return app
68+
69+
70+
@pytest.fixture
71+
def simple_fastapi_app() -> FastAPI:
72+
return make_simple_fastapi_app()

tests/test_sse_transport.py

Lines changed: 138 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,138 @@
1+
import anyio
2+
import multiprocessing
3+
import socket
4+
import time
5+
from typing import AsyncGenerator, Generator
6+
from mcp.client.session import ClientSession
7+
from mcp.client.sse import sse_client
8+
from mcp import InitializeResult
9+
from mcp.types import EmptyResult, CallToolResult, ListToolsResult
10+
import pytest
11+
import httpx
12+
import uvicorn
13+
from fastapi_mcp import FastApiMCP
14+
15+
from .fixtures.simple_app import make_simple_fastapi_app
16+
17+
18+
HOST = "127.0.0.1"
19+
SERVER_NAME = "Test MCP Server"
20+
21+
22+
@pytest.fixture
23+
def server_port() -> int:
24+
with socket.socket() as s:
25+
s.bind((HOST, 0))
26+
return s.getsockname()[1]
27+
28+
29+
@pytest.fixture
30+
def server_url(server_port: int) -> str:
31+
return f"http://{HOST}:{server_port}"
32+
33+
34+
def run_server(server_port: int) -> None:
35+
fastapi = make_simple_fastapi_app()
36+
mcp = FastApiMCP(
37+
fastapi,
38+
name=SERVER_NAME,
39+
description="Test description",
40+
)
41+
mcp.mount()
42+
43+
server = uvicorn.Server(config=uvicorn.Config(app=fastapi, host=HOST, port=server_port, log_level="error"))
44+
server.run()
45+
46+
# Give server time to start
47+
while not server.started:
48+
time.sleep(0.5)
49+
50+
51+
@pytest.fixture()
52+
def server(server_port: int) -> Generator[None, None, None]:
53+
proc = multiprocessing.Process(target=run_server, kwargs={"server_port": server_port}, daemon=True)
54+
proc.start()
55+
56+
# Wait for server to be running
57+
max_attempts = 20
58+
attempt = 0
59+
while attempt < max_attempts:
60+
try:
61+
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
62+
s.connect((HOST, server_port))
63+
break
64+
except ConnectionRefusedError:
65+
time.sleep(0.1)
66+
attempt += 1
67+
else:
68+
raise RuntimeError(f"Server failed to start after {max_attempts} attempts")
69+
70+
yield
71+
72+
# Signal the server to stop
73+
proc.kill()
74+
proc.join(timeout=2)
75+
if proc.is_alive():
76+
raise RuntimeError("server process failed to terminate")
77+
78+
79+
@pytest.fixture()
80+
async def http_client(server: None, server_url: str) -> AsyncGenerator[httpx.AsyncClient, None]:
81+
async with httpx.AsyncClient(base_url=server_url) as client:
82+
yield client
83+
84+
85+
@pytest.mark.anyio
86+
async def test_raw_sse_connection(http_client: httpx.AsyncClient) -> None:
87+
"""Test the SSE connection establishment simply with an HTTP client."""
88+
async with anyio.create_task_group():
89+
90+
async def connection_test() -> None:
91+
async with http_client.stream("GET", "/mcp") as response:
92+
assert response.status_code == 200
93+
assert response.headers["content-type"] == "text/event-stream; charset=utf-8"
94+
95+
line_number = 0
96+
async for line in response.aiter_lines():
97+
if line_number == 0:
98+
assert line == "event: endpoint"
99+
elif line_number == 1:
100+
assert line.startswith("data: /mcp/messages/?session_id=")
101+
else:
102+
return
103+
line_number += 1
104+
105+
# Add timeout to prevent test from hanging if it fails
106+
with anyio.fail_after(3):
107+
await connection_test()
108+
109+
110+
@pytest.mark.anyio
111+
async def test_sse_basic_connection(server: None, server_url: str) -> None:
112+
async with sse_client(server_url + "/mcp") as streams:
113+
async with ClientSession(*streams) as session:
114+
# Test initialization
115+
result = await session.initialize()
116+
assert isinstance(result, InitializeResult)
117+
assert result.serverInfo.name == SERVER_NAME
118+
119+
# Test ping
120+
ping_result = await session.send_ping()
121+
assert isinstance(ping_result, EmptyResult)
122+
123+
124+
@pytest.mark.anyio
125+
async def test_sse_tool_call(server: None, server_url: str) -> None:
126+
async with sse_client(server_url + "/mcp") as streams:
127+
async with ClientSession(*streams) as session:
128+
await session.initialize()
129+
130+
tools_list_result = await session.list_tools()
131+
assert isinstance(tools_list_result, ListToolsResult)
132+
assert len(tools_list_result.tools) > 0
133+
134+
tool_call_result = await session.call_tool("get_item", {"item_id": 1})
135+
assert isinstance(tool_call_result, CallToolResult)
136+
assert not tool_call_result.isError
137+
assert tool_call_result.content is not None
138+
assert len(tool_call_result.content) > 0

0 commit comments

Comments
 (0)