Skip to content

Commit 68f105e

Browse files
committed
feat(utils): add vision model type
Signed-off-by: Max Wittig <[email protected]>
1 parent 9ca3fcb commit 68f105e

File tree

1 file changed

+52
-15
lines changed

1 file changed

+52
-15
lines changed

src/vllm_router/utils.py

Lines changed: 52 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -66,12 +66,29 @@ def __call__(cls, *args, **kwargs):
6666

6767

6868
class ModelType(enum.Enum):
69-
chat = "/v1/chat/completions"
70-
completion = "/v1/completions"
71-
embeddings = "/v1/embeddings"
72-
rerank = "/v1/rerank"
73-
score = "/v1/score"
74-
transcription = "/v1/audio/transcriptions"
69+
chat = "chat"
70+
completion = "completion"
71+
embeddings = "embeddings"
72+
rerank = "rerank"
73+
score = "score"
74+
transcription = "transcription"
75+
vision = "vision"
76+
77+
@staticmethod
78+
def get_url(model_type: str):
79+
match ModelType[model_type]:
80+
case ModelType.chat | ModelType.vision:
81+
return "/v1/chat/completions"
82+
case ModelType.completion:
83+
return "/v1/completions"
84+
case ModelType.embeddings:
85+
return "/v1/embeddings"
86+
case ModelType.rerank:
87+
return "/v1/rerank"
88+
case ModelType.score:
89+
return "/v1/score"
90+
case ModelType.transcription:
91+
return "/v1/audio/transcriptions"
7592

7693
@staticmethod
7794
def get_test_payload(model_type: str):
@@ -101,6 +118,26 @@ def get_test_payload(model_type: str):
101118
return {
102119
"file": ("empty.wav", _SILENT_WAV_BYTES, "audio/wav"),
103120
}
121+
case ModelType.vision:
122+
return {
123+
"messages": [
124+
{
125+
"role": "user",
126+
"content": [
127+
{
128+
"type": "text",
129+
"text": "This is a test. Just reply with yes",
130+
},
131+
{
132+
"type": "image_url",
133+
"image_url": {
134+
"url": "data:image/jpeg;base64,iVBORw0KGgoAAAANSUhEUgAAAAIAAAACCAIAAAD91JpzAAAAG0lEQVR4nGLinfJq851wJn69udZSvIAAAAD//yf3BLKCfW8HAAAAAElFTkSuQmCC"
135+
},
136+
},
137+
],
138+
}
139+
]
140+
}
104141

105142
@staticmethod
106143
def get_all_fields():
@@ -186,27 +223,27 @@ def update_content_length(request: Request, request_body: str):
186223

187224

188225
def is_model_healthy(url: str, model: str, model_type: str) -> bool:
189-
model_details = ModelType[model_type]
226+
model_url = ModelType.get_url(model_type)
190227

191228
try:
192229
if model_type == "transcription":
193-
194230
# for transcription, the backend expects multipart/form-data with a file
195231
# we will use pre-generated silent wav bytes
196-
files = {"file": ("empty.wav", _SILENT_WAV_BYTES, "audio/wav")}
197-
data = {"model": model}
198232
response = requests.post(
199-
f"{url}{model_details.value}",
200-
files=files, # multipart/form-data
201-
data=data,
233+
f"{url}{model_url}",
234+
files=ModelType.get_test_payload(
235+
model_type
236+
), # multipart/form-data
237+
data={"model": model},
202238
timeout=10,
203239
)
204240
else:
205241
# for other model types (chat, completion, etc.)
206242
response = requests.post(
207-
f"{url}{model_details.value}",
243+
f"{url}{model_url}",
208244
headers={"Content-Type": "application/json"},
209-
json={"model": model} | model_details.get_test_payload(model_type),
245+
json={"model": model}
246+
| ModelType.get_test_payload(model_type),
210247
timeout=10,
211248
)
212249

0 commit comments

Comments
 (0)