Skip to content

Commit cd6bf2c

Browse files
committed
feat(request): use model_type instead of model_label
As model_label is already used for things related to the LMCache. It makes more sense to use the type of the model here. Replaces #681 Signed-off-by: Max Wittig <max.wittig@siemens.com>
1 parent b8a08a5 commit cd6bf2c

File tree

2 files changed

+70
-8
lines changed

2 files changed

+70
-8
lines changed

src/vllm_router/service_discovery.py

Lines changed: 68 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -97,6 +97,9 @@ class EndpointInfo:
9797
# Endpoint's sleep status
9898
sleep: bool
9999

100+
# Model type (e.g., "transcription", "chat", etc.)
101+
model_type: Optional[str] = None
102+
100103
# Pod name
101104
pod_name: Optional[str] = None
102105

@@ -317,6 +320,7 @@ def get_endpoint_info(self) -> List[EndpointInfo]:
317320
sleep=False,
318321
added_timestamp=self.added_timestamp,
319322
model_label=model_label,
323+
model_type=self.model_types[i] if self.model_types else None,
320324
model_info=self._get_model_info(model),
321325
)
322326
endpoint_infos.append(endpoint_info)
@@ -604,6 +608,20 @@ def _get_model_info(self, pod_ip) -> Dict[str, ModelInfo]:
604608
logger.error(f"Failed to get model info from {url}: {e}")
605609
return {}
606610

611+
def _get_model_type(self, pod) -> str:
612+
"""
613+
Get the model type from the pod's metadata labels.
614+
615+
Args:
616+
pod: The Kubernetes pod object
617+
618+
Returns:
619+
The model type if found, chat otherwise
620+
"""
621+
if isinstance(pod, str) or not pod.metadata.labels:
622+
return "chat" # Default to chat model type
623+
return pod.metadata.labels.get("model-type", "chat")
624+
607625
def _get_model_label(self, pod) -> Optional[str]:
608626
"""
609627
Get the model label from the pod's metadata labels.
@@ -649,9 +667,11 @@ def _watch_engines(self):
649667
if is_pod_ready:
650668
model_names = self._get_model_names(pod_ip)
651669
model_label = self._get_model_label(pod)
670+
model_type = self._get_model_type(pod)
652671
else:
653672
model_names = []
654673
model_label = None
674+
model_type = None
655675

656676
# Record pod status for debugging
657677
if is_container_ready and is_pod_terminating:
@@ -666,13 +686,19 @@ def _watch_engines(self):
666686
is_pod_ready,
667687
model_names,
668688
model_label,
689+
model_type,
669690
)
670691
except Exception as e:
671692
logger.error(f"K8s watcher error: {e}")
672693
time.sleep(0.5)
673694

674695
def _add_engine(
675-
self, engine_name: str, engine_ip: str, model_names: List[str], model_label: str
696+
self,
697+
engine_name: str,
698+
engine_ip: str,
699+
model_names: List[str],
700+
model_label: str,
701+
model_type: Optional[str],
676702
):
677703
logger.info(
678704
f"Discovered new serving engine {engine_name} at "
@@ -689,6 +715,10 @@ def _add_engine(
689715
sleep_status = False
690716

691717
with self.available_engines_lock:
718+
# Determine model type for each model
719+
model_types = [self._get_model_type(model) for model in model_names]
720+
model_type = model_types[0] if model_types else None
721+
692722
self.available_engines[engine_name] = EndpointInfo(
693723
url=f"http://{engine_ip}:{self.port}",
694724
model_names=model_names,
@@ -699,10 +729,12 @@ def _add_engine(
699729
pod_name=engine_name,
700730
namespace=self.namespace,
701731
model_info=model_info,
732+
model_type=model_type,
702733
)
703734

704735
# Store model information in the endpoint info
705736
self.available_engines[engine_name].model_info = model_info
737+
self.available_engines[engine_name].model_type = model_type
706738

707739
# Track all models we've ever seen
708740
with self.known_models_lock:
@@ -721,6 +753,7 @@ def _on_engine_update(
721753
is_pod_ready: bool,
722754
model_names: List[str],
723755
model_label: Optional[str],
756+
model_type: Optional[str] = None,
724757
) -> None:
725758
if event == "ADDED":
726759
if engine_ip is None:
@@ -732,7 +765,9 @@ def _on_engine_update(
732765
if not model_names:
733766
return
734767

735-
self._add_engine(engine_name, engine_ip, model_names, model_label)
768+
self._add_engine(
769+
engine_name, engine_ip, model_names, model_label, model_type
770+
)
736771

737772
elif event == "DELETED":
738773
if engine_name not in self.available_engines:
@@ -745,7 +780,9 @@ def _on_engine_update(
745780
return
746781

747782
if is_pod_ready and model_names:
748-
self._add_engine(engine_name, engine_ip, model_names, model_label)
783+
self._add_engine(
784+
engine_name, engine_ip, model_names, model_label, model_type
785+
)
749786
return
750787

751788
if (
@@ -1055,6 +1092,20 @@ def _get_model_info(self, service_name) -> Dict[str, ModelInfo]:
10551092
logger.error(f"Failed to get model info from {url}: {e}")
10561093
return {}
10571094

1095+
def _get_model_type(self, service) -> str:
1096+
"""
1097+
Get the model label from the service's selector.
1098+
1099+
Args:
1100+
service: The Kubernetes service object
1101+
1102+
Returns:
1103+
The model selector if found, chat otherwise
1104+
"""
1105+
if not service.spec.selector:
1106+
return "chat"
1107+
return service.spec.selector.get("model-type", "chat")
1108+
10581109
def _get_model_label(self, service) -> Optional[str]:
10591110
"""
10601111
Get the model label from the service's selector.
@@ -1094,18 +1145,27 @@ def _watch_engines(self):
10941145
else:
10951146
model_names = []
10961147
model_label = None
1148+
model_type = self._get_model_type(service)
1149+
10971150
self._on_engine_update(
10981151
service_name,
10991152
event_type,
11001153
is_service_ready,
11011154
model_names,
11021155
model_label,
1156+
model_type,
11031157
)
11041158
except Exception as e:
11051159
logger.error(f"K8s watcher error: {e}")
11061160
time.sleep(0.5)
11071161

