Skip to content

Commit 53bb830

Browse files
AkisAyaXunzhuo
andauthored
feat: add stream mode support (#282)
Signed-off-by: akisaya <[email protected]> Co-authored-by: Xunzhuo <[email protected]>
1 parent f982534 commit 53bb830

File tree

1 file changed

+50
-5
lines changed

1 file changed

+50
-5
lines changed

src/semantic-router/pkg/extproc/request_handler.go

Lines changed: 50 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,43 @@ func serializeOpenAIRequest(req *openai.ChatCompletionNewParams) ([]byte, error)
3333
return json.Marshal(req)
3434
}
3535

36+
// extractStreamParam extracts the stream parameter from the original request body
37+
func extractStreamParam(originalBody []byte) bool {
38+
var requestMap map[string]interface{}
39+
if err := json.Unmarshal(originalBody, &requestMap); err != nil {
40+
return false
41+
}
42+
43+
if streamValue, exists := requestMap["stream"]; exists {
44+
if stream, ok := streamValue.(bool); ok {
45+
return stream
46+
}
47+
}
48+
return false
49+
}
50+
51+
// serializeOpenAIRequestWithStream converts request back to JSON, preserving the stream parameter from original request
52+
func serializeOpenAIRequestWithStream(req *openai.ChatCompletionNewParams, hasStreamParam bool) ([]byte, error) {
53+
// First serialize the SDK object
54+
sdkBytes, err := json.Marshal(req)
55+
if err != nil {
56+
return nil, err
57+
}
58+
59+
// If original request had stream parameter, add it back
60+
if hasStreamParam {
61+
var sdkMap map[string]interface{}
62+
if err := json.Unmarshal(sdkBytes, &sdkMap); err == nil {
63+
sdkMap["stream"] = true
64+
if modifiedBytes, err := json.Marshal(sdkMap); err == nil {
65+
return modifiedBytes, nil
66+
}
67+
}
68+
}
69+
70+
return sdkBytes, nil
71+
}
72+
3673
// addSystemPromptToRequestBody adds a system prompt to the beginning of the messages array in the JSON request body
3774
func addSystemPromptToRequestBody(requestBody []byte, systemPrompt string) ([]byte, error) {
3875
if systemPrompt == "" {
@@ -166,7 +203,7 @@ type RequestContext struct {
166203
ProcessingStartTime time.Time
167204

168205
// Streaming detection
169-
ExpectStreamingResponse bool // set from request Accept header
206+
ExpectStreamingResponse bool // set from request Accept header or stream parameter
170207
IsStreamingResponse bool // set from response Content-Type
171208

172209
// TTFT tracking
@@ -207,6 +244,7 @@ func (r *OpenAIRouter) handleRequestHeaders(v *ext_proc.ProcessingRequest_Reques
207244
if accept, ok := ctx.Headers["accept"]; ok {
208245
if strings.Contains(strings.ToLower(accept), "text/event-stream") {
209246
ctx.ExpectStreamingResponse = true
247+
observability.Infof("Client expects streaming response based on Accept header")
210248
}
211249
}
212250

@@ -246,6 +284,13 @@ func (r *OpenAIRouter) handleRequestBody(v *ext_proc.ProcessingRequest_RequestBo
246284
// Save the original request body
247285
ctx.OriginalRequestBody = v.RequestBody.GetBody()
248286

287+
// Extract stream parameter from original request and update ExpectStreamingResponse if needed
288+
hasStreamParam := extractStreamParam(ctx.OriginalRequestBody)
289+
if hasStreamParam {
290+
observability.Infof("Original request contains stream parameter: true")
291+
ctx.ExpectStreamingResponse = true // Set this if stream param is found
292+
}
293+
249294
// Parse the OpenAI request using SDK types
250295
openAIRequest, err := parseOpenAIRequest(ctx.OriginalRequestBody)
251296
if err != nil {
@@ -499,8 +544,8 @@ func (r *OpenAIRouter) handleModelRouting(openAIRequest *openai.ChatCompletionNe
499544
// Modify the model in the request
500545
openAIRequest.Model = openai.ChatModel(matchedModel)
501546

502-
// Serialize the modified request
503-
modifiedBody, err := serializeOpenAIRequest(openAIRequest)
547+
// Serialize the modified request with stream parameter preserved
548+
modifiedBody, err := serializeOpenAIRequestWithStream(openAIRequest, ctx.ExpectStreamingResponse)
504549
if err != nil {
505550
observability.Errorf("Error serializing modified request: %v", err)
506551
metrics.RecordRequestError(actualModel, "serialization_error")
@@ -758,8 +803,8 @@ func (r *OpenAIRouter) handleToolSelection(openAIRequest *openai.ChatCompletionN
758803

759804
// updateRequestWithTools updates the request body with the selected tools
760805
func (r *OpenAIRouter) updateRequestWithTools(openAIRequest *openai.ChatCompletionNewParams, response **ext_proc.ProcessingResponse, ctx *RequestContext) error {
761-
// Re-serialize the request with modified tools
762-
modifiedBody, err := serializeOpenAIRequest(openAIRequest)
806+
// Re-serialize the request with modified tools and preserved stream parameter
807+
modifiedBody, err := serializeOpenAIRequestWithStream(openAIRequest, ctx.ExpectStreamingResponse)
763808
if err != nil {
764809
return err
765810
}

0 commit comments

Comments
 (0)