@@ -66,12 +66,24 @@ def __call__(cls, *args, **kwargs):
66
66
67
67
68
68
class ModelType (enum .Enum ):
69
- chat = "/v1/chat/completions"
70
- completion = "/v1/completions"
71
- embeddings = "/v1/embeddings"
72
- rerank = "/v1/rerank"
73
- score = "/v1/score"
74
- transcription = "/v1/audio/transcriptions"
69
+ chat = "chat"
70
+ completion = "completion"
71
+ embeddings = "embeddings"
72
+ rerank = "rerank"
73
+ score = "score"
74
+ transcription = "transcription"
75
+ vision = "vision"
76
+
77
+ @staticmethod
78
+ def get_url (model_type : str ):
79
+ match ModelType [model_type ]:
80
+ case ModelType .chat : return "/v1/chat/completions"
81
+ case ModelType .completion : return "/v1/completions"
82
+ case ModelType .embeddings : return "/v1/embeddings"
83
+ case ModelType .rerank : return "/v1/rerank"
84
+ case ModelType .score : return "/v1/score"
85
+ case ModelType .transcription : return "/v1/audio/transcriptions"
86
+ case ModelType .vision : return "/v1/chat/completions"
75
87
76
88
@staticmethod
77
89
def get_test_payload (model_type : str ):
@@ -101,6 +113,26 @@ def get_test_payload(model_type: str):
101
113
return {
102
114
"file" : ("empty.wav" , _SILENT_WAV_BYTES , "audio/wav" ),
103
115
}
116
+ case ModelType .vision :
117
+ return {
118
+ "messages" : [
119
+ {
120
+ "role" : "user" ,
121
+ "content" : [
122
+ {
123
+ "type" : "text" ,
124
+ "text" : "This is a test. Just reply with yes" ,
125
+ },
126
+ {
127
+ "type" : "image_url" ,
128
+ "image_url" : {
129
+ "url" : "data:image/jpeg;base64,iVBORw0KGgoAAAANSUhEUgAAAAIAAAACCAIAAAD91JpzAAAAG0lEQVR4nGLinfJq851wJn69udZSvIAAAAD//yf3BLKCfW8HAAAAAElFTkSuQmCC"
130
+ },
131
+ },
132
+ ],
133
+ }
134
+ ]
135
+ }
104
136
105
137
@staticmethod
106
138
def get_all_fields ():
@@ -186,27 +218,24 @@ def update_content_length(request: Request, request_body: str):
186
218
187
219
188
220
def is_model_healthy (url : str , model : str , model_type : str ) -> bool :
189
- model_details = ModelType [ model_type ]
221
+ model_url = ModelType . get_url ( model_type )
190
222
191
223
try :
192
224
if model_type == "transcription" :
193
-
194
225
# for transcription, the backend expects multipart/form-data with a file
195
226
# we will use pre-generated silent wav bytes
196
- files = {"file" : ("empty.wav" , _SILENT_WAV_BYTES , "audio/wav" )}
197
- data = {"model" : model }
198
227
response = requests .post (
199
- f"{ url } { model_details . value } " ,
200
- files = files , # multipart/form-data
201
- data = data ,
228
+ f"{ url } { model_url } " ,
229
+ files = ModelType [ model_type ]. get_test_payload ( model_type ) , # multipart/form-data
230
+ data = { "model" : model } ,
202
231
timeout = 10 ,
203
232
)
204
233
else :
205
234
# for other model types (chat, completion, etc.)
206
235
response = requests .post (
207
- f"{ url } { model_details . value } " ,
236
+ f"{ url } { model_url } " ,
208
237
headers = {"Content-Type" : "application/json" },
209
- json = {"model" : model } | model_details .get_test_payload (model_type ),
238
+ json = {"model" : model } | ModelType [ model_type ] .get_test_payload (model_type ),
210
239
timeout = 10 ,
211
240
)
212
241
0 commit comments