Skip to content

Commit 9b127b2

Browse files
authored
feat: add endpoint health checks to static router (#428)
Closes #420 Signed-off-by: Max Wittig <[email protected]>
1 parent 8742f32 commit 9b127b2

File tree

9 files changed

+1297
-447
lines changed

9 files changed

+1297
-447
lines changed

src/tests/test_parser.py

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -77,3 +77,37 @@ def test_load_initial_config_from_config_json_if_required_when_config_file_is_pr
7777
test_parser, args
7878
)
7979
assert args.routing_logic == "roundrobin"
80+
81+
82+
def test_validate_args_when_service_discovery_is_set_to_static_and_static_backend_health_checks_is_set_and_static_model_types_is_not_set_raises_value_error() -> (
83+
None
84+
):
85+
with pytest.raises(ValueError):
86+
parser.validate_args(
87+
MagicMock(
88+
routing_logic="roundrobin",
89+
service_discovery="static",
90+
static_backend_health_checks=True,
91+
static_model_types=None,
92+
)
93+
)
94+
95+
96+
def test_validate_static_model_types_when_model_types_is_not_defines_raises_value_error() -> (
97+
None
98+
):
99+
with pytest.raises(ValueError):
100+
parser.validate_static_model_types(None)
101+
102+
103+
def test_validate_static_model_types_when_model_types_contains_unsupported_model_type_raises_value_error() -> (
104+
None
105+
):
106+
with pytest.raises(ValueError):
107+
parser.validate_static_model_types("chat,unsupported")
108+
109+
110+
def test_validate_static_model_types_when_model_types_contains_only_supported_model_types_does_not_raise_error() -> (
111+
None
112+
):
113+
parser.validate_static_model_types("chat,completion,rerank,score")
Lines changed: 111 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,111 @@
1+
from unittest.mock import MagicMock
2+
3+
import pytest
4+
5+
from vllm_router.service_discovery import StaticServiceDiscovery
6+
7+
8+
def test_init_when_static_backend_health_checks_calls_start_health_checks(
9+
monkeypatch: pytest.MonkeyPatch,
10+
) -> None:
11+
start_health_check_mock = MagicMock()
12+
monkeypatch.setattr(
13+
"vllm_router.service_discovery.StaticServiceDiscovery.start_health_check_task",
14+
start_health_check_mock,
15+
)
16+
discovery_instance = StaticServiceDiscovery(
17+
[], [], None, None, None, static_backend_health_checks=True
18+
)
19+
discovery_instance.start_health_check_task.assert_called_once()
20+
21+
22+
def test_init_when_endpoint_health_check_disabled_does_not_call_start_health_checks(
23+
monkeypatch: pytest.MonkeyPatch,
24+
) -> None:
25+
start_health_check_mock = MagicMock()
26+
monkeypatch.setattr(
27+
"vllm_router.service_discovery.StaticServiceDiscovery.start_health_check_task",
28+
start_health_check_mock,
29+
)
30+
discovery_instance = StaticServiceDiscovery(
31+
[], [], None, None, None, static_backend_health_checks=False
32+
)
33+
discovery_instance.start_health_check_task.assert_not_called()
34+
35+
36+
def test_get_unhealthy_endpoint_hashes_when_only_healthy_models_exist_does_not_return_unhealthy_endpoint_hashes(
37+
monkeypatch: pytest.MonkeyPatch,
38+
) -> None:
39+
monkeypatch.setattr("vllm_router.utils.is_model_healthy", lambda *_: True)
40+
discovery_instance = StaticServiceDiscovery(
41+
["http://localhost.com"],
42+
["llama3"],
43+
None,
44+
None,
45+
["chat"],
46+
static_backend_health_checks=True,
47+
)
48+
assert discovery_instance.get_unhealthy_endpoint_hashes() == []
49+
50+
51+
def test_get_unhealthy_endpoint_hashes_when_unhealthy_model_exist_returns_unhealthy_endpoint_hash(
52+
monkeypatch: pytest.MonkeyPatch,
53+
) -> None:
54+
monkeypatch.setattr("vllm_router.utils.is_model_healthy", lambda *_: False)
55+
discovery_instance = StaticServiceDiscovery(
56+
["http://localhost.com"],
57+
["llama3"],
58+
None,
59+
None,
60+
["chat"],
61+
static_backend_health_checks=False,
62+
)
63+
assert discovery_instance.get_unhealthy_endpoint_hashes() == [
64+
"ee7d421a744e07595b70f98c11be93e7"
65+
]
66+
67+
68+
def test_get_unhealthy_endpoint_hashes_when_healthy_and_unhealthy_models_exist_returns_only_unhealthy_endpoint_hash(
69+
monkeypatch: pytest.MonkeyPatch,
70+
) -> None:
71+
unhealthy_model = "bge-m3"
72+
73+
def mock_is_model_healthy(url: str, model: str, model_type: str) -> bool:
74+
return model != unhealthy_model
75+
76+
monkeypatch.setattr("vllm_router.utils.is_model_healthy", mock_is_model_healthy)
77+
discovery_instance = StaticServiceDiscovery(
78+
["http://localhost.com", "http://10.123.112.412"],
79+
["llama3", unhealthy_model],
80+
None,
81+
None,
82+
["chat", "embeddings"],
83+
static_backend_health_checks=False,
84+
)
85+
assert discovery_instance.get_unhealthy_endpoint_hashes() == [
86+
"01e1b07eca36d39acacd55a33272a225"
87+
]
88+
89+
90+
def test_get_endpoint_info_when_model_endpoint_hash_is_in_unhealthy_endpoint_does_not_return_endpoint(
91+
monkeypatch: pytest.MonkeyPatch,
92+
) -> None:
93+
unhealthy_model = "mistral"
94+
95+
def mock_get_model_endpoint_hash(url: str, model: str) -> str:
96+
return "some-hash" if model == unhealthy_model else "other-hash"
97+
98+
discovery_instance = StaticServiceDiscovery(
99+
["http://localhost.com", "http://10.123.112.412"],
100+
["llama3", unhealthy_model],
101+
None,
102+
None,
103+
["chat", "chat"],
104+
static_backend_health_checks=False,
105+
)
106+
discovery_instance.unhealthy_endpoint_hashes = ["some-hash"]
107+
monkeypatch.setattr(
108+
discovery_instance, "get_model_endpoint_hash", mock_get_model_endpoint_hash
109+
)
110+
assert len(discovery_instance.get_endpoint_info()) == 1
111+
assert discovery_instance.get_endpoint_info()[0].model_name == "llama3"

