diff --git a/config/config-hybrid-example.yaml b/config/config-hybrid-example.yaml new file mode 100644 index 00000000..9b70051a --- /dev/null +++ b/config/config-hybrid-example.yaml @@ -0,0 +1,235 @@ +bert_model: + model_id: sentence-transformers/all-MiniLM-L12-v2 + threshold: 0.6 + use_cpu: true + +semantic_cache: + enabled: true + backend_type: "memory" + similarity_threshold: 0.8 + max_entries: 1000 + ttl_seconds: 3600 + eviction_policy: "fifo" + +tools: + enabled: true + top_k: 3 + similarity_threshold: 0.2 + tools_db_path: "config/tools_db.json" + fallback_to_empty: true + +prompt_guard: + enabled: true + use_modernbert: true + model_id: "models/jailbreak_classifier_modernbert-base_model" + threshold: 0.7 + use_cpu: true + jailbreak_mapping_path: "models/jailbreak_classifier_modernbert-base_model/jailbreak_type_mapping.json" + +# vLLM Endpoints Configuration +vllm_endpoints: + - name: "endpoint1" + address: "127.0.0.1" + port: 8000 + models: + - "openai/gpt-oss-20b" + - "math-specialized-model" + weight: 1 + health_check_path: "/health" + +model_config: + "openai/gpt-oss-20b": + reasoning_family: "gpt-oss" + preferred_endpoints: ["endpoint1"] + pii_policy: + allow_by_default: true + "math-specialized-model": + reasoning_family: "gpt-oss" + preferred_endpoints: ["endpoint1"] + pii_policy: + allow_by_default: true + +# Classifier configuration +classifier: + category_model: + model_id: "models/category_classifier_modernbert-base_model" + use_modernbert: true + threshold: 0.6 + use_cpu: true + category_mapping_path: "models/category_classifier_modernbert-base_model/category_mapping.json" + pii_model: + model_id: "models/pii_classifier_modernbert-base_presidio_token_model" + use_modernbert: true + threshold: 0.7 + use_cpu: true + pii_mapping_path: "models/pii_classifier_modernbert-base_presidio_token_model/pii_type_mapping.json" + +# Hybrid Routing Configuration +routing_strategy: + type: "hybrid" # Options: "model", "rules", "hybrid" + + model_routing: + enabled: true + fallback_to_rules: false + confidence_threshold: 0.7 + + rule_routing: + enabled: true + fallback_to_model: true + evaluation_timeout_ms: 100 + +# Custom Routing Rules +routing_rules: + - name: "enterprise-math-routing" + description: "Route complex math problems to specialized model" + enabled: true + priority: 100 + + conditions: + - type: "category_classification" + category: "math" + threshold: 0.8 + operator: "gte" + - type: "content_complexity" + metric: "token_count" + threshold: 50 + operator: "gt" + + actions: + - type: "route_to_model" + model: "math-specialized-model" + - type: "enable_reasoning" + enable_reasoning: true + reasoning_effort: "high" + + evaluation: + timeout_ms: 100 + fallback_action: "use_model_classification" + + - name: "premium-user-routing" + description: "Route premium users to best available models" + enabled: true + priority: 90 + + conditions: + - type: "request_header" + header_name: "x-user-tier" + value: "premium" + operator: "equals" + + actions: + - type: "route_to_model" + model: "openai/gpt-oss-20b" + - type: "enable_reasoning" + enable_reasoning: true + reasoning_effort: "medium" + + - name: "content-filter" + description: "Block inappropriate content" + enabled: true + priority: 150 + + conditions: + - type: "pattern_match" + pattern_match: "inappropriate" + operator: "contains" + + actions: + - type: "block_request" + block_with_message: "Content violates usage policy" + + - name: "simple-query-optimization" + description: "Route simple queries to efficient models" + enabled: true + priority: 50 + + conditions: + - type: "content_complexity" + metric: "token_count" + threshold: 20 + operator: "lt" + + actions: + - type: "route_to_model" + model: "openai/gpt-oss-20b" + - type: "enable_reasoning" + enable_reasoning: false + +# Categories with model scores (used by model-based routing) +categories: + - name: business + model_scores: + - model: openai/gpt-oss-20b + score: 0.7 + use_reasoning: false + - name: law + model_scores: + - model: openai/gpt-oss-20b + score: 0.4 + use_reasoning: false + - name: psychology + model_scores: + - model: openai/gpt-oss-20b + score: 0.6 + use_reasoning: false + - name: biology + model_scores: + - model: openai/gpt-oss-20b + score: 0.9 + use_reasoning: false + - name: chemistry + model_scores: + - model: openai/gpt-oss-20b + score: 0.6 + use_reasoning: true + - name: history + model_scores: + - model: openai/gpt-oss-20b + score: 0.7 + use_reasoning: false + - name: other + model_scores: + - model: openai/gpt-oss-20b + score: 0.7 + use_reasoning: false + - name: health + model_scores: + - model: openai/gpt-oss-20b + score: 0.8 + use_reasoning: false + - name: math + model_scores: + - model: math-specialized-model + score: 0.9 + use_reasoning: true + - model: openai/gpt-oss-20b + score: 0.7 + use_reasoning: true + - name: computer science + model_scores: + - model: openai/gpt-oss-20b + score: 0.8 + use_reasoning: true + - name: economics + model_scores: + - model: openai/gpt-oss-20b + score: 0.6 + use_reasoning: false + - name: engineering + model_scores: + - model: openai/gpt-oss-20b + score: 0.8 + use_reasoning: true + - name: physics + model_scores: + - model: openai/gpt-oss-20b + score: 0.8 + use_reasoning: true + +default_model: openai/gpt-oss-20b +default_reasoning_effort: medium + +reasoning_families: + gpt-oss: + type: "reasoning_effort" + parameter: "reasoning_effort" \ No newline at end of file diff --git a/examples/CONFIGURATION_COMPARISON.md b/examples/CONFIGURATION_COMPARISON.md new file mode 100644 index 00000000..4e1dd0fa --- /dev/null +++ b/examples/CONFIGURATION_COMPARISON.md @@ -0,0 +1,136 @@ +# Hybrid Routing Configuration Comparison + +## Before: Model-Only Routing (Black Box) + +```yaml +# Original semantic router - limited interpretability +categories: + - name: math + model_scores: + - model: openai/gpt-oss-20b + score: 0.9 + use_reasoning: true + +default_model: openai/gpt-oss-20b + +# Problems: +# - No visibility into routing decisions +# - Cannot customize routing logic beyond categories +# - No threshold control per use case +# - No request blocking capabilities +# - No explanation of why a model was selected +``` + +## After: Hybrid Routing (Interpretable & Configurable) + +```yaml +# New hybrid approach - full control and transparency +routing_strategy: + type: "hybrid" + model_routing: + enabled: true + confidence_threshold: 0.7 + rule_routing: + enabled: true + fallback_to_model: true + +routing_rules: + - name: "enterprise-math-routing" + description: "Route complex math to specialized model" + enabled: true + priority: 100 + + conditions: + - type: "category_classification" + category: "math" + threshold: 0.8 + operator: "gte" + - type: "content_complexity" + metric: "token_count" + threshold: 50 + operator: "gt" + + actions: + - type: "route_to_model" + model: "math-specialized-model" + - type: "enable_reasoning" + enable_reasoning: true + reasoning_effort: "high" + + - name: "premium-user-routing" + description: "Premium users get best models" + enabled: true + priority: 90 + + conditions: + - type: "request_header" + header_name: "x-user-tier" + value: "premium" + operator: "equals" + + actions: + - type: "route_to_model" + model: "premium-model" + + - name: "content-filter" + description: "Block inappropriate content" + enabled: true + priority: 150 + + conditions: + - type: "pattern_match" + pattern_match: "inappropriate" + operator: "contains" + + actions: + - type: "block_request" + block_with_message: "Content violates policy" + +# Benefits: +# ✅ Full transparency: Know exactly why each decision was made +# ✅ Custom logic: Business rules beyond ML categories +# ✅ Configurable thresholds: Fine-tune sensitivity per use case +# ✅ Request blocking: Security and policy enforcement +# ✅ Rule precedence: Control decision priority +# ✅ Real-time updates: Modify rules without restart +# ✅ Audit trail: Detailed decision explanations +``` + +## Decision Explanation Example + +```json +{ + "rule_matched": true, + "selected_model": "math-specialized-model", + "use_reasoning": true, + "reasoning_effort": "high", + "explanation": { + "decision_type": "rule_based", + "rule_name": "enterprise-math-routing", + "matched_conditions": [ + { + "condition_type": "pattern_match", + "matched": true, + "details": "Pattern 'math' found in content" + }, + { + "condition_type": "content_complexity", + "matched": true, + "actual_value": 15, + "threshold": 50, + "details": "token_count: 15 > 50" + } + ], + "executed_actions": [ + { + "action_type": "route_to_model", + "executed": true, + "details": "Routed to model: math-specialized-model" + } + ], + "reasoning": "Rule 'enterprise-math-routing' matched based on content analysis", + "confidence": 0.95 + }, + "evaluation_time_ms": 2 +} +``` \ No newline at end of file diff --git a/examples/hybrid-routing-demo.go b/examples/hybrid-routing-demo.go new file mode 100644 index 00000000..e9a5639b --- /dev/null +++ b/examples/hybrid-routing-demo.go @@ -0,0 +1,211 @@ +package main + +import ( + "context" + "fmt" + "log" + + "github.com/vllm-project/semantic-router/src/semantic-router/pkg/config" + "github.com/vllm-project/semantic-router/src/semantic-router/pkg/rules" +) + +func main() { + fmt.Println("=== Semantic Router - Hybrid Routing Demonstration ===") + + // Create example configuration with routing rules + cfg := &config.RouterConfig{ + DefaultModel: "default-model", + DefaultReasoningEffort: "medium", + RoutingStrategy: config.RoutingStrategyConfig{ + Type: "hybrid", + ModelRouting: config.ModelRoutingConfig{ + Enabled: true, + FallbackToRules: false, + ConfidenceThreshold: 0.7, + }, + RuleRouting: config.RuleRoutingConfig{ + Enabled: true, + FallbackToModel: true, + EvaluationTimeoutMs: 100, + }, + }, + RoutingRules: []config.RoutingRule{ + { + Name: "math-specialization", + Description: "Route math problems to specialized model", + Enabled: true, + Priority: 100, + Conditions: []config.RuleCondition{ + { + Type: "pattern_match", + PatternMatch: "math", + Operator: "contains", + }, + { + Type: "content_complexity", + Metric: "token_count", + Threshold: 5, + Operator: "gte", + }, + }, + Actions: []config.RuleAction{ + { + Type: "route_to_model", + Model: "math-specialized-model", + }, + { + Type: "enable_reasoning", + EnableReasoning: true, + ReasoningEffort: "high", + }, + }, + }, + { + Name: "premium-user", + Description: "Route premium users to best models", + Enabled: true, + Priority: 90, + Conditions: []config.RuleCondition{ + { + Type: "request_header", + HeaderName: "x-user-tier", + Value: "premium", + Operator: "equals", + }, + }, + Actions: []config.RuleAction{ + { + Type: "route_to_model", + Model: "premium-model", + }, + }, + }, + { + Name: "content-filter", + Description: "Block inappropriate content", + Enabled: true, + Priority: 150, + Conditions: []config.RuleCondition{ + { + Type: "pattern_match", + PatternMatch: "forbidden", + Operator: "contains", + }, + }, + Actions: []config.RuleAction{ + { + Type: "block_request", + BlockWithMessage: "Content violates usage policy", + }, + }, + }, + }, + } + + // Create hybrid router (without classifier for demo) + hybridRouter := rules.NewHybridRouter(cfg, nil) + + fmt.Printf("Created hybrid router with %d rules\n\n", hybridRouter.GetRuleCount()) + + // Test scenarios + testScenarios := []struct { + name string + content string + headers map[string]string + }{ + { + name: "Math Problem", + content: "solve this complex math equation: 2x + 3y = 10", + headers: map[string]string{}, + }, + { + name: "Premium User Request", + content: "write a story about cats", + headers: map[string]string{"x-user-tier": "premium"}, + }, + { + name: "Blocked Content", + content: "this contains forbidden content", + headers: map[string]string{}, + }, + { + name: "Simple Query", + content: "hi", + headers: map[string]string{}, + }, + { + name: "Long Math Problem (Premium User)", + content: "solve this very complex mathematical proof involving advanced calculus and linear algebra", + headers: map[string]string{"x-user-tier": "premium"}, + }, + } + + for i, scenario := range testScenarios { + fmt.Printf("--- Test %d: %s ---\n", i+1, scenario.name) + fmt.Printf("Content: %s\n", scenario.content) + fmt.Printf("Headers: %v\n", scenario.headers) + + // Route the request + decision, err := hybridRouter.RouteRequest( + context.Background(), + scenario.content, + nil, + scenario.headers, + "auto", + ) + + if err != nil { + log.Printf("Error routing request: %v", err) + continue + } + + // Display results + fmt.Printf("\n🎯 Routing Decision:\n") + if decision.RuleMatched { + fmt.Printf(" ✅ Rule Matched: %s\n", decision.MatchedRule.Name) + fmt.Printf(" 📝 Rule Description: %s\n", decision.MatchedRule.Description) + } else { + fmt.Printf(" 📊 Model-based routing (no rules matched)\n") + } + + fmt.Printf(" 🚀 Selected Model: %s\n", decision.SelectedModel) + fmt.Printf(" 🧠 Use Reasoning: %v\n", decision.UseReasoning) + if decision.UseReasoning { + fmt.Printf(" 💪 Reasoning Effort: %s\n", decision.ReasoningEffort) + } + + if decision.BlockRequest { + fmt.Printf(" 🚫 Request Blocked: %s\n", decision.BlockMessage) + } + + fmt.Printf(" ⏱️ Evaluation Time: %d ms\n", decision.EvaluationTimeMs) + fmt.Printf(" 🔍 Decision Type: %s\n", decision.Explanation.DecisionType) + fmt.Printf(" 💡 Reasoning: %s\n", decision.Explanation.Reasoning) + + if decision.RuleMatched && len(decision.Explanation.MatchedConditions) > 0 { + fmt.Printf(" 📋 Matched Conditions:\n") + for _, condition := range decision.Explanation.MatchedConditions { + status := "❌" + if condition.Matched { + status = "✅" + } + fmt.Printf(" %s %s: %s\n", status, condition.ConditionType, condition.Details) + } + } + + if decision.RuleMatched && len(decision.Explanation.ExecutedActions) > 0 { + fmt.Printf(" ⚡ Executed Actions:\n") + for _, action := range decision.Explanation.ExecutedActions { + status := "❌" + if action.Executed { + status = "✅" + } + fmt.Printf(" %s %s: %s\n", status, action.ActionType, action.Details) + } + } + + fmt.Println() + } + + fmt.Println("=== Demonstration Complete ===") +} \ No newline at end of file diff --git a/src/semantic-router/pkg/api/server.go b/src/semantic-router/pkg/api/server.go index 6fe8dfd3..57283ae9 100644 --- a/src/semantic-router/pkg/api/server.go +++ b/src/semantic-router/pkg/api/server.go @@ -12,6 +12,7 @@ import ( "github.com/vllm-project/semantic-router/src/semantic-router/pkg/config" "github.com/vllm-project/semantic-router/src/semantic-router/pkg/metrics" "github.com/vllm-project/semantic-router/src/semantic-router/pkg/observability" + "github.com/vllm-project/semantic-router/src/semantic-router/pkg/rules" "github.com/vllm-project/semantic-router/src/semantic-router/pkg/services" ) @@ -19,6 +20,7 @@ import ( type ClassificationAPIServer struct { classificationSvc *services.ClassificationService config *config.RouterConfig + ruleAPI *rules.RuleManagementAPI // Add rule management API } // ModelsInfoResponse represents the response for models info endpoint @@ -102,6 +104,11 @@ type ClassificationOptions struct { // StartClassificationAPI starts the Classification API server func StartClassificationAPI(configPath string, port int) error { + return StartClassificationAPIWithHybridRouter(configPath, port, nil) +} + +// StartClassificationAPIWithHybridRouter starts the Classification API server with rule management +func StartClassificationAPIWithHybridRouter(configPath string, port int, hybridRouter *rules.HybridRouter) error { // Load configuration cfg, err := config.LoadConfig(configPath) if err != nil { @@ -137,10 +144,18 @@ func StartClassificationAPI(configPath string, port int) error { metrics.SetBatchMetricsConfig(metricsConfig) } + // Create rule management API if hybrid router is provided + var ruleAPI *rules.RuleManagementAPI + if hybridRouter != nil { + ruleAPI = rules.NewRuleManagementAPI(hybridRouter, cfg) + observability.Infof("Rule management API enabled") + } + // Create server instance apiServer := &ClassificationAPIServer{ classificationSvc: classificationSvc, config: cfg, + ruleAPI: ruleAPI, } // Create HTTP server with routes @@ -203,6 +218,12 @@ func (s *ClassificationAPIServer) setupRoutes() *http.ServeMux { mux.HandleFunc("GET /config/classification", s.handleGetConfig) mux.HandleFunc("PUT /config/classification", s.handleUpdateConfig) + // Rule management endpoints (if rule API is available) + if s.ruleAPI != nil { + s.ruleAPI.RegisterRoutes(mux) + observability.Infof("Rule management API routes registered") + } + return mux } diff --git a/src/semantic-router/pkg/config/config.go b/src/semantic-router/pkg/config/config.go index 18828570..5a8d85a0 100644 --- a/src/semantic-router/pkg/config/config.go +++ b/src/semantic-router/pkg/config/config.go @@ -87,6 +87,12 @@ type RouterConfig struct { // API configuration for classification endpoints API APIConfig `yaml:"api"` + + // Routing strategy configuration for hybrid approach + RoutingStrategy RoutingStrategyConfig `yaml:"routing_strategy,omitempty"` + + // Custom routing rules for interpretable routing + RoutingRules []RoutingRule `yaml:"routing_rules,omitempty"` } // APIConfig represents configuration for API endpoints @@ -248,6 +254,106 @@ const ( PIITypeZipCode = "ZIP_CODE" // ZIP/Postal codes ) +// RoutingStrategyConfig represents configuration for routing strategy +type RoutingStrategyConfig struct { + // Strategy type: "model", "rules", or "hybrid" + Type string `yaml:"type,omitempty"` + + // Model-based routing configuration + ModelRouting ModelRoutingConfig `yaml:"model_routing,omitempty"` + + // Rule-based routing configuration + RuleRouting RuleRoutingConfig `yaml:"rule_routing,omitempty"` +} + +// ModelRoutingConfig represents configuration for model-based routing +type ModelRoutingConfig struct { + // Enable model-based routing + Enabled bool `yaml:"enabled"` + + // Fall back to rules if model confidence is low + FallbackToRules bool `yaml:"fallback_to_rules,omitempty"` + + // Confidence threshold for model-based routing + ConfidenceThreshold float64 `yaml:"confidence_threshold,omitempty"` +} + +// RuleRoutingConfig represents configuration for rule-based routing +type RuleRoutingConfig struct { + // Enable rule-based routing + Enabled bool `yaml:"enabled"` + + // Fall back to model if no rules match + FallbackToModel bool `yaml:"fallback_to_model,omitempty"` + + // Timeout for rule evaluation in milliseconds + EvaluationTimeoutMs int `yaml:"evaluation_timeout_ms,omitempty"` +} + +// RoutingRule represents a single routing rule +type RoutingRule struct { + // Rule identification + Name string `yaml:"name"` + Description string `yaml:"description,omitempty"` + Enabled bool `yaml:"enabled"` + Priority int `yaml:"priority,omitempty"` + + // Rule conditions (all must be satisfied) + Conditions []RuleCondition `yaml:"conditions"` + + // Actions to execute when rule matches + Actions []RuleAction `yaml:"actions"` + + // Rule evaluation configuration + Evaluation RuleEvaluation `yaml:"evaluation,omitempty"` +} + +// RuleCondition represents a condition in a routing rule +type RuleCondition struct { + // Condition type + Type string `yaml:"type"` + + // Condition parameters (varies by type) + Category string `yaml:"category,omitempty"` // For category_classification + Threshold float64 `yaml:"threshold,omitempty"` // For threshold-based conditions + Operator string `yaml:"operator,omitempty"` // Comparison operator (gte, gt, lt, lte, equals, contains) + Value string `yaml:"value,omitempty"` // For string/boolean comparisons + Metric string `yaml:"metric,omitempty"` // For content_complexity conditions + Permission string `yaml:"permission,omitempty"` // For user_permission conditions + HeaderName string `yaml:"header_name,omitempty"` // For request_header conditions + PatternMatch string `yaml:"pattern_match,omitempty"` // For pattern matching conditions + TimeRange string `yaml:"time_range,omitempty"` // For time-based conditions + ExternalEndpoint string `yaml:"external_endpoint,omitempty"` // For external API conditions +} + +// RuleAction represents an action to execute when a rule matches +type RuleAction struct { + // Action type + Type string `yaml:"type"` + + // Action parameters (varies by type) + Model string `yaml:"model,omitempty"` // For route_to_model + Weight int `yaml:"weight,omitempty"` // For weighted routing + EnableReasoning bool `yaml:"enable_reasoning,omitempty"` // For enable_reasoning + ReasoningEffort string `yaml:"reasoning_effort,omitempty"` // For reasoning configuration + MaxSteps int `yaml:"max_steps,omitempty"` // For reasoning configuration + Headers map[string]string `yaml:"headers,omitempty"` // For set_headers + BlockWithMessage string `yaml:"block_with_message,omitempty"` // For block_request + RedirectToModel string `yaml:"redirect_to_model,omitempty"` // For redirect actions +} + +// RuleEvaluation represents evaluation configuration for a rule +type RuleEvaluation struct { + // Timeout for rule evaluation in milliseconds + TimeoutMs int `yaml:"timeout_ms,omitempty"` + + // Action to take on evaluation timeout + FallbackAction string `yaml:"fallback_action,omitempty"` + + // Whether to cache evaluation results + CacheResults bool `yaml:"cache_results,omitempty"` +} + // GetCacheSimilarityThreshold returns the effective threshold for the semantic cache func (c *RouterConfig) GetCacheSimilarityThreshold() float32 { if c.SemanticCache.SimilarityThreshold != nil { diff --git a/src/semantic-router/pkg/extproc/model_selector.go b/src/semantic-router/pkg/extproc/model_selector.go index 2f16bbf3..fee34eb3 100644 --- a/src/semantic-router/pkg/extproc/model_selector.go +++ b/src/semantic-router/pkg/extproc/model_selector.go @@ -1,14 +1,66 @@ package extproc import ( + "context" + "github.com/vllm-project/semantic-router/src/semantic-router/pkg/observability" + "github.com/vllm-project/semantic-router/src/semantic-router/pkg/rules" ) -// classifyAndSelectBestModel chooses best models based on category classification and model quality and expected TTFT +// classifyAndSelectBestModel chooses best models based on hybrid routing (rules + ML) func (r *OpenAIRouter) classifyAndSelectBestModel(query string) string { + return r.classifyAndSelectBestModelWithContext(context.Background(), query, nil, "") +} + +// classifyAndSelectBestModelWithContext chooses best models using hybrid routing with full context +func (r *OpenAIRouter) classifyAndSelectBestModelWithContext(ctx context.Context, query string, headers map[string]string, originalModel string) string { + // Use hybrid router for enhanced routing decisions + if r.HybridRouter != nil { + decision, err := r.HybridRouter.RouteRequest(ctx, query, nil, headers, originalModel) + if err != nil { + observability.Errorf("Hybrid routing failed, falling back to legacy classifier: %v", err) + return r.Classifier.ClassifyAndSelectBestModel(query) + } + + observability.Infof("Hybrid routing decision: model=%s, rule_matched=%v, reasoning=%v", + decision.SelectedModel, decision.RuleMatched, decision.UseReasoning) + + return decision.SelectedModel + } + + // Fallback to legacy classifier return r.Classifier.ClassifyAndSelectBestModel(query) } +// getRoutingDecisionWithExplanation returns the full routing decision with explanation +func (r *OpenAIRouter) getRoutingDecisionWithExplanation(ctx context.Context, query string, headers map[string]string, originalModel string) (*rules.RoutingDecision, error) { + if r.HybridRouter != nil { + return r.HybridRouter.RouteRequest(ctx, query, nil, headers, originalModel) + } + + // Create a basic decision for legacy routing + selectedModel := r.Classifier.ClassifyAndSelectBestModel(query) + categoryName, confidence, _ := r.Classifier.ClassifyCategory(query) + + decision := &rules.RoutingDecision{ + RuleMatched: false, + SelectedModel: selectedModel, + UseReasoning: false, // Would need to check category config + Headers: make(map[string]string), + Explanation: rules.DecisionExplanation{ + DecisionType: "model_based", + CategoryClassification: &rules.CategoryClassificationResult{ + Category: categoryName, + Confidence: float64(confidence), + }, + Reasoning: "Legacy model-based routing", + Confidence: float64(confidence), + }, + } + + return decision, nil +} + // findCategoryForClassification determines the category for the given text using classification func (r *OpenAIRouter) findCategoryForClassification(query string) string { if len(r.CategoryDescriptions) == 0 { diff --git a/src/semantic-router/pkg/extproc/request_handler.go b/src/semantic-router/pkg/extproc/request_handler.go index 867333de..aaeb425c 100644 --- a/src/semantic-router/pkg/extproc/request_handler.go +++ b/src/semantic-router/pkg/extproc/request_handler.go @@ -1,6 +1,7 @@ package extproc import ( + "context" "encoding/json" "strings" "time" @@ -313,8 +314,37 @@ func (r *OpenAIRouter) handleModelRouting(openAIRequest *openai.ChatCompletionNe } if classificationText != "" { - // Find the most similar task description or classify, then select best model - matchedModel := r.classifyAndSelectBestModel(classificationText) + // Use hybrid router for enhanced routing with rule support + routingDecision, err := r.getRoutingDecisionWithExplanation(context.Background(), classificationText, ctx.Headers, originalModel) + if err != nil { + observability.Errorf("Hybrid routing failed: %v", err) + metrics.RecordRequestError(ctx.RequestModel, "routing_failed") + return nil, status.Errorf(codes.Internal, "routing failed: %v", err) + } + + matchedModel := routingDecision.SelectedModel + + // Log routing decision details + if routingDecision.RuleMatched { + observability.Infof("Rule-based routing: rule=%s, model=%s, reasoning=%v", + routingDecision.MatchedRule.Name, matchedModel, routingDecision.UseReasoning) + metrics.RecordRoutingReasonCode("rule_based", matchedModel) + } else { + observability.Infof("Model-based routing: category=%s, confidence=%.2f, model=%s", + routingDecision.Explanation.CategoryClassification.Category, + routingDecision.Explanation.CategoryClassification.Confidence, + matchedModel) + metrics.RecordRoutingReasonCode("model_based", matchedModel) + } + + // Check for request blocking + if routingDecision.BlockRequest { + observability.Warnf("Request blocked by routing rule: %s", routingDecision.BlockMessage) + metrics.RecordRequestError(ctx.RequestModel, "rule_block") + blockResponse := http.CreateJailbreakViolationResponse("rule_block", 1.0) + return blockResponse, nil + } + if matchedModel != originalModel && matchedModel != "" { // Get detected PII for policy checking allContent := pii.ExtractAllContent(userContent, nonUserMessages) diff --git a/src/semantic-router/pkg/extproc/router.go b/src/semantic-router/pkg/extproc/router.go index 90eed7c5..3c6028aa 100644 --- a/src/semantic-router/pkg/extproc/router.go +++ b/src/semantic-router/pkg/extproc/router.go @@ -10,6 +10,7 @@ import ( "github.com/vllm-project/semantic-router/src/semantic-router/pkg/cache" "github.com/vllm-project/semantic-router/src/semantic-router/pkg/config" "github.com/vllm-project/semantic-router/src/semantic-router/pkg/observability" + "github.com/vllm-project/semantic-router/src/semantic-router/pkg/rules" "github.com/vllm-project/semantic-router/src/semantic-router/pkg/services" "github.com/vllm-project/semantic-router/src/semantic-router/pkg/tools" "github.com/vllm-project/semantic-router/src/semantic-router/pkg/utils/classification" @@ -24,6 +25,7 @@ type OpenAIRouter struct { PIIChecker *pii.PolicyChecker Cache cache.CacheBackend ToolsDatabase *tools.ToolsDatabase + HybridRouter *rules.HybridRouter // New hybrid router for rule-based routing } // Ensure OpenAIRouter implements the ext_proc calls @@ -150,6 +152,11 @@ func NewOpenAIRouter(configPath string) (*OpenAIRouter, error) { _ = autoSvc } + // Create hybrid router for rule-based routing + hybridRouter := rules.NewHybridRouter(cfg, classifier) + observability.Infof("Hybrid router created with %d rules, strategy: %s", + hybridRouter.GetRuleCount(), cfg.RoutingStrategy.Type) + router := &OpenAIRouter{ Config: cfg, CategoryDescriptions: categoryDescriptions, @@ -157,6 +164,7 @@ func NewOpenAIRouter(configPath string) (*OpenAIRouter, error) { PIIChecker: piiChecker, Cache: semanticCache, ToolsDatabase: toolsDatabase, + HybridRouter: hybridRouter, } // Log reasoning configuration after router is created diff --git a/src/semantic-router/pkg/rules/README.md b/src/semantic-router/pkg/rules/README.md new file mode 100644 index 00000000..4121033c --- /dev/null +++ b/src/semantic-router/pkg/rules/README.md @@ -0,0 +1,209 @@ +# Hybrid Routing Rules System + +This package implements a configurable and interpretable routing rules system that extends the semantic router to support both model-based and rule-based routing approaches. + +## Features + +### Core Capabilities +- **Hybrid routing approach**: Support both model-based classification AND user-defined rules +- **Transparent decision-making**: Every routing decision provides a clear explanation of which rules fired and why +- **User-defined rules**: Ability to create custom routing logic with multiple condition types +- **Configurable thresholds**: Full control over classification sensitivity and decision boundaries +- **Rule precedence**: Ability to define when rules take precedence over model classification +- **Real-time evaluation**: Rules are evaluated in real-time without service restart + +### Supported Rule Conditions + +#### 1. Category Classification +```yaml +- type: "category_classification" + category: "math" + threshold: 0.8 + operator: "gte" +``` + +#### 2. Content Complexity +```yaml +- type: "content_complexity" + metric: "token_count" # or "character_count", "line_count" + threshold: 50 + operator: "gt" +``` + +#### 3. Request Headers +```yaml +- type: "request_header" + header_name: "x-user-tier" + value: "premium" + operator: "equals" +``` + +#### 4. Pattern Matching +```yaml +- type: "pattern_match" + pattern_match: "math" + operator: "contains" +``` + +#### 5. Time-based Conditions +```yaml +- type: "time_based" + time_range: "business_hours" +``` + +### Supported Rule Actions + +#### 1. Route to Model +```yaml +- type: "route_to_model" + model: "math-specialized-model" +``` + +#### 2. Enable Reasoning +```yaml +- type: "enable_reasoning" + enable_reasoning: true + reasoning_effort: "high" +``` + +#### 3. Set Headers +```yaml +- type: "set_headers" + headers: + x-routing-decision: "rule-based" + x-model-tier: "premium" +``` + +#### 4. Block Request +```yaml +- type: "block_request" + block_with_message: "Content violates usage policy" +``` + +## Configuration Example + +```yaml +# Hybrid Routing Configuration +routing_strategy: + type: "hybrid" # Options: "model", "rules", "hybrid" + + model_routing: + enabled: true + fallback_to_rules: false + confidence_threshold: 0.7 + + rule_routing: + enabled: true + fallback_to_model: true + evaluation_timeout_ms: 100 + +# Custom Routing Rules +routing_rules: + - name: "enterprise-math-routing" + description: "Route complex math problems to specialized model" + enabled: true + priority: 100 + + conditions: + - type: "category_classification" + category: "math" + threshold: 0.8 + operator: "gte" + - type: "content_complexity" + metric: "token_count" + threshold: 50 + operator: "gt" + + actions: + - type: "route_to_model" + model: "math-specialized-model" + - type: "enable_reasoning" + enable_reasoning: true + reasoning_effort: "high" + + evaluation: + timeout_ms: 100 + fallback_action: "use_model_classification" +``` + +## API Endpoints + +The rule management API provides the following endpoints: + +### Rule Management +- `GET /api/v1/rules` - List all rules +- `POST /api/v1/rules` - Create new rule +- `GET /api/v1/rules/{name}` - Get specific rule +- `PUT /api/v1/rules/{name}` - Update rule +- `DELETE /api/v1/rules/{name}` - Delete rule + +### Rule Evaluation and Debugging +- `POST /api/v1/rules/evaluate` - Evaluate rules for request +- `GET /api/v1/rules/explain/{id}` - Get decision explanation +- `POST /api/v1/rules/test` - Test rule with sample data + +## Usage + +### Basic Usage + +```go +// Create hybrid router +hybridRouter := rules.NewHybridRouter(config, classifier) + +// Route a request +decision, err := hybridRouter.RouteRequest( + ctx, + userContent, + nonUserContent, + headers, + originalModel, +) + +// Check decision +if decision.RuleMatched { + fmt.Printf("Rule matched: %s\n", decision.MatchedRule.Name) + fmt.Printf("Selected model: %s\n", decision.SelectedModel) + fmt.Printf("Use reasoning: %v\n", decision.UseReasoning) +} +``` + +### Decision Explanation + +Every routing decision includes detailed explanations: + +```go +type DecisionExplanation struct { + DecisionType string // "rule_based", "model_based", "fallback" + RuleName string // Name of matched rule + MatchedConditions []ConditionResult // Details of condition evaluation + ExecutedActions []ActionResult // Details of action execution + Reasoning string // Human-readable explanation + Confidence float64 // Confidence score +} +``` + +## Testing + +Run the comprehensive test suite: + +```bash +cd src/semantic-router +LD_LIBRARY_PATH=../../candle-binding/target/release go test ./pkg/rules -v +``` + +Run the interactive demonstration: + +```bash +cd src/semantic-router +LD_LIBRARY_PATH=../../candle-binding/target/release go run ../../examples/hybrid-routing-demo.go +``` + +## Architecture + +The hybrid routing system consists of three main components: + +1. **RuleEngine**: Evaluates routing rules and conditions +2. **HybridRouter**: Orchestrates rule-based and model-based routing +3. **RuleManagementAPI**: Provides HTTP endpoints for rule management + +The system maintains backward compatibility with existing model-based routing while adding powerful rule-based capabilities for interpretable and configurable routing decisions. \ No newline at end of file diff --git a/src/semantic-router/pkg/rules/api.go b/src/semantic-router/pkg/rules/api.go new file mode 100644 index 00000000..f108cc78 --- /dev/null +++ b/src/semantic-router/pkg/rules/api.go @@ -0,0 +1,359 @@ +package rules + +import ( + "encoding/json" + "fmt" + "net/http" + "strconv" + "time" + + "github.com/vllm-project/semantic-router/src/semantic-router/pkg/config" + "github.com/vllm-project/semantic-router/src/semantic-router/pkg/observability" +) + +// RuleManagementAPI provides HTTP endpoints for rule management +type RuleManagementAPI struct { + hybridRouter *HybridRouter + config *config.RouterConfig +} + +// NewRuleManagementAPI creates a new rule management API instance +func NewRuleManagementAPI(hybridRouter *HybridRouter, routerConfig *config.RouterConfig) *RuleManagementAPI { + return &RuleManagementAPI{ + hybridRouter: hybridRouter, + config: routerConfig, + } +} + +// Rule management request/response types + +type CreateRuleRequest struct { + Rule config.RoutingRule `json:"rule"` +} + +type CreateRuleResponse struct { + ID string `json:"id"` + Message string `json:"message"` +} + +type ListRulesResponse struct { + Rules []RuleInfo `json:"rules"` + Count int `json:"count"` +} + +type RuleInfo struct { + Name string `json:"name"` + Description string `json:"description"` + Enabled bool `json:"enabled"` + Priority int `json:"priority"` + Conditions int `json:"condition_count"` + Actions int `json:"action_count"` +} + +type RuleEvaluationRequest struct { + Content string `json:"content"` + Headers map[string]string `json:"headers,omitempty"` + OriginalModel string `json:"original_model,omitempty"` +} + +type RuleEvaluationResponse struct { + Decision *RoutingDecision `json:"decision"` + Success bool `json:"success"` + Error string `json:"error,omitempty"` +} + +type DecisionExplanationResponse struct { + DecisionID string `json:"decision_id"` + Explanation DecisionExplanation `json:"explanation"` + Timestamp time.Time `json:"timestamp"` +} + +// RegisterRoutes registers rule management API routes +func (api *RuleManagementAPI) RegisterRoutes(mux *http.ServeMux) { + // Rule management endpoints + mux.HandleFunc("/api/v1/rules", api.handleRules) + mux.HandleFunc("/api/v1/rules/", api.handleRuleOperations) + + // Rule evaluation and debugging endpoints + mux.HandleFunc("/api/v1/rules/evaluate", api.handleRuleEvaluation) + mux.HandleFunc("/api/v1/rules/explain/", api.handleDecisionExplanation) + mux.HandleFunc("/api/v1/rules/test", api.handleRuleTest) +} + +// handleRules handles GET /api/v1/rules (list) and POST /api/v1/rules (create) +func (api *RuleManagementAPI) handleRules(w http.ResponseWriter, r *http.Request) { + switch r.Method { + case http.MethodGet: + api.listRules(w, r) + case http.MethodPost: + api.createRule(w, r) + default: + http.Error(w, "Method not allowed", http.StatusMethodNotAllowed) + } +} + +// handleRuleOperations handles rule-specific operations (GET, PUT, DELETE /api/v1/rules/{id}) +func (api *RuleManagementAPI) handleRuleOperations(w http.ResponseWriter, r *http.Request) { + // Extract rule ID from path + ruleName := r.URL.Path[len("/api/v1/rules/"):] + if ruleName == "" { + http.Error(w, "Rule name required", http.StatusBadRequest) + return + } + + switch r.Method { + case http.MethodGet: + api.getRule(w, r, ruleName) + case http.MethodPut: + api.updateRule(w, r, ruleName) + case http.MethodDelete: + api.deleteRule(w, r, ruleName) + default: + http.Error(w, "Method not allowed", http.StatusMethodNotAllowed) + } +} + +// listRules returns all configured rules +func (api *RuleManagementAPI) listRules(w http.ResponseWriter, r *http.Request) { + rules := make([]RuleInfo, 0, len(api.config.RoutingRules)) + + for _, rule := range api.config.RoutingRules { + ruleInfo := RuleInfo{ + Name: rule.Name, + Description: rule.Description, + Enabled: rule.Enabled, + Priority: rule.Priority, + Conditions: len(rule.Conditions), + Actions: len(rule.Actions), + } + rules = append(rules, ruleInfo) + } + + response := ListRulesResponse{ + Rules: rules, + Count: len(rules), + } + + w.Header().Set("Content-Type", "application/json") + json.NewEncoder(w).Encode(response) +} + +// createRule creates a new routing rule +func (api *RuleManagementAPI) createRule(w http.ResponseWriter, r *http.Request) { + var req CreateRuleRequest + if err := json.NewDecoder(r.Body).Decode(&req); err != nil { + http.Error(w, fmt.Sprintf("Invalid request body: %v", err), http.StatusBadRequest) + return + } + + // Validate rule + if req.Rule.Name == "" { + http.Error(w, "Rule name is required", http.StatusBadRequest) + return + } + + // Check if rule already exists + for _, existingRule := range api.config.RoutingRules { + if existingRule.Name == req.Rule.Name { + http.Error(w, "Rule with this name already exists", http.StatusConflict) + return + } + } + + // Add rule to configuration (in real implementation, this would persist to storage) + api.config.RoutingRules = append(api.config.RoutingRules, req.Rule) + + observability.Infof("Created new routing rule: %s", req.Rule.Name) + + response := CreateRuleResponse{ + ID: req.Rule.Name, + Message: "Rule created successfully", + } + + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(http.StatusCreated) + json.NewEncoder(w).Encode(response) +} + +// getRule returns a specific rule +func (api *RuleManagementAPI) getRule(w http.ResponseWriter, r *http.Request, ruleName string) { + for _, rule := range api.config.RoutingRules { + if rule.Name == ruleName { + w.Header().Set("Content-Type", "application/json") + json.NewEncoder(w).Encode(rule) + return + } + } + + http.Error(w, "Rule not found", http.StatusNotFound) +} + +// updateRule updates an existing rule +func (api *RuleManagementAPI) updateRule(w http.ResponseWriter, r *http.Request, ruleName string) { + var updatedRule config.RoutingRule + if err := json.NewDecoder(r.Body).Decode(&updatedRule); err != nil { + http.Error(w, fmt.Sprintf("Invalid request body: %v", err), http.StatusBadRequest) + return + } + + // Find and update rule + for i, rule := range api.config.RoutingRules { + if rule.Name == ruleName { + // Ensure name consistency + updatedRule.Name = ruleName + api.config.RoutingRules[i] = updatedRule + + observability.Infof("Updated routing rule: %s", ruleName) + + w.Header().Set("Content-Type", "application/json") + json.NewEncoder(w).Encode(updatedRule) + return + } + } + + http.Error(w, "Rule not found", http.StatusNotFound) +} + +// deleteRule deletes a rule +func (api *RuleManagementAPI) deleteRule(w http.ResponseWriter, r *http.Request, ruleName string) { + for i, rule := range api.config.RoutingRules { + if rule.Name == ruleName { + // Remove rule from slice + api.config.RoutingRules = append(api.config.RoutingRules[:i], api.config.RoutingRules[i+1:]...) + + observability.Infof("Deleted routing rule: %s", ruleName) + + w.WriteHeader(http.StatusNoContent) + return + } + } + + http.Error(w, "Rule not found", http.StatusNotFound) +} + +// handleRuleEvaluation evaluates rules against provided content +func (api *RuleManagementAPI) handleRuleEvaluation(w http.ResponseWriter, r *http.Request) { + if r.Method != http.MethodPost { + http.Error(w, "Method not allowed", http.StatusMethodNotAllowed) + return + } + + var req RuleEvaluationRequest + if err := json.NewDecoder(r.Body).Decode(&req); err != nil { + http.Error(w, fmt.Sprintf("Invalid request body: %v", err), http.StatusBadRequest) + return + } + + if req.Content == "" { + http.Error(w, "Content is required", http.StatusBadRequest) + return + } + + // Create evaluation context + headers := req.Headers + if headers == nil { + headers = make(map[string]string) + } + + // Evaluate rules using hybrid router + decision, err := api.hybridRouter.RouteRequest(r.Context(), req.Content, nil, headers, req.OriginalModel) + + response := RuleEvaluationResponse{ + Success: err == nil, + } + + if err != nil { + response.Error = err.Error() + observability.Errorf("Rule evaluation failed: %v", err) + } else { + response.Decision = decision + } + + w.Header().Set("Content-Type", "application/json") + json.NewEncoder(w).Encode(response) +} + +// handleDecisionExplanation provides detailed explanation for a decision +func (api *RuleManagementAPI) handleDecisionExplanation(w http.ResponseWriter, r *http.Request) { + if r.Method != http.MethodGet { + http.Error(w, "Method not allowed", http.StatusMethodNotAllowed) + return + } + + // Extract decision ID from path + decisionID := r.URL.Path[len("/api/v1/rules/explain/"):] + if decisionID == "" { + http.Error(w, "Decision ID required", http.StatusBadRequest) + return + } + + // In a real implementation, this would lookup stored decision explanations + // For now, return a placeholder response + response := DecisionExplanationResponse{ + DecisionID: decisionID, + Timestamp: time.Now(), + Explanation: DecisionExplanation{ + DecisionType: "placeholder", + Reasoning: "Decision explanation feature requires persistent storage implementation", + }, + } + + w.Header().Set("Content-Type", "application/json") + json.NewEncoder(w).Encode(response) +} + +// handleRuleTest tests a rule with sample data +func (api *RuleManagementAPI) handleRuleTest(w http.ResponseWriter, r *http.Request) { + if r.Method != http.MethodPost { + http.Error(w, "Method not allowed", http.StatusMethodNotAllowed) + return + } + + var req struct { + Rule config.RoutingRule `json:"rule"` + Content string `json:"content"` + Headers map[string]string `json:"headers,omitempty"` + } + + if err := json.NewDecoder(r.Body).Decode(&req); err != nil { + http.Error(w, fmt.Sprintf("Invalid request body: %v", err), http.StatusBadRequest) + return + } + + // Create a temporary rule engine for testing + testRules := []config.RoutingRule{req.Rule} + testEngine := NewRuleEngine(testRules, api.hybridRouter.classifier, api.config) + + // Create evaluation context + headers := req.Headers + if headers == nil { + headers = make(map[string]string) + } + + evalCtx := &EvaluationContext{ + UserContent: req.Content, + NonUserContent: nil, + AllContent: req.Content, + Headers: headers, + RequestID: "test-" + strconv.FormatInt(time.Now().Unix(), 10), + Timestamp: time.Now(), + OriginalModel: "test", + ExternalData: make(map[string]interface{}), + } + + // Evaluate the test rule + decision, err := testEngine.EvaluateRules(r.Context(), evalCtx) + + response := RuleEvaluationResponse{ + Success: err == nil, + } + + if err != nil { + response.Error = err.Error() + } else { + response.Decision = decision + } + + w.Header().Set("Content-Type", "application/json") + json.NewEncoder(w).Encode(response) +} \ No newline at end of file diff --git a/src/semantic-router/pkg/rules/engine.go b/src/semantic-router/pkg/rules/engine.go new file mode 100644 index 00000000..ed6c9714 --- /dev/null +++ b/src/semantic-router/pkg/rules/engine.go @@ -0,0 +1,451 @@ +package rules + +import ( + "context" + "fmt" + "sort" + "strings" + "time" + + "github.com/vllm-project/semantic-router/src/semantic-router/pkg/config" + "github.com/vllm-project/semantic-router/src/semantic-router/pkg/observability" + "github.com/vllm-project/semantic-router/src/semantic-router/pkg/utils/classification" +) + +// RuleEngine evaluates routing rules and provides routing decisions +type RuleEngine struct { + rules []config.RoutingRule + classifier *classification.Classifier + config *config.RouterConfig +} + +// NewRuleEngine creates a new rule engine instance +func NewRuleEngine(rules []config.RoutingRule, classifier *classification.Classifier, routerConfig *config.RouterConfig) *RuleEngine { + // Sort rules by priority (higher priority first) + sortedRules := make([]config.RoutingRule, len(rules)) + copy(sortedRules, rules) + sort.Slice(sortedRules, func(i, j int) bool { + return sortedRules[i].Priority > sortedRules[j].Priority + }) + + return &RuleEngine{ + rules: sortedRules, + classifier: classifier, + config: routerConfig, + } +} + +// RoutingDecision represents the result of rule evaluation +type RoutingDecision struct { + // Whether any rule matched + RuleMatched bool + + // The matching rule (if any) + MatchedRule *config.RoutingRule + + // Selected model + SelectedModel string + + // Reasoning configuration + UseReasoning bool + ReasoningEffort string + + // Decision explanation + Explanation DecisionExplanation + + // Whether to block the request + BlockRequest bool + BlockMessage string + + // Additional headers to set + Headers map[string]string + + // Processing time for rule evaluation + EvaluationTimeMs int64 +} + +// DecisionExplanation provides detailed explanation of the routing decision +type DecisionExplanation struct { + // Decision type: "rule_based", "model_based", or "fallback" + DecisionType string + + // For rule-based decisions + RuleName string + MatchedConditions []ConditionResult + ExecutedActions []ActionResult + + // For model-based decisions + CategoryClassification *CategoryClassificationResult + + // Reasoning behind the decision + Reasoning string + + // Confidence score for the decision + Confidence float64 +} + +// ConditionResult represents the result of evaluating a rule condition +type ConditionResult struct { + ConditionType string + Matched bool + ActualValue interface{} + ExpectedValue interface{} + Confidence float64 + Details string +} + +// ActionResult represents the result of executing a rule action +type ActionResult struct { + ActionType string + Executed bool + Details string + Error string +} + +// CategoryClassificationResult represents ML classification results +type CategoryClassificationResult struct { + Category string + Confidence float64 + Probabilities map[string]float64 +} + +// EvaluationContext contains context for rule evaluation +type EvaluationContext struct { + // Request content + UserContent string + NonUserContent []string + AllContent string + + // Request metadata + Headers map[string]string + RequestID string + Timestamp time.Time + + // Model information + OriginalModel string + + // External context (can be extended) + ExternalData map[string]interface{} +} + +// EvaluateRules evaluates all rules against the given context and returns routing decision +func (re *RuleEngine) EvaluateRules(ctx context.Context, evalCtx *EvaluationContext) (*RoutingDecision, error) { + startTime := time.Now() + + decision := &RoutingDecision{ + RuleMatched: false, + SelectedModel: re.config.DefaultModel, + UseReasoning: false, + ReasoningEffort: re.config.DefaultReasoningEffort, + Headers: make(map[string]string), + Explanation: DecisionExplanation{ + DecisionType: "fallback", + Reasoning: "No rules matched, using default model", + }, + } + + observability.Infof("Evaluating %d routing rules", len(re.rules)) + + // Evaluate rules in priority order + for _, rule := range re.rules { + if !rule.Enabled { + continue + } + + observability.Infof("Evaluating rule: %s (priority: %d)", rule.Name, rule.Priority) + + // Check if rule matches + matched, conditionResults, err := re.evaluateRuleConditions(ctx, &rule, evalCtx) + if err != nil { + observability.Errorf("Error evaluating rule %s: %v", rule.Name, err) + continue + } + + if matched { + observability.Infof("Rule %s matched! Executing actions", rule.Name) + + // Execute rule actions + actionResults, err := re.executeRuleActions(ctx, &rule, evalCtx, decision) + if err != nil { + observability.Errorf("Error executing actions for rule %s: %v", rule.Name, err) + continue + } + + // Update decision with rule match + decision.RuleMatched = true + decision.MatchedRule = &rule + decision.Explanation = DecisionExplanation{ + DecisionType: "rule_based", + RuleName: rule.Name, + MatchedConditions: conditionResults, + ExecutedActions: actionResults, + Reasoning: fmt.Sprintf("Rule '%s' matched and actions executed", rule.Name), + Confidence: re.calculateRuleConfidence(conditionResults), + } + + // Rule matched and executed, stop evaluation + break + } + } + + // Calculate evaluation time + decision.EvaluationTimeMs = time.Since(startTime).Milliseconds() + + observability.Infof("Rule evaluation completed in %dms, rule matched: %v", + decision.EvaluationTimeMs, decision.RuleMatched) + + return decision, nil +} + +// evaluateRuleConditions evaluates all conditions for a rule +func (re *RuleEngine) evaluateRuleConditions(ctx context.Context, rule *config.RoutingRule, evalCtx *EvaluationContext) (bool, []ConditionResult, error) { + results := make([]ConditionResult, 0, len(rule.Conditions)) + allMatched := true + + for _, condition := range rule.Conditions { + result, err := re.evaluateCondition(ctx, &condition, evalCtx) + if err != nil { + return false, results, fmt.Errorf("failed to evaluate condition %s: %w", condition.Type, err) + } + + results = append(results, result) + if !result.Matched { + allMatched = false + } + } + + return allMatched, results, nil +} + +// evaluateCondition evaluates a single condition +func (re *RuleEngine) evaluateCondition(ctx context.Context, condition *config.RuleCondition, evalCtx *EvaluationContext) (ConditionResult, error) { + result := ConditionResult{ + ConditionType: condition.Type, + Matched: false, + } + + switch condition.Type { + case "category_classification": + return re.evaluateCategoryCondition(condition, evalCtx) + case "content_complexity": + return re.evaluateContentComplexityCondition(condition, evalCtx) + case "request_header": + return re.evaluateHeaderCondition(condition, evalCtx) + case "time_based": + return re.evaluateTimeCondition(condition, evalCtx) + case "pattern_match": + return re.evaluatePatternCondition(condition, evalCtx) + default: + return result, fmt.Errorf("unsupported condition type: %s", condition.Type) + } +} + +// evaluateCategoryCondition evaluates category classification conditions +func (re *RuleEngine) evaluateCategoryCondition(condition *config.RuleCondition, evalCtx *EvaluationContext) (ConditionResult, error) { + result := ConditionResult{ + ConditionType: condition.Type, + ExpectedValue: fmt.Sprintf("%s %s %.2f", condition.Category, condition.Operator, condition.Threshold), + } + + if re.classifier == nil { + return result, fmt.Errorf("classifier not available for category classification") + } + + // Perform classification + categoryName, confidence, err := re.classifier.ClassifyCategory(evalCtx.AllContent) + if err != nil { + return result, fmt.Errorf("classification failed: %w", err) + } + + result.ActualValue = fmt.Sprintf("%s (confidence: %.2f)", categoryName, confidence) + result.Confidence = float64(confidence) + result.Details = fmt.Sprintf("Classified as '%s' with confidence %.2f", categoryName, confidence) + + // Check if category matches and confidence meets threshold + categoryMatches := (condition.Category == "" || categoryName == condition.Category) + confidenceMatches := re.compareFloat(float64(confidence), condition.Threshold, condition.Operator) + + result.Matched = categoryMatches && confidenceMatches + + return result, nil +} + +// evaluateContentComplexityCondition evaluates content complexity conditions +func (re *RuleEngine) evaluateContentComplexityCondition(condition *config.RuleCondition, evalCtx *EvaluationContext) (ConditionResult, error) { + result := ConditionResult{ + ConditionType: condition.Type, + ExpectedValue: fmt.Sprintf("%s %s %.2f", condition.Metric, condition.Operator, condition.Threshold), + } + + var actualValue float64 + switch condition.Metric { + case "token_count": + // Simple token count estimation (split by whitespace) + actualValue = float64(len(strings.Fields(evalCtx.AllContent))) + case "character_count": + actualValue = float64(len(evalCtx.AllContent)) + case "line_count": + actualValue = float64(len(strings.Split(evalCtx.AllContent, "\n"))) + default: + return result, fmt.Errorf("unsupported complexity metric: %s", condition.Metric) + } + + result.ActualValue = actualValue + result.Details = fmt.Sprintf("%s: %.0f", condition.Metric, actualValue) + result.Matched = re.compareFloat(actualValue, condition.Threshold, condition.Operator) + + return result, nil +} + +// evaluateHeaderCondition evaluates request header conditions +func (re *RuleEngine) evaluateHeaderCondition(condition *config.RuleCondition, evalCtx *EvaluationContext) (ConditionResult, error) { + result := ConditionResult{ + ConditionType: condition.Type, + ExpectedValue: fmt.Sprintf("%s %s %s", condition.HeaderName, condition.Operator, condition.Value), + } + + headerValue, exists := evalCtx.Headers[condition.HeaderName] + if !exists { + result.ActualValue = "" + result.Details = fmt.Sprintf("Header '%s' not found", condition.HeaderName) + result.Matched = false + return result, nil + } + + result.ActualValue = headerValue + result.Details = fmt.Sprintf("Header '%s' = '%s'", condition.HeaderName, headerValue) + result.Matched = re.compareString(headerValue, condition.Value, condition.Operator) + + return result, nil +} + +// evaluateTimeCondition evaluates time-based conditions +func (re *RuleEngine) evaluateTimeCondition(condition *config.RuleCondition, evalCtx *EvaluationContext) (ConditionResult, error) { + result := ConditionResult{ + ConditionType: condition.Type, + ExpectedValue: condition.TimeRange, + } + + // Simple time range check (could be extended) + currentHour := evalCtx.Timestamp.Hour() + result.ActualValue = fmt.Sprintf("Hour: %d", currentHour) + result.Details = fmt.Sprintf("Current time: %s", evalCtx.Timestamp.Format("15:04:05")) + + // For now, always match (this could be extended with proper time range parsing) + result.Matched = true + + return result, nil +} + +// evaluatePatternCondition evaluates pattern matching conditions +func (re *RuleEngine) evaluatePatternCondition(condition *config.RuleCondition, evalCtx *EvaluationContext) (ConditionResult, error) { + result := ConditionResult{ + ConditionType: condition.Type, + ExpectedValue: condition.PatternMatch, + } + + // Simple pattern matching (contains check) + matched := strings.Contains(strings.ToLower(evalCtx.AllContent), strings.ToLower(condition.PatternMatch)) + + result.ActualValue = evalCtx.AllContent + result.Details = fmt.Sprintf("Pattern '%s' in content", condition.PatternMatch) + result.Matched = matched + + return result, nil +} + +// executeRuleActions executes all actions for a matched rule +func (re *RuleEngine) executeRuleActions(ctx context.Context, rule *config.RoutingRule, evalCtx *EvaluationContext, decision *RoutingDecision) ([]ActionResult, error) { + results := make([]ActionResult, 0, len(rule.Actions)) + + for _, action := range rule.Actions { + result := re.executeAction(&action, evalCtx, decision) + results = append(results, result) + } + + return results, nil +} + +// executeAction executes a single rule action +func (re *RuleEngine) executeAction(action *config.RuleAction, evalCtx *EvaluationContext, decision *RoutingDecision) ActionResult { + result := ActionResult{ + ActionType: action.Type, + Executed: false, + } + + switch action.Type { + case "route_to_model": + if action.Model != "" { + decision.SelectedModel = action.Model + result.Executed = true + result.Details = fmt.Sprintf("Routed to model: %s", action.Model) + } + case "enable_reasoning": + decision.UseReasoning = action.EnableReasoning + if action.ReasoningEffort != "" { + decision.ReasoningEffort = action.ReasoningEffort + } + result.Executed = true + result.Details = fmt.Sprintf("Reasoning: %v, Effort: %s", action.EnableReasoning, action.ReasoningEffort) + case "set_headers": + for key, value := range action.Headers { + decision.Headers[key] = value + } + result.Executed = true + result.Details = fmt.Sprintf("Set %d headers", len(action.Headers)) + case "block_request": + decision.BlockRequest = true + decision.BlockMessage = action.BlockWithMessage + result.Executed = true + result.Details = fmt.Sprintf("Blocked: %s", action.BlockWithMessage) + default: + result.Error = fmt.Sprintf("Unsupported action type: %s", action.Type) + } + + return result +} + +// Helper functions for condition evaluation + +func (re *RuleEngine) compareFloat(actual, expected float64, operator string) bool { + switch operator { + case "gte", ">=": + return actual >= expected + case "gt", ">": + return actual > expected + case "lte", "<=": + return actual <= expected + case "lt", "<": + return actual < expected + case "equals", "==": + return actual == expected + default: + return false + } +} + +func (re *RuleEngine) compareString(actual, expected, operator string) bool { + switch operator { + case "equals": + return actual == expected + case "contains": + return strings.Contains(strings.ToLower(actual), strings.ToLower(expected)) + default: + return false + } +} + +func (re *RuleEngine) calculateRuleConfidence(conditionResults []ConditionResult) float64 { + if len(conditionResults) == 0 { + return 0.0 + } + + total := 0.0 + for _, result := range conditionResults { + if result.Matched { + total += result.Confidence + } + } + + return total / float64(len(conditionResults)) +} \ No newline at end of file diff --git a/src/semantic-router/pkg/rules/engine_test.go b/src/semantic-router/pkg/rules/engine_test.go new file mode 100644 index 00000000..4a3e7cb9 --- /dev/null +++ b/src/semantic-router/pkg/rules/engine_test.go @@ -0,0 +1,349 @@ +package rules + +import ( + "context" + "testing" + "time" + + "github.com/vllm-project/semantic-router/src/semantic-router/pkg/config" +) + +func TestRuleEngine_BasicEvaluation(t *testing.T) { + // Create test rules + rules := []config.RoutingRule{ + { + Name: "math-rule", + Description: "Route math problems to math model", + Enabled: true, + Priority: 100, + Conditions: []config.RuleCondition{ + { + Type: "pattern_match", + PatternMatch: "math", + Operator: "contains", + }, + }, + Actions: []config.RuleAction{ + { + Type: "route_to_model", + Model: "math-model", + }, + }, + }, + } + + // Create test config + cfg := &config.RouterConfig{ + DefaultModel: "default-model", + } + + // Create rule engine + engine := NewRuleEngine(rules, nil, cfg) + + // Create evaluation context + evalCtx := &EvaluationContext{ + AllContent: "solve this math problem", + Headers: make(map[string]string), + Timestamp: time.Now(), + } + + // Evaluate rules + decision, err := engine.EvaluateRules(context.Background(), evalCtx) + if err != nil { + t.Fatalf("Rule evaluation failed: %v", err) + } + + // Check results + if !decision.RuleMatched { + t.Error("Expected rule to match") + } + + if decision.SelectedModel != "math-model" { + t.Errorf("Expected model 'math-model', got '%s'", decision.SelectedModel) + } + + if decision.MatchedRule.Name != "math-rule" { + t.Errorf("Expected rule 'math-rule', got '%s'", decision.MatchedRule.Name) + } +} + +func TestRuleEngine_ContentComplexity(t *testing.T) { + // Create rule with content complexity condition + rules := []config.RoutingRule{ + { + Name: "long-content-rule", + Enabled: true, + Priority: 100, + Conditions: []config.RuleCondition{ + { + Type: "content_complexity", + Metric: "token_count", + Threshold: 5, + Operator: "gt", + }, + }, + Actions: []config.RuleAction{ + { + Type: "route_to_model", + Model: "complex-model", + }, + }, + }, + } + + cfg := &config.RouterConfig{DefaultModel: "default-model"} + engine := NewRuleEngine(rules, nil, cfg) + + // Test with long content (should match) + evalCtx := &EvaluationContext{ + AllContent: "this is a very long piece of content that should trigger the rule", + Headers: make(map[string]string), + Timestamp: time.Now(), + } + + decision, err := engine.EvaluateRules(context.Background(), evalCtx) + if err != nil { + t.Fatalf("Rule evaluation failed: %v", err) + } + + if !decision.RuleMatched { + t.Error("Expected rule to match for long content") + } + + // Test with short content (should not match) + evalCtx.AllContent = "short content" + decision, err = engine.EvaluateRules(context.Background(), evalCtx) + if err != nil { + t.Fatalf("Rule evaluation failed: %v", err) + } + + if decision.RuleMatched { + t.Error("Expected rule not to match for short content") + } +} + +func TestRuleEngine_HeaderConditions(t *testing.T) { + rules := []config.RoutingRule{ + { + Name: "api-key-rule", + Enabled: true, + Priority: 100, + Conditions: []config.RuleCondition{ + { + Type: "request_header", + HeaderName: "x-api-key", + Value: "premium", + Operator: "equals", + }, + }, + Actions: []config.RuleAction{ + { + Type: "route_to_model", + Model: "premium-model", + }, + }, + }, + } + + cfg := &config.RouterConfig{DefaultModel: "default-model"} + engine := NewRuleEngine(rules, nil, cfg) + + // Test with matching header + evalCtx := &EvaluationContext{ + AllContent: "test content", + Headers: map[string]string{"x-api-key": "premium"}, + Timestamp: time.Now(), + } + + decision, err := engine.EvaluateRules(context.Background(), evalCtx) + if err != nil { + t.Fatalf("Rule evaluation failed: %v", err) + } + + if !decision.RuleMatched { + t.Error("Expected rule to match for premium API key") + } + + // Test with wrong header value + evalCtx.Headers["x-api-key"] = "basic" + decision, err = engine.EvaluateRules(context.Background(), evalCtx) + if err != nil { + t.Fatalf("Rule evaluation failed: %v", err) + } + + if decision.RuleMatched { + t.Error("Expected rule not to match for basic API key") + } +} + +func TestRuleEngine_MultipleConditions(t *testing.T) { + rules := []config.RoutingRule{ + { + Name: "complex-rule", + Enabled: true, + Priority: 100, + Conditions: []config.RuleCondition{ + { + Type: "pattern_match", + PatternMatch: "math", + Operator: "contains", + }, + { + Type: "content_complexity", + Metric: "token_count", + Threshold: 3, + Operator: "gte", + }, + }, + Actions: []config.RuleAction{ + { + Type: "route_to_model", + Model: "advanced-math-model", + }, + }, + }, + } + + cfg := &config.RouterConfig{DefaultModel: "default-model"} + engine := NewRuleEngine(rules, nil, cfg) + + // Test content that matches both conditions + evalCtx := &EvaluationContext{ + AllContent: "solve this complex math problem", + Headers: make(map[string]string), + Timestamp: time.Now(), + } + + decision, err := engine.EvaluateRules(context.Background(), evalCtx) + if err != nil { + t.Fatalf("Rule evaluation failed: %v", err) + } + + if !decision.RuleMatched { + t.Error("Expected rule to match when all conditions are met") + } + + // Test content that matches only one condition + evalCtx.AllContent = "math" + decision, err = engine.EvaluateRules(context.Background(), evalCtx) + if err != nil { + t.Fatalf("Rule evaluation failed: %v", err) + } + + if decision.RuleMatched { + t.Error("Expected rule not to match when only one condition is met") + } +} + +func TestRuleEngine_RulePriority(t *testing.T) { + rules := []config.RoutingRule{ + { + Name: "low-priority-rule", + Enabled: true, + Priority: 50, + Conditions: []config.RuleCondition{ + { + Type: "pattern_match", + PatternMatch: "test", + Operator: "contains", + }, + }, + Actions: []config.RuleAction{ + { + Type: "route_to_model", + Model: "low-priority-model", + }, + }, + }, + { + Name: "high-priority-rule", + Enabled: true, + Priority: 100, + Conditions: []config.RuleCondition{ + { + Type: "pattern_match", + PatternMatch: "test", + Operator: "contains", + }, + }, + Actions: []config.RuleAction{ + { + Type: "route_to_model", + Model: "high-priority-model", + }, + }, + }, + } + + cfg := &config.RouterConfig{DefaultModel: "default-model"} + engine := NewRuleEngine(rules, nil, cfg) + + evalCtx := &EvaluationContext{ + AllContent: "test content", + Headers: make(map[string]string), + Timestamp: time.Now(), + } + + decision, err := engine.EvaluateRules(context.Background(), evalCtx) + if err != nil { + t.Fatalf("Rule evaluation failed: %v", err) + } + + if !decision.RuleMatched { + t.Error("Expected a rule to match") + } + + // Should match the high priority rule first + if decision.SelectedModel != "high-priority-model" { + t.Errorf("Expected high priority model, got '%s'", decision.SelectedModel) + } +} + +func TestRuleEngine_BlockAction(t *testing.T) { + rules := []config.RoutingRule{ + { + Name: "block-rule", + Enabled: true, + Priority: 100, + Conditions: []config.RuleCondition{ + { + Type: "pattern_match", + PatternMatch: "forbidden", + Operator: "contains", + }, + }, + Actions: []config.RuleAction{ + { + Type: "block_request", + BlockWithMessage: "Content contains forbidden patterns", + }, + }, + }, + } + + cfg := &config.RouterConfig{DefaultModel: "default-model"} + engine := NewRuleEngine(rules, nil, cfg) + + evalCtx := &EvaluationContext{ + AllContent: "this contains forbidden content", + Headers: make(map[string]string), + Timestamp: time.Now(), + } + + decision, err := engine.EvaluateRules(context.Background(), evalCtx) + if err != nil { + t.Fatalf("Rule evaluation failed: %v", err) + } + + if !decision.RuleMatched { + t.Error("Expected rule to match") + } + + if !decision.BlockRequest { + t.Error("Expected request to be blocked") + } + + if decision.BlockMessage != "Content contains forbidden patterns" { + t.Errorf("Expected block message 'Content contains forbidden patterns', got '%s'", decision.BlockMessage) + } +} \ No newline at end of file diff --git a/src/semantic-router/pkg/rules/hybrid_router.go b/src/semantic-router/pkg/rules/hybrid_router.go new file mode 100644 index 00000000..89c3e646 --- /dev/null +++ b/src/semantic-router/pkg/rules/hybrid_router.go @@ -0,0 +1,283 @@ +package rules + +import ( + "context" + "fmt" + "strings" + "time" + + "github.com/vllm-project/semantic-router/src/semantic-router/pkg/config" + "github.com/vllm-project/semantic-router/src/semantic-router/pkg/observability" + "github.com/vllm-project/semantic-router/src/semantic-router/pkg/utils/classification" +) + +// HybridRouter combines rule-based and model-based routing +type HybridRouter struct { + ruleEngine *RuleEngine + classifier *classification.Classifier + config *config.RouterConfig +} + +// NewHybridRouter creates a new hybrid router instance +func NewHybridRouter(routerConfig *config.RouterConfig, classifier *classification.Classifier) *HybridRouter { + var ruleEngine *RuleEngine + + // Initialize rule engine if rules are configured + if len(routerConfig.RoutingRules) > 0 { + ruleEngine = NewRuleEngine(routerConfig.RoutingRules, classifier, routerConfig) + } + + return &HybridRouter{ + ruleEngine: ruleEngine, + classifier: classifier, + config: routerConfig, + } +} + +// RouteRequest determines the best model and configuration for a request +func (hr *HybridRouter) RouteRequest(ctx context.Context, userContent string, nonUserContent []string, headers map[string]string, originalModel string) (*RoutingDecision, error) { + startTime := time.Now() + + // Combine all content for analysis + allContent := userContent + if len(nonUserContent) > 0 { + allContent = strings.Join(append([]string{userContent}, nonUserContent...), " ") + } + + // Create evaluation context + evalCtx := &EvaluationContext{ + UserContent: userContent, + NonUserContent: nonUserContent, + AllContent: allContent, + Headers: headers, + RequestID: headers["x-request-id"], + Timestamp: time.Now(), + OriginalModel: originalModel, + ExternalData: make(map[string]interface{}), + } + + // Determine routing strategy + strategy := hr.getRoutingStrategy() + observability.Infof("Using routing strategy: %s", strategy) + + var decision *RoutingDecision + var err error + + switch strategy { + case "rules": + decision, err = hr.routeWithRules(ctx, evalCtx) + case "model": + decision, err = hr.routeWithModel(ctx, evalCtx) + case "hybrid": + decision, err = hr.routeWithHybrid(ctx, evalCtx) + default: + // Fallback to model-based routing + decision, err = hr.routeWithModel(ctx, evalCtx) + } + + if err != nil { + return nil, fmt.Errorf("routing failed: %w", err) + } + + // Update evaluation time + decision.EvaluationTimeMs = time.Since(startTime).Milliseconds() + + observability.Infof("Routing completed: model=%s, reasoning=%v, time=%dms", + decision.SelectedModel, decision.UseReasoning, decision.EvaluationTimeMs) + + return decision, nil +} + +// getRoutingStrategy determines which routing strategy to use +func (hr *HybridRouter) getRoutingStrategy() string { + if hr.config.RoutingStrategy.Type != "" { + return hr.config.RoutingStrategy.Type + } + + // Auto-determine strategy based on configuration + hasRules := len(hr.config.RoutingRules) > 0 && hr.ruleEngine != nil + hasModel := hr.classifier != nil + + if hasRules && hasModel { + return "hybrid" + } else if hasRules { + return "rules" + } else if hasModel { + return "model" + } + + return "model" // Default fallback +} + +// routeWithRules uses only rule-based routing +func (hr *HybridRouter) routeWithRules(ctx context.Context, evalCtx *EvaluationContext) (*RoutingDecision, error) { + if hr.ruleEngine == nil { + return hr.createFallbackDecision(evalCtx, "No rule engine available") + } + + decision, err := hr.ruleEngine.EvaluateRules(ctx, evalCtx) + if err != nil { + return nil, fmt.Errorf("rule evaluation failed: %w", err) + } + + // If no rule matched and fallback to model is enabled, try model routing + if !decision.RuleMatched && hr.config.RoutingStrategy.RuleRouting.FallbackToModel { + observability.Infof("No rules matched, falling back to model-based routing") + modelDecision, modelErr := hr.routeWithModel(ctx, evalCtx) + if modelErr == nil { + // Merge model decision with rule decision + decision.SelectedModel = modelDecision.SelectedModel + decision.UseReasoning = modelDecision.UseReasoning + decision.ReasoningEffort = modelDecision.ReasoningEffort + decision.Explanation.DecisionType = "fallback_to_model" + decision.Explanation.CategoryClassification = modelDecision.Explanation.CategoryClassification + decision.Explanation.Reasoning = "Rules did not match, used model-based routing" + } + } + + return decision, nil +} + +// routeWithModel uses only model-based routing +func (hr *HybridRouter) routeWithModel(ctx context.Context, evalCtx *EvaluationContext) (*RoutingDecision, error) { + if hr.classifier == nil { + return hr.createFallbackDecision(evalCtx, "No classifier available") + } + + // Perform classification + categoryName, confidence, err := hr.classifier.ClassifyCategory(evalCtx.AllContent) + if err != nil { + return nil, fmt.Errorf("classification failed: %w", err) + } + + // Get model for category + selectedModel := hr.classifier.ClassifyAndSelectBestModel(evalCtx.AllContent) + if selectedModel == "" { + selectedModel = hr.config.DefaultModel + } + + // Get reasoning configuration for the category + useReasoning, reasoningEffort := hr.getReasoningConfig(categoryName, selectedModel) + + decision := &RoutingDecision{ + RuleMatched: false, + SelectedModel: selectedModel, + UseReasoning: useReasoning, + ReasoningEffort: reasoningEffort, + Headers: make(map[string]string), + Explanation: DecisionExplanation{ + DecisionType: "model_based", + CategoryClassification: &CategoryClassificationResult{ + Category: categoryName, + Confidence: float64(confidence), + }, + Reasoning: fmt.Sprintf("Model-based classification selected '%s' with confidence %.2f", categoryName, confidence), + Confidence: float64(confidence), + }, + } + + return decision, nil +} + +// routeWithHybrid uses hybrid routing (rules first, then model) +func (hr *HybridRouter) routeWithHybrid(ctx context.Context, evalCtx *EvaluationContext) (*RoutingDecision, error) { + // First try rule-based routing + decision, err := hr.routeWithRules(ctx, evalCtx) + if err != nil { + return nil, fmt.Errorf("rule evaluation failed in hybrid mode: %w", err) + } + + // If rule matched, use rule decision + if decision.RuleMatched { + observability.Infof("Rule matched in hybrid mode: %s", decision.MatchedRule.Name) + return decision, nil + } + + // No rule matched, try model-based routing + observability.Infof("No rules matched in hybrid mode, trying model-based routing") + + // Check model routing configuration + if !hr.config.RoutingStrategy.ModelRouting.Enabled { + observability.Infof("Model routing disabled, using rule decision with default model") + return decision, nil + } + + modelDecision, err := hr.routeWithModel(ctx, evalCtx) + if err != nil { + observability.Errorf("Model routing failed in hybrid mode: %v", err) + return decision, nil // Return rule decision as fallback + } + + // Check if model confidence meets threshold + confidenceThreshold := hr.config.RoutingStrategy.ModelRouting.ConfidenceThreshold + if confidenceThreshold > 0 && modelDecision.Explanation.Confidence < confidenceThreshold { + observability.Infof("Model confidence %.2f below threshold %.2f, using default routing", + modelDecision.Explanation.Confidence, confidenceThreshold) + return decision, nil + } + + // Use model decision + modelDecision.Explanation.DecisionType = "hybrid_model" + modelDecision.Explanation.Reasoning = fmt.Sprintf("No rules matched, used model-based routing with confidence %.2f", + modelDecision.Explanation.Confidence) + + return modelDecision, nil +} + +// createFallbackDecision creates a fallback decision when routing fails +func (hr *HybridRouter) createFallbackDecision(evalCtx *EvaluationContext, reason string) (*RoutingDecision, error) { + decision := &RoutingDecision{ + RuleMatched: false, + SelectedModel: hr.config.DefaultModel, + UseReasoning: false, + ReasoningEffort: hr.config.DefaultReasoningEffort, + Headers: make(map[string]string), + Explanation: DecisionExplanation{ + DecisionType: "fallback", + Reasoning: reason, + Confidence: 0.0, + }, + } + + return decision, nil +} + +// getReasoningConfig determines reasoning configuration for a category/model +func (hr *HybridRouter) getReasoningConfig(categoryName, modelName string) (bool, string) { + // Find category configuration + for _, category := range hr.config.Categories { + if category.Name == categoryName { + // Find model in category + for _, modelScore := range category.ModelScores { + if modelScore.Model == modelName && modelScore.UseReasoning != nil { + reasoningEffort := category.ReasoningEffort + if reasoningEffort == "" { + reasoningEffort = hr.config.DefaultReasoningEffort + } + return *modelScore.UseReasoning, reasoningEffort + } + } + } + } + + // Default reasoning configuration + return false, hr.config.DefaultReasoningEffort +} + +// IsRulesEnabled returns true if rule-based routing is enabled +func (hr *HybridRouter) IsRulesEnabled() bool { + return hr.ruleEngine != nil && hr.config.RoutingStrategy.RuleRouting.Enabled +} + +// IsModelEnabled returns true if model-based routing is enabled +func (hr *HybridRouter) IsModelEnabled() bool { + return hr.classifier != nil && hr.config.RoutingStrategy.ModelRouting.Enabled +} + +// GetRuleCount returns the number of configured rules +func (hr *HybridRouter) GetRuleCount() int { + if hr.ruleEngine == nil { + return 0 + } + return len(hr.ruleEngine.rules) +} \ No newline at end of file