diff --git a/src/semantic-router/pkg/extproc/request_handler.go b/src/semantic-router/pkg/extproc/request_handler.go index b4ea97cf..a1d3319a 100644 --- a/src/semantic-router/pkg/extproc/request_handler.go +++ b/src/semantic-router/pkg/extproc/request_handler.go @@ -516,7 +516,7 @@ func (r *OpenAIRouter) handleCaching(ctx *RequestContext) (*ext_proc.ProcessingR "query": requestQuery, }) // Return immediate response from cache - response := http.CreateCacheHitResponse(cachedResponse) + response := http.CreateCacheHitResponse(cachedResponse, ctx.ExpectStreamingResponse) ctx.TraceContext = spanCtx return response, true } diff --git a/src/semantic-router/pkg/utils/http/response.go b/src/semantic-router/pkg/utils/http/response.go index d2f03b4c..d9460d3e 100644 --- a/src/semantic-router/pkg/utils/http/response.go +++ b/src/semantic-router/pkg/utils/http/response.go @@ -232,7 +232,56 @@ func CreateJailbreakViolationResponse(jailbreakType string, confidence float32, } // CreateCacheHitResponse creates an immediate response from cache -func CreateCacheHitResponse(cachedResponse []byte) *ext_proc.ProcessingResponse { +func CreateCacheHitResponse(cachedResponse []byte, isStreaming bool) *ext_proc.ProcessingResponse { + var responseBody []byte + var contentType string + + if isStreaming { + // For streaming responses, convert cached JSON to SSE format + contentType = "text/event-stream" + + // Parse the cached JSON response + var cachedCompletion openai.ChatCompletion + if err := json.Unmarshal(cachedResponse, &cachedCompletion); err != nil { + observability.Errorf("Error parsing cached response for streaming conversion: %v", err) + responseBody = []byte("data: {\"error\": \"Failed to convert cached response\"}\n\ndata: [DONE]\n\n") + } else { + // Convert chat.completion to chat.completion.chunk format + streamChunk := map[string]interface{}{ + "id": cachedCompletion.ID, + "object": "chat.completion.chunk", + "created": cachedCompletion.Created, + "model": cachedCompletion.Model, + "choices": []map[string]interface{}{}, + } + + // Convert choices from message format to delta format + for _, choice := range cachedCompletion.Choices { + streamChoice := map[string]interface{}{ + "index": choice.Index, + "delta": map[string]interface{}{ + "role": choice.Message.Role, + "content": choice.Message.Content, + }, + "finish_reason": choice.FinishReason, + } + streamChunk["choices"] = append(streamChunk["choices"].([]map[string]interface{}), streamChoice) + } + + chunkJSON, err := json.Marshal(streamChunk) + if err != nil { + observability.Errorf("Error marshaling streaming cache response: %v", err) + responseBody = []byte("data: {\"error\": \"Failed to generate response\"}\n\ndata: [DONE]\n\n") + } else { + responseBody = []byte(fmt.Sprintf("data: %s\n\ndata: [DONE]\n\n", chunkJSON)) + } + } + } else { + // For non-streaming responses, use cached JSON directly + contentType = "application/json" + responseBody = cachedResponse + } + immediateResponse := &ext_proc.ImmediateResponse{ Status: &typev3.HttpStatus{ Code: typev3.StatusCode_OK, @@ -242,7 +291,7 @@ func CreateCacheHitResponse(cachedResponse []byte) *ext_proc.ProcessingResponse { Header: &core.HeaderValue{ Key: "content-type", - RawValue: []byte("application/json"), + RawValue: []byte(contentType), }, }, { @@ -253,7 +302,7 @@ func CreateCacheHitResponse(cachedResponse []byte) *ext_proc.ProcessingResponse }, }, }, - Body: cachedResponse, + Body: responseBody, } return &ext_proc.ProcessingResponse{ diff --git a/src/semantic-router/pkg/utils/http/response_test.go b/src/semantic-router/pkg/utils/http/response_test.go new file mode 100644 index 00000000..b53539fc --- /dev/null +++ b/src/semantic-router/pkg/utils/http/response_test.go @@ -0,0 +1,250 @@ +package http + +import ( + "encoding/json" + "strings" + "testing" + + "github.com/openai/openai-go" +) + +func TestCreateCacheHitResponse_NonStreaming(t *testing.T) { + // Create a sample cached response + cachedCompletion := openai.ChatCompletion{ + ID: "chatcmpl-test-123", + Object: "chat.completion", + Created: 1234567890, + Model: "test-model", + Choices: []openai.ChatCompletionChoice{ + { + Index: 0, + Message: openai.ChatCompletionMessage{ + Role: "assistant", + Content: "This is a cached response.", + }, + FinishReason: "stop", + }, + }, + Usage: openai.CompletionUsage{ + PromptTokens: 10, + CompletionTokens: 5, + TotalTokens: 15, + }, + } + + cachedResponse, err := json.Marshal(cachedCompletion) + if err != nil { + t.Fatalf("Failed to marshal cached response: %v", err) + } + + // Test non-streaming response + response := CreateCacheHitResponse(cachedResponse, false) + + // Verify response structure + if response == nil { + t.Fatal("Response is nil") + } + + immediateResp := response.GetImmediateResponse() + if immediateResp == nil { + t.Fatal("ImmediateResponse is nil") + } + + // Verify status code + if immediateResp.Status.Code.String() != "OK" { + t.Errorf("Expected status OK, got %s", immediateResp.Status.Code.String()) + } + + // Verify content-type header + var contentType string + var cacheHit string + for _, header := range immediateResp.Headers.SetHeaders { + if header.Header.Key == "content-type" { + contentType = string(header.Header.RawValue) + } + if header.Header.Key == "x-vsr-cache-hit" { + cacheHit = string(header.Header.RawValue) + } + } + + if contentType != "application/json" { + t.Errorf("Expected content-type application/json, got %s", contentType) + } + + if cacheHit != "true" { + t.Errorf("Expected x-vsr-cache-hit true, got %s", cacheHit) + } + + // Verify body is unchanged + if string(immediateResp.Body) != string(cachedResponse) { + t.Error("Body was modified for non-streaming response") + } + + // Verify body can be parsed as JSON + var parsedResponse openai.ChatCompletion + if err := json.Unmarshal(immediateResp.Body, &parsedResponse); err != nil { + t.Errorf("Failed to parse response body as JSON: %v", err) + } + + if parsedResponse.Object != "chat.completion" { + t.Errorf("Expected object chat.completion, got %s", parsedResponse.Object) + } +} + +func TestCreateCacheHitResponse_Streaming(t *testing.T) { + // Create a sample cached response + cachedCompletion := openai.ChatCompletion{ + ID: "chatcmpl-test-456", + Object: "chat.completion", + Created: 1234567890, + Model: "test-model", + Choices: []openai.ChatCompletionChoice{ + { + Index: 0, + Message: openai.ChatCompletionMessage{ + Role: "assistant", + Content: "This is a cached streaming response.", + }, + FinishReason: "stop", + }, + }, + Usage: openai.CompletionUsage{ + PromptTokens: 10, + CompletionTokens: 5, + TotalTokens: 15, + }, + } + + cachedResponse, err := json.Marshal(cachedCompletion) + if err != nil { + t.Fatalf("Failed to marshal cached response: %v", err) + } + + // Test streaming response + response := CreateCacheHitResponse(cachedResponse, true) + + // Verify response structure + if response == nil { + t.Fatal("Response is nil") + } + + immediateResp := response.GetImmediateResponse() + if immediateResp == nil { + t.Fatal("ImmediateResponse is nil") + } + + // Verify status code + if immediateResp.Status.Code.String() != "OK" { + t.Errorf("Expected status OK, got %s", immediateResp.Status.Code.String()) + } + + // Verify content-type header + var contentType string + var cacheHit string + for _, header := range immediateResp.Headers.SetHeaders { + if header.Header.Key == "content-type" { + contentType = string(header.Header.RawValue) + } + if header.Header.Key == "x-vsr-cache-hit" { + cacheHit = string(header.Header.RawValue) + } + } + + if contentType != "text/event-stream" { + t.Errorf("Expected content-type text/event-stream, got %s", contentType) + } + + if cacheHit != "true" { + t.Errorf("Expected x-vsr-cache-hit true, got %s", cacheHit) + } + + // Verify body is in SSE format + bodyStr := string(immediateResp.Body) + if !strings.HasPrefix(bodyStr, "data: ") { + t.Error("Body does not start with 'data: ' prefix") + } + + if !strings.Contains(bodyStr, "data: [DONE]") { + t.Error("Body does not contain 'data: [DONE]' terminator") + } + + // Parse the SSE data + lines := strings.Split(bodyStr, "\n") + var dataLine string + for _, line := range lines { + if strings.HasPrefix(line, "data: ") && !strings.Contains(line, "[DONE]") { + dataLine = strings.TrimPrefix(line, "data: ") + break + } + } + + if dataLine == "" { + t.Fatal("No data line found in SSE response") + } + + // Parse the JSON chunk + var chunk map[string]interface{} + if err := json.Unmarshal([]byte(dataLine), &chunk); err != nil { + t.Fatalf("Failed to parse SSE data as JSON: %v", err) + } + + // Verify chunk structure + if chunk["object"] != "chat.completion.chunk" { + t.Errorf("Expected object chat.completion.chunk, got %v", chunk["object"]) + } + + if chunk["id"] != "chatcmpl-test-456" { + t.Errorf("Expected id chatcmpl-test-456, got %v", chunk["id"]) + } + + // Verify choices structure + choices, ok := chunk["choices"].([]interface{}) + if !ok || len(choices) == 0 { + t.Fatal("Choices not found or empty") + } + + choice := choices[0].(map[string]interface{}) + delta, ok := choice["delta"].(map[string]interface{}) + if !ok { + t.Fatal("Delta not found in choice") + } + + if delta["role"] != "assistant" { + t.Errorf("Expected role assistant, got %v", delta["role"]) + } + + if delta["content"] != "This is a cached streaming response." { + t.Errorf("Expected content 'This is a cached streaming response.', got %v", delta["content"]) + } + + if choice["finish_reason"] != "stop" { + t.Errorf("Expected finish_reason stop, got %v", choice["finish_reason"]) + } +} + +func TestCreateCacheHitResponse_StreamingWithInvalidJSON(t *testing.T) { + // Test with invalid JSON + invalidJSON := []byte("invalid json") + + response := CreateCacheHitResponse(invalidJSON, true) + + // Verify response structure + if response == nil { + t.Fatal("Response is nil") + } + + immediateResp := response.GetImmediateResponse() + if immediateResp == nil { + t.Fatal("ImmediateResponse is nil") + } + + // Verify error response + bodyStr := string(immediateResp.Body) + if !strings.Contains(bodyStr, "error") { + t.Error("Expected error message in response body") + } + + if !strings.Contains(bodyStr, "data: [DONE]") { + t.Error("Expected SSE terminator even in error case") + } +}