@@ -66,12 +66,31 @@ def __call__(cls, *args, **kwargs):
66
66
67
67
68
68
class ModelType (enum .Enum ):
69
- chat = "/v1/chat/completions"
70
- completion = "/v1/completions"
71
- embeddings = "/v1/embeddings"
72
- rerank = "/v1/rerank"
73
- score = "/v1/score"
74
- transcription = "/v1/audio/transcriptions"
69
+ chat = "chat"
70
+ completion = "completion"
71
+ embeddings = "embeddings"
72
+ rerank = "rerank"
73
+ score = "score"
74
+ transcription = "transcription"
75
+ vision = "vision"
76
+
77
+ @staticmethod
78
+ def get_url (model_type : str ):
79
+ match ModelType [model_type ]:
80
+ case ModelType .chat :
81
+ return "/v1/chat/completions"
82
+ case ModelType .completion :
83
+ return "/v1/completions"
84
+ case ModelType .embeddings :
85
+ return "/v1/embeddings"
86
+ case ModelType .rerank :
87
+ return "/v1/rerank"
88
+ case ModelType .score :
89
+ return "/v1/score"
90
+ case ModelType .transcription :
91
+ return "/v1/audio/transcriptions"
92
+ case ModelType .vision :
93
+ return "/v1/chat/completions"
75
94
76
95
@staticmethod
77
96
def get_test_payload (model_type : str ):
@@ -101,6 +120,26 @@ def get_test_payload(model_type: str):
101
120
return {
102
121
"file" : ("empty.wav" , _SILENT_WAV_BYTES , "audio/wav" ),
103
122
}
123
+ case ModelType .vision :
124
+ return {
125
+ "messages" : [
126
+ {
127
+ "role" : "user" ,
128
+ "content" : [
129
+ {
130
+ "type" : "text" ,
131
+ "text" : "This is a test. Just reply with yes" ,
132
+ },
133
+ {
134
+ "type" : "image_url" ,
135
+ "image_url" : {
136
+ "url" : "data:image/jpeg;base64,iVBORw0KGgoAAAANSUhEUgAAAAIAAAACCAIAAAD91JpzAAAAG0lEQVR4nGLinfJq851wJn69udZSvIAAAAD//yf3BLKCfW8HAAAAAElFTkSuQmCC"
137
+ },
138
+ },
139
+ ],
140
+ }
141
+ ]
142
+ }
104
143
105
144
@staticmethod
106
145
def get_all_fields ():
@@ -186,27 +225,27 @@ def update_content_length(request: Request, request_body: str):
186
225
187
226
188
227
def is_model_healthy (url : str , model : str , model_type : str ) -> bool :
189
- model_details = ModelType [ model_type ]
228
+ model_url = ModelType . get_url ( model_type )
190
229
191
230
try :
192
231
if model_type == "transcription" :
193
-
194
232
# for transcription, the backend expects multipart/form-data with a file
195
233
# we will use pre-generated silent wav bytes
196
- files = {"file" : ("empty.wav" , _SILENT_WAV_BYTES , "audio/wav" )}
197
- data = {"model" : model }
198
234
response = requests .post (
199
- f"{ url } { model_details .value } " ,
200
- files = files , # multipart/form-data
201
- data = data ,
235
+ f"{ url } { model_url } " ,
236
+ files = ModelType [model_type ].get_test_payload (
237
+ model_type
238
+ ), # multipart/form-data
239
+ data = {"model" : model },
202
240
timeout = 10 ,
203
241
)
204
242
else :
205
243
# for other model types (chat, completion, etc.)
206
244
response = requests .post (
207
- f"{ url } { model_details . value } " ,
245
+ f"{ url } { model_url } " ,
208
246
headers = {"Content-Type" : "application/json" },
209
- json = {"model" : model } | model_details .get_test_payload (model_type ),
247
+ json = {"model" : model }
248
+ | ModelType [model_type ].get_test_payload (model_type ),
210
249
timeout = 10 ,
211
250
)
212
251
0 commit comments