diff --git a/src/semantic-router/pkg/classification/classifier.go b/src/semantic-router/pkg/classification/classifier.go index 608132600..6c327d0a9 100644 --- a/src/semantic-router/pkg/classification/classifier.go +++ b/src/semantic-router/pkg/classification/classifier.go @@ -212,6 +212,9 @@ type Classifier struct { mcpCategoryInitializer MCPCategoryInitializer mcpCategoryInference MCPCategoryInference + // NEW: Unified classifier for LoRA models (preferred when available) + UnifiedClassifier *UnifiedClassifier + Config *config.RouterConfig CategoryMapping *CategoryMapping PIIMapping *PIIMapping @@ -539,7 +542,12 @@ func (c *Classifier) ClassifyCategoryWithEntropy(text string) (string, float64, } } - // Try in-tree first if properly configured + // Try UnifiedClassifier (LoRA models) first - highest accuracy + if c.UnifiedClassifier != nil { + return c.classifyWithUnifiedClassifier(text) + } + + // Try in-tree classifier if properly configured if c.IsCategoryEnabled() && c.categoryInference != nil { return c.classifyCategoryWithEntropyInTree(text) } @@ -587,6 +595,56 @@ func (c *Classifier) makeReasoningDecisionForKeywordCategory(category string) en } } +// classifyWithUnifiedClassifier uses UnifiedClassifier (LoRA models) for classification +func (c *Classifier) classifyWithUnifiedClassifier(text string) (string, float64, entropy.ReasoningDecision, error) { + // Use batch classification with single item + results, err := c.UnifiedClassifier.ClassifyBatch([]string{text}) + if err != nil { + return "", 0.0, entropy.ReasoningDecision{}, fmt.Errorf("unified classifier error: %w", err) + } + + if len(results.IntentResults) == 0 { + return "", 0.0, entropy.ReasoningDecision{}, fmt.Errorf("no classification results from unified classifier") + } + + intentResult := results.IntentResults[0] + category := intentResult.Category + confidence := float64(intentResult.Confidence) + + // Build reasoning decision based on category configuration + reasoningDecision := c.makeReasoningDecisionForCategory(category, confidence) + + return category, confidence, reasoningDecision, nil +} + +// makeReasoningDecisionForCategory creates reasoning decision based on category config +func (c *Classifier) makeReasoningDecisionForCategory(category string, confidence float64) entropy.ReasoningDecision { + normalizedCategory := strings.ToLower(strings.TrimSpace(category)) + useReasoning := false + + for _, cat := range c.Config.Categories { + if strings.ToLower(cat.Name) == normalizedCategory { + if len(cat.ModelScores) > 0 && cat.ModelScores[0].UseReasoning != nil { + useReasoning = *cat.ModelScores[0].UseReasoning + } + break + } + } + + return entropy.ReasoningDecision{ + UseReasoning: useReasoning, + Confidence: confidence, + DecisionReason: "unified_lora_classification", + FallbackStrategy: "lora_based_classification", + TopCategories: []entropy.CategoryProbability{ + { + Category: category, + Probability: float32(confidence), + }, + }, + } +} + // classifyCategoryWithEntropyInTree performs category classification with entropy using in-tree model func (c *Classifier) classifyCategoryWithEntropyInTree(text string) (string, float64, entropy.ReasoningDecision, error) { if !c.IsCategoryEnabled() { diff --git a/src/semantic-router/pkg/extproc/router.go b/src/semantic-router/pkg/extproc/router.go index afee16668..a4c5d3144 100644 --- a/src/semantic-router/pkg/extproc/router.go +++ b/src/semantic-router/pkg/extproc/router.go @@ -23,6 +23,7 @@ type OpenAIRouter struct { Config *config.RouterConfig CategoryDescriptions []string Classifier *classification.Classifier + ClassificationSvc *services.ClassificationService // NEW: Use service with UnifiedClassifier PIIChecker *pii.PolicyChecker Cache cache.CacheBackend ToolsDatabase *tools.ToolsDatabase @@ -143,20 +144,28 @@ func NewOpenAIRouter(configPath string) (*OpenAIRouter, error) { // Create global classification service for API access with auto-discovery // This will prioritize LoRA models over legacy ModernBERT + var classificationSvc *services.ClassificationService autoSvc, err := services.NewClassificationServiceWithAutoDiscovery(cfg) if err != nil { logging.Warnf("Auto-discovery failed during router initialization: %v, using legacy classifier", err) - services.NewClassificationService(classifier, cfg) + classificationSvc = services.NewClassificationService(classifier, cfg) } else { - logging.Infof("Router initialization: Using auto-discovered unified classifier") - // The service is already set as global in NewUnifiedClassificationService - _ = autoSvc + classificationSvc = autoSvc + if classificationSvc.HasUnifiedClassifier() { + // Wire the UnifiedClassifier from the service to the legacy Classifier for delegation + unifiedClassifier := classificationSvc.GetUnifiedClassifier() + if unifiedClassifier != nil { + classifier.UnifiedClassifier = unifiedClassifier + logging.Infof("Router using UnifiedClassifier (LoRA models) for category classification") + } + } } router := &OpenAIRouter{ Config: cfg, CategoryDescriptions: categoryDescriptions, Classifier: classifier, + ClassificationSvc: classificationSvc, // NEW: Store the service PIIChecker: piiChecker, Cache: semanticCache, ToolsDatabase: toolsDatabase, diff --git a/src/semantic-router/pkg/services/classification.go b/src/semantic-router/pkg/services/classification.go index f83ed9e59..19107c348 100644 --- a/src/semantic-router/pkg/services/classification.go +++ b/src/semantic-router/pkg/services/classification.go @@ -541,6 +541,11 @@ func (s *ClassificationService) HasUnifiedClassifier() bool { return s.unifiedClassifier != nil && s.unifiedClassifier.IsInitialized() } +// GetUnifiedClassifier returns the UnifiedClassifier instance (for delegation) +func (s *ClassificationService) GetUnifiedClassifier() *classification.UnifiedClassifier { + return s.unifiedClassifier +} + // GetUnifiedClassifierStats returns statistics about the unified classifier func (s *ClassificationService) GetUnifiedClassifierStats() map[string]interface{} { if s.unifiedClassifier == nil {