Skip to content

Commit 1c2bd2d

Browse files
committed
feat(router): use aiohttp session for health checks
Signed-off-by: Nejc Habjan <nejc.habjan@siemens.com>
1 parent ab2c023 commit 1c2bd2d

File tree

7 files changed

+2516
-280
lines changed

7 files changed

+2516
-280
lines changed

pyproject.toml

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -69,5 +69,6 @@ lint = [
6969
]
7070
test = [
7171
"pytest>=8.3.4",
72-
"pytest-asyncio>=0.25.3"
72+
"pytest-aiohttp>=1.1.0",
73+
"pytest-asyncio>=0.25.3",
7374
]

src/tests/conftest.py

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,28 @@
1+
from typing import Any, AsyncGenerator, Callable
2+
from unittest.mock import MagicMock
3+
4+
import aiohttp
5+
import pytest
6+
from aiohttp import web
7+
from fastapi import FastAPI
8+
9+
10+
@pytest.fixture
11+
async def mock_app() -> AsyncGenerator[FastAPI]:
12+
mock_app = MagicMock()
13+
async with aiohttp.ClientSession() as session:
14+
mock_app.state.aiohttp_client_wrapper = MagicMock(return_value=session)
15+
yield mock_app
16+
17+
18+
@pytest.fixture
19+
async def make_mock_engine(aiohttp_client: Any) -> Callable[[dict[str, Callable]], str]:
20+
async def _make_mock_engine(routes: dict[str, Callable]) -> str:
21+
app = web.Application()
22+
for path, handler in routes.items():
23+
app.router.add_post(path, handler)
24+
25+
client = await aiohttp_client(app)
26+
return str(client.make_url(""))
27+
28+
return _make_mock_engine

src/tests/test_static_service_discovery.py

Lines changed: 52 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,15 @@
1-
from unittest.mock import MagicMock
1+
import hashlib
2+
from unittest.mock import AsyncMock, MagicMock
23

34
import pytest
5+
from aiohttp import web
6+
from fastapi import FastAPI
47

58
from vllm_router.service_discovery import StaticServiceDiscovery
69

710

