Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -69,5 +69,6 @@ lint = [
]
test = [
"pytest>=8.3.4",
"pytest-asyncio>=0.25.3"
"pytest-aiohttp>=1.1.0",
"pytest-asyncio>=0.25.3",
]
28 changes: 28 additions & 0 deletions src/tests/conftest.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
from typing import Any, AsyncGenerator, Callable
from unittest.mock import MagicMock

import aiohttp
import pytest
from aiohttp import web
from fastapi import FastAPI


@pytest.fixture
async def mock_app() -> AsyncGenerator[FastAPI]:
mock_app = MagicMock()
async with aiohttp.ClientSession() as session:
mock_app.state.aiohttp_client_wrapper = MagicMock(return_value=session)
yield mock_app


@pytest.fixture
async def make_mock_engine(aiohttp_client: Any) -> Callable[[dict[str, Callable]], str]:
Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

async def _make_mock_engine(routes: dict[str, Callable]) -> str:
app = web.Application()
for path, handler in routes.items():
app.router.add_post(path, handler)

client = await aiohttp_client(app)
return str(client.make_url(""))

return _make_mock_engine
82 changes: 52 additions & 30 deletions src/tests/test_static_service_discovery.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,15 @@
from unittest.mock import MagicMock
import hashlib
from unittest.mock import AsyncMock, MagicMock

import pytest
from aiohttp import web
from fastapi import FastAPI

from vllm_router.service_discovery import StaticServiceDiscovery


