Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
55 changes: 50 additions & 5 deletions src/semantic-router/pkg/extproc/request_handler.go
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,43 @@ func serializeOpenAIRequest(req *openai.ChatCompletionNewParams) ([]byte, error)
return json.Marshal(req)
}

// extractStreamParam extracts the stream parameter from the original request body
func extractStreamParam(originalBody []byte) bool {
var requestMap map[string]interface{}
if err := json.Unmarshal(originalBody, &requestMap); err != nil {
return false
}

if streamValue, exists := requestMap["stream"]; exists {
if stream, ok := streamValue.(bool); ok {
return stream
}
}
return false
}

// serializeOpenAIRequestWithStream converts request back to JSON, preserving the stream parameter from original request
func serializeOpenAIRequestWithStream(req *openai.ChatCompletionNewParams, hasStreamParam bool) ([]byte, error) {
// First serialize the SDK object
sdkBytes, err := json.Marshal(req)
if err != nil {
return nil, err
}

// If original request had stream parameter, add it back
if hasStreamParam {
var sdkMap map[string]interface{}
if err := json.Unmarshal(sdkBytes, &sdkMap); err == nil {
sdkMap["stream"] = true
if modifiedBytes, err := json.Marshal(sdkMap); err == nil {
return modifiedBytes, nil
}
}
}

return sdkBytes, nil
}

// addSystemPromptToRequestBody adds a system prompt to the beginning of the messages array in the JSON request body
func addSystemPromptToRequestBody(requestBody []byte, systemPrompt string) ([]byte, error) {
if systemPrompt == "" {
Expand Down Expand Up @@ -166,7 +203,7 @@ type RequestContext struct {
ProcessingStartTime time.Time

// Streaming detection
ExpectStreamingResponse bool // set from request Accept header
ExpectStreamingResponse bool // set from request Accept header or stream parameter
IsStreamingResponse bool // set from response Content-Type

// TTFT tracking
Expand Down Expand Up @@ -207,6 +244,7 @@ func (r *OpenAIRouter) handleRequestHeaders(v *ext_proc.ProcessingRequest_Reques
if accept, ok := ctx.Headers["accept"]; ok {
if strings.Contains(strings.ToLower(accept), "text/event-stream") {
ctx.ExpectStreamingResponse = true
observability.Infof("Client expects streaming response based on Accept header")
}
}

Expand Down Expand Up @@ -246,6 +284,13 @@ func (r *OpenAIRouter) handleRequestBody(v *ext_proc.ProcessingRequest_RequestBo
// Save the original request body
ctx.OriginalRequestBody = v.RequestBody.GetBody()

// Extract stream parameter from original request and update ExpectStreamingResponse if needed
hasStreamParam := extractStreamParam(ctx.OriginalRequestBody)
if hasStreamParam {
observability.Infof("Original request contains stream parameter: true")
ctx.ExpectStreamingResponse = true // Set this if stream param is found
}

// Parse the OpenAI request using SDK types
openAIRequest, err := parseOpenAIRequest(ctx.OriginalRequestBody)
if err != nil {
Expand Down Expand Up @@ -499,8 +544,8 @@ func (r *OpenAIRouter) handleModelRouting(openAIRequest *openai.ChatCompletionNe
// Modify the model in the request
openAIRequest.Model = openai.ChatModel(matchedModel)

// Serialize the modified request
modifiedBody, err := serializeOpenAIRequest(openAIRequest)
// Serialize the modified request with stream parameter preserved
modifiedBody, err := serializeOpenAIRequestWithStream(openAIRequest, ctx.ExpectStreamingResponse)
if err != nil {
observability.Errorf("Error serializing modified request: %v", err)
metrics.RecordRequestError(actualModel, "serialization_error")
Expand Down Expand Up @@ -758,8 +803,8 @@ func (r *OpenAIRouter) handleToolSelection(openAIRequest *openai.ChatCompletionN

// updateRequestWithTools updates the request body with the selected tools
func (r *OpenAIRouter) updateRequestWithTools(openAIRequest *openai.ChatCompletionNewParams, response **ext_proc.ProcessingResponse, ctx *RequestContext) error {
// Re-serialize the request with modified tools
modifiedBody, err := serializeOpenAIRequest(openAIRequest)
// Re-serialize the request with modified tools and preserved stream parameter
modifiedBody, err := serializeOpenAIRequestWithStream(openAIRequest, ctx.ExpectStreamingResponse)
if err != nil {
return err
}
Expand Down
Loading