@@ -35,9 +35,26 @@ import (
35
35
"github.com/vllm-project/aibrix/pkg/utils"
36
36
)
37
37
38
- func (s * Server ) HandleResponseBody (ctx context.Context , requestID string , req * extProcPb.ProcessingRequest , user utils.User , rpm int64 , model string , stream bool , traceTerm int64 , hasCompleted bool ) (* extProcPb.ProcessingResponse , bool ) {
38
+ func (s * Server ) HandleResponseBody (ctx context.Context , requestID string , req * extProcPb.ProcessingRequest , requestType OpenAiRequestType , user utils.User , rpm int64 , model string , stream bool , traceTerm int64 , hasCompleted bool ) (* extProcPb.ProcessingResponse , bool ) {
39
39
b := req .Request .(* extProcPb.ProcessingRequest_ResponseBody )
40
40
41
+ switch requestType {
42
+ case OpenAiRequestChatCompletionsType , OpenAiRequestCompletionsType :
43
+ return s .handleChatCompletionsResponseBody (ctx , requestID , b , user , rpm , model , stream , traceTerm , hasCompleted )
44
+ case OpenAiRequestEmbeddingsType :
45
+ return s .handleEmbeddingsResponseBody (ctx , requestID , b , user , rpm , model , false , traceTerm , hasCompleted )
46
+ default :
47
+ // all other openAi request types (e.g. audio, image, ..) are not supported yet
48
+ return generateErrorResponse (
49
+ envoyTypePb .StatusCode_NotImplemented ,
50
+ []* configPb.HeaderValueOption {{Header : & configPb.HeaderValue {
51
+ Key : HeaderErrorResponseUnknown , RawValue : []byte ("true" ),
52
+ }}},
53
+ "request type not supported" ), true
54
+ }
55
+ }
56
+
57
+ func (s * Server ) handleChatCompletionsResponseBody (ctx context.Context , requestID string , b * extProcPb.ProcessingRequest_ResponseBody , user utils.User , rpm int64 , model string , stream bool , traceTerm int64 , hasCompleted bool ) (* extProcPb.ProcessingResponse , bool ) {
41
58
var res openai.ChatCompletion
42
59
var usage openai.CompletionUsage
43
60
var promptTokens , completionTokens int64
@@ -203,3 +220,143 @@ func (s *Server) HandleResponseBody(ctx context.Context, requestID string, req *
203
220
},
204
221
}, complete
205
222
}
223
+
224
+ func (s * Server ) handleEmbeddingsResponseBody (ctx context.Context , requestID string , b * extProcPb.ProcessingRequest_ResponseBody , user utils.User , rpm int64 , model string , stream bool , traceTerm int64 , hasCompleted bool ) (* extProcPb.ProcessingResponse , bool ) {
225
+ var res openai.CreateEmbeddingResponse
226
+ var usage openai.CreateEmbeddingResponseUsage
227
+ var promptTokens , completionTokens int64
228
+ var headers []* configPb.HeaderValueOption
229
+ complete := hasCompleted
230
+ routerCtx , _ := ctx .(* types.RoutingContext )
231
+
232
+ defer func () {
233
+ // Wrapped in a function to delay the evaluation of parameters. Using complete to make sure DoneRequestTrace only call once for a request.
234
+ if ! hasCompleted && complete {
235
+ s .cache .DoneRequestTrace (routerCtx , requestID , model , promptTokens , completionTokens , traceTerm )
236
+ if routerCtx != nil {
237
+ routerCtx .Delete ()
238
+ }
239
+ }
240
+ }()
241
+
242
+ // Use request ID as a key to store per-request buffer
243
+ // Retrieve or create buffer
244
+ buf , _ := requestBuffers .LoadOrStore (requestID , & bytes.Buffer {})
245
+ buffer := buf .(* bytes.Buffer )
246
+ // Append data to per-request buffer
247
+ buffer .Write (b .ResponseBody .Body )
248
+
249
+ if ! b .ResponseBody .EndOfStream {
250
+ // Partial data received, wait for more chunks, we just return a common response here.
251
+ return & extProcPb.ProcessingResponse {
252
+ Response : & extProcPb.ProcessingResponse_ResponseBody {
253
+ ResponseBody : & extProcPb.BodyResponse {
254
+ Response : & extProcPb.CommonResponse {},
255
+ },
256
+ },
257
+ }, complete
258
+ }
259
+
260
+ // Last part received, process the full response
261
+ finalBody := buffer .Bytes ()
262
+ // Clean up the buffer after final processing
263
+ requestBuffers .Delete (requestID )
264
+
265
+ if err := json .Unmarshal (finalBody , & res ); err != nil {
266
+ klog .ErrorS (err , "error to unmarshal response" , "requestID" , requestID , "responseBody" , string (b .ResponseBody .GetBody ()))
267
+ complete = true
268
+ return generateErrorResponse (
269
+ envoyTypePb .StatusCode_InternalServerError ,
270
+ []* configPb.HeaderValueOption {{Header : & configPb.HeaderValue {
271
+ Key : HeaderErrorResponseUnmarshal , RawValue : []byte ("true" ),
272
+ }}},
273
+ err .Error ()), complete
274
+ } else if len (res .Model ) == 0 {
275
+ msg := ErrorUnknownResponse .Error ()
276
+ responseBodyContent := string (b .ResponseBody .GetBody ())
277
+ if len (responseBodyContent ) != 0 {
278
+ msg = responseBodyContent
279
+ }
280
+ klog .ErrorS (err , "unexpected response" , "requestID" , requestID , "responseBody" , responseBodyContent )
281
+ complete = true
282
+ return generateErrorResponse (
283
+ envoyTypePb .StatusCode_InternalServerError ,
284
+ []* configPb.HeaderValueOption {{Header : & configPb.HeaderValue {
285
+ Key : HeaderErrorResponseUnknown , RawValue : []byte ("true" ),
286
+ }}},
287
+ msg ), complete
288
+ }
289
+ // Do not overwrite model, res can be empty.
290
+ usage = res .Usage
291
+
292
+ var requestEnd string
293
+ if usage .TotalTokens != 0 {
294
+ complete = true
295
+ // Update promptTokens and completeTokens
296
+ promptTokens = usage .PromptTokens
297
+ completionTokens = 0 // no completion tokens in embeddings request
298
+ // Count token per user.
299
+ if user .Name != "" {
300
+ tpm , err := s .ratelimiter .Incr (ctx , fmt .Sprintf ("%v_TPM_CURRENT" , user .Name ), res .Usage .TotalTokens )
301
+ if err != nil {
302
+ return generateErrorResponse (
303
+ envoyTypePb .StatusCode_InternalServerError ,
304
+ []* configPb.HeaderValueOption {{Header : & configPb.HeaderValue {
305
+ Key : HeaderErrorIncrTPM , RawValue : []byte ("true" ),
306
+ }}},
307
+ err .Error ()), complete
308
+ }
309
+
310
+ headers = append (headers ,
311
+ & configPb.HeaderValueOption {
312
+ Header : & configPb.HeaderValue {
313
+ Key : HeaderUpdateRPM ,
314
+ RawValue : []byte (fmt .Sprintf ("%d" , rpm )),
315
+ },
316
+ },
317
+ & configPb.HeaderValueOption {
318
+ Header : & configPb.HeaderValue {
319
+ Key : HeaderUpdateTPM ,
320
+ RawValue : []byte (fmt .Sprintf ("%d" , tpm )),
321
+ },
322
+ },
323
+ )
324
+ requestEnd = fmt .Sprintf (requestEnd + "rpm: %d, tpm: %d, " , rpm , tpm )
325
+ }
326
+
327
+ if routerCtx != nil && routerCtx .HasRouted () {
328
+ targetPodIP := routerCtx .TargetAddress ()
329
+ headers = append (headers ,
330
+ & configPb.HeaderValueOption {
331
+ Header : & configPb.HeaderValue {
332
+ Key : HeaderTargetPod ,
333
+ RawValue : []byte (targetPodIP ),
334
+ },
335
+ },
336
+ & configPb.HeaderValueOption {
337
+ Header : & configPb.HeaderValue {
338
+ Key : HeaderRequestID ,
339
+ RawValue : []byte (requestID ),
340
+ },
341
+ },
342
+ )
343
+ requestEnd = fmt .Sprintf (requestEnd + "targetPod: %s" , targetPodIP )
344
+ }
345
+
346
+ klog .Infof ("request end, requestID: %s - %s" , requestID , requestEnd )
347
+ } else if b .ResponseBody .EndOfStream {
348
+ complete = true
349
+ }
350
+
351
+ return & extProcPb.ProcessingResponse {
352
+ Response : & extProcPb.ProcessingResponse_ResponseBody {
353
+ ResponseBody : & extProcPb.BodyResponse {
354
+ Response : & extProcPb.CommonResponse {
355
+ HeaderMutation : & extProcPb.HeaderMutation {
356
+ SetHeaders : headers ,
357
+ },
358
+ },
359
+ },
360
+ },
361
+ }, complete
362
+ }
0 commit comments