811
def test_init_when_static_backend_health_checks_calls_start_health_checks(
12+
mock_app: FastAPI,
913
monkeypatch: pytest.MonkeyPatch,
1014
) -> None:
1115
start_health_check_mock = MagicMock()
@@ -14,7 +18,7 @@ def test_init_when_static_backend_health_checks_calls_start_health_checks(
1418
start_health_check_mock,
1519
)
1620
discovery_instance = StaticServiceDiscovery(
17-
None,
21+
mock_app,
1822
[],
1923
[],
2024
None,
@@ -28,6 +32,7 @@ def test_init_when_static_backend_health_checks_calls_start_health_checks(
2832

2933

3034
def test_init_when_endpoint_health_check_disabled_does_not_call_start_health_checks(
35+
mock_app: FastAPI,
3136
monkeypatch: pytest.MonkeyPatch,
3237
) -> None:
3338
start_health_check_mock = MagicMock()
@@ -36,7 +41,7 @@ def test_init_when_endpoint_health_check_disabled_does_not_call_start_health_che
3641
start_health_check_mock,
3742
)
3843
discovery_instance = StaticServiceDiscovery(
39-
None,
44+
mock_app,
4045
[],
4146
[],
4247
None,
@@ -49,31 +54,39 @@ def test_init_when_endpoint_health_check_disabled_does_not_call_start_health_che
4954
discovery_instance.start_health_check_task.assert_not_called()
5055

5156

52-
def test_get_unhealthy_endpoint_hashes_when_only_healthy_models_exist_does_not_return_unhealthy_endpoint_hashes(
53-
monkeypatch: pytest.MonkeyPatch,
57+
@pytest.mark.asyncio
58+
async def test_get_unhealthy_endpoint_hashes_when_only_healthy_models_exist_does_not_return_unhealthy_endpoint_hashes(
59+
make_mock_engine, mock_app
5460
) -> None:
55-
monkeypatch.setattr("vllm_router.utils.is_model_healthy", lambda *_: True)
61+
mock_response = AsyncMock(return_value=web.json_response(status=200))
62+
base_url = await make_mock_engine({"/v1/chat/completions": mock_response})
63+
5664
discovery_instance = StaticServiceDiscovery(
57-
None,
58-
["http://localhost.com"],
65+
mock_app,
66+
[base_url],
5967
["llama3"],
6068
None,
6169
None,
6270
["chat"],
63-
static_backend_health_checks=True,
71+
static_backend_health_checks=False,
6472
prefill_model_labels=None,
6573
decode_model_labels=None,
6674
)
67-
assert discovery_instance.get_unhealthy_endpoint_hashes() == []
75+
assert await discovery_instance.get_unhealthy_endpoint_hashes() == []
6876

6977

70-
def test_get_unhealthy_endpoint_hashes_when_unhealthy_model_exist_returns_unhealthy_endpoint_hash(
71-
monkeypatch: pytest.MonkeyPatch,
78+
@pytest.mark.asyncio
79+
async def test_get_unhealthy_endpoint_hashes_when_unhealthy_model_exist_returns_unhealthy_endpoint_hash(
80+
make_mock_engine,
81+
mock_app,
7282
) -> None:
73-
monkeypatch.setattr("vllm_router.utils.is_model_healthy", lambda *_: False)
83+
mock_response = AsyncMock(return_value=web.json_response(status=500))
84+
base_url = await make_mock_engine({"/v1/chat/completions": mock_response})
85+
expected_hash = hashlib.md5(f"{base_url}llama3".encode()).hexdigest()
86+
7487
discovery_instance = StaticServiceDiscovery(
75-
None,
76-
["http://localhost.com"],
88+
mock_app,
89+
[base_url],
7790
["llama3"],
7891
None,
7992
None,
@@ -82,23 +95,33 @@ def test_get_unhealthy_endpoint_hashes_when_unhealthy_model_exist_returns_unheal
8295
prefill_model_labels=None,
8396
decode_model_labels=None,
8497
)
85-
assert discovery_instance.get_unhealthy_endpoint_hashes() == [
86-
"ee7d421a744e07595b70f98c11be93e7"
87-
]
98+
assert await discovery_instance.get_unhealthy_endpoint_hashes() == [expected_hash]
8899

89100

90-
def test_get_unhealthy_endpoint_hashes_when_healthy_and_unhealthy_models_exist_returns_only_unhealthy_endpoint_hash(
91-
monkeypatch: pytest.MonkeyPatch,
101+
@pytest.mark.asyncio
102+
async def test_get_unhealthy_endpoint_hashes_when_healthy_and_unhealthy_models_exist_returns_only_unhealthy_endpoint_hash(
103+
make_mock_engine,
104+
mock_app,
92105
) -> None:
93106
unhealthy_model = "bge-m3"
107+
mock_response = AsyncMock(return_value=web.json_response(status=500))
108+
109+
async def mock_mixed_chat_response(request: web.Request) -> web.Response:
110+
data = await request.json()
111+
status = 500 if data.get("model") == unhealthy_model else 200
112+
return web.json_response(status=status)
113+
114+
base_url = await make_mock_engine(
115+
{
116+
"/v1/chat/completions": mock_mixed_chat_response,
117+
"/v1/embeddings": mock_response,
118+
}
119+
)
120+
expected_hash = hashlib.md5(f"{base_url}{unhealthy_model}".encode()).hexdigest()
94121

95-
def mock_is_model_healthy(url: str, model: str, model_type: str) -> bool:
96-
return model != unhealthy_model
97-
98-
monkeypatch.setattr("vllm_router.utils.is_model_healthy", mock_is_model_healthy)
99122
discovery_instance = StaticServiceDiscovery(
100-
None,
101-
["http://localhost.com", "http://10.123.112.412"],
123+
mock_app,
124+
[base_url, base_url],
102125
["llama3", unhealthy_model],
103126
None,
104127
None,
@@ -107,12 +130,11 @@ def mock_is_model_healthy(url: str, model: str, model_type: str) -> bool:
107130
prefill_model_labels=None,
108131
decode_model_labels=None,
109132
)
110-
assert discovery_instance.get_unhealthy_endpoint_hashes() == [
111-
"01e1b07eca36d39acacd55a33272a225"
112-
]
133+
assert await discovery_instance.get_unhealthy_endpoint_hashes() == [expected_hash]
113134

114135

115136
def test_get_endpoint_info_when_model_endpoint_hash_is_in_unhealthy_endpoint_does_not_return_endpoint(
137+
mock_app: FastAPI,
116138
monkeypatch: pytest.MonkeyPatch,
117139
) -> None:
118140
unhealthy_model = "mistral"
@@ -121,7 +143,7 @@ def mock_get_model_endpoint_hash(url: str, model: str) -> str:
121143
return "some-hash" if model == unhealthy_model else "other-hash"
122144

123145
discovery_instance = StaticServiceDiscovery(
124-
None,
146+
mock_app,
125147
["http://localhost.com", "http://10.123.112.412"],
126148
["llama3", unhealthy_model],
127149
None,

src/tests/test_utils.py

Lines changed: 27 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,9 @@
11
import json
2-
from unittest.mock import MagicMock
2+
from unittest.mock import AsyncMock, MagicMock
33

4+
import aiohttp
45
import pytest
5-
import requests
6+
from aiohttp import web
67
from starlette.datastructures import MutableHeaders
78

89
from vllm_router import utils
@@ -85,33 +86,34 @@ def test_get_all_fields_returns_list_of_strings() -> None:
8586
assert isinstance(fields[0], str)
8687

8788

88-
def test_is_model_healthy_when_requests_responds_with_status_code_200_returns_true(
89-
monkeypatch: pytest.MonkeyPatch,
89+
@pytest.mark.asyncio
90+
async def test_is_model_healthy_when_aiohttp_responds_with_status_code_200_returns_true(
91+
make_mock_engine,
9092
) -> None:
91-
request_mock = MagicMock(return_value=MagicMock(status_code=200))
92-
monkeypatch.setattr("requests.post", request_mock)
93-
assert utils.is_model_healthy("http://localhost", "test", "chat") is True
93+
mock_response = AsyncMock(return_value=web.json_response(status=200))
94+
base_url = await make_mock_engine({"/v1/chat/completions": mock_response})
95+
async with aiohttp.ClientSession() as session:
96+
result = await utils.is_model_healthy(session, base_url, "test", "chat")
97+
assert result is True
9498

9599

96-
def test_is_model_healthy_when_requests_raises_exception_returns_false(
97-
monkeypatch: pytest.MonkeyPatch,
100+
@pytest.mark.asyncio
101+
async def test_is_model_healthy_when_aiohttp_raises_exception_returns_false(
102+
make_mock_engine,
98103
) -> None:
99-
request_mock = MagicMock(side_effect=requests.exceptions.ReadTimeout)
100-
monkeypatch.setattr("requests.post", request_mock)
101-
assert utils.is_model_healthy("http://localhost", "test", "chat") is False
104+
mock_response = AsyncMock(side_effect=aiohttp.ConnectionTimeoutError())
105+
base_url = await make_mock_engine({"/v1/chat/completions": mock_response})
106+
async with aiohttp.ClientSession() as session:
107+
result = await utils.is_model_healthy(session, base_url, "test", "chat")
108+
assert result is False
102109

103110

104-
def test_is_model_healthy_when_requests_status_with_status_code_not_200_returns_false(
105-
monkeypatch: pytest.MonkeyPatch,
111+
@pytest.mark.asyncio
112+
async def test_is_model_healthy_when_aiohttp_status_with_status_code_not_200_returns_false(
113+
make_mock_engine,
106114
) -> None:
107-
108-
# Mock an internal server error response
109-
mock_response = MagicMock(status_code=500)
110-
111-
# Tell the mock to raise an HTTP Error when raise_for_status() is called
112-
mock_response.raise_for_status.side_effect = requests.exceptions.HTTPError
113-
114-
request_mock = MagicMock(return_value=mock_response)
115-
monkeypatch.setattr("requests.post", request_mock)
116-
117-
assert utils.is_model_healthy("http://localhost", "test", "chat") is False
115+
mock_response = AsyncMock(return_value=web.json_response(status=500))
116+
base_url = await make_mock_engine({"/v1/chat/completions": mock_response})
117+
async with aiohttp.ClientSession() as session:
118+
result = await utils.is_model_healthy(session, base_url, "test", "chat")
119+
assert result is False

src/vllm_router/service_discovery.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -217,6 +217,7 @@ def __init__(
217217
decode_model_labels: List[str] | None = None,
218218
):
219219
self.app = app
220+
self.client = self.app.state.aiohttp_client_wrapper()
220221
assert len(urls) == len(models), "URLs and models should have the same length"
221222
self.urls = urls
222223
self.models = models
@@ -232,13 +233,13 @@ def __init__(
232233
self.prefill_model_labels = prefill_model_labels
233234
self.decode_model_labels = decode_model_labels
234235

235-
def get_unhealthy_endpoint_hashes(self) -> list[str]:
236+
async def get_unhealthy_endpoint_hashes(self) -> list[str]:
236237
unhealthy_endpoints = []
237238
try:
238239
for url, model, model_type in zip(
239240
self.urls, self.models, self.model_types, strict=True
240241
):
241-
if utils.is_model_healthy(url, model, model_type):
242+
if await utils.is_model_healthy(self.client, url, model, model_type):
242243
logger.debug(f"{model} at {url} is healthy")
243244
else:
244245
logger.warning(f"{model} at {url} not healthy!")
@@ -253,7 +254,9 @@ def get_unhealthy_endpoint_hashes(self) -> list[str]:
253254
async def check_model_health(self):
254255
while self._running:
255256
try:
256-
self.unhealthy_endpoint_hashes = self.get_unhealthy_endpoint_hashes()
257+
self.unhealthy_endpoint_hashes = (
258+
await self.get_unhealthy_endpoint_hashes()
259+
)
257260
await asyncio.sleep(60)
258261
except asyncio.CancelledError:
259262
logger.debug("Health check task cancelled")

src/vllm_router/utils.py

Lines changed: 27 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
import wave
88
from typing import Optional
99

10-
import requests
10+
import aiohttp
1111
from fastapi.requests import Request
1212
from starlette.datastructures import MutableHeaders
1313

@@ -222,36 +222,44 @@ def update_content_length(request: Request, request_body: str):
222222
request._headers = headers
223223

224224

225-
def is_model_healthy(url: str, model: str, model_type: str) -> bool:
225+
async def is_model_healthy(
226+
session: aiohttp.ClientSession, url: str, model: str, model_type: str
227+
):
226228
model_url = ModelType.get_url(model_type)
227229

228230
try:
229231
if model_type == "transcription":
230232
# for transcription, the backend expects multipart/form-data with a file
231233
# we will use pre-generated silent wav bytes
232-
response = requests.post(
233-
f"{url}{model_url}",
234-
files=ModelType.get_test_payload(model_type), # multipart/form-data
235-
data={"model": model},
236-
timeout=10,
234+
test_payload = ModelType.get_test_payload(model_type)
235+
form_data = aiohttp.FormData()
236+
form_data.add_field(
237+
"file",
238+
test_payload["file"][1],
239+
filename=test_payload["file"][0],
240+
content_type=test_payload["file"][2],
237241
)
242+
form_data.add_field("model", model)
243+
244+
async with session.post(
245+
f"{url}{model_url}",
246+
data=form_data,
247+
timeout=aiohttp.ClientTimeout(total=10),
248+
) as response:
249+
response.raise_for_status()
250+
return True
238251
else:
239252
# for other model types (chat, completion, etc.)
240-
response = requests.post(
253+
async with session.post(
241254
f"{url}{model_url}",
242255
headers={"Content-Type": "application/json"},
243256
json={"model": model} | ModelType.get_test_payload(model_type),
244-
timeout=10,
245-
)
246-
247-
response.raise_for_status()
248-
249-
if model_type == "transcription":
250-
return True
251-
else:
252-
response.json() # verify it's valid json for other model types
253-
return True # validation passed
257+
timeout=aiohttp.ClientTimeout(total=10),
258+
) as response:
259+
response.raise_for_status()
260+
await response.json() # verify it's valid json for other model types
261+
return True # validation passed
254262

255-
except requests.exceptions.RequestException as e:
263+
except aiohttp.ClientError as e:
256264
logger.debug(f"{model_type} Model {model} at {url} is not healthy: {e}")
257265
return False

0 commit comments

Comments
 (0)