@@ -140,35 +140,55 @@ func createJailbreakInference(useModernBERT bool) JailbreakInference {
140140}
141141
142142type PIIInitializer interface {
143- Init (modelID string , useCPU bool ) error
143+ Init (modelID string , useCPU bool , numClasses int ) error
144144}
145145
146- type ModernBertPIIInitializer struct {}
146+ type PIIInitializerImpl struct {
147+ usedModernBERT bool // Track which init path succeeded for inference routing
148+ }
149+
150+ func (c * PIIInitializerImpl ) Init (modelID string , useCPU bool , numClasses int ) error {
151+ // Try auto-detecting Candle BERT init first - checks for lora_config.json
152+ // This enables LoRA PII models when available
153+ success := candle_binding .InitCandleBertTokenClassifier (modelID , numClasses , useCPU )
154+ if success {
155+ c .usedModernBERT = false
156+ logging .Infof ("Initialized PII token classifier with auto-detection (LoRA or Traditional BERT)" )
157+ return nil
158+ }
147159
148- func (c * ModernBertPIIInitializer ) Init (modelID string , useCPU bool ) error {
160+ // Fallback to ModernBERT-specific init for backward compatibility
161+ // This handles models with incomplete configs (missing hidden_act, etc.)
162+ logging .Infof ("Auto-detection failed, falling back to ModernBERT PII initializer" )
149163 err := candle_binding .InitModernBertPIITokenClassifier (modelID , useCPU )
150164 if err != nil {
151- return err
165+ return fmt . Errorf ( "failed to initialize PII token classifier (both auto-detect and ModernBERT): %w" , err )
152166 }
153- logging .Infof ("Initialized ModernBERT PII token classifier for entity detection" )
167+ c .usedModernBERT = true
168+ logging .Infof ("Initialized ModernBERT PII token classifier (fallback mode)" )
154169 return nil
155170}
156171
157- // createPIIInitializer creates the appropriate PII initializer (currently only ModernBERT)
158- func createPIIInitializer () PIIInitializer { return & ModernBertPIIInitializer {} }
172+ // createPIIInitializer creates the PII initializer (auto-detecting)
173+ func createPIIInitializer () PIIInitializer {
174+ return & PIIInitializerImpl {}
175+ }
159176
160177type PIIInference interface {
161178 ClassifyTokens (text string , configPath string ) (candle_binding.TokenClassificationResult , error )
162179}
163180
164- type ModernBertPIIInference struct {}
181+ type PIIInferenceImpl struct {}
165182
166- func (c * ModernBertPIIInference ) ClassifyTokens (text string , configPath string ) (candle_binding.TokenClassificationResult , error ) {
167- return candle_binding .ClassifyModernBertPIITokens (text , configPath )
183+ func (c * PIIInferenceImpl ) ClassifyTokens (text string , configPath string ) (candle_binding.TokenClassificationResult , error ) {
184+ // Auto-detecting inference - uses whichever classifier was initialized (LoRA or Traditional)
185+ return candle_binding .ClassifyCandleBertTokens (text )
168186}
169187
170- // createPIIInference creates the appropriate PII inference (currently only ModernBERT)
171- func createPIIInference () PIIInference { return & ModernBertPIIInference {} }
188+ // createPIIInference creates the PII inference (auto-detecting)
189+ func createPIIInference () PIIInference {
190+ return & PIIInferenceImpl {}
191+ }
172192
173193// JailbreakDetection represents the result of jailbreak analysis for a piece of content
174194type JailbreakDetection struct {
@@ -348,7 +368,7 @@ func NewClassifier(cfg *config.RouterConfig, categoryMapping *CategoryMapping, p
348368
349369 // Add in-tree classifier if configured
350370 if cfg .CategoryModel .ModelID != "" {
351- options = append (options , withCategory (categoryMapping , createCategoryInitializer (cfg .UseModernBERT ), createCategoryInference (cfg .UseModernBERT )))
371+ options = append (options , withCategory (categoryMapping , createCategoryInitializer (cfg .CategoryModel . UseModernBERT ), createCategoryInference (cfg . CategoryModel .UseModernBERT )))
352372 }
353373
354374 // Add MCP classifier if configured
@@ -509,7 +529,8 @@ func (c *Classifier) initializePIIClassifier() error {
509529 return fmt .Errorf ("not enough PII types for classification, need at least 2, got %d" , numPIIClasses )
510530 }
511531
512- return c .piiInitializer .Init (c .Config .PIIModel .ModelID , c .Config .PIIModel .UseCPU )
532+ // Pass numClasses to support auto-detection
533+ return c .piiInitializer .Init (c .Config .PIIModel .ModelID , c .Config .PIIModel .UseCPU , numPIIClasses )
513534}
514535
515536// EvaluateAllRules evaluates all rule types and returns matched rule names
0 commit comments