src/tests/test_utils.py

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
from unittest.mock import MagicMock
33

44
import pytest
5+
import requests
56
from starlette.datastructures import MutableHeaders
67

78
from vllm_router import utils
@@ -82,3 +83,27 @@ def test_get_all_fields_returns_list_of_strings() -> None:
8283
fields = utils.ModelType.get_all_fields()
8384
assert isinstance(fields, list)
8485
assert isinstance(fields[0], str)
86+
87+
88+
def test_is_model_healthy_when_requests_responds_with_status_code_200_returns_true(
89+
monkeypatch: pytest.MonkeyPatch,
90+
) -> None:
91+
request_mock = MagicMock(return_value=MagicMock(status_code=200))
92+
monkeypatch.setattr("requests.post", request_mock)
93+
assert utils.is_model_healthy("http://localhost", "test", "chat") is True
94+
95+
96+
def test_is_model_healthy_when_requests_raises_exception_returns_false(
97+
monkeypatch: pytest.MonkeyPatch,
98+
) -> None:
99+
request_mock = MagicMock(side_effect=requests.exceptions.ReadTimeout)
100+
monkeypatch.setattr("requests.post", request_mock)
101+
assert utils.is_model_healthy("http://localhost", "test", "chat") is False
102+
103+
104+
def test_is_model_healthy_when_requests_status_with_status_code_not_200_returns_false(
105+
monkeypatch: pytest.MonkeyPatch,
106+
) -> None:
107+
request_mock = MagicMock(return_value=MagicMock(status_code=500))
108+
monkeypatch.setattr("requests.post", request_mock)
109+
assert utils.is_model_healthy("http://localhost", "test", "chat") is False

