Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
60 changes: 59 additions & 1 deletion src/semantic-router/pkg/classification/classifier.go
Original file line number Diff line number Diff line change
Expand Up @@ -212,6 +212,9 @@ type Classifier struct {
mcpCategoryInitializer MCPCategoryInitializer
mcpCategoryInference MCPCategoryInference

// NEW: Unified classifier for LoRA models (preferred when available)
UnifiedClassifier *UnifiedClassifier

Config *config.RouterConfig
CategoryMapping *CategoryMapping
PIIMapping *PIIMapping
Expand Down Expand Up @@ -539,7 +542,12 @@ func (c *Classifier) ClassifyCategoryWithEntropy(text string) (string, float64,
}
}

// Try in-tree first if properly configured
// Try UnifiedClassifier (LoRA models) first - highest accuracy
if c.UnifiedClassifier != nil {
return c.classifyWithUnifiedClassifier(text)
}

// Try in-tree classifier if properly configured
if c.IsCategoryEnabled() && c.categoryInference != nil {
return c.classifyCategoryWithEntropyInTree(text)
}
Expand Down Expand Up @@ -587,6 +595,56 @@ func (c *Classifier) makeReasoningDecisionForKeywordCategory(category string) en
}
}

// classifyWithUnifiedClassifier uses UnifiedClassifier (LoRA models) for classification
func (c *Classifier) classifyWithUnifiedClassifier(text string) (string, float64, entropy.ReasoningDecision, error) {
// Use batch classification with single item
results, err := c.UnifiedClassifier.ClassifyBatch([]string{text})
if err != nil {
return "", 0.0, entropy.ReasoningDecision{}, fmt.Errorf("unified classifier error: %w", err)
}

if len(results.IntentResults) == 0 {
return "", 0.0, entropy.ReasoningDecision{}, fmt.Errorf("no classification results from unified classifier")
}

intentResult := results.IntentResults[0]
category := intentResult.Category
confidence := float64(intentResult.Confidence)

// Build reasoning decision based on category configuration
reasoningDecision := c.makeReasoningDecisionForCategory(category, confidence)

return category, confidence, reasoningDecision, nil
}

// makeReasoningDecisionForCategory creates reasoning decision based on category config
func (c *Classifier) makeReasoningDecisionForCategory(category string, confidence float64) entropy.ReasoningDecision {
normalizedCategory := strings.ToLower(strings.TrimSpace(category))
useReasoning := false

for _, cat := range c.Config.Categories {
if strings.ToLower(cat.Name) == normalizedCategory {
if len(cat.ModelScores) > 0 && cat.ModelScores[0].UseReasoning != nil {
useReasoning = *cat.ModelScores[0].UseReasoning
}
break
}
}

return entropy.ReasoningDecision{
UseReasoning: useReasoning,
Confidence: confidence,
DecisionReason: "unified_lora_classification",
FallbackStrategy: "lora_based_classification",
TopCategories: []entropy.CategoryProbability{
{
Category: category,
Probability: float32(confidence),
},
},
}
}

// classifyCategoryWithEntropyInTree performs category classification with entropy using in-tree model
func (c *Classifier) classifyCategoryWithEntropyInTree(text string) (string, float64, entropy.ReasoningDecision, error) {
if !c.IsCategoryEnabled() {
Expand Down
17 changes: 13 additions & 4 deletions src/semantic-router/pkg/extproc/router.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ type OpenAIRouter struct {
Config *config.RouterConfig
CategoryDescriptions []string
Classifier *classification.Classifier
ClassificationSvc *services.ClassificationService // NEW: Use service with UnifiedClassifier
PIIChecker *pii.PolicyChecker
Cache cache.CacheBackend
ToolsDatabase *tools.ToolsDatabase
Expand Down Expand Up @@ -143,20 +144,28 @@ func NewOpenAIRouter(configPath string) (*OpenAIRouter, error) {

// Create global classification service for API access with auto-discovery
// This will prioritize LoRA models over legacy ModernBERT
var classificationSvc *services.ClassificationService
autoSvc, err := services.NewClassificationServiceWithAutoDiscovery(cfg)
if err != nil {
logging.Warnf("Auto-discovery failed during router initialization: %v, using legacy classifier", err)
services.NewClassificationService(classifier, cfg)
classificationSvc = services.NewClassificationService(classifier, cfg)
} else {
logging.Infof("Router initialization: Using auto-discovered unified classifier")
// The service is already set as global in NewUnifiedClassificationService
_ = autoSvc
classificationSvc = autoSvc
if classificationSvc.HasUnifiedClassifier() {
// Wire the UnifiedClassifier from the service to the legacy Classifier for delegation
unifiedClassifier := classificationSvc.GetUnifiedClassifier()
if unifiedClassifier != nil {
classifier.UnifiedClassifier = unifiedClassifier
logging.Infof("Router using UnifiedClassifier (LoRA models) for category classification")
}
}
}

router := &OpenAIRouter{
Config: cfg,
CategoryDescriptions: categoryDescriptions,
Classifier: classifier,
ClassificationSvc: classificationSvc, // NEW: Store the service
PIIChecker: piiChecker,
Cache: semanticCache,
ToolsDatabase: toolsDatabase,
Expand Down
5 changes: 5 additions & 0 deletions src/semantic-router/pkg/services/classification.go
Original file line number Diff line number Diff line change
Expand Up @@ -541,6 +541,11 @@ func (s *ClassificationService) HasUnifiedClassifier() bool {
return s.unifiedClassifier != nil && s.unifiedClassifier.IsInitialized()
}

// GetUnifiedClassifier returns the UnifiedClassifier instance (for delegation)
func (s *ClassificationService) GetUnifiedClassifier() *classification.UnifiedClassifier {
return s.unifiedClassifier
}

// GetUnifiedClassifierStats returns statistics about the unified classifier
func (s *ClassificationService) GetUnifiedClassifierStats() map[string]interface{} {
if s.unifiedClassifier == nil {
Expand Down
Loading