Skip to content

Commit f8c6f64

Browse files
committed
feat(utils): add vision model type
Signed-off-by: Max Wittig <[email protected]>
1 parent 9ca3fcb commit f8c6f64

File tree

1 file changed

+54
-15
lines changed

1 file changed

+54
-15
lines changed

src/vllm_router/utils.py

Lines changed: 54 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -66,12 +66,31 @@ def __call__(cls, *args, **kwargs):
6666

6767

6868
class ModelType(enum.Enum):
69-
chat = "/v1/chat/completions"
70-
completion = "/v1/completions"
71-
embeddings = "/v1/embeddings"
72-
rerank = "/v1/rerank"
73-
score = "/v1/score"
74-
transcription = "/v1/audio/transcriptions"
69+
chat = "chat"
70+
completion = "completion"
71+
embeddings = "embeddings"
72+
rerank = "rerank"
73+
score = "score"
74+
transcription = "transcription"
75+
vision = "vision"
76+
77+
@staticmethod
78+
def get_url(model_type: str):
79+
match ModelType[model_type]:
80+
case ModelType.chat:
81+
return "/v1/chat/completions"
82+
case ModelType.completion:
83+
return "/v1/completions"
84+
case ModelType.embeddings:
85+
return "/v1/embeddings"
86+
case ModelType.rerank:
87+
return "/v1/rerank"
88+
case ModelType.score:
89+
return "/v1/score"
90+
case ModelType.transcription:
91+
return "/v1/audio/transcriptions"
92+
case ModelType.vision:
93+
return "/v1/chat/completions"
7594

7695
@staticmethod
7796
def get_test_payload(model_type: str):
@@ -101,6 +120,26 @@ def get_test_payload(model_type: str):
101120
return {
102121
"file": ("empty.wav", _SILENT_WAV_BYTES, "audio/wav"),
103122
}
123+
case ModelType.vision:
124+
return {
125+
"messages": [
126+
{
127+
"role": "user",
128+
"content": [
129+
{
130+
"type": "text",
131+
"text": "This is a test. Just reply with yes",
132+
},
133+
{
134+
"type": "image_url",
135+
"image_url": {
136+
"url": ""
137+
},
138+
},
139+
],
140+
}
141+
]
142+
}
104143

105144
@staticmethod
106145
def get_all_fields():
@@ -186,27 +225,27 @@ def update_content_length(request: Request, request_body: str):
186225

187226

188227
def is_model_healthy(url: str, model: str, model_type: str) -> bool:
189-
model_details = ModelType[model_type]
228+
model_url = ModelType.get_url(model_type)
190229

191230
try:
192231
if model_type == "transcription":
193-
194232
# for transcription, the backend expects multipart/form-data with a file
195233
# we will use pre-generated silent wav bytes
196-
files = {"file": ("empty.wav", _SILENT_WAV_BYTES, "audio/wav")}
197-
data = {"model": model}
198234
response = requests.post(
199-
f"{url}{model_details.value}",
200-
files=files, # multipart/form-data
201-
data=data,
235+
f"{url}{model_url}",
236+
files=ModelType[model_type].get_test_payload(
237+
model_type
238+
), # multipart/form-data
239+
data={"model": model},
202240
timeout=10,
203241
)
204242
else:
205243
# for other model types (chat, completion, etc.)
206244
response = requests.post(
207-
f"{url}{model_details.value}",
245+
f"{url}{model_url}",
208246
headers={"Content-Type": "application/json"},
209-
json={"model": model} | model_details.get_test_payload(model_type),
247+
json={"model": model}
248+
| ModelType[model_type].get_test_payload(model_type),
210249
timeout=10,
211250
)
212251

0 commit comments

Comments
 (0)