Skip to content

Commit 067be2c

Browse files
yossiovadiaclaude
andcommitted
fix: wire UnifiedClassifier to ExtProc router for consistent LoRA-based classification
This change ensures the ExtProc router uses the same UnifiedClassifier (LoRA-based) instance as the Classification API, fixing inconsistent model selection behavior. **Problem:** - Classification API (port 8080) used UnifiedClassifier (LoRA models) - ExtProc router (port 8801) used legacy Classifier (traditional BERT) - This caused different classification results for the same query, leading to incorrect model selection in category-based routing **Solution:** 1. Wire UnifiedClassifier from ClassificationService to legacy Classifier 2. Add delegation in Classifier.ClassifyCategoryWithEntropy() to use UnifiedClassifier when available 3. Add GetUnifiedClassifier() method to ClassificationService **Changes:** - router.go: Wire UnifiedClassifier to Classifier during initialization - classifier.go: Delegate to UnifiedClassifier before trying in-tree classifier, add classifyWithUnifiedClassifier() helper method - classification.go: Add GetUnifiedClassifier() getter method Related to vllm-project#640 Co-Authored-By: Claude <[email protected]> Signed-off-by: Yossi Ovadia <[email protected]>
1 parent 8838cbf commit 067be2c

File tree

3 files changed

+77
-5
lines changed

3 files changed

+77
-5
lines changed

src/semantic-router/pkg/classification/classifier.go

Lines changed: 59 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -212,6 +212,9 @@ type Classifier struct {
212212
mcpCategoryInitializer MCPCategoryInitializer
213213
mcpCategoryInference MCPCategoryInference
214214

215+
// NEW: Unified classifier for LoRA models (preferred when available)
216+
UnifiedClassifier *UnifiedClassifier
217+
215218
Config *config.RouterConfig
216219
CategoryMapping *CategoryMapping
217220
PIIMapping *PIIMapping
@@ -539,7 +542,12 @@ func (c *Classifier) ClassifyCategoryWithEntropy(text string) (string, float64,
539542
}
540543
}
541544

542-
// Try in-tree first if properly configured
545+
// Try UnifiedClassifier (LoRA models) first - highest accuracy
546+
if c.UnifiedClassifier != nil {
547+
return c.classifyWithUnifiedClassifier(text)
548+
}
549+
550+
// Try in-tree classifier if properly configured
543551
if c.IsCategoryEnabled() && c.categoryInference != nil {
544552
return c.classifyCategoryWithEntropyInTree(text)
545553
}
@@ -587,6 +595,56 @@ func (c *Classifier) makeReasoningDecisionForKeywordCategory(category string) en
587595
}
588596
}
589597