1108-
def _add_engine(self, engine_name: str, model_names: List[str], model_label: str):
1162+
def _add_engine(
1163+
self,
1164+
engine_name: str,
1165+
model_names: List[str],
1166+
model_label: str,
1167+
model_type: str,
1168+
):
11091169
logger.info(
11101170
f"Discovered new serving engine {engine_name} at "
11111171
f"running models: {model_names}"
@@ -1131,6 +1191,7 @@ def _add_engine(self, engine_name: str, model_names: List[str], model_label: str
11311191
service_name=engine_name,
11321192
namespace=self.namespace,
11331193
model_info=model_info,
1194+
model_type=model_type,
11341195
)
11351196

11361197
# Store model information in the endpoint info
@@ -1148,6 +1209,7 @@ def _on_engine_update(
11481209
is_service_ready: bool,
11491210
model_names: List[str],
11501211
model_label: Optional[str],
1212+
model_type: str,
11511213
) -> None:
11521214
if event == "ADDED":
11531215
if not engine_name:
@@ -1159,7 +1221,7 @@ def _on_engine_update(
11591221
if not model_names:
11601222
return
11611223

1162-
self._add_engine(engine_name, model_names, model_label)
1224+
self._add_engine(engine_name, model_names, model_label, model_type)
11631225

11641226
elif event == "DELETED":
11651227
if engine_name not in self.available_engines:
@@ -1172,7 +1234,7 @@ def _on_engine_update(
11721234
return
11731235

11741236
if is_service_ready and model_names:
1175-
self._add_engine(engine_name, model_names, model_label)
1237+
self._add_engine(engine_name, model_names, model_label, model_type)
11761238
return
11771239

11781240
if (

src/vllm_router/services/request_service/request.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -585,13 +585,13 @@ async def route_general_transcriptions(
585585

586586
endpoints = service_discovery.get_endpoint_info()
587587

588-
# filter the endpoints url by model name and model label for transcriptions
588+
# filter the endpoints url by model name and model type for transcriptions
589589
transcription_endpoints = []
590590
for ep in endpoints:
591591
for model_name in ep.model_names:
592592
if (
593593
model == model_name
594-
and ep.model_label == "transcription"
594+
and ep.model_type == "transcription"
595595
and not ep.sleep
596596
):
597597
transcription_endpoints.append(ep)

0 commit comments

Comments
 (0)