Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
524 changes: 266 additions & 258 deletions config/config.yaml

Large diffs are not rendered by default.

21 changes: 21 additions & 0 deletions src/semantic-router/pkg/config/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -192,6 +192,13 @@ type VLLMEndpoint struct {
}

// ModelParams represents configuration for model-specific parameters
type ModelPricing struct {
// Price in USD per 1M prompt tokens
PromptUSDPer1M float64 `yaml:"prompt_usd_per_1m,omitempty"`
// Price in USD per 1M completion tokens
CompletionUSDPer1M float64 `yaml:"completion_usd_per_1m,omitempty"`
}

type ModelParams struct {
// Number of parameters in the model
ParamCount float64 `yaml:"param_count"`
Expand All @@ -207,6 +214,9 @@ type ModelParams struct {

// Preferred endpoints for this model (optional)
PreferredEndpoints []string `yaml:"preferred_endpoints,omitempty"`

// Optional pricing used for cost computation
Pricing ModelPricing `yaml:"pricing,omitempty"`
}

// PIIPolicy represents the PII (Personally Identifiable Information) policy for a model
Expand Down Expand Up @@ -364,6 +374,17 @@ func (c *RouterConfig) GetModelContextSize(modelName string, defaultValue float6
return defaultValue
}

// GetModelPricing returns pricing in USD per 1M tokens for prompt and completion.
func (c *RouterConfig) GetModelPricing(modelName string) (promptUSDPer1M float64, completionUSDPer1M float64, ok bool) {
if modelConfig, okc := c.ModelConfig[modelName]; okc {
p := modelConfig.Pricing
if p.PromptUSDPer1M != 0 || p.CompletionUSDPer1M != 0 {
return p.PromptUSDPer1M, p.CompletionUSDPer1M, true
}
}
return 0, 0, false
}

// GetModelPIIPolicy returns the PII policy for a given model
// If the model is not found in the config, returns a default policy that allows all PII
func (c *RouterConfig) GetModelPIIPolicy(modelName string) PIIPolicy {
Expand Down
66 changes: 64 additions & 2 deletions src/semantic-router/pkg/extproc/request_handler.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ import (

"github.com/vllm-project/semantic-router/semantic-router/pkg/cache"
"github.com/vllm-project/semantic-router/semantic-router/pkg/metrics"
"github.com/vllm-project/semantic-router/semantic-router/pkg/observability"
"github.com/vllm-project/semantic-router/semantic-router/pkg/utils/http"
"github.com/vllm-project/semantic-router/semantic-router/pkg/utils/pii"
)
Expand Down Expand Up @@ -173,7 +174,7 @@ func (r *OpenAIRouter) handleRequestBody(v *ext_proc.ProcessingRequest_RequestBo
userContent, nonUserMessages := extractUserAndNonUserContent(openAIRequest)

// Perform security checks
if response, shouldReturn := r.performSecurityChecks(userContent, nonUserMessages); shouldReturn {
if response, shouldReturn := r.performSecurityChecks(ctx, userContent, nonUserMessages); shouldReturn {
return response, nil
}

Expand All @@ -187,7 +188,7 @@ func (r *OpenAIRouter) handleRequestBody(v *ext_proc.ProcessingRequest_RequestBo
}

// performSecurityChecks performs PII and jailbreak detection
func (r *OpenAIRouter) performSecurityChecks(userContent string, nonUserMessages []string) (*ext_proc.ProcessingResponse, bool) {
func (r *OpenAIRouter) performSecurityChecks(ctx *RequestContext, userContent string, nonUserMessages []string) (*ext_proc.ProcessingResponse, bool) {
// Perform PII classification on all message content
allContent := pii.ExtractAllContent(userContent, nonUserMessages)

Expand All @@ -212,6 +213,13 @@ func (r *OpenAIRouter) performSecurityChecks(userContent string, nonUserMessages
log.Printf("JAILBREAK ATTEMPT BLOCKED: %s (confidence: %.3f)", jailbreakType, confidence)

// Return immediate jailbreak violation response
// Structured log for security block
observability.LogEvent("security_block", map[string]interface{}{
"reason_code": "jailbreak_detected",
"jailbreak_type": jailbreakType,
"confidence": confidence,
"request_id": ctx.RequestID,
})
jailbreakResponse := http.CreateJailbreakViolationResponse(jailbreakType, confidence)
return jailbreakResponse, true
} else {
Expand Down Expand Up @@ -241,6 +249,13 @@ func (r *OpenAIRouter) handleCaching(ctx *RequestContext) (*ext_proc.ProcessingR
if err != nil {
log.Printf("Error searching cache: %v", err)
} else if found {
// Record and log cache hit
metrics.RecordCacheHit()
observability.LogEvent("cache_hit", map[string]interface{}{
"request_id": ctx.RequestID,
"model": requestModel,
"query": requestQuery,
})
// Return immediate response from cache
response := http.CreateCacheHitResponse(cachedResponse)
return response, true
Expand Down Expand Up @@ -313,19 +328,33 @@ func (r *OpenAIRouter) handleModelRouting(openAIRequest *openai.ChatCompletionNe
// Select the best allowed model from this category
matchedModel = r.Classifier.SelectBestModelFromList(allowedModels, categoryName)
log.Printf("Selected alternative model %s that passes PII policy", matchedModel)
// Record reason code for selecting alternative due to PII
metrics.RecordRoutingReasonCode("pii_policy_alternative_selected", matchedModel)
} else {
log.Printf("No models in category %s pass PII policy, using default", categoryName)
matchedModel = r.Config.DefaultModel
// Check if default model passes policy
defaultAllowed, defaultDeniedPII, _ := r.PIIChecker.CheckPolicy(matchedModel, detectedPII)
if !defaultAllowed {
log.Printf("Default model also violates PII policy, returning error")
observability.LogEvent("routing_block", map[string]interface{}{
"reason_code": "pii_policy_denied_default_model",
"request_id": ctx.RequestID,
"model": matchedModel,
"denied_pii": defaultDeniedPII,
})
piiResponse := http.CreatePIIViolationResponse(matchedModel, defaultDeniedPII)
return piiResponse, nil
}
}
} else {
log.Printf("Could not determine category, returning PII violation for model %s", matchedModel)
observability.LogEvent("routing_block", map[string]interface{}{
"reason_code": "pii_policy_denied",
"request_id": ctx.RequestID,
"model": matchedModel,
"denied_pii": deniedPII,
})
piiResponse := http.CreatePIIViolationResponse(matchedModel, deniedPII)
return piiResponse, nil
}
Expand Down Expand Up @@ -424,6 +453,20 @@ func (r *OpenAIRouter) handleModelRouting(openAIRequest *openai.ChatCompletionNe
}

log.Printf("Use new model: %s", matchedModel)

// Structured log for routing decision (auto)
observability.LogEvent("routing_decision", map[string]interface{}{
"reason_code": "auto_routing",
"request_id": ctx.RequestID,
"original_model": originalModel,
"selected_model": matchedModel,
"category": categoryName,
"reasoning_enabled": useReasoning,
"reasoning_effort": effortForMetrics,
"selected_endpoint": selectedEndpoint,
"routing_latency_ms": time.Since(ctx.ProcessingStartTime).Milliseconds(),
})
metrics.RecordRoutingReasonCode("auto_routing", matchedModel)
}
}
} else if originalModel != "auto" {
Expand All @@ -438,6 +481,12 @@ func (r *OpenAIRouter) handleModelRouting(openAIRequest *openai.ChatCompletionNe
// Continue with request on error
} else if !allowed {
log.Printf("Model %s violates PII policy, returning error", originalModel)
observability.LogEvent("routing_block", map[string]interface{}{
"reason_code": "pii_policy_denied",
"request_id": ctx.RequestID,
"model": originalModel,
"denied_pii": deniedPII,
})
piiResponse := http.CreatePIIViolationResponse(originalModel, deniedPII)
return piiResponse, nil
}
Expand Down Expand Up @@ -472,6 +521,19 @@ func (r *OpenAIRouter) handleModelRouting(openAIRequest *openai.ChatCompletionNe
},
},
}
// Structured log for routing decision (explicit model)
observability.LogEvent("routing_decision", map[string]interface{}{
"reason_code": "model_specified",
"request_id": ctx.RequestID,
"original_model": originalModel,
"selected_model": originalModel,
"category": "",
"reasoning_enabled": false,
"reasoning_effort": "",
"selected_endpoint": selectedEndpoint,
"routing_latency_ms": time.Since(ctx.ProcessingStartTime).Milliseconds(),
})
metrics.RecordRoutingReasonCode("model_specified", originalModel)
}

// Save the actual model that will be used for token tracking
Expand Down
30 changes: 30 additions & 0 deletions src/semantic-router/pkg/extproc/response_handler.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ import (

"github.com/openai/openai-go"
"github.com/vllm-project/semantic-router/semantic-router/pkg/metrics"
"github.com/vllm-project/semantic-router/semantic-router/pkg/observability"
)

// handleResponseHeaders processes the response headers
Expand Down Expand Up @@ -52,6 +53,35 @@ func (r *OpenAIRouter) handleResponseBody(v *ext_proc.ProcessingRequest_Response
)
metrics.RecordModelCompletionLatency(ctx.RequestModel, completionLatency.Seconds())
r.Classifier.DecrementModelLoad(ctx.RequestModel)

// Compute and record cost if pricing is configured
if r.Config != nil {
promptRatePer1M, completionRatePer1M, ok := r.Config.GetModelPricing(ctx.RequestModel)
if ok {
costUSD := (float64(promptTokens)*promptRatePer1M + float64(completionTokens)*completionRatePer1M) / 1_000_000.0
metrics.RecordModelCostUSD(ctx.RequestModel, costUSD)
observability.LogEvent("llm_usage", map[string]interface{}{
"request_id": ctx.RequestID,
"model": ctx.RequestModel,
"prompt_tokens": promptTokens,
"completion_tokens": completionTokens,
"total_tokens": promptTokens + completionTokens,
"completion_latency_ms": completionLatency.Milliseconds(),
"cost_usd": costUSD,
})
} else {
observability.LogEvent("llm_usage", map[string]interface{}{
"request_id": ctx.RequestID,
"model": ctx.RequestModel,
"prompt_tokens": promptTokens,
"completion_tokens": completionTokens,
"total_tokens": promptTokens + completionTokens,
"completion_latency_ms": completionLatency.Milliseconds(),
"cost_usd": 0.0,
"pricing": "not_configured",
})
}
}
}

// Check if this request has a pending cache entry
Expand Down
37 changes: 37 additions & 0 deletions src/semantic-router/pkg/metrics/metrics.go
Original file line number Diff line number Diff line change
Expand Up @@ -102,6 +102,15 @@ var (
[]string{"model"},
)

// ModelCostUSD tracks the total USD cost attributed to each model
ModelCostUSD = promauto.NewCounterVec(
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

should USD be a specific naming suffix for this metric?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

if we want to support more units a dimension in metric should be easier to extend.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Your suggestion is good.
Maybe I can change the metric name to llm_model_cost_total and add a currency label. e.g. llm_model_cost_total{currency="USD"}

But a trade-off needs to be made here.

Pros:

  • Easy to extend, allowing us to add CNY, EUR, etc, in the future as needed
  • Clearer. A metric name for "cost," with currency as the dimension

Cons:

  • Aggregations must be done per-currency (sum by (currency, model) …). It’s easy to accidentally sum different currencies if you forget to filter.

WDYT? @Xunzhuo @rootfs

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For the cons, maybe we should have clear docs around metrics, it is a bit bad to add more metrics when units increase linearly

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

adding a currency label is a good idea!

This comment was marked as outdated.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I created an issue for tracking.
#75

prometheus.CounterOpts{
Name: "llm_model_cost_usd_total",
Help: "The total USD cost attributed to each LLM model",
},
[]string{"model"},
)

// ModelTokens tracks the number of tokens used by each model
ModelTokens = promauto.NewCounterVec(
prometheus.CounterOpts{
Expand Down Expand Up @@ -138,6 +147,15 @@ var (
[]string{"source_model", "target_model"},
)

// RoutingReasonCodes tracks routing decisions by reason_code and model
RoutingReasonCodes = promauto.NewCounterVec(
prometheus.CounterOpts{
Name: "llm_routing_reason_codes_total",
Help: "The total number of routing decisions by reason code and model",
},
[]string{"reason_code", "model"},
)

// ModelCompletionLatency tracks the latency of completions by model
ModelCompletionLatency = promauto.NewHistogramVec(
prometheus.HistogramOpts{
Expand Down Expand Up @@ -238,6 +256,25 @@ func RecordModelTokens(model string, tokens float64) {
ModelTokens.WithLabelValues(model).Add(tokens)
}

// RecordModelCostUSD adds the dollar cost attributed to a specific model
func RecordModelCostUSD(model string, usd float64) {
if usd < 0 {
return
}
ModelCostUSD.WithLabelValues(model).Add(usd)
}

// RecordRoutingReasonCode increments the counter for a routing decision reason code and model
func RecordRoutingReasonCode(reasonCode, model string) {
if reasonCode == "" {
reasonCode = "unknown"
}
if model == "" {
model = "unknown"
}
RoutingReasonCodes.WithLabelValues(reasonCode, model).Inc()
}

// RecordModelTokensDetailed records detailed token usage (prompt and completion)
func RecordModelTokensDetailed(model string, promptTokens, completionTokens float64) {
// Record in both the aggregated and detailed metrics
Expand Down
28 changes: 28 additions & 0 deletions src/semantic-router/pkg/observability/logging.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
package observability

import (
"encoding/json"
"log"
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I prefer zap, but I'll create an issue to see what others think.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this one #76

"time"
)

// LogEvent emits a structured JSON log line with a standard envelope
// Fields provided by callers take precedence and will not be overwritten.
func LogEvent(event string, fields map[string]interface{}) {
if fields == nil {
fields = map[string]interface{}{}
}
if _, ok := fields["event"]; !ok {
fields["event"] = event
}
if _, ok := fields["ts"]; !ok {
fields["ts"] = time.Now().UTC().Format(time.RFC3339Nano)
}
b, err := json.Marshal(fields)
if err != nil {
// Fallback to regular log on marshal error
log.Printf("event=%s marshal_error=%v fields_len=%d", event, err, len(fields))
return
}
log.Println(string(b))
}
49 changes: 49 additions & 0 deletions website/docs/api/router.md
Original file line number Diff line number Diff line change
Expand Up @@ -218,6 +218,55 @@ sum by (family, effort) (
)
```

### Cost and Routing Metrics

The router exposes additional metrics for cost accounting and routing decisions.

- `llm_model_cost_usd_total{model}`
- Description: Total accumulated USD cost attributed to each model (computed from token usage and per-1M pricing).
- Labels:
- model: model name used for the request

- `llm_routing_reason_codes_total{reason_code, model}`
- Description: Count of routing decisions by reason code and selected model.
- Labels:
- reason_code: why a routing decision happened (e.g., auto_routing, model_specified, pii_policy_alternative_selected)
- model: final selected model

Example PromQL:

```prometheus
# Cost by model over the last hour
sum by (model) (increase(llm_model_cost_usd_total[1h]))

# Routing decisions by reason code over the last 15 minutes
sum by (reason_code) (increase(llm_routing_reason_codes_total[15m]))
```

### Pricing Configuration

Provide per-1M pricing for your models so the router can compute request cost and emit metrics/logs.

```yaml
model_config:
phi4:
pricing:
prompt_usd_per_1m: 200.0
completion_usd_per_1m: 600.0
"mistral-small3.1":
pricing:
prompt_usd_per_1m: 300.0
completion_usd_per_1m: 900.0
gemma3:27b:
pricing:
prompt_usd_per_1m: 500.0
completion_usd_per_1m: 1500.0
```

Notes:
- Pricing is optional; if omitted, cost is treated as 0 and only token metrics are emitted.
- Cost is computed as: (prompt_tokens * prompt_usd_per_1m + completion_tokens * completion_usd_per_1m) / 1_000_000.

## gRPC ExtProc API

For direct integration with the ExtProc protocol:
Expand Down
Loading
Loading