@@ -3,6 +3,7 @@ package services
3
3
import (
4
4
"fmt"
5
5
"os"
6
+ "strings"
6
7
"sync"
7
8
"time"
8
9
@@ -35,9 +36,9 @@ func NewClassificationService(classifier *classification.Classifier, config *con
35
36
}
36
37
37
38
// NewUnifiedClassificationService creates a new service with unified classifier
38
- func NewUnifiedClassificationService (unifiedClassifier * classification.UnifiedClassifier , config * config.RouterConfig ) * ClassificationService {
39
+ func NewUnifiedClassificationService (unifiedClassifier * classification.UnifiedClassifier , legacyClassifier * classification. Classifier , config * config.RouterConfig ) * ClassificationService {
39
40
service := & ClassificationService {
40
- classifier : nil , // Legacy classifier not used
41
+ classifier : legacyClassifier ,
41
42
unifiedClassifier : unifiedClassifier ,
42
43
config : config ,
43
44
}
@@ -54,16 +55,69 @@ func NewClassificationServiceWithAutoDiscovery(config *config.RouterConfig) (*Cl
54
55
observability .Debugf ("Debug: Attempting to discover models in: ./models" )
55
56
56
57
// Always try to auto-discover and initialize unified classifier for batch processing
57
- unifiedClassifier , err := classification .AutoInitializeUnifiedClassifier ("./models" )
58
+ // Use model path from config, fallback to "./models" if not specified
59
+ modelsPath := "./models"
60
+ if config != nil && config .Classifier .CategoryModel .ModelID != "" {
61
+ // Extract the models directory from the model path
62
+ // e.g., "models/category_classifier_modernbert-base_model" -> "models"
63
+ if idx := strings .Index (config .Classifier .CategoryModel .ModelID , "/" ); idx > 0 {
64
+ modelsPath = config .Classifier .CategoryModel .ModelID [:idx ]
65
+ }
66
+ }
67
+ unifiedClassifier , ucErr := classification .AutoInitializeUnifiedClassifier (modelsPath )
68
+ if ucErr != nil {
69
+ observability .Infof ("Unified classifier auto-discovery failed: %v" , ucErr )
70
+ }
71
+ // create legacy classifier
72
+ legacyClassifier , lcErr := createLegacyClassifier (config )
73
+ if lcErr != nil {
74
+ observability .Warnf ("Legacy classifier initialization failed: %v" , lcErr )
75
+ }
76
+ if unifiedClassifier == nil && legacyClassifier == nil {
77
+ observability .Warnf ("No classifier initialized. Using placeholder service." )
78
+ }
79
+ return NewUnifiedClassificationService (unifiedClassifier , legacyClassifier , config ), nil
80
+ }
81
+
82
+ // createLegacyClassifier creates a legacy classifier with proper model loading
83
+ func createLegacyClassifier (config * config.RouterConfig ) (* classification.Classifier , error ) {
84
+ // Load category mapping
85
+ var categoryMapping * classification.CategoryMapping
86
+ if config .Classifier .CategoryModel .CategoryMappingPath != "" {
87
+ var err error
88
+ categoryMapping , err = classification .LoadCategoryMapping (config .Classifier .CategoryModel .CategoryMappingPath )
89
+ if err != nil {
90
+ return nil , fmt .Errorf ("failed to load category mapping: %w" , err )
91
+ }
92
+ }
93
+
94
+ // Load PII mapping
95
+ var piiMapping * classification.PIIMapping
96
+ if config .Classifier .PIIModel .PIIMappingPath != "" {
97
+ var err error
98
+ piiMapping , err = classification .LoadPIIMapping (config .Classifier .PIIModel .PIIMappingPath )
99
+ if err != nil {
100
+ return nil , fmt .Errorf ("failed to load PII mapping: %w" , err )
101
+ }
102
+ }
103
+
104
+ // Load jailbreak mapping
105
+ var jailbreakMapping * classification.JailbreakMapping
106
+ if config .PromptGuard .JailbreakMappingPath != "" {
107
+ var err error
108
+ jailbreakMapping , err = classification .LoadJailbreakMapping (config .PromptGuard .JailbreakMappingPath )
109
+ if err != nil {
110
+ return nil , fmt .Errorf ("failed to load jailbreak mapping: %w" , err )
111
+ }
112
+ }
113
+
114
+ // Create classifier
115
+ classifier , err := classification .NewClassifier (config , categoryMapping , piiMapping , jailbreakMapping )
58
116
if err != nil {
59
- // Log the discovery failure but don't fail - fall back to legacy processing
60
- observability .Infof ("Unified classifier auto-discovery failed: %v. Using legacy processing." , err )
61
- return NewClassificationService (nil , config ), nil
117
+ return nil , fmt .Errorf ("failed to create classifier: %w" , err )
62
118
}
63
119
64
- // Success! Create service with unified classifier
65
- observability .Infof ("Unified classifier auto-discovered and initialized. Using batch processing." )
66
- return NewUnifiedClassificationService (unifiedClassifier , config ), nil
120
+ return classifier , nil
67
121
}
68
122
69
123
// GetGlobalClassificationService returns the global classification service instance
0 commit comments