598+
// classifyWithUnifiedClassifier uses UnifiedClassifier (LoRA models) for classification
599+
func (c *Classifier) classifyWithUnifiedClassifier(text string) (string, float64, entropy.ReasoningDecision, error) {
600+
// Use batch classification with single item
601+
results, err := c.UnifiedClassifier.ClassifyBatch([]string{text})
602+
if err != nil {
603+
return "", 0.0, entropy.ReasoningDecision{}, fmt.Errorf("unified classifier error: %w", err)
604+
}
605+
606+
if len(results.IntentResults) == 0 {
607+
return "", 0.0, entropy.ReasoningDecision{}, fmt.Errorf("no classification results from unified classifier")
608+
}
609+
610+
intentResult := results.IntentResults[0]
611+
category := intentResult.Category
612+
confidence := float64(intentResult.Confidence)
613+
614+
// Build reasoning decision based on category configuration
615+
reasoningDecision := c.makeReasoningDecisionForCategory(category, confidence)
616+
617+
return category, confidence, reasoningDecision, nil
618+
}
619+
620+
// makeReasoningDecisionForCategory creates reasoning decision based on category config
621+
func (c *Classifier) makeReasoningDecisionForCategory(category string, confidence float64) entropy.ReasoningDecision {
622+
normalizedCategory := strings.ToLower(strings.TrimSpace(category))
623+
useReasoning := false
624+
625+
for _, cat := range c.Config.Categories {
626+
if strings.ToLower(cat.Name) == normalizedCategory {
627+
if len(cat.ModelScores) > 0 && cat.ModelScores[0].UseReasoning != nil {
628+
useReasoning = *cat.ModelScores[0].UseReasoning
629+
}
630+
break
631+
}
632+
}
633+
634+
return entropy.ReasoningDecision{
635+
UseReasoning: useReasoning,
636+
Confidence: confidence,
637+
DecisionReason: "unified_lora_classification",
638+
FallbackStrategy: "lora_based_classification",
639+
TopCategories: []entropy.CategoryProbability{
640+
{
641+
Category: category,
642+
Probability: float32(confidence),
643+
},
644+
},
645+
}
646+
}
647+
590648
// classifyCategoryWithEntropyInTree performs category classification with entropy using in-tree model
591649
func (c *Classifier) classifyCategoryWithEntropyInTree(text string) (string, float64, entropy.ReasoningDecision, error) {
592650
if !c.IsCategoryEnabled() {

src/semantic-router/pkg/extproc/router.go

Lines changed: 13 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@ type OpenAIRouter struct {
2323
Config *config.RouterConfig
2424
CategoryDescriptions []string
2525
Classifier *classification.Classifier
26+
ClassificationSvc *services.ClassificationService // NEW: Use service with UnifiedClassifier
2627
PIIChecker *pii.PolicyChecker
2728
Cache cache.CacheBackend
2829
ToolsDatabase *tools.ToolsDatabase
@@ -143,20 +144,28 @@ func NewOpenAIRouter(configPath string) (*OpenAIRouter, error) {
143144

144145
// Create global classification service for API access with auto-discovery
145146
// This will prioritize LoRA models over legacy ModernBERT
147+
var classificationSvc *services.ClassificationService
146148
autoSvc, err := services.NewClassificationServiceWithAutoDiscovery(cfg)
147149
if err != nil {
148150
logging.Warnf("Auto-discovery failed during router initialization: %v, using legacy classifier", err)
149-
services.NewClassificationService(classifier, cfg)
151+
classificationSvc = services.NewClassificationService(classifier, cfg)
150152
} else {
151-
logging.Infof("Router initialization: Using auto-discovered unified classifier")
152-
// The service is already set as global in NewUnifiedClassificationService
153-
_ = autoSvc
153+
classificationSvc = autoSvc
154+
if classificationSvc.HasUnifiedClassifier() {
155+
// Wire the UnifiedClassifier from the service to the legacy Classifier for delegation
156+
unifiedClassifier := classificationSvc.GetUnifiedClassifier()
157+
if unifiedClassifier != nil {
158+
classifier.UnifiedClassifier = unifiedClassifier
159+
logging.Infof("Router using UnifiedClassifier (LoRA models) for category classification")
160+
}
161+
}
154162
}
155163

156164
router := &OpenAIRouter{
157165
Config: cfg,
158166
CategoryDescriptions: categoryDescriptions,
159167
Classifier: classifier,
168+
ClassificationSvc: classificationSvc, // NEW: Store the service
160169
PIIChecker: piiChecker,
161170
Cache: semanticCache,
162171
ToolsDatabase: toolsDatabase,

src/semantic-router/pkg/services/classification.go

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -541,6 +541,11 @@ func (s *ClassificationService) HasUnifiedClassifier() bool {
541541
return s.unifiedClassifier != nil && s.unifiedClassifier.IsInitialized()
542542
}
543543

544+
// GetUnifiedClassifier returns the UnifiedClassifier instance (for delegation)
545+
func (s *ClassificationService) GetUnifiedClassifier() *classification.UnifiedClassifier {
546+
return s.unifiedClassifier
547+
}
548+
544549
// GetUnifiedClassifierStats returns statistics about the unified classifier
545550
func (s *ClassificationService) GetUnifiedClassifierStats() map[string]interface{} {
546551
if s.unifiedClassifier == nil {

0 commit comments

Comments
 (0)