@@ -66,12 +66,29 @@ def __call__(cls, *args, **kwargs):
66
66
67
67
68
68
class ModelType (enum .Enum ):
69
- chat = "/v1/chat/completions"
70
- completion = "/v1/completions"
71
- embeddings = "/v1/embeddings"
72
- rerank = "/v1/rerank"
73
- score = "/v1/score"
74
- transcription = "/v1/audio/transcriptions"
69
+ chat = "chat"
70
+ completion = "completion"
71
+ embeddings = "embeddings"
72
+ rerank = "rerank"
73
+ score = "score"
74
+ transcription = "transcription"
75
+ vision = "vision"
76
+
77
+ @staticmethod
78
+ def get_url (model_type : str ):
79
+ match ModelType [model_type ]:
80
+ case ModelType .chat | ModelType .vision :
81
+ return "/v1/chat/completions"
82
+ case ModelType .completion :
83
+ return "/v1/completions"
84
+ case ModelType .embeddings :
85
+ return "/v1/embeddings"
86
+ case ModelType .rerank :
87
+ return "/v1/rerank"
88
+ case ModelType .score :
89
+ return "/v1/score"
90
+ case ModelType .transcription :
91
+ return "/v1/audio/transcriptions"
75
92
76
93
@staticmethod
77
94
def get_test_payload (model_type : str ):
@@ -101,6 +118,26 @@ def get_test_payload(model_type: str):
101
118
return {
102
119
"file" : ("empty.wav" , _SILENT_WAV_BYTES , "audio/wav" ),
103
120
}
121
+ case ModelType .vision :
122
+ return {
123
+ "messages" : [
124
+ {
125
+ "role" : "user" ,
126
+ "content" : [
127
+ {
128
+ "type" : "text" ,
129
+ "text" : "This is a test. Just reply with yes" ,
130
+ },
131
+ {
132
+ "type" : "image_url" ,
133
+ "image_url" : {
134
+ "url" : "data:image/jpeg;base64,iVBORw0KGgoAAAANSUhEUgAAAAIAAAACCAIAAAD91JpzAAAAG0lEQVR4nGLinfJq851wJn69udZSvIAAAAD//yf3BLKCfW8HAAAAAElFTkSuQmCC"
135
+ },
136
+ },
137
+ ],
138
+ }
139
+ ]
140
+ }
104
141
105
142
@staticmethod
106
143
def get_all_fields ():
@@ -186,27 +223,27 @@ def update_content_length(request: Request, request_body: str):
186
223
187
224
188
225
def is_model_healthy (url : str , model : str , model_type : str ) -> bool :
189
- model_details = ModelType [ model_type ]
226
+ model_url = ModelType . get_url ( model_type )
190
227
191
228
try :
192
229
if model_type == "transcription" :
193
-
194
230
# for transcription, the backend expects multipart/form-data with a file
195
231
# we will use pre-generated silent wav bytes
196
- files = {"file" : ("empty.wav" , _SILENT_WAV_BYTES , "audio/wav" )}
197
- data = {"model" : model }
198
232
response = requests .post (
199
- f"{ url } { model_details .value } " ,
200
- files = files , # multipart/form-data
201
- data = data ,
233
+ f"{ url } { model_url } " ,
234
+ files = ModelType .get_test_payload (
235
+ model_type
236
+ ), # multipart/form-data
237
+ data = {"model" : model },
202
238
timeout = 10 ,
203
239
)
204
240
else :
205
241
# for other model types (chat, completion, etc.)
206
242
response = requests .post (
207
- f"{ url } { model_details . value } " ,
243
+ f"{ url } { model_url } " ,
208
244
headers = {"Content-Type" : "application/json" },
209
- json = {"model" : model } | model_details .get_test_payload (model_type ),
245
+ json = {"model" : model }
246
+ | ModelType .get_test_payload (model_type ),
210
247
timeout = 10 ,
211
248
)
212
249
0 commit comments