src/vllm_router/README.md

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@ The router can be configured using command-line arguments. Below are the availab
2828
- `--static-backends`: The URLs of static serving engines, separated by commas (e.g., `http://localhost:8000,http://localhost:8001`).
2929
- `--static-models`: The models running in the static serving engines, separated by commas (e.g., `model1,model2`).
3030
- `--static-aliases`: The aliases of the models running in the static serving engines, separated by commas and associated using colons (e.g., `model_alias1:model,mode_alias2:model`).
31+
- `--static-backend-health-checks`: Enable this flag to make vllm-router check periodically if the models work by sending dummy requests to their endpoints.
3132
- `--k8s-port`: The port of vLLM processes when using K8s service discovery. Default is `8000`.
3233
- `--k8s-namespace`: The namespace of vLLM pods when using K8s service discovery. Default is `default`.
3334
- `--k8s-label-selector`: The label selector to filter vLLM pods when using K8s service discovery.
@@ -82,11 +83,25 @@ vllm-router --port 8000 \
8283
--static-backends "http://localhost:9001,http://localhost:9002,http://localhost:9003" \
8384
--static-models "facebook/opt-125m,meta-llama/Llama-3.1-8B-Instruct,facebook/opt-125m" \
8485
--static-aliases "gpt4:meta-llama/Llama-3.1-8B-Instruct" \
86+
--static-model-types "chat,chat,chat" \
87+
--static-backend-health-checks \
8588
--engine-stats-interval 10 \
8689
--log-stats \
8790
--routing-logic roundrobin
8891
```
8992

93+
## Backend health checks
94+
95+
By enabling the `--static-backend-health-checks` flag, **vllm-router** will send a simple request to
96+
your LLM nodes every minute to verify that they still work.
97+
If a node is down, it will output a warning and exclude the node from being routed to.
98+
99+
If you enable this flag, its also required that you specify `--static-model-types` as we have to use
100+
different endpoints for each model type.
101+
102+
> Enabling this flag will put some load on your backend every minute as real requests are send to the nodes
103+
> to test their functionality.
104+
90105
## Dynamic Router Config
91106

92107
The router can be configured dynamically using a json file when passing the `--dynamic-config-json` option.

src/vllm_router/app.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -147,6 +147,7 @@ def initialize_all(app: FastAPI, args):
147147
if args.static_model_labels
148148
else None
149149
),
150+
static_backend_health_checks=args.static_backend_health_checks,
150151
)
151152
elif args.service_discovery == "k8s":
152153
initialize_service_discovery(

src/vllm_router/parsers/parser.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
import logging
1717
import sys
1818

19+
from vllm_router import utils
1920
from vllm_router.version import __version__
2021

2122
try:
@@ -65,6 +66,19 @@ def load_initial_config_from_config_json_if_required(
6566
return args
6667

6768

69+
def validate_static_model_types(model_types: str | None) -> None:
70+
if model_types is None:
71+
raise ValueError(
72+
"Static model types must be provided when using the backend healthcheck."
73+
)
74+
all_models = utils.ModelType.get_all_fields()
75+
for model_type in utils.parse_comma_separated_args(model_types):
76+
if model_type not in all_models:
77+
raise ValueError(
78+
f"The model type '{model_type}' is not supported. Supported model types are '{','.join(all_models)}'"
79+
)
80+
81+
6882
# --- Argument Parsing and Initialization ---
6983
def validate_args(args):
7084
verify_required_args_provided(args)
@@ -77,6 +91,8 @@ def validate_args(args):
7791
raise ValueError(
7892
"Static models must be provided when using static service discovery."
7993
)
94+
if args.static_backend_health_checks:
95+
validate_static_model_types(args.static_model_types)
8096
if args.service_discovery == "k8s" and args.k8s_port is None:
8197
raise ValueError("K8s port must be provided when using K8s service discovery.")
8298
if args.routing_logic == "session" and args.session_key is None:
@@ -135,6 +151,11 @@ def parse_args():
135151
default=None,
136152
help="The model labels of static backends, separated by commas. E.g., model1,model2",
137153
)
154+
parser.add_argument(
155+
"--static-backend-health-checks",
156+
action="store_true",
157+
help="Enable this flag to make vllm-router check periodically if the models work by sending dummy requests to their endpoints.",
158+
)
138159
parser.add_argument(
139160
"--k8s-port",
140161
type=int,

src/vllm_router/service_discovery.py

Lines changed: 45 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,9 @@
1313
# limitations under the License.
1414

1515
import abc
16+
import asyncio
1617
import enum
18+
import hashlib
1719
import os
1820
import threading
1921
import time
@@ -23,6 +25,7 @@
2325
import requests
2426
from kubernetes import client, config, watch
2527

28+
from vllm_router import utils
2629
from vllm_router.log import init_logger
2730

2831
logger = init_logger(__name__)
@@ -86,6 +89,7 @@ def __init__(
8689
aliases: List[str] | None,
8790
model_labels: List[str] | None,
8891
model_types: List[str] | None,
92+
static_backend_health_checks: bool,
8993
):
9094
assert len(urls) == len(models), "URLs and models should have the same length"
9195
self.urls = urls
@@ -94,6 +98,37 @@ def __init__(
9498
self.model_labels = model_labels
9599
self.model_types = model_types
96100
self.added_timestamp = int(time.time())
101+
self.unhealthy_endpoint_hashes = []
102+
if static_backend_health_checks:
103+
self.start_health_check_task()
104+
105+
def get_unhealthy_endpoint_hashes(self) -> list[str]:
106+
unhealthy_endpoints = []
107+
for url, model, model_type in zip(self.urls, self.models, self.model_types):
108+
if utils.is_model_healthy(url, model, model_type):
109+
logger.debug(f"{model} at {url} is healthy")
110+
else:
111+
logger.warning(f"{model} at {url} not healthy!")
112+
unhealthy_endpoints.append(self.get_model_endpoint_hash(url, model))
113+
return unhealthy_endpoints
114+
115+
async def check_model_health(self):
116+
while True:
117+
try:
118+
self.unhealthy_endpoint_hashes = self.get_unhealthy_endpoint_hashes()
119+
time.sleep(60)
120+
except Exception as e:
121+
logger.error(e)
122+
123+
def start_health_check_task(self) -> None:
124+
self.loop = asyncio.new_event_loop()
125+
self.thread = threading.Thread(target=self.loop.run_forever, daemon=True)
126+
self.thread.start()
127+
asyncio.run_coroutine_threadsafe(self.check_model_health(), self.loop)
128+
logger.info("Health check thread started")
129+
130+
def get_model_endpoint_hash(self, url: str, model: str) -> str:
131+
return hashlib.md5(f"{url}{model}".encode()).hexdigest()
97132

98133
def get_endpoint_info(self) -> List[EndpointInfo]:
99134
"""
@@ -103,18 +138,16 @@ def get_endpoint_info(self) -> List[EndpointInfo]:
103138
Returns:
104139
a list of engine URLs
105140
"""
106-
if self.model_labels is None:
107-
endpoint_infos = [
108-
EndpointInfo(url, model, self.added_timestamp, "default")
109-
for url, model in zip(self.urls, self.models)
110-
]
111-
else:
112-
endpoint_infos = [
113-
EndpointInfo(url, model, self.added_timestamp, model_label)
114-
for url, model, model_label in zip(
115-
self.urls, self.models, self.model_labels
116-
)
117-
]
141+
if not self.model_labels:
142+
self.model_labels = ["default"] * len(self.models)
143+
endpoint_infos = [
144+
EndpointInfo(url, model, self.added_timestamp, model_label)
145+
for url, model, model_label in zip(
146+
self.urls, self.models, self.model_labels
147+
)
148+
if self.get_model_endpoint_hash(url, model)
149+
not in self.unhealthy_endpoint_hashes
150+
]
118151
return endpoint_infos
119152

120153

0 commit comments

Comments
 (0)