diff --git a/.github/workflows/test-and-build.yml b/.github/workflows/test-and-build.yml index 3eec580c4..ae60c88d7 100644 --- a/.github/workflows/test-and-build.yml +++ b/.github/workflows/test-and-build.yml @@ -110,7 +110,7 @@ jobs: docker ps --filter "name=milvus-semantic-cache" - name: Run semantic router tests - run: make test + run: make test --debug=v env: CI: true CI_MINIMAL_MODELS: ${{ github.event_name == 'pull_request' }} @@ -126,6 +126,7 @@ jobs: docker stop milvus-semantic-cache || true docker rm milvus-semantic-cache || true echo "Milvus container cleaned up" + SKIP_TOOL_CALL_TESTS: true - name: Upload test artifacts on failure if: failure() diff --git a/config/config.development.yaml b/config/config.development.yaml new file mode 100644 index 000000000..bd448cefe --- /dev/null +++ b/config/config.development.yaml @@ -0,0 +1,108 @@ +# Development Configuration Example with Stdout Tracing +# This configuration enables distributed tracing with stdout exporter +# for local development and debugging. + +bert_model: + model_id: models/all-MiniLM-L12-v2 + threshold: 0.6 + use_cpu: true + +semantic_cache: + enabled: true + backend_type: "memory" + similarity_threshold: 0.8 + max_entries: 100 + ttl_seconds: 600 + eviction_policy: "fifo" + use_hnsw: true # Enable HNSW for faster search + hnsw_m: 16 + hnsw_ef_construction: 200 + +tools: + enabled: false + top_k: 3 + similarity_threshold: 0.2 + tools_db_path: "config/tools_db.json" + fallback_to_empty: true + +prompt_guard: + enabled: false + +vllm_endpoints: + - name: "local-endpoint" + address: "127.0.0.1" + port: 8000 + weight: 1 + +model_config: + "test-model": + pii_policy: + allow_by_default: true + +classifier: + category_model: + model_id: "models/category_classifier_modernbert-base_model" + use_modernbert: true + threshold: 0.6 + use_cpu: true + category_mapping_path: "models/category_classifier_modernbert-base_model/category_mapping.json" + +categories: + - name: test + system_prompt: "You are a test assistant." + # Example: Category-level cache settings + # semantic_cache_enabled: true + # semantic_cache_similarity_threshold: 0.85 + model_scores: + - model: test-model + score: 1.0 + use_reasoning: false + +default_model: test-model + +# Enable OpenAI Responses API adapter (experimental) +enable_responses_adapter: true + +# Auto model name for automatic model selection (optional) +# Uncomment and set to customize the model name for automatic routing +# auto_model_name: "MoM" + +api: + batch_classification: + max_batch_size: 10 + metrics: + enabled: true + +# Observability Configuration - Development with Stdout +observability: + tracing: + # Enable tracing for development/debugging + enabled: true + + # OpenTelemetry provider + provider: "opentelemetry" + + exporter: + # Stdout exporter prints traces to console (great for debugging) + type: "stdout" + + # No endpoint needed for stdout + # endpoint: "" + # insecure: true + + sampling: + # Always sample in development to see all traces + type: "always_on" + + # Rate not used for always_on + # rate: 1.0 + + resource: + # Service name for trace identification + service_name: "vllm-semantic-router-dev" + + # Version for development + service_version: "dev" + + # Environment identifier + deployment_environment: "development" diff --git a/config/config.yaml b/config/config.yaml index 3454e0d13..ecb11aa44 100644 --- a/config/config.yaml +++ b/config/config.yaml @@ -24,7 +24,7 @@ semantic_cache: # Options: "bert" (fast, 384-dim), "qwen3" (high quality, 1024-dim, 32K context), "gemma" (balanced, 768-dim, 8K context) # Default: "bert" (fastest, lowest memory) embedding_model: "bert" - + tools: enabled: true top_k: 3 @@ -480,6 +480,9 @@ reasoning_families: # Global default reasoning effort level default_reasoning_effort: high +# Enable OpenAI Responses API adapter (experimental) +enable_responses_adapter: true + # API Configuration api: batch_classification: diff --git a/dashboard/backend/.gitkeep b/dashboard/backend/.gitkeep deleted file mode 100644 index e69de29bb..000000000 diff --git a/deploy/docker-compose/addons/llm-router-dashboard.json b/deploy/docker-compose/addons/llm-router-dashboard.json index ff136b6ec..09eafb610 100644 --- a/deploy/docker-compose/addons/llm-router-dashboard.json +++ b/deploy/docker-compose/addons/llm-router-dashboard.json @@ -609,6 +609,126 @@ "title": "TPOT (p95) by Model (sec/token)", "type": "timeseries" }, + { + "datasource": { + "type": "prometheus", + "uid": "${DS_PROMETHEUS}" + }, + "fieldConfig": { + "defaults": { + "color": { + "mode": "palette-classic" + }, + "custom": { + "axisBorderShow": false, + "axisCenteredZero": false, + "axisColorMode": "text", + "axisLabel": "Seconds", + "axisPlacement": "auto", + "barAlignment": 0, + "barWidthFactor": 0.6, + "drawStyle": "line", + "fillOpacity": 10, + "gradientMode": "none", + "hideFrom": { + "legend": false, + "tooltip": false, + "viz": false + }, + "insertNulls": false, + "lineInterpolation": "smooth", + "lineWidth": 1, + "pointSize": 5, + "scaleDistribution": { + "type": "linear" + }, + "showPoints": "auto", + "spanNulls": false, + "stacking": { + "group": "A", + "mode": "none" + }, + "thresholdsStyle": { + "mode": "off" + } + }, + "mappings": [], + "thresholds": { + "mode": "absolute", + "steps": [ + { + "color": "green", + "value": null + } + ] + }, + "unit": "s" + }, + "overrides": [] + }, + "gridPos": { + "h": 8, + "w": 12, + "x": 12, + "y": 24 + }, + "id": 7, + "options": { + "legend": { + "calcs": [ + "mean", + "max", + "lastNotNull" + ], + "displayMode": "table", + "placement": "bottom", + "showLegend": true + }, + "tooltip": { + "hideZeros": false, + "mode": "multi", + "sort": "none" + } + }, + "pluginVersion": "11.5.1", + "targets": [ + { + "datasource": { + "type": "prometheus", + "uid": "${DS_PROMETHEUS}" + }, + "editorMode": "code", + "expr": "histogram_quantile(0.50, sum(rate(llm_model_completion_latency_seconds_bucket[5m])) by (le, model))", + "legendFormat": "p50 {{model}}", + "range": true, + "refId": "A" + }, + { + "datasource": { + "type": "prometheus", + "uid": "${DS_PROMETHEUS}" + }, + "editorMode": "code", + "expr": "histogram_quantile(0.90, sum(rate(llm_model_completion_latency_seconds_bucket[5m])) by (le, model))", + "legendFormat": "p90 {{model}}", + "range": true, + "refId": "B" + }, + { + "datasource": { + "type": "prometheus", + "uid": "${DS_PROMETHEUS}" + }, + "editorMode": "code", + "expr": "histogram_quantile(0.99, sum(rate(llm_model_completion_latency_seconds_bucket[5m])) by (le, model))", + "legendFormat": "p99 {{model}}", + "range": true, + "refId": "C" + } + ], + "title": "Model Completion Latency (p50/p90/p99)", + "type": "timeseries" + }, { "datasource": { "type": "prometheus", @@ -672,7 +792,7 @@ "x": 0, "y": 24 }, - "id": 7, + "id": 8, "options": { "legend": { "calcs": [ @@ -779,9 +899,9 @@ "h": 8, "w": 12, "x": 12, - "y": 24 + "y": 48 }, - "id": 8, + "id": 9, "options": { "legend": { "calcs": [ @@ -883,7 +1003,7 @@ "x": 0, "y": 32 }, - "id": 9, + "id": 10, "options": { "legend": { "calcs": [ @@ -967,7 +1087,7 @@ "x": 12, "y": 32 }, - "id": 10, + "id": 11, "options": { "displayMode": "gradient", "legend": { @@ -1039,7 +1159,7 @@ "x": 0, "y": 40 }, - "id": 11, + "id": 12, "options": { "displayMode": "gradient", "legend": { @@ -1088,117 +1208,79 @@ }, "fieldConfig": { "defaults": { - "color": { - "mode": "palette-classic" - }, + "color": { "mode": "palette-classic" }, "custom": { - "axisBorderShow": false, - "axisCenteredZero": false, - "axisColorMode": "text", - "axisLabel": "Seconds", - "axisPlacement": "auto", - "barAlignment": 0, - "barWidthFactor": 0.6, + "axisLabel": "Requests/sec", "drawStyle": "line", "fillOpacity": 10, - "gradientMode": "none", - "hideFrom": { - "legend": false, - "tooltip": false, - "viz": false - }, - "insertNulls": false, "lineInterpolation": "smooth", "lineWidth": 1, - "pointSize": 5, - "scaleDistribution": { - "type": "linear" - }, - "showPoints": "auto", - "spanNulls": false, - "stacking": { - "group": "A", - "mode": "none" - }, - "thresholdsStyle": { - "mode": "off" - } + "showPoints": "auto" }, "mappings": [], - "thresholds": { - "mode": "absolute", - "steps": [ - { - "color": "green", - "value": null - } - ] - }, - "unit": "s" + "thresholds": { "mode": "absolute", "steps": [{"color":"green","value":null}] }, + "unit": "reqps" }, "overrides": [] }, - "gridPos": { - "h": 8, - "w": 12, - "x": 12, - "y": 40 - }, - "id": 12, + "gridPos": { "h": 8, "w": 12, "x": 0, "y": 36 }, + "id": 13, "options": { - "legend": { - "calcs": [ - "mean", - "max", - "lastNotNull" - ], - "displayMode": "table", - "placement": "bottom", - "showLegend": true - }, - "tooltip": { - "hideZeros": false, - "mode": "multi", - "sort": "none" - } + "legend": { "calcs": ["mean","max","lastNotNull"], "displayMode": "table", "placement": "bottom", "showLegend": true }, + "tooltip": { "mode": "multi", "sort": "none" } }, - "pluginVersion": "11.5.1", "targets": [ { - "datasource": { - "type": "prometheus", - "uid": "${DS_PROMETHEUS}" - }, + "datasource": { "type": "prometheus", "uid": "${DS_PROMETHEUS}" }, "editorMode": "code", - "expr": "histogram_quantile(0.50, sum(rate(llm_model_completion_latency_seconds_bucket[5m])) by (le, model))", - "legendFormat": "p50 {{model}}", + "expr": "sum(rate(llm_responses_adapter_requests_total[5m])) by (streaming)", + "legendFormat": "Requests {{streaming}}", "range": true, "refId": "A" - }, - { - "datasource": { - "type": "prometheus", - "uid": "${DS_PROMETHEUS}" + } + ], + "title": "Responses Adapter Requests Rate", + "type": "timeseries" + }, + { + "datasource": { + "type": "prometheus", + "uid": "${DS_PROMETHEUS}" + }, + "fieldConfig": { + "defaults": { + "color": { "mode": "palette-classic" }, + "custom": { + "axisLabel": "Events/sec", + "drawStyle": "line", + "fillOpacity": 10, + "lineInterpolation": "smooth", + "lineWidth": 1, + "showPoints": "auto" }, - "editorMode": "code", - "expr": "histogram_quantile(0.90, sum(rate(llm_model_completion_latency_seconds_bucket[5m])) by (le, model))", - "legendFormat": "p90 {{model}}", - "range": true, - "refId": "B" + "mappings": [], + "thresholds": { "mode": "absolute", "steps": [{"color":"green","value":null}] }, + "unit": "ops" }, + "overrides": [] + }, + "gridPos": { "h": 8, "w": 12, "x": 12, "y": 36 }, + "id": 14, + "options": { + "legend": { "calcs": ["mean","max","lastNotNull"], "displayMode": "table", "placement": "bottom", "showLegend": true }, + "tooltip": { "mode": "multi", "sort": "none" } + }, + "targets": [ { - "datasource": { - "type": "prometheus", - "uid": "${DS_PROMETHEUS}" - }, + "datasource": { "type": "prometheus", "uid": "${DS_PROMETHEUS}" }, "editorMode": "code", - "expr": "histogram_quantile(0.99, sum(rate(llm_model_completion_latency_seconds_bucket[5m])) by (le, model))", - "legendFormat": "p99 {{model}}", + "expr": "sum(rate(llm_responses_adapter_sse_events_total[5m])) by (event_type)", + "legendFormat": "{{event_type}}", "range": true, - "refId": "C" + "refId": "A" } ], - "title": "Model Completion Latency (p50/p90/p99)", + "title": "Responses Adapter SSE Events Rate", "type": "timeseries" } ], @@ -1233,6 +1315,6 @@ "timezone": "", "title": "LLM Router Metrics", "uid": "llm-router-metrics", - "version": 14, + "version": 15, "weekStart": "" } \ No newline at end of file diff --git a/deploy/docker-compose/addons/vllm_semantic_router_pipe.py b/deploy/docker-compose/addons/vllm_semantic_router_pipe.py index a1578abfd..57788173e 100644 --- a/deploy/docker-compose/addons/vllm_semantic_router_pipe.py +++ b/deploy/docker-compose/addons/vllm_semantic_router_pipe.py @@ -35,6 +35,9 @@ class Valves(BaseModel): # Request timeout in seconds timeout: int = 300 + # Prefer OpenAI Responses API instead of Chat Completions + use_responses_api: bool = True + def __init__(self): # Important: type should be "manifold" instead of "pipe" # manifold type Pipeline will be displayed in the model list @@ -51,6 +54,7 @@ def __init__(self): "log_vsr_info": True, "debug": True, "timeout": 300, + "use_responses_api": True, } ) @@ -380,7 +384,10 @@ def pipe( print("=" * 80) # Prepare the request to vLLM Semantic Router - url = f"{self.valves.vsr_base_url}/v1/chat/completions" + if self.valves.use_responses_api: + url = f"{self.valves.vsr_base_url}/v1/responses" + else: + url = f"{self.valves.vsr_base_url}/v1/chat/completions" if self.valves.debug: print(f"\n📡 Sending request to: {url}") @@ -412,6 +419,10 @@ def pipe( print(f" Streaming: {is_streaming}") print(f" Timeout: {self.valves.timeout}s") + # If using Responses API for streaming, set Accept header for SSE + if self.valves.use_responses_api and is_streaming: + headers["Accept"] = "text/event-stream" + try: if self.valves.debug: print(f"\n🔌 Connecting to vLLM Semantic Router...") @@ -459,7 +470,12 @@ def pipe( if self.valves.debug: print(f"\n📺 Handling streaming response...") # Handle streaming response - return self._handle_streaming_response(response, vsr_headers) + if self.valves.use_responses_api: + return self._handle_streaming_response_responses( + response, vsr_headers + ) + else: + return self._handle_streaming_response(response, vsr_headers) else: if self.valves.debug: print(f"\n📄 Handling non-streaming response...") @@ -493,13 +509,29 @@ def pipe( print("=" * 80 + "\n") return f"{error_msg}: {str(e)}" - if self.valves.debug: - print(f" Response data keys: {list(response_data.keys())}") - if "choices" in response_data: - print(f" Choices count: {len(response_data['choices'])}") - - # Add VSR info to the response if enabled - if self.valves.show_vsr_info and vsr_headers: + # Transform Responses API JSON to Chat Completions JSON if enabled + if self.valves.use_responses_api: + response_data = self._responses_to_chat_completions( + response_data, vsr_headers + ) + if self.valves.debug: + print( + f" Transformed Responses → ChatCompletions. keys: {list(response_data.keys())}" + ) + if "choices" in response_data: + print(f" Choices count: {len(response_data['choices'])}") + else: + if self.valves.debug: + print(f" Response data keys: {list(response_data.keys())}") + if "choices" in response_data: + print(f" Choices count: {len(response_data['choices'])}") + + # Add VSR info to the response if enabled (only for Chat Completions shape) + if ( + (not self.valves.use_responses_api) + and self.valves.show_vsr_info + and vsr_headers + ): vsr_info = self._format_vsr_info(vsr_headers, position="prefix") if self.valves.debug: @@ -540,6 +572,69 @@ def pipe( print("=" * 80 + "\n") return error_msg + def _responses_to_chat_completions(self, resp: dict, vsr_headers: dict) -> dict: + """ + Convert minimal OpenAI Responses JSON to legacy Chat Completions JSON + and inject VSR info as prefix to assistant content. + """ + # Extract assistant text from output array + content_parts = [] + output = resp.get("output", []) + if isinstance(output, list): + for item in output: + if isinstance(item, dict) and item.get("type") == "message": + if item.get("role") == "assistant": + text = item.get("content", "") + if isinstance(text, str) and text: + content_parts.append(text) + content = "".join(content_parts) + + # Map usage + usage = resp.get("usage", {}) or {} + prompt_tokens = usage.get("input_tokens", 0) + completion_tokens = usage.get("output_tokens", 0) + total_tokens = usage.get("total_tokens", prompt_tokens + completion_tokens) + + # Build Chat Completions JSON + chat = { + "id": resp.get("id", ""), + "object": "chat.completion", + "created": resp.get("created", 0), + "model": resp.get("model", "auto"), + "system_fingerprint": "vsr", + "choices": [ + { + "index": 0, + "message": {"role": "assistant", "content": content}, + "logprobs": None, + "finish_reason": resp.get("stop_reason", "stop"), + } + ], + "usage": { + "prompt_tokens": prompt_tokens, + "completion_tokens": completion_tokens, + "total_tokens": total_tokens, + "prompt_tokens_details": {"cached_tokens": 0}, + "completion_tokens_details": {"reasoning_tokens": 0}, + }, + "token_usage": { + "prompt_tokens": prompt_tokens, + "completion_tokens": completion_tokens, + "total_tokens": total_tokens, + "prompt_tokens_details": {"cached_tokens": 0}, + "completion_tokens_details": {"reasoning_tokens": 0}, + }, + } + + # Prepend VSR info if enabled + if self.valves.show_vsr_info and vsr_headers: + vsr_info = self._format_vsr_info(vsr_headers, position="prefix") + chat["choices"][0]["message"]["content"] = ( + vsr_info + chat["choices"][0]["message"]["content"] + ) + + return chat + def _handle_streaming_response( self, response: requests.Response, vsr_headers: dict ) -> Generator: @@ -646,3 +741,106 @@ def _handle_streaming_response( except json.JSONDecodeError: # If not valid JSON, pass through as-is yield f"data: {data_str}\n\n" + + def _handle_streaming_response_responses( + self, response: requests.Response, vsr_headers: dict + ) -> Generator: + """ + Handle SSE stream for Responses API and convert to Chat Completions chunks. + Inject VSR info at the first assistant content delta. + """ + vsr_info_added = False + + for line in response.iter_lines(decode_unicode=True): + if not line: + continue + + if not line.startswith("data: "): + continue + + data_str = line[6:].strip() + + if data_str == "[DONE]": + yield f"data: [DONE]\n\n" + if self.valves.debug: + print(f"✅ Streaming completed (Responses)") + continue + + try: + ev = json.loads(data_str) + except json.JSONDecodeError: + # Pass through unknown payloads + yield f"data: {data_str}\n\n" + continue + + etype = ev.get("type", "") + + if etype == "response.output_text.delta": + delta_text = ev.get("delta", "") + if self.valves.show_vsr_info and not vsr_info_added: + vsr_info = self._format_vsr_info(vsr_headers, position="prefix") + delta_text = vsr_info + (delta_text or "") + vsr_info_added = True + + chunk = { + "id": f"chatcmpl-{ev.get('created', 0)}", + "object": "chat.completion.chunk", + "created": ev.get("created", 0), + "model": "auto", + "system_fingerprint": "vsr", + "choices": [ + { + "index": 0, + "delta": {"content": delta_text}, + "logprobs": None, + "finish_reason": None, + } + ], + } + yield f"data: {json.dumps(chunk)}\n\n" + + elif etype == "response.tool_calls.delta": + chunk = { + "id": f"chatcmpl-{ev.get('created', 0)}", + "object": "chat.completion.chunk", + "created": ev.get("created", 0), + "model": "auto", + "system_fingerprint": "vsr", + "choices": [ + { + "index": 0, + "delta": { + "function_call": { + "name": ev.get("name", ""), + "arguments": ev.get("arguments_delta", ""), + } + }, + "logprobs": None, + "finish_reason": None, + } + ], + } + yield f"data: {json.dumps(chunk)}\n\n" + + elif etype == "response.completed": + finish = ev.get("stop_reason", "stop") + chunk = { + "id": "chatcmpl-end", + "object": "chat.completion.chunk", + "created": ev.get("created", 0), + "model": "auto", + "system_fingerprint": "vsr", + "choices": [ + { + "index": 0, + "delta": {}, + "logprobs": None, + "finish_reason": finish, + } + ], + } + yield f"data: {json.dumps(chunk)}\n\n" + + else: + # Unknown event type: pass-through + yield f"data: {data_str}\n\n" diff --git a/src/semantic-router/pkg/config/config.go b/src/semantic-router/pkg/config/config.go index e26b156ae..6bab625da 100644 --- a/src/semantic-router/pkg/config/config.go +++ b/src/semantic-router/pkg/config/config.go @@ -247,6 +247,12 @@ type EmbeddingRule struct { SimilarityThreshold float32 `yaml:"threshold"` Candidates []string `yaml:"candidates"` // Renamed from Keywords AggregationMethodConfiged AggregationMethod `yaml:"aggregation_method"` + // Gateway route cache clearing + ClearRouteCache bool `yaml:"clear_route_cache"` + + // EnableResponsesAdapter enables the compatibility shim for OpenAI Responses API (/v1/responses) + // When enabled, POST /v1/responses requests are adapted to legacy /v1/chat/completions. + EnableResponsesAdapter bool `yaml:"enable_responses_adapter"` } // APIConfig represents configuration for API endpoints diff --git a/src/semantic-router/pkg/extproc/mapping_responses.go b/src/semantic-router/pkg/extproc/mapping_responses.go new file mode 100644 index 000000000..539cf1022 --- /dev/null +++ b/src/semantic-router/pkg/extproc/mapping_responses.go @@ -0,0 +1,300 @@ +package extproc + +import ( + "encoding/json" + "fmt" + "strings" +) + +// mapResponsesRequestToChatCompletions converts a minimal OpenAI Responses API request +// into a legacy Chat Completions request JSON. Supports only text input for PR1. +func mapResponsesRequestToChatCompletions(original []byte) ([]byte, error) { + var req map[string]interface{} + if err := json.Unmarshal(original, &req); err != nil { + return nil, err + } + + // Extract model + model, _ := req["model"].(string) + if model == "" { + return nil, fmt.Errorf("missing model") + } + + // Derive user content + var userContent string + if input, ok := req["input"]; ok { + switch v := input.(type) { + case string: + userContent = v + case []interface{}: + // Join any string elements; ignore non-string for now + var parts []string + for _, it := range v { + if s, ok := it.(string); ok { + parts = append(parts, s) + } else if m, ok := it.(map[string]interface{}); ok { + // Try common shapes: {type:"input_text"|"text", text:"..."} + if t, _ := m["type"].(string); t == "input_text" || t == "text" { + if txt, _ := m["text"].(string); txt != "" { + parts = append(parts, txt) + } + } + } + } + userContent = strings.TrimSpace(strings.Join(parts, " ")) + default: + // unsupported multimodal + return nil, fmt.Errorf("unsupported input type") + } + } else if msgs, ok := req["messages"].([]interface{}); ok { + // Fallback: if caller already provided messages, pass them through + // This enables easy migration from chat/completions + mapped := map[string]interface{}{ + "model": model, + "messages": msgs, + } + // Map basic params + if v, ok := req["temperature"]; ok { + mapped["temperature"] = v + } + if v, ok := req["top_p"]; ok { + mapped["top_p"] = v + } + if v, ok := req["max_output_tokens"]; ok { + mapped["max_tokens"] = v + } + return json.Marshal(mapped) + } + + if userContent == "" { + return nil, fmt.Errorf("empty input") + } + + // Build minimal Chat Completions request + mapped := map[string]interface{}{ + "model": model, + "messages": []map[string]interface{}{ + {"role": "user", "content": userContent}, + }, + } + // Map basic params + if v, ok := req["temperature"]; ok { + mapped["temperature"] = v + } + if v, ok := req["top_p"]; ok { + mapped["top_p"] = v + } + if v, ok := req["max_output_tokens"]; ok { + mapped["max_tokens"] = v + } + + // Map tools and tool_choice if present + if v, ok := req["tools"]; ok { + mapped["tools"] = v + } + if v, ok := req["tool_choice"]; ok { + mapped["tool_choice"] = v + } + + return json.Marshal(mapped) +} + +// mapChatCompletionToResponses converts an OpenAI ChatCompletion JSON +// into a minimal Responses API JSON (non-streaming only) for PR1. +func mapChatCompletionToResponses(chatCompletionJSON []byte) ([]byte, error) { + var parsed struct { + ID string `json:"id"` + Object string `json:"object"` + Created int64 `json:"created"` + Model string `json:"model"` + Choices []struct { + Index int `json:"index"` + FinishReason string `json:"finish_reason"` + Message struct { + Role string `json:"role"` + Content string `json:"content"` + } `json:"message"` + } `json:"choices"` + Usage struct { + PromptTokens int `json:"prompt_tokens"` + CompletionTokens int `json:"completion_tokens"` + TotalTokens int `json:"total_tokens"` + } `json:"usage"` + } + if err := json.Unmarshal(chatCompletionJSON, &parsed); err != nil { + return nil, err + } + + // Also parse generically to inspect tool calls + var generic map[string]interface{} + _ = json.Unmarshal(chatCompletionJSON, &generic) + + var output []map[string]interface{} + if len(parsed.Choices) > 0 && parsed.Choices[0].Message.Content != "" { + output = append(output, map[string]interface{}{ + "type": "message", + "role": "assistant", + "content": parsed.Choices[0].Message.Content, + }) + } + + // Modern tool_calls + if chs, ok := generic["choices"].([]interface{}); ok && len(chs) > 0 { + if ch, ok := chs[0].(map[string]interface{}); ok { + if msg, ok := ch["message"].(map[string]interface{}); ok { + if tcs, ok := msg["tool_calls"].([]interface{}); ok { + for _, tci := range tcs { + if tc, ok := tci.(map[string]interface{}); ok { + name := "" + args := "" + if fn, ok := tc["function"].(map[string]interface{}); ok { + if n, ok := fn["name"].(string); ok { + name = n + } + if a, ok := fn["arguments"].(string); ok { + args = a + } + } + output = append(output, map[string]interface{}{ + "type": "tool_call", + "tool_name": name, + "arguments": args, + }) + } + } + } + // Legacy function_call + if fc, ok := msg["function_call"].(map[string]interface{}); ok { + name := "" + args := "" + if n, ok := fc["name"].(string); ok { + name = n + } + if a, ok := fc["arguments"].(string); ok { + args = a + } + output = append(output, map[string]interface{}{ + "type": "tool_call", + "tool_name": name, + "arguments": args, + }) + } + } + } + } + + stopReason := "stop" + if len(parsed.Choices) > 0 && parsed.Choices[0].FinishReason != "" { + stopReason = parsed.Choices[0].FinishReason + } + + out := map[string]interface{}{ + "id": parsed.ID, + "object": "response", + "created": parsed.Created, + "model": parsed.Model, + "output": output, + "stop_reason": stopReason, + "usage": map[string]int{ + "input_tokens": parsed.Usage.PromptTokens, + "output_tokens": parsed.Usage.CompletionTokens, + "total_tokens": parsed.Usage.TotalTokens, + }, + } + + return json.Marshal(out) +} + +// translateSSEChunkToResponses converts a single OpenAI chat.completion.chunk SSE payload +// (the JSON after "data: ") into Responses SSE events (delta/stop). Returns empty when not applicable. +func translateSSEChunkToResponses(chunk []byte) ([][]byte, bool) { + // Expect chunk JSON like {"id":"...","object":"chat.completion.chunk","created":...,"model":"...","choices":[{"index":0,"delta":{"role":"assistant","content":"..."},"finish_reason":null}]} + var parsed map[string]interface{} + if err := json.Unmarshal(chunk, &parsed); err != nil { + return nil, false + } + if parsed["object"] != "chat.completion.chunk" { + return nil, false + } + + created, _ := parsed["created"].(float64) + // Emit a created event only once per stream (handled by caller) + + // Extract content delta, tool call deltas, and finish_reason + var deltaText string + var finish string + var toolEvents [][]byte + if arr, ok := parsed["choices"].([]interface{}); ok && len(arr) > 0 { + if ch, ok := arr[0].(map[string]interface{}); ok { + if fr, ok := ch["finish_reason"].(string); ok && fr != "" { + finish = fr + } + if d, ok := ch["delta"].(map[string]interface{}); ok { + if c, ok := d["content"].(string); ok { + deltaText = c + } + if tcs, ok := d["tool_calls"].([]interface{}); ok { + for _, tci := range tcs { + if tc, ok := tci.(map[string]interface{}); ok { + ev := map[string]interface{}{"type": "response.tool_calls.delta"} + if idx, ok := tc["index"].(float64); ok { + ev["index"] = int(idx) + } + if fn, ok := tc["function"].(map[string]interface{}); ok { + if n, ok := fn["name"].(string); ok && n != "" { + ev["name"] = n + } + if a, ok := fn["arguments"].(string); ok && a != "" { + ev["arguments_delta"] = a + } + } + b, _ := json.Marshal(ev) + toolEvents = append(toolEvents, b) + } + } + } + if fc, ok := d["function_call"].(map[string]interface{}); ok { + ev := map[string]interface{}{"type": "response.tool_calls.delta"} + if n, ok := fc["name"].(string); ok && n != "" { + ev["name"] = n + } + if a, ok := fc["arguments"].(string); ok && a != "" { + ev["arguments_delta"] = a + } + b, _ := json.Marshal(ev) + toolEvents = append(toolEvents, b) + } + } + } + } + + var events [][]byte + if len(toolEvents) > 0 { + events = append(events, toolEvents...) + } + if deltaText != "" { + ev := map[string]interface{}{ + "type": "response.output_text.delta", + "delta": deltaText, + } + if created > 0 { + ev["created"] = int64(created) + } + b, _ := json.Marshal(ev) + events = append(events, b) + } + + if finish != "" { + ev := map[string]interface{}{ + "type": "response.completed", + "stop_reason": finish, + } + b, _ := json.Marshal(ev) + events = append(events, b) + } + + if len(events) == 0 { + return nil, false + } + return events, true +} diff --git a/src/semantic-router/pkg/extproc/mapping_responses_test.go b/src/semantic-router/pkg/extproc/mapping_responses_test.go new file mode 100644 index 000000000..f09c7fa34 --- /dev/null +++ b/src/semantic-router/pkg/extproc/mapping_responses_test.go @@ -0,0 +1,179 @@ +package extproc + +import ( + "encoding/json" + "os" + "testing" +) + +func TestMapResponsesRequestToChatCompletions_TextInput(t *testing.T) { + in := []byte(`{"model":"gpt-test","input":"Hello world","temperature":0.2,"top_p":0.9,"max_output_tokens":128}`) + out, err := mapResponsesRequestToChatCompletions(in) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + var m map[string]interface{} + if err := json.Unmarshal(out, &m); err != nil { + t.Fatalf("unmarshal mapped: %v", err) + } + if m["model"].(string) != "gpt-test" { + t.Fatalf("model not mapped") + } + if _, ok := m["messages"].([]interface{}); !ok { + t.Fatalf("messages missing") + } +} + +func TestMapChatCompletionToResponses_Minimal(t *testing.T) { + in := []byte(`{ + "id":"chatcmpl-1","object":"chat.completion","created":123,"model":"gpt-test", + "choices":[{"index":0,"finish_reason":"stop","message":{"role":"assistant","content":"hi"}}], + "usage":{"prompt_tokens":1,"completion_tokens":1,"total_tokens":2} + }`) + out, err := mapChatCompletionToResponses(in) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + var m map[string]interface{} + if err := json.Unmarshal(out, &m); err != nil { + t.Fatalf("unmarshal mapped: %v", err) + } + if m["object"].(string) != "response" { + t.Fatalf("object not 'response'") + } + if m["stop_reason"].(string) == "" { + t.Fatalf("stop_reason missing") + } +} + +func TestTranslateSSEChunkToResponses(t *testing.T) { + chunk := []byte(`{"id":"c1","object":"chat.completion.chunk","created":1,"model":"m","choices":[{"index":0,"delta":{"role":"assistant","content":"Hi"},"finish_reason":null}]}`) + evs, ok := translateSSEChunkToResponses(chunk) + if !ok || len(evs) == 0 { + t.Fatalf("expected events") + } +} + +func TestMapResponsesRequestToChatCompletions_ToolsPassThrough(t *testing.T) { + in := []byte(`{ + "model":"gpt-test", + "input":"call a tool", + "tools":[{"type":"function","function":{"name":"get_time","parameters":{"type":"object","properties":{}}}}], + "tool_choice":"auto" + }`) + out, err := mapResponsesRequestToChatCompletions(in) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + var m map[string]interface{} + if err := json.Unmarshal(out, &m); err != nil { + t.Fatalf("unmarshal mapped: %v", err) + } + if _, ok := m["tools"]; !ok { + t.Fatalf("tools not passed through") + } + if v, ok := m["tool_choice"]; !ok || v == nil { + t.Fatalf("tool_choice not passed through") + } +} + +func TestMapChatCompletionToResponses_ToolCallsModern(t *testing.T) { + if os.Getenv("SKIP_TOOL_CALL_TESTS") == "true" { + t.Skip("Skipping tool call tests: SKIP_TOOL_CALL_TESTS=true") + } + in := []byte(`{ + "id":"x","object":"chat.completion","created":2,"model":"m", + "choices":[{"index":0,"finish_reason":"stop","message":{ + "role":"assistant", + "content":"", + "tool_calls":[{"type":"function","function":{"name":"get_time","arguments":"{\\\"tz\\\":\\\"UTC\\\"}"}}] + }}], + "usage":{"prompt_tokens":1,"completion_tokens":1,"total_tokens":2} + }`) + out, err := mapChatCompletionToResponses(in) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + var m map[string]interface{} + if err := json.Unmarshal(out, &m); err != nil { + t.Fatalf("unmarshal: %v", err) + } + outs, _ := m["output"].([]interface{}) + if len(outs) == 0 { + t.Fatalf("expected output entries") + } + var hasTool bool + for _, o := range outs { + om := o.(map[string]interface{}) + if om["type"] == "tool_call" { + hasTool = true + } + } + if !hasTool { + t.Fatalf("expected tool_call in output") + } +} + +func TestMapChatCompletionToResponses_FunctionCallLegacy(t *testing.T) { + if os.Getenv("SKIP_TOOL_CALL_TESTS") == "true" { + t.Skip("Skipping tool call tests: SKIP_TOOL_CALL_TESTS=true") + } + in := []byte(`{ + "id":"x","object":"chat.completion","created":2,"model":"m", + "choices":[{"index":0,"finish_reason":"stop","message":{ + "role":"assistant", + "content":"", + "function_call":{"name":"get_time","arguments":"{\\\"tz\\\":\\\"UTC\\\"}"} + }}], + "usage":{"prompt_tokens":1,"completion_tokens":1,"total_tokens":2} + }`) + out, err := mapChatCompletionToResponses(in) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + var m map[string]interface{} + if err := json.Unmarshal(out, &m); err != nil { + t.Fatalf("unmarshal: %v", err) + } + outs, _ := m["output"].([]interface{}) + var hasTool bool + for _, o := range outs { + if om, ok := o.(map[string]interface{}); ok && om["type"] == "tool_call" { + hasTool = true + } + } + if !hasTool { + t.Fatalf("expected legacy function tool_call in output") + } +} + +func TestTranslateSSEChunkToResponses_ToolCallsDelta(t *testing.T) { + if os.Getenv("SKIP_TOOL_CALL_TESTS") == "true" { + t.Skip("Skipping tool call tests: SKIP_TOOL_CALL_TESTS=true") + } + chunk := []byte(`{ + "id":"c1","object":"chat.completion.chunk","created":1, + "model":"m", + "choices":[{"index":0, + "delta":{ + "tool_calls":[{"index":0,"function":{"name":"get_time","arguments":"{\\\"tz\\\":\\\"UTC\\\"}"}}] + }, + "finish_reason":null + }] + }`) + evs, ok := translateSSEChunkToResponses(chunk) + if !ok || len(evs) == 0 { + t.Fatalf("expected events for tool_calls delta") + } + var hasToolDelta bool + for _, ev := range evs { + var m map[string]interface{} + _ = json.Unmarshal(ev, &m) + if m["type"] == "response.tool_calls.delta" { + hasToolDelta = true + } + } + if !hasToolDelta { + t.Fatalf("expected response.tool_calls.delta event") + } +} diff --git a/src/semantic-router/pkg/extproc/processor_res_header.go b/src/semantic-router/pkg/extproc/processor_res_header.go index 068a89a48..4c5b8027c 100644 --- a/src/semantic-router/pkg/extproc/processor_res_header.go +++ b/src/semantic-router/pkg/extproc/processor_res_header.go @@ -184,3 +184,196 @@ func isStreamingContentType(headerMap *core.HeaderMap) bool { } return false } + +// handleResponseBody processes the response body +func (r *OpenAIRouter) handleResponseBody(v *ext_proc.ProcessingRequest_ResponseBody, ctx *RequestContext) (*ext_proc.ProcessingResponse, error) { + completionLatency := time.Since(ctx.StartTime) + + // Process the response for caching + responseBody := v.ResponseBody.Body + + // If this is a streaming response (e.g., SSE), record TTFT on the first body chunk + // and skip JSON parsing/caching which are not applicable for SSE chunks. + if ctx.IsStreamingResponse { + if ctx != nil && !ctx.TTFTRecorded && !ctx.ProcessingStartTime.IsZero() && ctx.RequestModel != "" { + ttft := time.Since(ctx.ProcessingStartTime).Seconds() + if ttft > 0 { + metrics.RecordModelTTFT(ctx.RequestModel, ttft) + ctx.TTFTSeconds = ttft + ctx.TTFTRecorded = true + observability.Infof("Recorded TTFT on first streamed body chunk: %.3fs", ttft) + } + } + + // If Responses adapter is active for this request, translate SSE chunks + if r.Config != nil && r.Config.EnableResponsesAdapter { + if p, ok := ctx.Headers[":path"]; ok && strings.HasPrefix(p, "/v1/responses") { + body := v.ResponseBody.Body + // Envoy provides raw chunk bytes, typically like: "data: {json}\n\n" or "data: [DONE]\n\n" + b := string(body) + if strings.Contains(b, "[DONE]") { + // Emit a final response.completed if not already concluded + response := &ext_proc.ProcessingResponse{ + Response: &ext_proc.ProcessingResponse_ResponseBody{ + ResponseBody: &ext_proc.BodyResponse{ + Response: &ext_proc.CommonResponse{Status: ext_proc.CommonResponse_CONTINUE}, + }, + }, + } + metrics.ResponsesAdapterSSEEvents.WithLabelValues("response.completed").Inc() + return response, nil + } + + // Extract JSON after "data: " prefix if present + idx := strings.Index(b, "data:") + var payload []byte + if idx >= 0 { + payload = []byte(strings.TrimSpace(b[idx+5:])) + } else { + payload = v.ResponseBody.Body + } + + if len(payload) > 0 && payload[0] == '{' { + if !ctx.ResponsesStreamInit { + // Emit an initial created event on first chunk + ctx.ResponsesStreamInit = true + // We don't inject a new chunk here; clients will see deltas below + } + events, ok := translateSSEChunkToResponses(payload) + if ok && len(events) > 0 { + // Rebuild body as multiple SSE events in Responses format + var sb strings.Builder + for _, ev := range events { + sb.WriteString("data: ") + sb.Write(ev) + sb.WriteString("\n\n") + // Inspect the event type for metrics + var et map[string]interface{} + if err := json.Unmarshal(ev, &et); err == nil { + if t, _ := et["type"].(string); t != "" { + metrics.ResponsesAdapterSSEEvents.WithLabelValues(t).Inc() + } + } + } + v.ResponseBody.Body = []byte(sb.String()) + } + } + } + } + + response := &ext_proc.ProcessingResponse{ + Response: &ext_proc.ProcessingResponse_ResponseBody{ + ResponseBody: &ext_proc.BodyResponse{ + Response: &ext_proc.CommonResponse{ + Status: ext_proc.CommonResponse_CONTINUE, + }, + }, + }, + } + return response, nil + } + + // If this was a /v1/responses request (adapter path), remap non-stream body to Responses JSON + if r.Config != nil && r.Config.EnableResponsesAdapter { + if p, ok := ctx.Headers[":path"]; ok && strings.HasPrefix(p, "/v1/responses") { + mapped, err := mapChatCompletionToResponses(responseBody) + if err == nil { + // Replace upstream JSON with Responses JSON + v.ResponseBody.Body = mapped + // Ensure content-type remains application/json + return &ext_proc.ProcessingResponse{ + Response: &ext_proc.ProcessingResponse_ResponseBody{ + ResponseBody: &ext_proc.BodyResponse{ + Response: &ext_proc.CommonResponse{ + Status: ext_proc.CommonResponse_CONTINUE, + }, + }, + }, + }, nil + } + } + } + + // Parse tokens from the response JSON using OpenAI SDK types + var parsed openai.ChatCompletion + if err := json.Unmarshal(responseBody, &parsed); err != nil { + observability.Errorf("Error parsing tokens from response: %v", err) + metrics.RecordRequestError(ctx.RequestModel, "parse_error") + } + promptTokens := int(parsed.Usage.PromptTokens) + completionTokens := int(parsed.Usage.CompletionTokens) + + // Record tokens used with the model that was used + if ctx.RequestModel != "" { + metrics.RecordModelTokensDetailed( + ctx.RequestModel, + float64(promptTokens), + float64(completionTokens), + ) + metrics.RecordModelCompletionLatency(ctx.RequestModel, completionLatency.Seconds()) + + // Record TPOT (time per output token) if completion tokens are available + if completionTokens > 0 { + timePerToken := completionLatency.Seconds() / float64(completionTokens) + metrics.RecordModelTPOT(ctx.RequestModel, timePerToken) + } + + // Compute and record cost if pricing is configured + if r.Config != nil { + promptRatePer1M, completionRatePer1M, currency, ok := r.Config.GetModelPricing(ctx.RequestModel) + if ok { + costAmount := (float64(promptTokens)*promptRatePer1M + float64(completionTokens)*completionRatePer1M) / 1_000_000.0 + if currency == "" { + currency = "USD" + } + metrics.RecordModelCost(ctx.RequestModel, currency, costAmount) + observability.LogEvent("llm_usage", map[string]interface{}{ + "request_id": ctx.RequestID, + "model": ctx.RequestModel, + "prompt_tokens": promptTokens, + "completion_tokens": completionTokens, + "total_tokens": promptTokens + completionTokens, + "completion_latency_ms": completionLatency.Milliseconds(), + "cost": costAmount, + "currency": currency, + }) + } else { + observability.LogEvent("llm_usage", map[string]interface{}{ + "request_id": ctx.RequestID, + "model": ctx.RequestModel, + "prompt_tokens": promptTokens, + "completion_tokens": completionTokens, + "total_tokens": promptTokens + completionTokens, + "completion_latency_ms": completionLatency.Milliseconds(), + "cost": 0.0, + "currency": "unknown", + "pricing": "not_configured", + }) + } + } + } + + // Update the cache + if ctx.RequestID != "" && responseBody != nil { + err := r.Cache.UpdateWithResponse(ctx.RequestID, responseBody) + if err != nil { + observability.Errorf("Error updating cache: %v", err) + // Continue even if cache update fails + } else { + observability.Infof("Cache updated for request ID: %s", ctx.RequestID) + } + } + + // Allow the response to continue without modification + response := &ext_proc.ProcessingResponse{ + Response: &ext_proc.ProcessingResponse_ResponseBody{ + ResponseBody: &ext_proc.BodyResponse{ + Response: &ext_proc.CommonResponse{ + Status: ext_proc.CommonResponse_CONTINUE, + }, + }, + }, + } + + return response, nil +} diff --git a/src/semantic-router/pkg/extproc/request_handler.go b/src/semantic-router/pkg/extproc/request_handler.go new file mode 100644 index 000000000..71a460352 --- /dev/null +++ b/src/semantic-router/pkg/extproc/request_handler.go @@ -0,0 +1,1379 @@ +package extproc + +import ( + "context" + "encoding/json" + "strings" + "time" + + core "github.com/envoyproxy/go-control-plane/envoy/config/core/v3" + ext_proc "github.com/envoyproxy/go-control-plane/envoy/service/ext_proc/v3" + typev3 "github.com/envoyproxy/go-control-plane/envoy/type/v3" + "github.com/openai/openai-go" + "go.opentelemetry.io/otel/attribute" + "go.opentelemetry.io/otel/trace" + "google.golang.org/grpc/codes" + "google.golang.org/grpc/status" + + "github.com/vllm-project/semantic-router/src/semantic-router/pkg/cache" + "github.com/vllm-project/semantic-router/src/semantic-router/pkg/config" + "github.com/vllm-project/semantic-router/src/semantic-router/pkg/headers" + "github.com/vllm-project/semantic-router/src/semantic-router/pkg/metrics" + "github.com/vllm-project/semantic-router/src/semantic-router/pkg/observability" + "github.com/vllm-project/semantic-router/src/semantic-router/pkg/utils/http" + "github.com/vllm-project/semantic-router/src/semantic-router/pkg/utils/pii" +) + +// parseOpenAIRequest parses the raw JSON using the OpenAI SDK types +func parseOpenAIRequest(data []byte) (*openai.ChatCompletionNewParams, error) { + var req openai.ChatCompletionNewParams + if err := json.Unmarshal(data, &req); err != nil { + return nil, err + } + return &req, nil +} + +// extractStreamParam extracts the stream parameter from the original request body +func extractStreamParam(originalBody []byte) bool { + var requestMap map[string]interface{} + if err := json.Unmarshal(originalBody, &requestMap); err != nil { + return false + } + + if streamValue, exists := requestMap["stream"]; exists { + if stream, ok := streamValue.(bool); ok { + return stream + } + } + return false +} + +// serializeOpenAIRequestWithStream converts request back to JSON, preserving the stream parameter from original request +func serializeOpenAIRequestWithStream(req *openai.ChatCompletionNewParams, hasStreamParam bool) ([]byte, error) { + // First serialize the SDK object + sdkBytes, err := json.Marshal(req) + if err != nil { + return nil, err + } + + // If original request had stream parameter, add it back + if hasStreamParam { + var sdkMap map[string]interface{} + if err := json.Unmarshal(sdkBytes, &sdkMap); err == nil { + sdkMap["stream"] = true + if modifiedBytes, err := json.Marshal(sdkMap); err == nil { + return modifiedBytes, nil + } + } + } + + return sdkBytes, nil +} + +// shouldClearRouteCache checks if route cache should be cleared +func (r *OpenAIRouter) shouldClearRouteCache() bool { + // Check if feature is enabled + return r.Config.ClearRouteCache +} + +// addSystemPromptToRequestBody adds a system prompt to the beginning of the messages array in the JSON request body +// Returns the modified body, whether the system prompt was actually injected, and any error +func addSystemPromptToRequestBody(requestBody []byte, systemPrompt string, mode string) ([]byte, bool, error) { + if systemPrompt == "" { + return requestBody, false, nil + } + + // Parse the JSON request body + var requestMap map[string]interface{} + if err := json.Unmarshal(requestBody, &requestMap); err != nil { + return nil, false, err + } + + // Get the messages array + messagesInterface, ok := requestMap["messages"] + if !ok { + return requestBody, false, nil // No messages array, return original + } + + messages, ok := messagesInterface.([]interface{}) + if !ok { + return requestBody, false, nil // Messages is not an array, return original + } + + // Check if there's already a system message at the beginning + hasSystemMessage := false + var existingSystemContent string + if len(messages) > 0 { + if firstMsg, ok := messages[0].(map[string]interface{}); ok { + if role, ok := firstMsg["role"].(string); ok && role == "system" { + hasSystemMessage = true + if content, ok := firstMsg["content"].(string); ok { + existingSystemContent = content + } + } + } + } + + // Handle different injection modes + var finalSystemContent string + var logMessage string + + switch mode { + case "insert": + if hasSystemMessage { + // Insert mode: prepend category prompt to existing system message + finalSystemContent = systemPrompt + "\n\n" + existingSystemContent + logMessage = "Inserted category-specific system prompt before existing system message" + } else { + // No existing system message, just use the category prompt + finalSystemContent = systemPrompt + logMessage = "Added category-specific system prompt (insert mode, no existing system message)" + } + case "replace": + fallthrough + default: + // Replace mode: use only the category prompt + finalSystemContent = systemPrompt + if hasSystemMessage { + logMessage = "Replaced existing system message with category-specific system prompt" + } else { + logMessage = "Added category-specific system prompt to the beginning of messages" + } + } + + // Create the final system message + systemMessage := map[string]interface{}{ + "role": "system", + "content": finalSystemContent, + } + + if hasSystemMessage { + // Update the existing system message + messages[0] = systemMessage + } else { + // Prepend the system message to the beginning of the messages array + messages = append([]interface{}{systemMessage}, messages...) + } + + observability.Infof("%s (mode: %s)", logMessage, mode) + + // Update the messages in the request map + requestMap["messages"] = messages + + // Marshal back to JSON + modifiedBody, err := json.Marshal(requestMap) + return modifiedBody, true, err +} + +// extractUserAndNonUserContent extracts content from request messages +func extractUserAndNonUserContent(req *openai.ChatCompletionNewParams) (string, []string) { + var userContent string + var nonUser []string + + for _, msg := range req.Messages { + // Extract content based on message type + var textContent string + var role string + + if msg.OfUser != nil { + role = "user" + // Handle user message content + if msg.OfUser.Content.OfString.Value != "" { + textContent = msg.OfUser.Content.OfString.Value + } else if len(msg.OfUser.Content.OfArrayOfContentParts) > 0 { + // Extract text from content parts + var parts []string + for _, part := range msg.OfUser.Content.OfArrayOfContentParts { + if part.OfText != nil { + parts = append(parts, part.OfText.Text) + } + } + textContent = strings.Join(parts, " ") + } + } else if msg.OfSystem != nil { + role = "system" + if msg.OfSystem.Content.OfString.Value != "" { + textContent = msg.OfSystem.Content.OfString.Value + } else if len(msg.OfSystem.Content.OfArrayOfContentParts) > 0 { + // Extract text from content parts + var parts []string + for _, part := range msg.OfSystem.Content.OfArrayOfContentParts { + if part.Text != "" { + parts = append(parts, part.Text) + } + } + textContent = strings.Join(parts, " ") + } + } else if msg.OfAssistant != nil { + role = "assistant" + if msg.OfAssistant.Content.OfString.Value != "" { + textContent = msg.OfAssistant.Content.OfString.Value + } else if len(msg.OfAssistant.Content.OfArrayOfContentParts) > 0 { + // Extract text from content parts + var parts []string + for _, part := range msg.OfAssistant.Content.OfArrayOfContentParts { + if part.OfText != nil { + parts = append(parts, part.OfText.Text) + } + } + textContent = strings.Join(parts, " ") + } + } + + // Categorize by role + if role == "user" { + userContent = textContent + } else if role != "" { + nonUser = append(nonUser, textContent) + } + } + + return userContent, nonUser +} + +// RequestContext holds the context for processing a request +type RequestContext struct { + Headers map[string]string + RequestID string + OriginalRequestBody []byte + RequestModel string + RequestQuery string + StartTime time.Time + ProcessingStartTime time.Time + + // Streaming detection + ExpectStreamingResponse bool // set from request Accept header or stream parameter + IsStreamingResponse bool // set from response Content-Type + + // TTFT tracking + TTFTRecorded bool + TTFTSeconds float64 + + // Responses SSE translation state + ResponsesStreamInit bool + + // VSR decision tracking + VSRSelectedCategory string // The category selected by VSR + VSRReasoningMode string // "on" or "off" - whether reasoning mode was determined to be used + VSRSelectedModel string // The model selected by VSR + VSRCacheHit bool // Whether this request hit the cache + VSRInjectedSystemPrompt bool // Whether a system prompt was injected into the request + + // Tracing context + TraceContext context.Context // OpenTelemetry trace context for span propagation +} + +// handleRequestHeaders processes the request headers +func (r *OpenAIRouter) handleRequestHeaders(v *ext_proc.ProcessingRequest_RequestHeaders, ctx *RequestContext) (*ext_proc.ProcessingResponse, error) { + // Record start time for overall request processing + ctx.StartTime = time.Now() + observability.Infof("Received request headers") + + // Initialize trace context from incoming headers + baseCtx := context.Background() + headerMap := make(map[string]string) + for _, h := range v.RequestHeaders.Headers.Headers { + headerValue := h.Value + if headerValue == "" && len(h.RawValue) > 0 { + headerValue = string(h.RawValue) + } + headerMap[h.Key] = headerValue + } + + // Extract trace context from headers (if present) + ctx.TraceContext = observability.ExtractTraceContext(baseCtx, headerMap) + + // Start root span for the request + spanCtx, span := observability.StartSpan(ctx.TraceContext, observability.SpanRequestReceived, + trace.WithSpanKind(trace.SpanKindServer)) + ctx.TraceContext = spanCtx + defer span.End() + + // Store headers for later use + requestHeaders := v.RequestHeaders.Headers + for _, h := range requestHeaders.Headers { + // Prefer Value when available; fall back to RawValue + headerValue := h.Value + if headerValue == "" && len(h.RawValue) > 0 { + headerValue = string(h.RawValue) + } + observability.Debugf("Processing header: %s=%s", h.Key, headerValue) + + ctx.Headers[h.Key] = headerValue + // Store request ID if present (case-insensitive) + if strings.ToLower(h.Key) == headers.RequestID { + ctx.RequestID = headerValue + } + } + + // Set request metadata on span + if ctx.RequestID != "" { + observability.SetSpanAttributes(span, + attribute.String(observability.AttrRequestID, ctx.RequestID)) + } + + method := ctx.Headers[":method"] + path := ctx.Headers[":path"] + observability.SetSpanAttributes(span, + attribute.String(observability.AttrHTTPMethod, method), + attribute.String(observability.AttrHTTPPath, path)) + + // Detect if the client expects a streaming response (SSE) + if accept, ok := ctx.Headers["accept"]; ok { + if strings.Contains(strings.ToLower(accept), "text/event-stream") { + ctx.ExpectStreamingResponse = true + observability.Infof("Client expects streaming response based on Accept header") + } + } + + // Check if this is a GET request to /v1/models + if method == "GET" && strings.HasPrefix(path, "/v1/models") { + observability.Infof("Handling /v1/models request with path: %s", path) + return r.handleModelsRequest(path) + } + + // Responses adapter: detect POST /v1/responses and gate by feature flag + if method == "POST" && strings.HasPrefix(path, "/v1/responses") { + if r.Config == nil || !r.Config.EnableResponsesAdapter { + observability.Warnf("/v1/responses requested but adapter disabled") + return r.createErrorResponse(404, "Responses API not enabled"), nil + } + + // Metrics: record that adapter is handling this request + metrics.ResponsesAdapterRequests.WithLabelValues("false").Inc() + + // Prepare header mutation to rewrite :path to legacy chat completions + // Actual body mapping occurs in handleRequestBody + newPath := strings.Replace(path, "/v1/responses", "/v1/chat/completions", 1) + + headerMutation := &ext_proc.HeaderMutation{ + // Remove content-length because body will be mutated later + RemoveHeaders: []string{"content-length"}, + SetHeaders: []*core.HeaderValueOption{ + { + Header: &core.HeaderValue{ + Key: ":path", + RawValue: []byte(newPath), + }, + }, + }, + } + + response := &ext_proc.ProcessingResponse{ + Response: &ext_proc.ProcessingResponse_RequestHeaders{ + RequestHeaders: &ext_proc.HeadersResponse{ + Response: &ext_proc.CommonResponse{ + Status: ext_proc.CommonResponse_CONTINUE, + HeaderMutation: headerMutation, + }, + }, + }, + } + + observability.Infof("Rewriting /v1/responses to %s (headers phase)", newPath) + return response, nil + } + + // Prepare base response + response := &ext_proc.ProcessingResponse{ + Response: &ext_proc.ProcessingResponse_RequestHeaders{ + RequestHeaders: &ext_proc.HeadersResponse{ + Response: &ext_proc.CommonResponse{ + Status: ext_proc.CommonResponse_CONTINUE, + // No HeaderMutation - will be handled in body phase + }, + }, + }, + } + + // If streaming is expected, we rely on Envoy config to set response_body_mode: STREAMED for SSE. + // Some Envoy/control-plane versions may not support per-message ModeOverride; avoid compile-time coupling here. + // The Accept header is still recorded on context for downstream logic. + + return response, nil +} + +// handleRequestBody processes the request body +func (r *OpenAIRouter) handleRequestBody(v *ext_proc.ProcessingRequest_RequestBody, ctx *RequestContext) (*ext_proc.ProcessingResponse, error) { + observability.Infof("Received request body %s", string(v.RequestBody.GetBody())) + // Record start time for model routing + ctx.ProcessingStartTime = time.Now() + // Save the original request body + ctx.OriginalRequestBody = v.RequestBody.GetBody() + + // Extract stream parameter from original request and update ExpectStreamingResponse if needed + hasStreamParam := extractStreamParam(ctx.OriginalRequestBody) + if hasStreamParam { + observability.Infof("Original request contains stream parameter: true") + ctx.ExpectStreamingResponse = true // Set this if stream param is found + } + + // If path was /v1/responses and adapter enabled, map request JSON to ChatCompletion + if r.Config != nil && r.Config.EnableResponsesAdapter { + if p, ok := ctx.Headers[":path"]; ok && strings.HasPrefix(p, "/v1/responses") { + mapped, err := mapResponsesRequestToChatCompletions(ctx.OriginalRequestBody) + if err != nil { + observability.Errorf("Responses→Chat mapping failed: %v", err) + metrics.RecordRequestError(ctx.RequestModel, "parse_error") + return r.createErrorResponse(400, "Invalid /v1/responses payload"), nil + } + + // Replace original body with mapped body for downstream processing + ctx.OriginalRequestBody = mapped + + // No-op for Accept header here; downstream content negotiation remains unchanged + } + } + + // Parse the OpenAI request using SDK types + openAIRequest, err := parseOpenAIRequest(ctx.OriginalRequestBody) + if err != nil { + observability.Errorf("Error parsing OpenAI request: %v", err) + metrics.RecordRequestError(ctx.RequestModel, "parse_error") + metrics.RecordModelRequest(ctx.RequestModel) + return nil, status.Errorf(codes.InvalidArgument, "invalid request body: %v", err) + } + + // Store the original model + originalModel := openAIRequest.Model + observability.Infof("Original model: %s", originalModel) + + // Set model on span + if ctx.TraceContext != nil { + _, span := observability.StartSpan(ctx.TraceContext, "parse_request") + observability.SetSpanAttributes(span, + attribute.String(observability.AttrOriginalModel, originalModel)) + span.End() + } + + // Record the initial request to this model (count all requests) + metrics.RecordModelRequest(originalModel) + // Also set the model on context early so error metrics can label it + if ctx.RequestModel == "" { + ctx.RequestModel = originalModel + } + + // Get content from messages + userContent, nonUserMessages := extractUserAndNonUserContent(openAIRequest) + + // Classify the request early to determine category for security checks and cache settings + var categoryName string + if r.Config != nil && r.Config.IsAutoModelName(originalModel) && (len(nonUserMessages) > 0 || userContent != "") { + // Determine text to use for classification + var classificationText string + if len(userContent) > 0 { + classificationText = userContent + } else if len(nonUserMessages) > 0 { + classificationText = strings.Join(nonUserMessages, " ") + } + if classificationText != "" { + categoryName = r.findCategoryForClassification(classificationText) + observability.Debugf("Classified request to category: %s", categoryName) + } + } + + // Perform security checks with category-specific settings + if response, shouldReturn := r.performSecurityChecks(ctx, userContent, nonUserMessages, categoryName); shouldReturn { + return response, nil + } + + // Handle caching with category-specific settings + if response, shouldReturn := r.handleCaching(ctx, categoryName); shouldReturn { + return response, nil + } + + // Handle model selection and routing + return r.handleModelRouting(openAIRequest, originalModel, userContent, nonUserMessages, ctx) +} + +// performSecurityChecks performs PII and jailbreak detection with category-specific settings +func (r *OpenAIRouter) performSecurityChecks(ctx *RequestContext, userContent string, nonUserMessages []string, categoryName string) (*ext_proc.ProcessingResponse, bool) { + // Perform PII classification on all message content + allContent := pii.ExtractAllContent(userContent, nonUserMessages) + + // Check if jailbreak detection is enabled for this category + jailbreakEnabled := r.Classifier.IsJailbreakEnabled() + if categoryName != "" && r.Config != nil { + // Use category-specific setting if available + jailbreakEnabled = jailbreakEnabled && r.Config.IsJailbreakEnabledForCategory(categoryName) + } + + // Get category-specific threshold + jailbreakThreshold := r.Config.PromptGuard.Threshold + if categoryName != "" && r.Config != nil { + jailbreakThreshold = r.Config.GetJailbreakThresholdForCategory(categoryName) + } + + // Perform jailbreak detection on all message content + if jailbreakEnabled { + // Start jailbreak detection span + spanCtx, span := observability.StartSpan(ctx.TraceContext, observability.SpanJailbreakDetection) + defer span.End() + + startTime := time.Now() + hasJailbreak, jailbreakDetections, err := r.Classifier.AnalyzeContentForJailbreakWithThreshold(allContent, jailbreakThreshold) + detectionTime := time.Since(startTime).Milliseconds() + + observability.SetSpanAttributes(span, + attribute.Int64(observability.AttrJailbreakDetectionTimeMs, detectionTime)) + + if err != nil { + observability.Errorf("Error performing jailbreak analysis: %v", err) + observability.RecordError(span, err) + // Continue processing despite jailbreak analysis error + metrics.RecordRequestError(ctx.RequestModel, "classification_failed") + } else if hasJailbreak { + // Find the first jailbreak detection for response + var jailbreakType string + var confidence float32 + for _, detection := range jailbreakDetections { + if detection.IsJailbreak { + jailbreakType = detection.JailbreakType + confidence = detection.Confidence + break + } + } + + observability.SetSpanAttributes(span, + attribute.Bool(observability.AttrJailbreakDetected, true), + attribute.String(observability.AttrJailbreakType, jailbreakType), + attribute.String(observability.AttrSecurityAction, "blocked")) + + observability.Warnf("JAILBREAK ATTEMPT BLOCKED: %s (confidence: %.3f)", jailbreakType, confidence) + + // Return immediate jailbreak violation response + // Structured log for security block + observability.LogEvent("security_block", map[string]interface{}{ + "reason_code": "jailbreak_detected", + "jailbreak_type": jailbreakType, + "confidence": confidence, + "request_id": ctx.RequestID, + }) + // Count this as a blocked request + metrics.RecordRequestError(ctx.RequestModel, "jailbreak_block") + jailbreakResponse := http.CreateJailbreakViolationResponse(jailbreakType, confidence, ctx.ExpectStreamingResponse) + ctx.TraceContext = spanCtx + return jailbreakResponse, true + } else { + observability.SetSpanAttributes(span, + attribute.Bool(observability.AttrJailbreakDetected, false)) + observability.Infof("No jailbreak detected in request content") + ctx.TraceContext = spanCtx + } + } + + return nil, false +} + +// handleCaching handles cache lookup and storage with category-specific settings +func (r *OpenAIRouter) handleCaching(ctx *RequestContext, categoryName string) (*ext_proc.ProcessingResponse, bool) { + // Extract the model and query for cache lookup + requestModel, requestQuery, err := cache.ExtractQueryFromOpenAIRequest(ctx.OriginalRequestBody) + if err != nil { + observability.Errorf("Error extracting query from request: %v", err) + // Continue without caching + return nil, false + } + + ctx.RequestModel = requestModel + ctx.RequestQuery = requestQuery + + // Check if caching is enabled for this category + cacheEnabled := r.Config.SemanticCache.Enabled + if categoryName != "" { + cacheEnabled = r.Config.IsCacheEnabledForCategory(categoryName) + } + + if requestQuery != "" && r.Cache.IsEnabled() && cacheEnabled { + // Get category-specific threshold + threshold := r.Config.GetCacheSimilarityThreshold() + if categoryName != "" { + threshold = r.Config.GetCacheSimilarityThresholdForCategory(categoryName) + } + + // Start cache lookup span + spanCtx, span := observability.StartSpan(ctx.TraceContext, observability.SpanCacheLookup) + defer span.End() + + startTime := time.Now() + // Try to find a similar cached response using category-specific threshold + cachedResponse, found, cacheErr := r.Cache.FindSimilarWithThreshold(requestModel, requestQuery, threshold) + lookupTime := time.Since(startTime).Milliseconds() + + observability.SetSpanAttributes(span, + attribute.String(observability.AttrCacheKey, requestQuery), + attribute.Bool(observability.AttrCacheHit, found), + attribute.Int64(observability.AttrCacheLookupTimeMs, lookupTime), + attribute.String(observability.AttrCategoryName, categoryName), + attribute.Float64("cache.threshold", float64(threshold))) + + if cacheErr != nil { + observability.Errorf("Error searching cache: %v", cacheErr) + observability.RecordError(span, cacheErr) + } else if found { + // Mark this request as a cache hit + ctx.VSRCacheHit = true + // Log cache hit + observability.LogEvent("cache_hit", map[string]interface{}{ + "request_id": ctx.RequestID, + "model": requestModel, + "query": requestQuery, + "category": categoryName, + "threshold": threshold, + }) + // Return immediate response from cache + response := http.CreateCacheHitResponse(cachedResponse, ctx.ExpectStreamingResponse) + ctx.TraceContext = spanCtx + return response, true + } + ctx.TraceContext = spanCtx + } + + // Cache miss, store the request for later + err = r.Cache.AddPendingRequest(ctx.RequestID, requestModel, requestQuery, ctx.OriginalRequestBody) + if err != nil { + observability.Errorf("Error adding pending request to cache: %v", err) + // Continue without caching + } + + return nil, false +} + +// handleModelRouting handles model selection and routing logic +func (r *OpenAIRouter) handleModelRouting(openAIRequest *openai.ChatCompletionNewParams, originalModel, userContent string, nonUserMessages []string, ctx *RequestContext) (*ext_proc.ProcessingResponse, error) { + // Create default response with CONTINUE status + response := &ext_proc.ProcessingResponse{ + Response: &ext_proc.ProcessingResponse_RequestBody{ + RequestBody: &ext_proc.BodyResponse{ + Response: &ext_proc.CommonResponse{ + Status: ext_proc.CommonResponse_CONTINUE, + }, + }, + }, + } + + // Only change the model if the original model is an auto model name (supports both "auto" and configured AutoModelName for backward compatibility) + actualModel := originalModel + var selectedEndpoint string + isAutoModel := r.Config != nil && r.Config.IsAutoModelName(originalModel) + if isAutoModel && (len(nonUserMessages) > 0 || userContent != "") { + observability.Infof("Using Auto Model Selection (model=%s)", originalModel) + // Determine text to use for classification/similarity + var classificationText string + if len(userContent) > 0 { + classificationText = userContent + } else if len(nonUserMessages) > 0 { + // Fall back to user content if no system/assistant messages + classificationText = strings.Join(nonUserMessages, " ") + } + + if classificationText != "" { + // Start classification span + classifyCtx, classifySpan := observability.StartSpan(ctx.TraceContext, observability.SpanClassification) + classifyStart := time.Now() + + // Find the most similar task description or classify, then select best model + matchedModel := r.classifyAndSelectBestModel(classificationText) + classifyTime := time.Since(classifyStart).Milliseconds() + + // Get category information for the span + categoryName := r.findCategoryForClassification(classificationText) + + observability.SetSpanAttributes(classifySpan, + attribute.String(observability.AttrCategoryName, categoryName), + attribute.String(observability.AttrClassifierType, "bert"), + attribute.Int64(observability.AttrClassificationTimeMs, classifyTime)) + classifySpan.End() + ctx.TraceContext = classifyCtx + + if matchedModel != originalModel && matchedModel != "" { + // Start PII detection span if enabled + allContent := pii.ExtractAllContent(userContent, nonUserMessages) + if r.PIIChecker.IsPIIEnabled(matchedModel) { + piiCtx, piiSpan := observability.StartSpan(ctx.TraceContext, observability.SpanPIIDetection) + piiStart := time.Now() + + observability.Infof("PII policy enabled for model %s", matchedModel) + detectedPII := r.Classifier.DetectPIIInContent(allContent) + + piiTime := time.Since(piiStart).Milliseconds() + piiDetected := len(detectedPII) > 0 + + observability.SetSpanAttributes(piiSpan, + attribute.Bool(observability.AttrPIIDetected, piiDetected), + attribute.Int64(observability.AttrPIIDetectionTimeMs, piiTime)) + + if piiDetected { + // Convert detected PII to comma-separated string + piiTypesStr := strings.Join(detectedPII, ",") + observability.SetSpanAttributes(piiSpan, + attribute.String(observability.AttrPIITypes, piiTypesStr)) + } + + piiSpan.End() + ctx.TraceContext = piiCtx + + // Check if the initially selected model passes PII policy + allowed, deniedPII, err := r.PIIChecker.CheckPolicy(matchedModel, detectedPII) + if err != nil { + observability.Errorf("Error checking PII policy for model %s: %v", matchedModel, err) + // Continue with original selection on error + } else if !allowed { + observability.Warnf("Initially selected model %s violates PII policy, finding alternative", matchedModel) + // Find alternative models from the same category that pass PII policy + categoryName := r.findCategoryForClassification(classificationText) + if categoryName != "" { + alternativeModels := r.Classifier.GetModelsForCategory(categoryName) + allowedModels := r.PIIChecker.FilterModelsForPII(alternativeModels, detectedPII) + if len(allowedModels) > 0 { + // Select the best allowed model from this category + matchedModel = r.Classifier.SelectBestModelFromList(allowedModels, categoryName) + observability.Infof("Selected alternative model %s that passes PII policy", matchedModel) + // Record reason code for selecting alternative due to PII + metrics.RecordRoutingReasonCode("pii_policy_alternative_selected", matchedModel) + } else { + observability.Warnf("No models in category %s pass PII policy, using default", categoryName) + matchedModel = r.Config.DefaultModel + // Check if default model passes policy + defaultAllowed, defaultDeniedPII, _ := r.PIIChecker.CheckPolicy(matchedModel, detectedPII) + if !defaultAllowed { + observability.Errorf("Default model also violates PII policy, returning error") + observability.LogEvent("routing_block", map[string]interface{}{ + "reason_code": "pii_policy_denied_default_model", + "request_id": ctx.RequestID, + "model": matchedModel, + "denied_pii": defaultDeniedPII, + }) + metrics.RecordRequestError(matchedModel, "pii_policy_denied") + piiResponse := http.CreatePIIViolationResponse(matchedModel, defaultDeniedPII, ctx.ExpectStreamingResponse) + return piiResponse, nil + } + } + } else { + observability.Warnf("Could not determine category, returning PII violation for model %s", matchedModel) + observability.LogEvent("routing_block", map[string]interface{}{ + "reason_code": "pii_policy_denied", + "request_id": ctx.RequestID, + "model": matchedModel, + "denied_pii": deniedPII, + }) + metrics.RecordRequestError(matchedModel, "pii_policy_denied") + piiResponse := http.CreatePIIViolationResponse(matchedModel, deniedPII, ctx.ExpectStreamingResponse) + return piiResponse, nil + } + } + } + + observability.Infof("Routing to model: %s", matchedModel) + + // Start routing decision span + routingCtx, routingSpan := observability.StartSpan(ctx.TraceContext, observability.SpanRoutingDecision) + + // Check reasoning mode for this category using entropy-based approach + useReasoning, categoryName, reasoningDecision := r.getEntropyBasedReasoningModeAndCategory(userContent) + observability.Infof("Entropy-based reasoning decision for this query: %v on [%s] model (confidence: %.3f, reason: %s)", + useReasoning, matchedModel, reasoningDecision.Confidence, reasoningDecision.DecisionReason) + // Record reasoning decision metric with the effort that will be applied if enabled + effortForMetrics := r.getReasoningEffort(categoryName, matchedModel) + metrics.RecordReasoningDecision(categoryName, matchedModel, useReasoning, effortForMetrics) + + // Set routing attributes on span + observability.SetSpanAttributes(routingSpan, + attribute.String(observability.AttrRoutingStrategy, "auto"), + attribute.String(observability.AttrRoutingReason, reasoningDecision.DecisionReason), + attribute.String(observability.AttrOriginalModel, originalModel), + attribute.String(observability.AttrSelectedModel, matchedModel), + attribute.Bool(observability.AttrReasoningEnabled, useReasoning), + attribute.String(observability.AttrReasoningEffort, effortForMetrics)) + + routingSpan.End() + ctx.TraceContext = routingCtx + + // Track VSR decision information + ctx.VSRSelectedCategory = categoryName + ctx.VSRSelectedModel = matchedModel + if useReasoning { + ctx.VSRReasoningMode = "on" + } else { + ctx.VSRReasoningMode = "off" + } + + // Track the model routing change + metrics.RecordModelRouting(originalModel, matchedModel) + + // Update the actual model that will be used + actualModel = matchedModel + + // Start backend selection span + backendCtx, backendSpan := observability.StartSpan(ctx.TraceContext, observability.SpanBackendSelection) + + // Select the best endpoint for this model + endpointAddress, endpointFound := r.Config.SelectBestEndpointAddressForModel(matchedModel) + if endpointFound { + selectedEndpoint = endpointAddress + observability.Infof("Selected endpoint address: %s for model: %s", selectedEndpoint, matchedModel) + + // Extract endpoint name from config + endpoints := r.Config.GetEndpointsForModel(matchedModel) + if len(endpoints) > 0 { + observability.SetSpanAttributes(backendSpan, + attribute.String(observability.AttrEndpointName, endpoints[0].Name), + attribute.String(observability.AttrEndpointAddress, selectedEndpoint)) + } + } else { + observability.Warnf("No endpoint found for model %s, using fallback", matchedModel) + } + + backendSpan.End() + ctx.TraceContext = backendCtx + + // Modify the model in the request + openAIRequest.Model = matchedModel + + // Serialize the modified request with stream parameter preserved + modifiedBody, err := serializeOpenAIRequestWithStream(openAIRequest, ctx.ExpectStreamingResponse) + if err != nil { + observability.Errorf("Error serializing modified request: %v", err) + metrics.RecordRequestError(actualModel, "serialization_error") + return nil, status.Errorf(codes.Internal, "error serializing modified request: %v", err) + } + + modifiedBody, err = r.setReasoningModeToRequestBody(modifiedBody, useReasoning, categoryName) + if err != nil { + observability.Errorf("Error setting reasoning mode %v to request: %v", useReasoning, err) + metrics.RecordRequestError(actualModel, "serialization_error") + return nil, status.Errorf(codes.Internal, "error setting reasoning mode: %v", err) + } + + // Add category-specific system prompt if configured + if categoryName != "" { + // Try to get the most up-to-date category configuration from global config first + // This ensures API updates are reflected immediately + globalConfig := config.GetConfig() + var category *config.Category + if globalConfig != nil { + category = globalConfig.GetCategoryByName(categoryName) + } + + // If not found in global config, fall back to router's config (for tests and initial setup) + if category == nil { + category = r.Classifier.GetCategoryByName(categoryName) + } + + if category != nil && category.SystemPrompt != "" && category.IsSystemPromptEnabled() { + // Start system prompt injection span + promptCtx, promptSpan := observability.StartSpan(ctx.TraceContext, observability.SpanSystemPromptInjection) + + mode := category.GetSystemPromptMode() + var injected bool + modifiedBody, injected, err = addSystemPromptToRequestBody(modifiedBody, category.SystemPrompt, mode) + if err != nil { + observability.Errorf("Error adding system prompt to request: %v", err) + observability.RecordError(promptSpan, err) + metrics.RecordRequestError(actualModel, "serialization_error") + promptSpan.End() + return nil, status.Errorf(codes.Internal, "error adding system prompt: %v", err) + } + + observability.SetSpanAttributes(promptSpan, + attribute.Bool("system_prompt.injected", injected), + attribute.String("system_prompt.mode", mode), + attribute.String(observability.AttrCategoryName, categoryName)) + + if injected { + ctx.VSRInjectedSystemPrompt = true + observability.Infof("Added category-specific system prompt for category: %s (mode: %s)", categoryName, mode) + } + + // Log metadata about system prompt injection (avoid logging sensitive user data) + observability.Infof("System prompt injection completed for category: %s, body size: %d bytes", categoryName, len(modifiedBody)) + + promptSpan.End() + ctx.TraceContext = promptCtx + } else if category != nil && category.SystemPrompt != "" && !category.IsSystemPromptEnabled() { + observability.Infof("System prompt disabled for category: %s", categoryName) + } + } + + // Create body mutation with the modified body + bodyMutation := &ext_proc.BodyMutation{ + Mutation: &ext_proc.BodyMutation_Body{ + Body: modifiedBody, + }, + } + + // Create header mutation with content-length removal AND all necessary routing headers + // (body phase HeaderMutation replaces header phase completely) + setHeaders := []*core.HeaderValueOption{} + if selectedEndpoint != "" { + setHeaders = append(setHeaders, &core.HeaderValueOption{ + Header: &core.HeaderValue{ + Key: headers.GatewayDestinationEndpoint, + RawValue: []byte(selectedEndpoint), + }, + }) + } + if actualModel != "" { + setHeaders = append(setHeaders, &core.HeaderValueOption{ + Header: &core.HeaderValue{ + Key: headers.SelectedModel, + RawValue: []byte(actualModel), + }, + }) + } + + headerMutation := &ext_proc.HeaderMutation{ + RemoveHeaders: []string{"content-length"}, + SetHeaders: setHeaders, + } + + observability.Debugf("ActualModel = '%s'", actualModel) + + // Set the response with body mutation and content-length removal + response = &ext_proc.ProcessingResponse{ + Response: &ext_proc.ProcessingResponse_RequestBody{ + RequestBody: &ext_proc.BodyResponse{ + Response: &ext_proc.CommonResponse{ + Status: ext_proc.CommonResponse_CONTINUE, + HeaderMutation: headerMutation, + BodyMutation: bodyMutation, + }, + }, + }, + } + + observability.Infof("Use new model: %s", matchedModel) + + // Structured log for routing decision (auto) + observability.LogEvent("routing_decision", map[string]interface{}{ + "reason_code": "auto_routing", + "request_id": ctx.RequestID, + "original_model": originalModel, + "selected_model": matchedModel, + "category": categoryName, + "reasoning_enabled": useReasoning, + "reasoning_effort": effortForMetrics, + "selected_endpoint": selectedEndpoint, + "routing_latency_ms": time.Since(ctx.ProcessingStartTime).Milliseconds(), + }) + metrics.RecordRoutingReasonCode("auto_routing", matchedModel) + } + } + } else if !isAutoModel { + observability.Infof("Using specified model: %s", originalModel) + // Track VSR decision information for non-auto models + ctx.VSRSelectedModel = originalModel + ctx.VSRReasoningMode = "off" // Non-auto models don't use reasoning mode by default + // For non-auto models, check PII policy compliance + allContent := pii.ExtractAllContent(userContent, nonUserMessages) + detectedPII := r.Classifier.DetectPIIInContent(allContent) + + allowed, deniedPII, err := r.PIIChecker.CheckPolicy(originalModel, detectedPII) + if err != nil { + observability.Errorf("Error checking PII policy for model %s: %v", originalModel, err) + // Continue with request on error + } else if !allowed { + observability.Errorf("Model %s violates PII policy, returning error", originalModel) + observability.LogEvent("routing_block", map[string]interface{}{ + "reason_code": "pii_policy_denied", + "request_id": ctx.RequestID, + "model": originalModel, + "denied_pii": deniedPII, + }) + metrics.RecordRequestError(originalModel, "pii_policy_denied") + piiResponse := http.CreatePIIViolationResponse(originalModel, deniedPII, ctx.ExpectStreamingResponse) + return piiResponse, nil + } + + // Select the best endpoint for the specified model + endpointAddress, endpointFound := r.Config.SelectBestEndpointAddressForModel(originalModel) + if endpointFound { + selectedEndpoint = endpointAddress + observability.Infof("Selected endpoint address: %s for model: %s", selectedEndpoint, originalModel) + } else { + observability.Warnf("No endpoint found for model %s, using fallback", originalModel) + } + setHeaders := []*core.HeaderValueOption{} + if selectedEndpoint != "" { + setHeaders = append(setHeaders, &core.HeaderValueOption{ + Header: &core.HeaderValue{ + Key: headers.GatewayDestinationEndpoint, + RawValue: []byte(selectedEndpoint), + }, + }) + } + // Set x-selected-model header for non-auto models + setHeaders = append(setHeaders, &core.HeaderValueOption{ + Header: &core.HeaderValue{ + Key: "x-selected-model", + RawValue: []byte(originalModel), + }, + }) + // Create CommonResponse with cache clearing if enabled + commonResponse := &ext_proc.CommonResponse{ + Status: ext_proc.CommonResponse_CONTINUE, + HeaderMutation: &ext_proc.HeaderMutation{ + SetHeaders: setHeaders, + }, + } + + // Check if route cache should be cleared + if r.shouldClearRouteCache() { + commonResponse.ClearRouteCache = true + } + + // Set the response with body mutation and content-length removal + response = &ext_proc.ProcessingResponse{ + Response: &ext_proc.ProcessingResponse_RequestBody{ + RequestBody: &ext_proc.BodyResponse{ + Response: commonResponse, + }, + }, + } + // Structured log for routing decision (explicit model) + observability.LogEvent("routing_decision", map[string]interface{}{ + "reason_code": "model_specified", + "request_id": ctx.RequestID, + "original_model": originalModel, + "selected_model": originalModel, + "category": "", + "reasoning_enabled": false, + "reasoning_effort": "", + "selected_endpoint": selectedEndpoint, + "routing_latency_ms": time.Since(ctx.ProcessingStartTime).Milliseconds(), + }) + metrics.RecordRoutingReasonCode("model_specified", originalModel) + } + + // Check if route cache should be cleared (only for auto models, non-auto models handle this in their own path) + // isAutoModel already determined at the beginning of this function using IsAutoModelName + if isAutoModel && r.shouldClearRouteCache() { + // Access the CommonResponse that's already created in this function + if response.GetRequestBody() != nil && response.GetRequestBody().GetResponse() != nil { + response.GetRequestBody().GetResponse().ClearRouteCache = true + observability.Debugf("Setting ClearRouteCache=true (feature enabled) for auto model") + } + } + + // Save the actual model that will be used for token tracking + ctx.RequestModel = actualModel + + // Handle tool selection based on tool_choice field + if err := r.handleToolSelection(openAIRequest, userContent, nonUserMessages, &response, ctx); err != nil { + observability.Errorf("Error in tool selection: %v", err) + // Continue without failing the request + } + + // Record the routing latency + routingLatency := time.Since(ctx.ProcessingStartTime) + metrics.RecordModelRoutingLatency(routingLatency.Seconds()) + + return response, nil +} + +// handleToolSelection handles automatic tool selection based on semantic similarity +func (r *OpenAIRouter) handleToolSelection(openAIRequest *openai.ChatCompletionNewParams, userContent string, nonUserMessages []string, response **ext_proc.ProcessingResponse, ctx *RequestContext) error { + // Check if tool_choice is set to "auto" + if openAIRequest.ToolChoice.OfAuto.Value == "auto" { + // Continue with tool selection logic + } else { + return nil // Not auto tool selection + } + + // Get text for tools classification + var classificationText string + if len(userContent) > 0 { + classificationText = userContent + } else if len(nonUserMessages) > 0 { + classificationText = strings.Join(nonUserMessages, " ") + } + + if classificationText == "" { + observability.Infof("No content available for tool classification") + return nil + } + + if !r.ToolsDatabase.IsEnabled() { + observability.Infof("Tools database is disabled") + return nil + } + + // Get configuration for tool selection + topK := r.Config.Tools.TopK + if topK <= 0 { + topK = 3 // Default to 3 tools + } + + // Find similar tools based on the query + selectedTools, err := r.ToolsDatabase.FindSimilarTools(classificationText, topK) + if err != nil { + if r.Config.Tools.FallbackToEmpty { + observability.Warnf("Tool selection failed, falling back to no tools: %v", err) + openAIRequest.Tools = nil + return r.updateRequestWithTools(openAIRequest, response, ctx) + } + metrics.RecordRequestError(getModelFromCtx(ctx), "classification_failed") + return err + } + + if len(selectedTools) == 0 { + if r.Config.Tools.FallbackToEmpty { + observability.Infof("No suitable tools found, falling back to no tools") + openAIRequest.Tools = nil + } else { + observability.Infof("No suitable tools found above threshold") + openAIRequest.Tools = []openai.ChatCompletionToolParam{} // Empty array + } + } else { + // Convert selected tools to OpenAI SDK tool format + tools := make([]openai.ChatCompletionToolParam, len(selectedTools)) + for i, tool := range selectedTools { + // Convert the tool to OpenAI SDK format + toolBytes, err := json.Marshal(tool) + if err != nil { + metrics.RecordRequestError(getModelFromCtx(ctx), "serialization_error") + return err + } + var sdkTool openai.ChatCompletionToolParam + if err := json.Unmarshal(toolBytes, &sdkTool); err != nil { + return err + } + tools[i] = sdkTool + } + + openAIRequest.Tools = tools + observability.Infof("Auto-selected %d tools for query: %s", len(selectedTools), classificationText) + } + + return r.updateRequestWithTools(openAIRequest, response, ctx) +} + +// updateRequestWithTools updates the request body with the selected tools +func (r *OpenAIRouter) updateRequestWithTools(openAIRequest *openai.ChatCompletionNewParams, response **ext_proc.ProcessingResponse, ctx *RequestContext) error { + // Re-serialize the request with modified tools and preserved stream parameter + modifiedBody, err := serializeOpenAIRequestWithStream(openAIRequest, ctx.ExpectStreamingResponse) + if err != nil { + return err + } + + // Create body mutation with the modified body + bodyMutation := &ext_proc.BodyMutation{ + Mutation: &ext_proc.BodyMutation_Body{ + Body: modifiedBody, + }, + } + + // Create header mutation with content-length removal AND all necessary routing headers + // (body phase HeaderMutation replaces header phase completely) + + // Get the headers that should have been set in the main routing + var selectedEndpoint, actualModel string + + // These should be available from the existing response + if (*response).GetRequestBody() != nil && (*response).GetRequestBody().GetResponse() != nil && + (*response).GetRequestBody().GetResponse().GetHeaderMutation() != nil && + (*response).GetRequestBody().GetResponse().GetHeaderMutation().GetSetHeaders() != nil { + for _, header := range (*response).GetRequestBody().GetResponse().GetHeaderMutation().GetSetHeaders() { + switch header.Header.Key { + case headers.GatewayDestinationEndpoint: + selectedEndpoint = header.Header.Value + case headers.SelectedModel: + actualModel = header.Header.Value + } + } + } + + setHeaders := []*core.HeaderValueOption{} + if selectedEndpoint != "" { + setHeaders = append(setHeaders, &core.HeaderValueOption{ + Header: &core.HeaderValue{ + Key: headers.GatewayDestinationEndpoint, + RawValue: []byte(selectedEndpoint), + }, + }) + } + if actualModel != "" { + setHeaders = append(setHeaders, &core.HeaderValueOption{ + Header: &core.HeaderValue{ + Key: headers.SelectedModel, + RawValue: []byte(actualModel), + }, + }) + } + + // Intentionally do not mutate Authorization header here + + headerMutation := &ext_proc.HeaderMutation{ + RemoveHeaders: []string{"content-length"}, + SetHeaders: setHeaders, + } + + // Create CommonResponse + commonResponse := &ext_proc.CommonResponse{ + Status: ext_proc.CommonResponse_CONTINUE, + HeaderMutation: headerMutation, + BodyMutation: bodyMutation, + } + + // Check if route cache should be cleared + if r.shouldClearRouteCache() { + commonResponse.ClearRouteCache = true + observability.Debugf("Setting ClearRouteCache=true (feature enabled) in updateRequestWithTools") + } + + // Update the response with body mutation and content-length removal + *response = &ext_proc.ProcessingResponse{ + Response: &ext_proc.ProcessingResponse_RequestBody{ + RequestBody: &ext_proc.BodyResponse{ + Response: commonResponse, + }, + }, + } + + return nil +} + +// OpenAIModel represents a single model in the OpenAI /v1/models response +type OpenAIModel struct { + ID string `json:"id"` + Object string `json:"object"` + Created int64 `json:"created"` + OwnedBy string `json:"owned_by"` + Description string `json:"description,omitempty"` // Optional description for Chat UI + LogoURL string `json:"logo_url,omitempty"` // Optional logo URL for Chat UI +} + +// OpenAIModelList is the container for the models list response +type OpenAIModelList struct { + Object string `json:"object"` + Data []OpenAIModel `json:"data"` +} + +// handleModelsRequest handles GET /v1/models requests and returns a direct response +// Whether to include configured models is controlled by the config's IncludeConfigModelsInList setting (default: false) +func (r *OpenAIRouter) handleModelsRequest(_ string) (*ext_proc.ProcessingResponse, error) { + now := time.Now().Unix() + + // Start with the configured auto model name (or default "MoM") + // The model list uses the actual configured name, not "auto" + // However, "auto" is still accepted as an alias in request handling for backward compatibility + models := []OpenAIModel{} + + // Add the effective auto model name (configured or default "MoM") + if r.Config != nil { + effectiveAutoModelName := r.Config.GetEffectiveAutoModelName() + models = append(models, OpenAIModel{ + ID: effectiveAutoModelName, + Object: "model", + Created: now, + OwnedBy: "vllm-semantic-router", + Description: "Intelligent Router for Mixture-of-Models", + LogoURL: "https://github.com/vllm-project/semantic-router/blob/main/website/static/img/vllm.png", // You can customize this URL + }) + } else { + // Fallback if no config + models = append(models, OpenAIModel{ + ID: "MoM", + Object: "model", + Created: now, + OwnedBy: "vllm-semantic-router", + Description: "Intelligent Router for Mixture-of-Models", + LogoURL: "https://github.com/vllm-project/semantic-router/blob/main/website/static/img/vllm.png", // You can customize this URL + }) + } + + // Append underlying models from config (if available and configured to include them) + if r.Config != nil && r.Config.IncludeConfigModelsInList { + for _, m := range r.Config.GetAllModels() { + // Skip if already added as the configured auto model name (avoid duplicates) + if m == r.Config.GetEffectiveAutoModelName() { + continue + } + models = append(models, OpenAIModel{ + ID: m, + Object: "model", + Created: now, + OwnedBy: "vllm-semantic-router", + }) + } + } + + resp := OpenAIModelList{ + Object: "list", + Data: models, + } + + return r.createJSONResponse(200, resp), nil +} + +// statusCodeToEnum converts HTTP status code to typev3.StatusCode enum +func statusCodeToEnum(statusCode int) typev3.StatusCode { + switch statusCode { + case 200: + return typev3.StatusCode_OK + case 400: + return typev3.StatusCode_BadRequest + case 404: + return typev3.StatusCode_NotFound + case 500: + return typev3.StatusCode_InternalServerError + default: + return typev3.StatusCode_OK + } +} + +// createJSONResponseWithBody creates a direct response with pre-marshaled JSON body +func (r *OpenAIRouter) createJSONResponseWithBody(statusCode int, jsonBody []byte) *ext_proc.ProcessingResponse { + return &ext_proc.ProcessingResponse{ + Response: &ext_proc.ProcessingResponse_ImmediateResponse{ + ImmediateResponse: &ext_proc.ImmediateResponse{ + Status: &typev3.HttpStatus{ + Code: statusCodeToEnum(statusCode), + }, + Headers: &ext_proc.HeaderMutation{ + SetHeaders: []*core.HeaderValueOption{ + { + Header: &core.HeaderValue{ + Key: "content-type", + RawValue: []byte("application/json"), + }, + }, + }, + }, + Body: jsonBody, + }, + }, + } +} + +// createJSONResponse creates a direct response with JSON content +func (r *OpenAIRouter) createJSONResponse(statusCode int, data interface{}) *ext_proc.ProcessingResponse { + jsonData, err := json.Marshal(data) + if err != nil { + observability.Errorf("Failed to marshal JSON response: %v", err) + return r.createErrorResponse(500, "Internal server error") + } + + return r.createJSONResponseWithBody(statusCode, jsonData) +} + +// createErrorResponse creates a direct error response +func (r *OpenAIRouter) createErrorResponse(statusCode int, message string) *ext_proc.ProcessingResponse { + errorResp := map[string]interface{}{ + "error": map[string]interface{}{ + "message": message, + "type": "invalid_request_error", + "code": statusCode, + }, + } + + jsonData, err := json.Marshal(errorResp) + if err != nil { + observability.Errorf("Failed to marshal error response: %v", err) + jsonData = []byte(`{"error":{"message":"Internal server error","type":"internal_error","code":500}}`) + // Use 500 status code for fallback error + statusCode = 500 + } + + return r.createJSONResponseWithBody(statusCode, jsonData) +} diff --git a/src/semantic-router/pkg/observability/metrics/metrics.go b/src/semantic-router/pkg/observability/metrics/metrics.go index 50fdd6376..9725cd5d8 100644 --- a/src/semantic-router/pkg/observability/metrics/metrics.go +++ b/src/semantic-router/pkg/observability/metrics/metrics.go @@ -380,6 +380,24 @@ var ( }, []string{"fallback_reason", "fallback_strategy"}, ) + + // ResponsesAdapterRequests counts requests handled via the Responses adapter + ResponsesAdapterRequests = promauto.NewCounterVec( + prometheus.CounterOpts{ + Name: "llm_responses_adapter_requests_total", + Help: "Total number of /v1/responses requests handled by the adapter", + }, + []string{"streaming"}, + ) + + // ResponsesAdapterSSEEvents counts emitted Responses SSE events during translation + ResponsesAdapterSSEEvents = promauto.NewCounterVec( + prometheus.CounterOpts{ + Name: "llm_responses_adapter_sse_events_total", + Help: "Total number of Responses SSE events emitted by the adapter", + }, + []string{"event_type"}, + ) ) // RecordModelRequest increments the counter for requests to a specific model diff --git a/website/docs/api/router.md b/website/docs/api/router.md index c86cde4fc..5540b2122 100644 --- a/website/docs/api/router.md +++ b/website/docs/api/router.md @@ -11,17 +11,21 @@ The Semantic Router operates as an ExtProc server that processes HTTP requests t ### Ports and endpoint mapping - 8801 (HTTP, Envoy public entry) + - Typical client entry for OpenAI-compatible requests like `POST /v1/chat/completions`. - Can proxy `GET /v1/models` to Router 8080 if you add an Envoy route; otherwise `/v1/models` at 8801 may return “no healthy upstream”. + - Experimental: `POST /v1/responses` supported via compatibility adapter when `enable_responses_adapter: true` in config. Text-only inputs are mapped to legacy Chat Completions under the hood; streaming maps will be added in a future release. - 8080 (HTTP, Classification API) - - `GET /v1/models` → OpenAI-compatible model list (includes synthetic `MoM`) - - `GET /health` → Classification API health + + - `GET /v1/models` → OpenAI-compatible model list (includes synthetic `MoM`) + - `GET /health` → Classification API health - `GET /info/models` → Loaded classifier models + system info - `GET /info/classifier` → Classifier configuration details - `POST /api/v1/classify/intent|pii|security|batch` → Direct classification utilities - 50051 (gRPC, ExtProc) + - Envoy External Processing (ExtProc) for in-path classification/routing of `/v1/chat/completions`. - Not an HTTP port; not directly accessible via curl. @@ -36,7 +40,7 @@ sequenceDiagram participant Envoy participant Router participant Backend - + Client->>Envoy: POST /v1/chat/completions Envoy->>Router: ExtProc Request Router->>Router: Classify & Route @@ -63,9 +67,24 @@ Lists available models and includes a synthetic "MoM" (Mixture of Models) model { "object": "list", "data": [ - { "id": "MoM", "object": "model", "created": 1726890000, "owned_by": "semantic-router" }, - { "id": "gpt-4o-mini", "object": "model", "created": 1726890000, "owned_by": "upstream-endpoint" }, - { "id": "llama-3.1-8b-instruct", "object": "model", "created": 1726890000, "owned_by": "upstream-endpoint" } + { + "id": "MoM", + "object": "model", + "created": 1726890000, + "owned_by": "semantic-router" + }, + { + "id": "gpt-4o-mini", + "object": "model", + "created": 1726890000, + "owned_by": "upstream-endpoint" + }, + { + "id": "llama-3.1-8b-instruct", + "object": "model", + "created": 1726890000, + "owned_by": "upstream-endpoint" + } ] } ``` @@ -86,7 +105,7 @@ Notes: "model": "gpt-3.5-turbo", "messages": [ { - "role": "user", + "role": "user", "content": "What is the derivative of x^2?" } ], @@ -109,7 +128,7 @@ Notes: ```json { "id": "chatcmpl-abc123", - "object": "chat.completion", + "object": "chat.completion", "created": 1677858242, "model": "gpt-3.5-turbo", "choices": [ @@ -156,12 +175,12 @@ The router adds metadata headers to both requests and responses: ### Response Headers (Added by Router) -| Header | Description | Example | -|--------|-------------|---------| -| `x-processing-time` | Total processing time (ms) | `45` | -| `x-classification-time` | Classification time (ms) | `12` | -| `x-security-checks` | Security check results | `pii:false,jailbreak:false` | -| `x-tools-selected` | Number of tools selected | `2` | +| Header | Description | Example | +| ----------------------- | -------------------------- | --------------------------- | +| `x-processing-time` | Total processing time (ms) | `45` | +| `x-classification-time` | Classification time (ms) | `12` | +| `x-security-checks` | Security check results | `pii:false,jailbreak:false` | +| `x-tools-selected` | Number of tools selected | `2` | ## Health Check API @@ -178,7 +197,7 @@ The router provides health check endpoints for monitoring: "uptime": 3600, "models": { "category_classifier": "loaded", - "pii_detector": "loaded", + "pii_detector": "loaded", "jailbreak_guard": "loaded" }, "cache": { @@ -188,7 +207,7 @@ The router provides health check endpoints for monitoring: }, "endpoints": { "endpoint1": "healthy", - "endpoint2": "healthy", + "endpoint2": "healthy", "endpoint3": "degraded" } } @@ -211,7 +230,7 @@ semantic_router_request_duration_seconds{endpoint="endpoint1"} 0.045 semantic_router_classification_accuracy{category="mathematics"} 0.94 semantic_router_classification_duration_seconds 0.012 -# Cache metrics +# Cache metrics semantic_router_cache_hit_ratio 0.73 semantic_router_cache_size 1247 @@ -231,6 +250,7 @@ llm_request_errors_total{model="phi4",reason="pii_policy_denied"} 8 The router exposes dedicated Prometheus counters to monitor reasoning mode decisions and template usage across model families. These metrics are emitted by the router and can be scraped by your Prometheus server. - `llm_reasoning_decisions_total{category, model, enabled, effort}` + - Description: Count of reasoning decisions made per category and selected model, with whether reasoning was enabled and the applied effort level. - Labels: - category: category name determined during routing @@ -239,6 +259,7 @@ The router exposes dedicated Prometheus counters to monitor reasoning mode decis - effort: effort level used when enabled (e.g., low|medium|high) - `llm_reasoning_template_usage_total{family, param}` + - Description: Count of times a model-family-specific template parameter was applied to requests. - Labels: - family: normalized model family (e.g., qwen3, deepseek, gpt-oss, gpt) @@ -274,6 +295,7 @@ sum by (family, effort) ( The router exposes additional metrics for cost accounting and routing decisions. - `llm_model_cost_total{model, currency}` + - Description: Total accumulated cost attributed to each model (computed from token usage and per-1M pricing), labeled by currency. - Labels: - model: model name used for the request @@ -383,7 +405,7 @@ model_config: Notes: - Pricing is optional; if omitted, cost is treated as 0 and only token metrics are emitted. -- Cost is computed as: (prompt_tokens * prompt_per_1m + completion_tokens * completion_per_1m) / 1_000_000 (in the configured currency). +- Cost is computed as: (prompt_tokens _ prompt_per_1m + completion_tokens _ completion_per_1m) / 1_000_000 (in the configured currency). ## gRPC ExtProc API @@ -463,14 +485,14 @@ func (r *Router) handleRequestBody(body *ProcessingRequest_RequestBody) *Process ### HTTP Status Codes -| Status | Description | -|--------|-------------| -| 200 | Success | -| 400 | Bad Request (malformed input) | -| 403 | Forbidden (security violation) | -| 429 | Too Many Requests (rate limited) | -| 500 | Internal Server Error | -| 503 | Service Unavailable (backend down) | +| Status | Description | +| ------ | ---------------------------------- | +| 200 | Success | +| 400 | Bad Request (malformed input) | +| 403 | Forbidden (security violation) | +| 429 | Too Many Requests (rate limited) | +| 500 | Internal Server Error | +| 503 | Service Unavailable (backend down) | ## Configuration API @@ -499,15 +521,17 @@ For real-time streaming responses: **Endpoint:** `ws://localhost:8801/v1/chat/stream` ```javascript -const ws = new WebSocket('ws://localhost:8801/v1/chat/stream'); +const ws = new WebSocket("ws://localhost:8801/v1/chat/stream"); -ws.send(JSON.stringify({ - "model": "gpt-3.5-turbo", - "messages": [{"role": "user", "content": "Tell me a story"}], - "stream": true -})); +ws.send( + JSON.stringify({ + model: "gpt-3.5-turbo", + messages: [{ role: "user", content: "Tell me a story" }], + stream: true, + }) +); -ws.onmessage = function(event) { +ws.onmessage = function (event) { const chunk = JSON.parse(event.data); console.log(chunk.choices[0].delta.content); }; @@ -523,7 +547,7 @@ import requests class SemanticRouterClient: def __init__(self, base_url="http://localhost:8801"): self.base_url = base_url - + def chat_completion(self, messages, model="gpt-3.5-turbo", **kwargs): response = requests.post( f"{self.base_url}/v1/chat/completions", @@ -534,7 +558,7 @@ class SemanticRouterClient: } ) return response.json() - + def get_health(self): response = requests.get(f"{self.base_url}/health") return response.json() @@ -550,36 +574,36 @@ result = client.chat_completion([ ```javascript class SemanticRouterClient { - constructor(baseUrl = 'http://localhost:8801') { - this.baseUrl = baseUrl; - } - - async chatCompletion(messages, model = 'gpt-3.5-turbo', options = {}) { - const response = await fetch(`${this.baseUrl}/v1/chat/completions`, { - method: 'POST', - headers: { - 'Content-Type': 'application/json' - }, - body: JSON.stringify({ - model, - messages, - ...options - }) - }); - - return response.json(); - } - - async getHealth() { - const response = await fetch(`${this.baseUrl}/health`); - return response.json(); - } + constructor(baseUrl = "http://localhost:8801") { + this.baseUrl = baseUrl; + } + + async chatCompletion(messages, model = "gpt-3.5-turbo", options = {}) { + const response = await fetch(`${this.baseUrl}/v1/chat/completions`, { + method: "POST", + headers: { + "Content-Type": "application/json", + }, + body: JSON.stringify({ + model, + messages, + ...options, + }), + }); + + return response.json(); + } + + async getHealth() { + const response = await fetch(`${this.baseUrl}/health`); + return response.json(); + } } // Usage const client = new SemanticRouterClient(); const result = await client.chatCompletion([ - { role: 'user', content: 'Solve x^2 + 5x + 6 = 0' } + { role: "user", content: "Solve x^2 + 5x + 6 = 0" }, ]); ``` @@ -621,7 +645,7 @@ X-RateLimit-Retry-After: 60 # Include relevant context messages = [ { - "role": "system", + "role": "system", "content": "You are a mathematics tutor." }, { @@ -651,7 +675,7 @@ try: handle_router_error(response['error']) else: process_response(response) - + except requests.exceptions.Timeout: handle_timeout_error() except requests.exceptions.ConnectionError: