@@ -3,6 +3,7 @@ package services
33import (
44 "fmt"
55 "os"
6+ "strings"
67 "sync"
78 "time"
89
@@ -35,9 +36,9 @@ func NewClassificationService(classifier *classification.Classifier, config *con
3536}
3637
3738// NewUnifiedClassificationService creates a new service with unified classifier
38- func NewUnifiedClassificationService (unifiedClassifier * classification.UnifiedClassifier , config * config.RouterConfig ) * ClassificationService {
39+ func NewUnifiedClassificationService (unifiedClassifier * classification.UnifiedClassifier , legacyClassifier * classification. Classifier , config * config.RouterConfig ) * ClassificationService {
3940 service := & ClassificationService {
40- classifier : nil , // Legacy classifier not used
41+ classifier : legacyClassifier ,
4142 unifiedClassifier : unifiedClassifier ,
4243 config : config ,
4344 }
@@ -54,16 +55,69 @@ func NewClassificationServiceWithAutoDiscovery(config *config.RouterConfig) (*Cl
5455 observability .Debugf ("Debug: Attempting to discover models in: ./models" )
5556
5657 // Always try to auto-discover and initialize unified classifier for batch processing
57- unifiedClassifier , err := classification .AutoInitializeUnifiedClassifier ("./models" )
58+ // Use model path from config, fallback to "./models" if not specified
59+ modelsPath := "./models"
60+ if config != nil && config .Classifier .CategoryModel .ModelID != "" {
61+ // Extract the models directory from the model path
62+ // e.g., "models/category_classifier_modernbert-base_model" -> "models"
63+ if idx := strings .Index (config .Classifier .CategoryModel .ModelID , "/" ); idx > 0 {
64+ modelsPath = config .Classifier .CategoryModel .ModelID [:idx ]
65+ }
66+ }
67+ unifiedClassifier , ucErr := classification .AutoInitializeUnifiedClassifier (modelsPath )
68+ if ucErr != nil {
69+ observability .Infof ("Unified classifier auto-discovery failed: %v" , ucErr )
70+ }
71+ // create legacy classifier
72+ legacyClassifier , lcErr := createLegacyClassifier (config )
73+ if lcErr != nil {
74+ observability .Warnf ("Legacy classifier initialization failed: %v" , lcErr )
75+ }
76+ if unifiedClassifier == nil && legacyClassifier == nil {
77+ observability .Warnf ("No classifier initialized. Using placeholder service." )
78+ }
79+ return NewUnifiedClassificationService (unifiedClassifier , legacyClassifier , config ), nil
80+ }
81+
82+ // createLegacyClassifier creates a legacy classifier with proper model loading
83+ func createLegacyClassifier (config * config.RouterConfig ) (* classification.Classifier , error ) {
84+ // Load category mapping
85+ var categoryMapping * classification.CategoryMapping
86+ if config .Classifier .CategoryModel .CategoryMappingPath != "" {
87+ var err error
88+ categoryMapping , err = classification .LoadCategoryMapping (config .Classifier .CategoryModel .CategoryMappingPath )
89+ if err != nil {
90+ return nil , fmt .Errorf ("failed to load category mapping: %w" , err )
91+ }
92+ }
93+
94+ // Load PII mapping
95+ var piiMapping * classification.PIIMapping
96+ if config .Classifier .PIIModel .PIIMappingPath != "" {
97+ var err error
98+ piiMapping , err = classification .LoadPIIMapping (config .Classifier .PIIModel .PIIMappingPath )
99+ if err != nil {
100+ return nil , fmt .Errorf ("failed to load PII mapping: %w" , err )
101+ }
102+ }
103+
104+ // Load jailbreak mapping
105+ var jailbreakMapping * classification.JailbreakMapping
106+ if config .PromptGuard .JailbreakMappingPath != "" {
107+ var err error
108+ jailbreakMapping , err = classification .LoadJailbreakMapping (config .PromptGuard .JailbreakMappingPath )
109+ if err != nil {
110+ return nil , fmt .Errorf ("failed to load jailbreak mapping: %w" , err )
111+ }
112+ }
113+
114+ // Create classifier
115+ classifier , err := classification .NewClassifier (config , categoryMapping , piiMapping , jailbreakMapping )
58116 if err != nil {
59- // Log the discovery failure but don't fail - fall back to legacy processing
60- observability .Infof ("Unified classifier auto-discovery failed: %v. Using legacy processing." , err )
61- return NewClassificationService (nil , config ), nil
117+ return nil , fmt .Errorf ("failed to create classifier: %w" , err )
62118 }
63119
64- // Success! Create service with unified classifier
65- observability .Infof ("Unified classifier auto-discovered and initialized. Using batch processing." )
66- return NewUnifiedClassificationService (unifiedClassifier , config ), nil
120+ return classifier , nil
67121}
68122
69123// GetGlobalClassificationService returns the global classification service instance
0 commit comments