Skip to content

Commit 2c86645

Browse files
committed
feat: process response body differentially based on request type
Signed-off-by: Guillaume Calmettes <[email protected]>
1 parent 61f1a18 commit 2c86645

File tree

2 files changed

+163
-4
lines changed

2 files changed

+163
-4
lines changed

pkg/plugins/gateway/gateway.go

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -82,6 +82,7 @@ func (s *Server) Process(srv extProcPb.ExternalProcessor_ProcessServer) error {
8282
var respErrorCode int
8383
var model string
8484
var requestPath string
85+
var requestType OpenAiRequestType
8586
var routingAlgorithm types.RoutingAlgorithm
8687
var routerCtx *types.RoutingContext
8788
var stream, isRespError bool
@@ -113,7 +114,7 @@ func (s *Server) Process(srv extProcPb.ExternalProcessor_ProcessServer) error {
113114
resp, user, rpm, routingAlgorithm, requestPath = s.HandleRequestHeaders(ctx, requestID, req)
114115

115116
case *extProcPb.ProcessingRequest_RequestBody:
116-
resp, model, routerCtx, stream, traceTerm = s.HandleRequestBody(ctx, requestID, requestPath, req, user, routingAlgorithm)
117+
resp, model, routerCtx, stream, requestType, traceTerm = s.HandleRequestBody(ctx, requestID, requestPath, req, user, routingAlgorithm)
117118
if routerCtx != nil {
118119
ctx = routerCtx
119120
}
@@ -135,7 +136,7 @@ func (s *Server) Process(srv extProcPb.ExternalProcessor_ProcessServer) error {
135136
resp = s.responseErrorProcessing(ctx, resp, respErrorCode, model, requestID,
136137
string(req.Request.(*extProcPb.ProcessingRequest_ResponseBody).ResponseBody.GetBody()))
137138
} else {
138-
resp, completed = s.HandleResponseBody(ctx, requestID, req, user, rpm, model, stream, traceTerm, completed)
139+
resp, completed = s.HandleResponseBody(ctx, requestID, req, requestType, user, rpm, model, stream, traceTerm, completed)
139140
}
140141
default:
141142
klog.Infof("Unknown Request type %+v\n", v)
@@ -205,7 +206,8 @@ func (s *Server) validateHTTPRouteStatus(ctx context.Context, model string) erro
205206
}
206207

207208
func (s *Server) responseErrorProcessing(ctx context.Context, resp *extProcPb.ProcessingResponse, respErrorCode int,
208-
model, requestID, errMsg string) *extProcPb.ProcessingResponse {
209+
model, requestID, errMsg string,
210+
) *extProcPb.ProcessingResponse {
209211
httprouteErr := s.validateHTTPRouteStatus(ctx, model)
210212
if errMsg != "" && httprouteErr != nil {
211213
errMsg = fmt.Sprintf("%s. %s", errMsg, httprouteErr.Error())

pkg/plugins/gateway/gateway_rsp_body.go

Lines changed: 158 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -35,9 +35,26 @@ import (
3535
"github.com/vllm-project/aibrix/pkg/utils"
3636
)
3737

38-
func (s *Server) HandleResponseBody(ctx context.Context, requestID string, req *extProcPb.ProcessingRequest, user utils.User, rpm int64, model string, stream bool, traceTerm int64, hasCompleted bool) (*extProcPb.ProcessingResponse, bool) {
38+
func (s *Server) HandleResponseBody(ctx context.Context, requestID string, req *extProcPb.ProcessingRequest, requestType OpenAiRequestType, user utils.User, rpm int64, model string, stream bool, traceTerm int64, hasCompleted bool) (*extProcPb.ProcessingResponse, bool) {
3939
b := req.Request.(*extProcPb.ProcessingRequest_ResponseBody)
4040

41+
switch requestType {
42+
case OpenAiRequestChatCompletionsType, OpenAiRequestCompletionsType:
43+
return s.handleChatCompletionsResponseBody(ctx, requestID, b, user, rpm, model, stream, traceTerm, hasCompleted)
44+
case OpenAiRequestEmbeddingsType:
45+
return s.handleEmbeddingsResponseBody(ctx, requestID, b, user, rpm, model, false, traceTerm, hasCompleted)
46+
default:
47+
// all other openAi request types (e.g. audio, image, ..) are not supported yet
48+
return generateErrorResponse(
49+
envoyTypePb.StatusCode_NotImplemented,
50+
[]*configPb.HeaderValueOption{{Header: &configPb.HeaderValue{
51+
Key: HeaderErrorResponseUnknown, RawValue: []byte("true"),
52+
}}},
53+
"request type not supported"), true
54+
}
55+
}
56+
57+
func (s *Server) handleChatCompletionsResponseBody(ctx context.Context, requestID string, b *extProcPb.ProcessingRequest_ResponseBody, user utils.User, rpm int64, model string, stream bool, traceTerm int64, hasCompleted bool) (*extProcPb.ProcessingResponse, bool) {
4158
var res openai.ChatCompletion
4259
var usage openai.CompletionUsage
4360
var promptTokens, completionTokens int64
@@ -203,3 +220,143 @@ func (s *Server) HandleResponseBody(ctx context.Context, requestID string, req *
203220
},
204221
}, complete
205222
}
223+
224+
func (s *Server) handleEmbeddingsResponseBody(ctx context.Context, requestID string, b *extProcPb.ProcessingRequest_ResponseBody, user utils.User, rpm int64, model string, stream bool, traceTerm int64, hasCompleted bool) (*extProcPb.ProcessingResponse, bool) {
225+
var res openai.CreateEmbeddingResponse
226+
var usage openai.CreateEmbeddingResponseUsage
227+
var promptTokens, completionTokens int64
228+
var headers []*configPb.HeaderValueOption
229+
complete := hasCompleted
230+
routerCtx, _ := ctx.(*types.RoutingContext)
231+
232+
defer func() {
233+
// Wrapped in a function to delay the evaluation of parameters. Using complete to make sure DoneRequestTrace only call once for a request.
234+
if !hasCompleted && complete {
235+
s.cache.DoneRequestTrace(routerCtx, requestID, model, promptTokens, completionTokens, traceTerm)
236+
if routerCtx != nil {
237+
routerCtx.Delete()
238+
}
239+
}
240+
}()
241+
242+
// Use request ID as a key to store per-request buffer
243+
// Retrieve or create buffer
244+
buf, _ := requestBuffers.LoadOrStore(requestID, &bytes.Buffer{})
245+
buffer := buf.(*bytes.Buffer)
246+
// Append data to per-request buffer
247+
buffer.Write(b.ResponseBody.Body)
248+
249+
if !b.ResponseBody.EndOfStream {
250+
// Partial data received, wait for more chunks, we just return a common response here.
251+
return &extProcPb.ProcessingResponse{
252+
Response: &extProcPb.ProcessingResponse_ResponseBody{
253+
ResponseBody: &extProcPb.BodyResponse{
254+
Response: &extProcPb.CommonResponse{},
255+
},
256+
},
257+
}, complete
258+
}
259+
260+
// Last part received, process the full response
261+
finalBody := buffer.Bytes()
262+
// Clean up the buffer after final processing
263+
requestBuffers.Delete(requestID)
264+
265+
if err := json.Unmarshal(finalBody, &res); err != nil {
266+
klog.ErrorS(err, "error to unmarshal response", "requestID", requestID, "responseBody", string(b.ResponseBody.GetBody()))
267+
complete = true
268+
return generateErrorResponse(
269+
envoyTypePb.StatusCode_InternalServerError,
270+
[]*configPb.HeaderValueOption{{Header: &configPb.HeaderValue{
271+
Key: HeaderErrorResponseUnmarshal, RawValue: []byte("true"),
272+
}}},
273+
err.Error()), complete
274+
} else if len(res.Model) == 0 {
275+
msg := ErrorUnknownResponse.Error()
276+
responseBodyContent := string(b.ResponseBody.GetBody())
277+
if len(responseBodyContent) != 0 {
278+
msg = responseBodyContent
279+
}
280+
klog.ErrorS(err, "unexpected response", "requestID", requestID, "responseBody", responseBodyContent)
281+
complete = true
282+
return generateErrorResponse(
283+
envoyTypePb.StatusCode_InternalServerError,
284+
[]*configPb.HeaderValueOption{{Header: &configPb.HeaderValue{
285+
Key: HeaderErrorResponseUnknown, RawValue: []byte("true"),
286+
}}},
287+
msg), complete
288+
}
289+
// Do not overwrite model, res can be empty.
290+
usage = res.Usage
291+
292+
var requestEnd string
293+
if usage.TotalTokens != 0 {
294+
complete = true
295+
// Update promptTokens and completeTokens
296+
promptTokens = usage.PromptTokens
297+
completionTokens = 0 // no completion tokens in embeddings request
298+
// Count token per user.
299+
if user.Name != "" {
300+
tpm, err := s.ratelimiter.Incr(ctx, fmt.Sprintf("%v_TPM_CURRENT", user.Name), res.Usage.TotalTokens)
301+
if err != nil {
302+
return generateErrorResponse(
303+
envoyTypePb.StatusCode_InternalServerError,
304+
[]*configPb.HeaderValueOption{{Header: &configPb.HeaderValue{
305+
Key: HeaderErrorIncrTPM, RawValue: []byte("true"),
306+
}}},
307+
err.Error()), complete
308+
}
309+
310+
headers = append(headers,
311+
&configPb.HeaderValueOption{
312+
Header: &configPb.HeaderValue{
313+
Key: HeaderUpdateRPM,
314+
RawValue: []byte(fmt.Sprintf("%d", rpm)),
315+
},
316+
},
317+
&configPb.HeaderValueOption{
318+
Header: &configPb.HeaderValue{
319+
Key: HeaderUpdateTPM,
320+
RawValue: []byte(fmt.Sprintf("%d", tpm)),
321+
},
322+
},
323+
)
324+
requestEnd = fmt.Sprintf(requestEnd+"rpm: %d, tpm: %d, ", rpm, tpm)
325+
}
326+
327+
if routerCtx != nil && routerCtx.HasRouted() {
328+
targetPodIP := routerCtx.TargetAddress()
329+
headers = append(headers,
330+
&configPb.HeaderValueOption{
331+
Header: &configPb.HeaderValue{
332+
Key: HeaderTargetPod,
333+
RawValue: []byte(targetPodIP),
334+
},
335+
},
336+
&configPb.HeaderValueOption{
337+
Header: &configPb.HeaderValue{
338+
Key: HeaderRequestID,
339+
RawValue: []byte(requestID),
340+
},
341+
},
342+
)
343+
requestEnd = fmt.Sprintf(requestEnd+"targetPod: %s", targetPodIP)
344+
}
345+
346+
klog.Infof("request end, requestID: %s - %s", requestID, requestEnd)
347+
} else if b.ResponseBody.EndOfStream {
348+
complete = true
349+
}
350+
351+
return &extProcPb.ProcessingResponse{
352+
Response: &extProcPb.ProcessingResponse_ResponseBody{
353+
ResponseBody: &extProcPb.BodyResponse{
354+
Response: &extProcPb.CommonResponse{
355+
HeaderMutation: &extProcPb.HeaderMutation{
356+
SetHeaders: headers,
357+
},
358+
},
359+
},
360+
},
361+
}, complete
362+
}

0 commit comments

Comments
 (0)