diff --git a/src/semantic-router/pkg/services/classification.go b/src/semantic-router/pkg/services/classification.go index f58406b0..63890c61 100644 --- a/src/semantic-router/pkg/services/classification.go +++ b/src/semantic-router/pkg/services/classification.go @@ -3,6 +3,7 @@ package services import ( "fmt" "os" + "strings" "sync" "time" @@ -35,9 +36,9 @@ func NewClassificationService(classifier *classification.Classifier, config *con } // NewUnifiedClassificationService creates a new service with unified classifier -func NewUnifiedClassificationService(unifiedClassifier *classification.UnifiedClassifier, config *config.RouterConfig) *ClassificationService { +func NewUnifiedClassificationService(unifiedClassifier *classification.UnifiedClassifier, legacyClassifier *classification.Classifier, config *config.RouterConfig) *ClassificationService { service := &ClassificationService{ - classifier: nil, // Legacy classifier not used + classifier: legacyClassifier, unifiedClassifier: unifiedClassifier, config: config, } @@ -54,16 +55,69 @@ func NewClassificationServiceWithAutoDiscovery(config *config.RouterConfig) (*Cl observability.Debugf("Debug: Attempting to discover models in: ./models") // Always try to auto-discover and initialize unified classifier for batch processing - unifiedClassifier, err := classification.AutoInitializeUnifiedClassifier("./models") + // Use model path from config, fallback to "./models" if not specified + modelsPath := "./models" + if config != nil && config.Classifier.CategoryModel.ModelID != "" { + // Extract the models directory from the model path + // e.g., "models/category_classifier_modernbert-base_model" -> "models" + if idx := strings.Index(config.Classifier.CategoryModel.ModelID, "/"); idx > 0 { + modelsPath = config.Classifier.CategoryModel.ModelID[:idx] + } + } + unifiedClassifier, ucErr := classification.AutoInitializeUnifiedClassifier(modelsPath) + if ucErr != nil { + observability.Infof("Unified classifier auto-discovery failed: %v", ucErr) + } + // create legacy classifier + legacyClassifier, lcErr := createLegacyClassifier(config) + if lcErr != nil { + observability.Warnf("Legacy classifier initialization failed: %v", lcErr) + } + if unifiedClassifier == nil && legacyClassifier == nil { + observability.Warnf("No classifier initialized. Using placeholder service.") + } + return NewUnifiedClassificationService(unifiedClassifier, legacyClassifier, config), nil +} + +// createLegacyClassifier creates a legacy classifier with proper model loading +func createLegacyClassifier(config *config.RouterConfig) (*classification.Classifier, error) { + // Load category mapping + var categoryMapping *classification.CategoryMapping + if config.Classifier.CategoryModel.CategoryMappingPath != "" { + var err error + categoryMapping, err = classification.LoadCategoryMapping(config.Classifier.CategoryModel.CategoryMappingPath) + if err != nil { + return nil, fmt.Errorf("failed to load category mapping: %w", err) + } + } + + // Load PII mapping + var piiMapping *classification.PIIMapping + if config.Classifier.PIIModel.PIIMappingPath != "" { + var err error + piiMapping, err = classification.LoadPIIMapping(config.Classifier.PIIModel.PIIMappingPath) + if err != nil { + return nil, fmt.Errorf("failed to load PII mapping: %w", err) + } + } + + // Load jailbreak mapping + var jailbreakMapping *classification.JailbreakMapping + if config.PromptGuard.JailbreakMappingPath != "" { + var err error + jailbreakMapping, err = classification.LoadJailbreakMapping(config.PromptGuard.JailbreakMappingPath) + if err != nil { + return nil, fmt.Errorf("failed to load jailbreak mapping: %w", err) + } + } + + // Create classifier + classifier, err := classification.NewClassifier(config, categoryMapping, piiMapping, jailbreakMapping) if err != nil { - // Log the discovery failure but don't fail - fall back to legacy processing - observability.Infof("Unified classifier auto-discovery failed: %v. Using legacy processing.", err) - return NewClassificationService(nil, config), nil + return nil, fmt.Errorf("failed to create classifier: %w", err) } - // Success! Create service with unified classifier - observability.Infof("Unified classifier auto-discovered and initialized. Using batch processing.") - return NewUnifiedClassificationService(unifiedClassifier, config), nil + return classifier, nil } // GetGlobalClassificationService returns the global classification service instance