From ff3fb8f26716a6c68d4a9ceaca8850cb7cb4d551 Mon Sep 17 00:00:00 2001 From: Max Wittig Date: Thu, 24 Jul 2025 14:51:11 +0200 Subject: [PATCH] feat(utils): add vision model type Signed-off-by: Max Wittig --- src/vllm_router/utils.py | 64 ++++++++++++++++++++++++++++++---------- 1 file changed, 49 insertions(+), 15 deletions(-) diff --git a/src/vllm_router/utils.py b/src/vllm_router/utils.py index 61a8aa55..17b4464e 100644 --- a/src/vllm_router/utils.py +++ b/src/vllm_router/utils.py @@ -66,12 +66,29 @@ def __call__(cls, *args, **kwargs): class ModelType(enum.Enum): - chat = "/v1/chat/completions" - completion = "/v1/completions" - embeddings = "/v1/embeddings" - rerank = "/v1/rerank" - score = "/v1/score" - transcription = "/v1/audio/transcriptions" + chat = "chat" + completion = "completion" + embeddings = "embeddings" + rerank = "rerank" + score = "score" + transcription = "transcription" + vision = "vision" + + @staticmethod + def get_url(model_type: str): + match ModelType[model_type]: + case ModelType.chat | ModelType.vision: + return "/v1/chat/completions" + case ModelType.completion: + return "/v1/completions" + case ModelType.embeddings: + return "/v1/embeddings" + case ModelType.rerank: + return "/v1/rerank" + case ModelType.score: + return "/v1/score" + case ModelType.transcription: + return "/v1/audio/transcriptions" @staticmethod def get_test_payload(model_type: str): @@ -101,6 +118,26 @@ def get_test_payload(model_type: str): return { "file": ("empty.wav", _SILENT_WAV_BYTES, "audio/wav"), } + case ModelType.vision: + return { + "messages": [ + { + "role": "user", + "content": [ + { + "type": "text", + "text": "This is a test. Just reply with yes", + }, + { + "type": "image_url", + "image_url": { + "url": "" + }, + }, + ], + } + ] + } @staticmethod def get_all_fields(): @@ -186,27 +223,24 @@ def update_content_length(request: Request, request_body: str): def is_model_healthy(url: str, model: str, model_type: str) -> bool: - model_details = ModelType[model_type] + model_url = ModelType.get_url(model_type) try: if model_type == "transcription": - # for transcription, the backend expects multipart/form-data with a file # we will use pre-generated silent wav bytes - files = {"file": ("empty.wav", _SILENT_WAV_BYTES, "audio/wav")} - data = {"model": model} response = requests.post( - f"{url}{model_details.value}", - files=files, # multipart/form-data - data=data, + f"{url}{model_url}", + files=ModelType.get_test_payload(model_type), # multipart/form-data + data={"model": model}, timeout=10, ) else: # for other model types (chat, completion, etc.) response = requests.post( - f"{url}{model_details.value}", + f"{url}{model_url}", headers={"Content-Type": "application/json"}, - json={"model": model} | model_details.get_test_payload(model_type), + json={"model": model} | ModelType.get_test_payload(model_type), timeout=10, )