Skip to content

Commit 0e4f210

Browse files
committed
fix stream mode
Signed-off-by: akisaya <[email protected]>
1 parent b3658bc commit 0e4f210

File tree

1 file changed

+50
-5
lines changed

1 file changed

+50
-5
lines changed

src/semantic-router/pkg/extproc/request_handler.go

Lines changed: 50 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,43 @@ func serializeOpenAIRequest(req *openai.ChatCompletionNewParams) ([]byte, error)
3232
return json.Marshal(req)
3333
}
3434

35+
// extractStreamParam extracts the stream parameter from the original request body
36+
func extractStreamParam(originalBody []byte) bool {
37+
var requestMap map[string]interface{}
38+
if err := json.Unmarshal(originalBody, &requestMap); err != nil {
39+
return false
40+
}
41+
42+
if streamValue, exists := requestMap["stream"]; exists {
43+
if stream, ok := streamValue.(bool); ok {
44+
return stream
45+
}
46+
}
47+
return false
48+
}
49+
50+
// serializeOpenAIRequestWithStream converts request back to JSON, preserving the stream parameter from original request
51+
func serializeOpenAIRequestWithStream(req *openai.ChatCompletionNewParams, hasStreamParam bool) ([]byte, error) {
52+
// First serialize the SDK object
53+
sdkBytes, err := json.Marshal(req)
54+
if err != nil {
55+
return nil, err
56+
}
57+
58+
// If original request had stream parameter, add it back
59+
if hasStreamParam {
60+
var sdkMap map[string]interface{}
61+
if err := json.Unmarshal(sdkBytes, &sdkMap); err == nil {
62+
sdkMap["stream"] = true
63+
if modifiedBytes, err := json.Marshal(sdkMap); err == nil {
64+
return modifiedBytes, nil
65+
}
66+
}
67+
}
68+
69+
return sdkBytes, nil
70+
}
71+
3572
// addSystemPromptToRequestBody adds a system prompt to the beginning of the messages array in the JSON request body
3673
func addSystemPromptToRequestBody(requestBody []byte, systemPrompt string) ([]byte, error) {
3774
if systemPrompt == "" {
@@ -165,7 +202,7 @@ type RequestContext struct {
165202
ProcessingStartTime time.Time
166203

167204
// Streaming detection
168-
ExpectStreamingResponse bool // set from request Accept header
205+
ExpectStreamingResponse bool // set from request Accept header or stream parameter
169206
IsStreamingResponse bool // set from response Content-Type
170207

171208
// TTFT tracking
@@ -200,6 +237,7 @@ func (r *OpenAIRouter) handleRequestHeaders(v *ext_proc.ProcessingRequest_Reques
200237
if accept, ok := ctx.Headers["accept"]; ok {
201238
if strings.Contains(strings.ToLower(accept), "text/event-stream") {
202239
ctx.ExpectStreamingResponse = true
240+
observability.Infof("Client expects streaming response based on Accept header")
203241
}
204242
}
205243

@@ -230,6 +268,13 @@ func (r *OpenAIRouter) handleRequestBody(v *ext_proc.ProcessingRequest_RequestBo
230268
// Save the original request body
231269
ctx.OriginalRequestBody = v.RequestBody.GetBody()
232270

271+
// Extract stream parameter from original request and update ExpectStreamingResponse if needed
272+
hasStreamParam := extractStreamParam(ctx.OriginalRequestBody)
273+
if hasStreamParam {
274+
observability.Infof("Original request contains stream parameter: true")
275+
ctx.ExpectStreamingResponse = true // Set this if stream param is found
276+
}
277+
233278
// Parse the OpenAI request using SDK types
234279
openAIRequest, err := parseOpenAIRequest(ctx.OriginalRequestBody)
235280
if err != nil {
@@ -472,8 +517,8 @@ func (r *OpenAIRouter) handleModelRouting(openAIRequest *openai.ChatCompletionNe
472517
// Modify the model in the request
473518
openAIRequest.Model = openai.ChatModel(matchedModel)
474519

475-
// Serialize the modified request
476-
modifiedBody, err := serializeOpenAIRequest(openAIRequest)
520+
// Serialize the modified request with stream parameter preserved
521+
modifiedBody, err := serializeOpenAIRequestWithStream(openAIRequest, ctx.ExpectStreamingResponse)
477522
if err != nil {
478523
observability.Errorf("Error serializing modified request: %v", err)
479524
metrics.RecordRequestError(actualModel, "serialization_error")
@@ -728,8 +773,8 @@ func (r *OpenAIRouter) handleToolSelection(openAIRequest *openai.ChatCompletionN
728773

729774
// updateRequestWithTools updates the request body with the selected tools
730775
func (r *OpenAIRouter) updateRequestWithTools(openAIRequest *openai.ChatCompletionNewParams, response **ext_proc.ProcessingResponse, ctx *RequestContext) error {
731-
// Re-serialize the request with modified tools
732-
modifiedBody, err := serializeOpenAIRequest(openAIRequest)
776+
// Re-serialize the request with modified tools and preserved stream parameter
777+
modifiedBody, err := serializeOpenAIRequestWithStream(openAIRequest, ctx.ExpectStreamingResponse)
733778
if err != nil {
734779
return err
735780
}

0 commit comments

Comments
 (0)