Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion src/semantic-router/pkg/extproc/request_handler.go
Original file line number Diff line number Diff line change
Expand Up @@ -516,7 +516,7 @@ func (r *OpenAIRouter) handleCaching(ctx *RequestContext) (*ext_proc.ProcessingR
"query": requestQuery,
})
// Return immediate response from cache
response := http.CreateCacheHitResponse(cachedResponse)
response := http.CreateCacheHitResponse(cachedResponse, ctx.ExpectStreamingResponse)
ctx.TraceContext = spanCtx
return response, true
}
Expand Down
55 changes: 52 additions & 3 deletions src/semantic-router/pkg/utils/http/response.go
Original file line number Diff line number Diff line change
Expand Up @@ -232,7 +232,56 @@ func CreateJailbreakViolationResponse(jailbreakType string, confidence float32,
}

// CreateCacheHitResponse creates an immediate response from cache
func CreateCacheHitResponse(cachedResponse []byte) *ext_proc.ProcessingResponse {
func CreateCacheHitResponse(cachedResponse []byte, isStreaming bool) *ext_proc.ProcessingResponse {
var responseBody []byte
var contentType string

if isStreaming {
// For streaming responses, convert cached JSON to SSE format
contentType = "text/event-stream"

// Parse the cached JSON response
var cachedCompletion openai.ChatCompletion
if err := json.Unmarshal(cachedResponse, &cachedCompletion); err != nil {
observability.Errorf("Error parsing cached response for streaming conversion: %v", err)
responseBody = []byte("data: {\"error\": \"Failed to convert cached response\"}\n\ndata: [DONE]\n\n")
} else {
// Convert chat.completion to chat.completion.chunk format
streamChunk := map[string]interface{}{
"id": cachedCompletion.ID,
"object": "chat.completion.chunk",
"created": cachedCompletion.Created,
"model": cachedCompletion.Model,
"choices": []map[string]interface{}{},
}

// Convert choices from message format to delta format
for _, choice := range cachedCompletion.Choices {
streamChoice := map[string]interface{}{
"index": choice.Index,
"delta": map[string]interface{}{
"role": choice.Message.Role,
"content": choice.Message.Content,
},
"finish_reason": choice.FinishReason,
}
streamChunk["choices"] = append(streamChunk["choices"].([]map[string]interface{}), streamChoice)
}

chunkJSON, err := json.Marshal(streamChunk)
if err != nil {
observability.Errorf("Error marshaling streaming cache response: %v", err)
responseBody = []byte("data: {\"error\": \"Failed to generate response\"}\n\ndata: [DONE]\n\n")
} else {
responseBody = []byte(fmt.Sprintf("data: %s\n\ndata: [DONE]\n\n", chunkJSON))
}
}
} else {
// For non-streaming responses, use cached JSON directly
contentType = "application/json"
responseBody = cachedResponse
}

immediateResponse := &ext_proc.ImmediateResponse{
Status: &typev3.HttpStatus{
Code: typev3.StatusCode_OK,
Expand All @@ -242,7 +291,7 @@ func CreateCacheHitResponse(cachedResponse []byte) *ext_proc.ProcessingResponse
{
Header: &core.HeaderValue{
Key: "content-type",
RawValue: []byte("application/json"),
RawValue: []byte(contentType),
},
},
{
Expand All @@ -253,7 +302,7 @@ func CreateCacheHitResponse(cachedResponse []byte) *ext_proc.ProcessingResponse
},
},
},
Body: cachedResponse,
Body: responseBody,
}

return &ext_proc.ProcessingResponse{
Expand Down
250 changes: 250 additions & 0 deletions src/semantic-router/pkg/utils/http/response_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,250 @@
package http

import (
"encoding/json"
"strings"
"testing"

"github.com/openai/openai-go"
)

func TestCreateCacheHitResponse_NonStreaming(t *testing.T) {
// Create a sample cached response
cachedCompletion := openai.ChatCompletion{
ID: "chatcmpl-test-123",
Object: "chat.completion",
Created: 1234567890,
Model: "test-model",
Choices: []openai.ChatCompletionChoice{
{
Index: 0,
Message: openai.ChatCompletionMessage{
Role: "assistant",
Content: "This is a cached response.",
},
FinishReason: "stop",
},
},
Usage: openai.CompletionUsage{
PromptTokens: 10,
CompletionTokens: 5,
TotalTokens: 15,
},
}

cachedResponse, err := json.Marshal(cachedCompletion)
if err != nil {
t.Fatalf("Failed to marshal cached response: %v", err)
}

// Test non-streaming response
response := CreateCacheHitResponse(cachedResponse, false)

// Verify response structure
if response == nil {
t.Fatal("Response is nil")
}

immediateResp := response.GetImmediateResponse()
if immediateResp == nil {
t.Fatal("ImmediateResponse is nil")
}

// Verify status code
if immediateResp.Status.Code.String() != "OK" {
t.Errorf("Expected status OK, got %s", immediateResp.Status.Code.String())
}

// Verify content-type header
var contentType string
var cacheHit string
for _, header := range immediateResp.Headers.SetHeaders {
if header.Header.Key == "content-type" {
contentType = string(header.Header.RawValue)
}
if header.Header.Key == "x-vsr-cache-hit" {
cacheHit = string(header.Header.RawValue)
}
}

if contentType != "application/json" {
t.Errorf("Expected content-type application/json, got %s", contentType)
}

if cacheHit != "true" {
t.Errorf("Expected x-vsr-cache-hit true, got %s", cacheHit)
}

// Verify body is unchanged
if string(immediateResp.Body) != string(cachedResponse) {
t.Error("Body was modified for non-streaming response")
}

// Verify body can be parsed as JSON
var parsedResponse openai.ChatCompletion
if err := json.Unmarshal(immediateResp.Body, &parsedResponse); err != nil {
t.Errorf("Failed to parse response body as JSON: %v", err)
}

if parsedResponse.Object != "chat.completion" {
t.Errorf("Expected object chat.completion, got %s", parsedResponse.Object)
}
}

func TestCreateCacheHitResponse_Streaming(t *testing.T) {
// Create a sample cached response
cachedCompletion := openai.ChatCompletion{
ID: "chatcmpl-test-456",
Object: "chat.completion",
Created: 1234567890,
Model: "test-model",
Choices: []openai.ChatCompletionChoice{
{
Index: 0,
Message: openai.ChatCompletionMessage{
Role: "assistant",
Content: "This is a cached streaming response.",
},
FinishReason: "stop",
},
},
Usage: openai.CompletionUsage{
PromptTokens: 10,
CompletionTokens: 5,
TotalTokens: 15,
},
}

cachedResponse, err := json.Marshal(cachedCompletion)
if err != nil {
t.Fatalf("Failed to marshal cached response: %v", err)
}

// Test streaming response
response := CreateCacheHitResponse(cachedResponse, true)

// Verify response structure
if response == nil {
t.Fatal("Response is nil")
}

immediateResp := response.GetImmediateResponse()
if immediateResp == nil {
t.Fatal("ImmediateResponse is nil")
}

// Verify status code
if immediateResp.Status.Code.String() != "OK" {
t.Errorf("Expected status OK, got %s", immediateResp.Status.Code.String())
}

// Verify content-type header
var contentType string
var cacheHit string
for _, header := range immediateResp.Headers.SetHeaders {
if header.Header.Key == "content-type" {
contentType = string(header.Header.RawValue)
}
if header.Header.Key == "x-vsr-cache-hit" {
cacheHit = string(header.Header.RawValue)
}
}

if contentType != "text/event-stream" {
t.Errorf("Expected content-type text/event-stream, got %s", contentType)
}

if cacheHit != "true" {
t.Errorf("Expected x-vsr-cache-hit true, got %s", cacheHit)
}

// Verify body is in SSE format
bodyStr := string(immediateResp.Body)
if !strings.HasPrefix(bodyStr, "data: ") {
t.Error("Body does not start with 'data: ' prefix")
}

if !strings.Contains(bodyStr, "data: [DONE]") {
t.Error("Body does not contain 'data: [DONE]' terminator")
}

// Parse the SSE data
lines := strings.Split(bodyStr, "\n")
var dataLine string
for _, line := range lines {
if strings.HasPrefix(line, "data: ") && !strings.Contains(line, "[DONE]") {
dataLine = strings.TrimPrefix(line, "data: ")
break
}
}

if dataLine == "" {
t.Fatal("No data line found in SSE response")
}

// Parse the JSON chunk
var chunk map[string]interface{}
if err := json.Unmarshal([]byte(dataLine), &chunk); err != nil {
t.Fatalf("Failed to parse SSE data as JSON: %v", err)
}

// Verify chunk structure
if chunk["object"] != "chat.completion.chunk" {
t.Errorf("Expected object chat.completion.chunk, got %v", chunk["object"])
}

if chunk["id"] != "chatcmpl-test-456" {
t.Errorf("Expected id chatcmpl-test-456, got %v", chunk["id"])
}

// Verify choices structure
choices, ok := chunk["choices"].([]interface{})
if !ok || len(choices) == 0 {
t.Fatal("Choices not found or empty")
}

choice := choices[0].(map[string]interface{})
delta, ok := choice["delta"].(map[string]interface{})
if !ok {
t.Fatal("Delta not found in choice")
}

if delta["role"] != "assistant" {
t.Errorf("Expected role assistant, got %v", delta["role"])
}

if delta["content"] != "This is a cached streaming response." {
t.Errorf("Expected content 'This is a cached streaming response.', got %v", delta["content"])
}

if choice["finish_reason"] != "stop" {
t.Errorf("Expected finish_reason stop, got %v", choice["finish_reason"])
}
}

func TestCreateCacheHitResponse_StreamingWithInvalidJSON(t *testing.T) {
// Test with invalid JSON
invalidJSON := []byte("invalid json")

response := CreateCacheHitResponse(invalidJSON, true)

// Verify response structure
if response == nil {
t.Fatal("Response is nil")
}

immediateResp := response.GetImmediateResponse()
if immediateResp == nil {
t.Fatal("ImmediateResponse is nil")
}

// Verify error response
bodyStr := string(immediateResp.Body)
if !strings.Contains(bodyStr, "error") {
t.Error("Expected error message in response body")
}

if !strings.Contains(bodyStr, "data: [DONE]") {
t.Error("Expected SSE terminator even in error case")
}
}
Loading