Skip to content

Commit b35f9f8

Browse files
committed
feat(utils): add vision model type
Signed-off-by: Max Wittig <[email protected]>
1 parent 9ca3fcb commit b35f9f8

File tree

1 file changed

+44
-15
lines changed

1 file changed

+44
-15
lines changed

src/vllm_router/utils.py

Lines changed: 44 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -66,12 +66,24 @@ def __call__(cls, *args, **kwargs):
6666

6767

6868
class ModelType(enum.Enum):
69-
chat = "/v1/chat/completions"
70-
completion = "/v1/completions"
71-
embeddings = "/v1/embeddings"
72-
rerank = "/v1/rerank"
73-
score = "/v1/score"
74-
transcription = "/v1/audio/transcriptions"
69+
chat = "chat"
70+
completion = "completion"
71+
embeddings = "embeddings"
72+
rerank = "rerank"
73+
score = "score"
74+
transcription = "transcription"
75+
vision = "vision"
76+
77+
@staticmethod
78+
def get_url(model_type: str):
79+
match ModelType[model_type]:
80+
case ModelType.chat: return "/v1/chat/completions"
81+
case ModelType.completion: return "/v1/completions"
82+
case ModelType.embeddings: return "/v1/embeddings"
83+
case ModelType.rerank: return "/v1/rerank"
84+
case ModelType.score: return "/v1/score"
85+
case ModelType.transcription: return "/v1/audio/transcriptions"
86+
case ModelType.vision: return "/v1/chat/completions"
7587

7688
@staticmethod
7789
def get_test_payload(model_type: str):
@@ -101,6 +113,26 @@ def get_test_payload(model_type: str):
101113
return {
102114
"file": ("empty.wav", _SILENT_WAV_BYTES, "audio/wav"),
103115
}
116+
case ModelType.vision:
117+
return {
118+
"messages": [
119+
{
120+
"role": "user",
121+
"content": [
122+
{
123+
"type": "text",
124+
"text": "This is a test. Just reply with yes",
125+
},
126+
{
127+
"type": "image_url",
128+
"image_url": {
129+
"url": "data:image/jpeg;base64,iVBORw0KGgoAAAANSUhEUgAAAAIAAAACCAIAAAD91JpzAAAAG0lEQVR4nGLinfJq851wJn69udZSvIAAAAD//yf3BLKCfW8HAAAAAElFTkSuQmCC"
130+
},
131+
},
132+
],
133+
}
134+
]
135+
}
104136

105137
@staticmethod
106138
def get_all_fields():
@@ -186,27 +218,24 @@ def update_content_length(request: Request, request_body: str):
186218

187219

188220
def is_model_healthy(url: str, model: str, model_type: str) -> bool:
189-
model_details = ModelType[model_type]
221+
model_url = ModelType.get_url(model_type)
190222

191223
try:
192224
if model_type == "transcription":
193-
194225
# for transcription, the backend expects multipart/form-data with a file
195226
# we will use pre-generated silent wav bytes
196-
files = {"file": ("empty.wav", _SILENT_WAV_BYTES, "audio/wav")}
197-
data = {"model": model}
198227
response = requests.post(
199-
f"{url}{model_details.value}",
200-
files=files, # multipart/form-data
201-
data=data,
228+
f"{url}{model_url}",
229+
files=ModelType[model_type].get_test_payload(model_type), # multipart/form-data
230+
data={"model": model},
202231
timeout=10,
203232
)
204233
else:
205234
# for other model types (chat, completion, etc.)
206235
response = requests.post(
207-
f"{url}{model_details.value}",
236+
f"{url}{model_url}",
208237
headers={"Content-Type": "application/json"},
209-
json={"model": model} | model_details.get_test_payload(model_type),
238+
json={"model": model} | ModelType[model_type].get_test_payload(model_type),
210239
timeout=10,
211240
)
212241

0 commit comments

Comments
 (0)