def test_init_when_static_backend_health_checks_calls_start_health_checks(
mock_app: FastAPI,
monkeypatch: pytest.MonkeyPatch,
) -> None:
start_health_check_mock = MagicMock()
Expand All @@ -14,7 +18,7 @@ def test_init_when_static_backend_health_checks_calls_start_health_checks(
start_health_check_mock,
)
discovery_instance = StaticServiceDiscovery(
None,
mock_app,
[],
[],
None,
Expand All @@ -28,6 +32,7 @@ def test_init_when_static_backend_health_checks_calls_start_health_checks(


def test_init_when_endpoint_health_check_disabled_does_not_call_start_health_checks(
mock_app: FastAPI,
monkeypatch: pytest.MonkeyPatch,
) -> None:
start_health_check_mock = MagicMock()
Expand All @@ -36,7 +41,7 @@ def test_init_when_endpoint_health_check_disabled_does_not_call_start_health_che
start_health_check_mock,
)
discovery_instance = StaticServiceDiscovery(
None,
mock_app,
[],
[],
None,
Expand All @@ -49,31 +54,39 @@ def test_init_when_endpoint_health_check_disabled_does_not_call_start_health_che
discovery_instance.start_health_check_task.assert_not_called()


def test_get_unhealthy_endpoint_hashes_when_only_healthy_models_exist_does_not_return_unhealthy_endpoint_hashes(
monkeypatch: pytest.MonkeyPatch,
@pytest.mark.asyncio
async def test_get_unhealthy_endpoint_hashes_when_only_healthy_models_exist_does_not_return_unhealthy_endpoint_hashes(
make_mock_engine, mock_app
) -> None:
monkeypatch.setattr("vllm_router.utils.is_model_healthy", lambda *_: True)
mock_response = AsyncMock(return_value=web.json_response(status=200))
base_url = await make_mock_engine({"/v1/chat/completions": mock_response})

discovery_instance = StaticServiceDiscovery(
None,
["http://localhost.com"],
mock_app,
[base_url],
["llama3"],
None,
None,
["chat"],
static_backend_health_checks=True,
static_backend_health_checks=False,
prefill_model_labels=None,
decode_model_labels=None,
)
assert discovery_instance.get_unhealthy_endpoint_hashes() == []
assert await discovery_instance.get_unhealthy_endpoint_hashes() == []


def test_get_unhealthy_endpoint_hashes_when_unhealthy_model_exist_returns_unhealthy_endpoint_hash(
monkeypatch: pytest.MonkeyPatch,
@pytest.mark.asyncio
async def test_get_unhealthy_endpoint_hashes_when_unhealthy_model_exist_returns_unhealthy_endpoint_hash(
make_mock_engine,
mock_app,
) -> None:
monkeypatch.setattr("vllm_router.utils.is_model_healthy", lambda *_: False)
mock_response = AsyncMock(return_value=web.json_response(status=500))
base_url = await make_mock_engine({"/v1/chat/completions": mock_response})
expected_hash = hashlib.md5(f"{base_url}llama3".encode()).hexdigest()

discovery_instance = StaticServiceDiscovery(
None,
["http://localhost.com"],
mock_app,
[base_url],
["llama3"],
None,
None,
Expand All @@ -82,23 +95,33 @@ def test_get_unhealthy_endpoint_hashes_when_unhealthy_model_exist_returns_unheal
prefill_model_labels=None,
decode_model_labels=None,
)
assert discovery_instance.get_unhealthy_endpoint_hashes() == [
"ee7d421a744e07595b70f98c11be93e7"
]
assert await discovery_instance.get_unhealthy_endpoint_hashes() == [expected_hash]


def test_get_unhealthy_endpoint_hashes_when_healthy_and_unhealthy_models_exist_returns_only_unhealthy_endpoint_hash(
monkeypatch: pytest.MonkeyPatch,
@pytest.mark.asyncio
async def test_get_unhealthy_endpoint_hashes_when_healthy_and_unhealthy_models_exist_returns_only_unhealthy_endpoint_hash(
make_mock_engine,
mock_app,
) -> None:
unhealthy_model = "bge-m3"
mock_response = AsyncMock(return_value=web.json_response(status=500))

async def mock_mixed_chat_response(request: web.Request) -> web.Response:
data = await request.json()
status = 500 if data.get("model") == unhealthy_model else 200
return web.json_response(status=status)

base_url = await make_mock_engine(
{
"/v1/chat/completions": mock_mixed_chat_response,
"/v1/embeddings": mock_response,
}
)
expected_hash = hashlib.md5(f"{base_url}{unhealthy_model}".encode()).hexdigest()

def mock_is_model_healthy(url: str, model: str, model_type: str) -> bool:
return model != unhealthy_model

monkeypatch.setattr("vllm_router.utils.is_model_healthy", mock_is_model_healthy)
discovery_instance = StaticServiceDiscovery(
None,
["http://localhost.com", "http://10.123.112.412"],
mock_app,
[base_url, base_url],
["llama3", unhealthy_model],
None,
None,
Expand All @@ -107,12 +130,11 @@ def mock_is_model_healthy(url: str, model: str, model_type: str) -> bool:
prefill_model_labels=None,
decode_model_labels=None,
)
assert discovery_instance.get_unhealthy_endpoint_hashes() == [
"01e1b07eca36d39acacd55a33272a225"
]
assert await discovery_instance.get_unhealthy_endpoint_hashes() == [expected_hash]


def test_get_endpoint_info_when_model_endpoint_hash_is_in_unhealthy_endpoint_does_not_return_endpoint(
mock_app: FastAPI,
monkeypatch: pytest.MonkeyPatch,
) -> None:
unhealthy_model = "mistral"
Expand All @@ -121,7 +143,7 @@ def mock_get_model_endpoint_hash(url: str, model: str) -> str:
return "some-hash" if model == unhealthy_model else "other-hash"

discovery_instance = StaticServiceDiscovery(
None,
mock_app,
["http://localhost.com", "http://10.123.112.412"],
["llama3", unhealthy_model],
None,
Expand Down
52 changes: 27 additions & 25 deletions src/tests/test_utils.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
import json
from unittest.mock import MagicMock
from unittest.mock import AsyncMock, MagicMock

import aiohttp
import pytest
import requests
from aiohttp import web
from starlette.datastructures import MutableHeaders

from vllm_router import utils
Expand Down Expand Up @@ -85,33 +86,34 @@ def test_get_all_fields_returns_list_of_strings() -> None:
assert isinstance(fields[0], str)


def test_is_model_healthy_when_requests_responds_with_status_code_200_returns_true(
monkeypatch: pytest.MonkeyPatch,
@pytest.mark.asyncio
async def test_is_model_healthy_when_aiohttp_responds_with_status_code_200_returns_true(
make_mock_engine,
) -> None:
request_mock = MagicMock(return_value=MagicMock(status_code=200))
monkeypatch.setattr("requests.post", request_mock)
assert utils.is_model_healthy("http://localhost", "test", "chat") is True
mock_response = AsyncMock(return_value=web.json_response(status=200))
base_url = await make_mock_engine({"/v1/chat/completions": mock_response})
async with aiohttp.ClientSession() as session:
result = await utils.is_model_healthy(session, base_url, "test", "chat")
assert result is True


def test_is_model_healthy_when_requests_raises_exception_returns_false(
monkeypatch: pytest.MonkeyPatch,
@pytest.mark.asyncio
async def test_is_model_healthy_when_aiohttp_raises_exception_returns_false(
make_mock_engine,
) -> None:
request_mock = MagicMock(side_effect=requests.exceptions.ReadTimeout)
monkeypatch.setattr("requests.post", request_mock)
assert utils.is_model_healthy("http://localhost", "test", "chat") is False
mock_response = AsyncMock(side_effect=aiohttp.ConnectionTimeoutError())
base_url = await make_mock_engine({"/v1/chat/completions": mock_response})
async with aiohttp.ClientSession() as session:
result = await utils.is_model_healthy(session, base_url, "test", "chat")
assert result is False


def test_is_model_healthy_when_requests_status_with_status_code_not_200_returns_false(
monkeypatch: pytest.MonkeyPatch,
@pytest.mark.asyncio
async def test_is_model_healthy_when_aiohttp_status_with_status_code_not_200_returns_false(
make_mock_engine,
) -> None:

# Mock an internal server error response
mock_response = MagicMock(status_code=500)

# Tell the mock to raise an HTTP Error when raise_for_status() is called
mock_response.raise_for_status.side_effect = requests.exceptions.HTTPError

request_mock = MagicMock(return_value=mock_response)
monkeypatch.setattr("requests.post", request_mock)

assert utils.is_model_healthy("http://localhost", "test", "chat") is False
mock_response = AsyncMock(return_value=web.json_response(status=500))
base_url = await make_mock_engine({"/v1/chat/completions": mock_response})
async with aiohttp.ClientSession() as session:
result = await utils.is_model_healthy(session, base_url, "test", "chat")
assert result is False
9 changes: 6 additions & 3 deletions src/vllm_router/service_discovery.py
Original file line number Diff line number Diff line change
Expand Up @@ -217,6 +217,7 @@ def __init__(
decode_model_labels: List[str] | None = None,
):
self.app = app
self.client = self.app.state.aiohttp_client_wrapper()
assert len(urls) == len(models), "URLs and models should have the same length"
self.urls = urls
self.models = models
Expand All @@ -232,13 +233,13 @@ def __init__(
self.prefill_model_labels = prefill_model_labels
self.decode_model_labels = decode_model_labels

def get_unhealthy_endpoint_hashes(self) -> list[str]:
async def get_unhealthy_endpoint_hashes(self) -> list[str]:
unhealthy_endpoints = []
try:
for url, model, model_type in zip(
self.urls, self.models, self.model_types, strict=True
):
if utils.is_model_healthy(url, model, model_type):
if await utils.is_model_healthy(self.client, url, model, model_type):
logger.debug(f"{model} at {url} is healthy")
else:
logger.warning(f"{model} at {url} not healthy!")
Expand All @@ -253,7 +254,9 @@ def get_unhealthy_endpoint_hashes(self) -> list[str]:
async def check_model_health(self):
while self._running:
try:
self.unhealthy_endpoint_hashes = self.get_unhealthy_endpoint_hashes()
self.unhealthy_endpoint_hashes = (
await self.get_unhealthy_endpoint_hashes()
)
await asyncio.sleep(60)
except asyncio.CancelledError:
logger.debug("Health check task cancelled")
Expand Down
46 changes: 27 additions & 19 deletions src/vllm_router/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
import wave
from typing import Optional

import requests
import aiohttp
from fastapi.requests import Request
from starlette.datastructures import MutableHeaders

Expand Down Expand Up @@ -222,36 +222,44 @@ def update_content_length(request: Request, request_body: str):
request._headers = headers


def is_model_healthy(url: str, model: str, model_type: str) -> bool:
async def is_model_healthy(
session: aiohttp.ClientSession, url: str, model: str, model_type: str
):
model_url = ModelType.get_url(model_type)

try:
if model_type == "transcription":
# for transcription, the backend expects multipart/form-data with a file
# we will use pre-generated silent wav bytes
response = requests.post(
f"{url}{model_url}",
files=ModelType.get_test_payload(model_type), # multipart/form-data
data={"model": model},
timeout=10,
test_payload = ModelType.get_test_payload(model_type)
form_data = aiohttp.FormData()
form_data.add_field(
"file",
test_payload["file"][1],
filename=test_payload["file"][0],
content_type=test_payload["file"][2],
)
form_data.add_field("model", model)

async with session.post(
f"{url}{model_url}",
data=form_data,
timeout=aiohttp.ClientTimeout(total=10),
) as response:
response.raise_for_status()
return True
else:
# for other model types (chat, completion, etc.)
response = requests.post(
async with session.post(
f"{url}{model_url}",
headers={"Content-Type": "application/json"},
json={"model": model} | ModelType.get_test_payload(model_type),
timeout=10,
)

response.raise_for_status()

if model_type == "transcription":
return True
else:
response.json() # verify it's valid json for other model types
return True # validation passed
timeout=aiohttp.ClientTimeout(total=10),
) as response:
response.raise_for_status()
await response.json() # verify it's valid json for other model types
return True # validation passed

except requests.exceptions.RequestException as e:
except aiohttp.ClientError as e:
logger.debug(f"{model_type} Model {model} at {url} is not healthy: {e}")
return False
Comment on lines +225 to 265
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

This function is missing a return type hint. Additionally, the if/else block for different model types contains duplicated logic for making the HTTP request. This can be refactored to prepare the request arguments first, then make a single session.post call. This will improve readability and maintainability by reducing code duplication.

async def is_model_healthy(
    session: aiohttp.ClientSession, url: str, model: str, model_type: str
) -> bool:
    model_url = ModelType.get_url(model_type)

    try:
        post_kwargs = {
            "timeout": aiohttp.ClientTimeout(total=10),
        }

        if model_type == "transcription":
            # for transcription, the backend expects multipart/form-data with a file
            # we will use pre-generated silent wav bytes
            test_payload = ModelType.get_test_payload(model_type)
            form_data = aiohttp.FormData()
            form_data.add_field(
                "file",
                test_payload["file"][1],
                filename=test_payload["file"][0],
                content_type=test_payload["file"][2],
            )
            form_data.add_field("model", model)
            post_kwargs["data"] = form_data
        else:
            # for other model types (chat, completion, etc.)
            post_kwargs["headers"] = {"Content-Type": "application/json"}
            post_kwargs["json"] = {"model": model} | ModelType.get_test_payload(model_type)

        async with session.post(f"{url}{model_url}", **post_kwargs) as response:
            response.raise_for_status()
            if model_type != "transcription":
                await response.json()  # verify it's valid json for other model types
            return True  # validation passed

    except aiohttp.ClientError as e:
        logger.debug(f"{model_type} Model {model} at {url} is not healthy: {e}")
        return False

Loading