@@ -18,67 +18,58 @@ type CategoryInitializer interface {
1818 Init (modelID string , useCPU bool , numClasses ... int ) error
1919}
2020
21- type LinearCategoryInitializer struct {}
22-
23- func (c * LinearCategoryInitializer ) Init (modelID string , useCPU bool , numClasses ... int ) error {
24- err := candle_binding .InitClassifier (modelID , numClasses [0 ], useCPU )
25- if err != nil {
26- return err
27- }
28- logging .Infof ("Initialized linear category classifier with %d classes" , numClasses [0 ])
29- return nil
21+ type CategoryInitializerImpl struct {
22+ usedModernBERT bool // Track which init path succeeded for inference routing
3023}
3124
32- type ModernBertCategoryInitializer struct {}
25+ func (c * CategoryInitializerImpl ) Init (modelID string , useCPU bool , numClasses ... int ) error {
26+ // Try auto-detecting Candle BERT init first - checks for lora_config.json
27+ // This enables LoRA Intent/Category models when available
28+ success := candle_binding .InitCandleBertClassifier (modelID , numClasses [0 ], useCPU )
29+ if success {
30+ c .usedModernBERT = false
31+ logging .Infof ("Initialized category classifier with auto-detection (LoRA or Traditional BERT)" )
32+ return nil
33+ }
3334
34- func (c * ModernBertCategoryInitializer ) Init (modelID string , useCPU bool , numClasses ... int ) error {
35+ // Fallback to ModernBERT-specific init for backward compatibility
36+ // This handles models with incomplete configs (missing hidden_act, etc.)
37+ logging .Infof ("Auto-detection failed, falling back to ModernBERT category initializer" )
3538 err := candle_binding .InitModernBertClassifier (modelID , useCPU )
3639 if err != nil {
37- return err
40+ return fmt . Errorf ( "failed to initialize category classifier (both auto-detect and ModernBERT): %w" , err )
3841 }
39- logging .Infof ("Initialized ModernBERT category classifier (classes auto-detected from model)" )
42+ c .usedModernBERT = true
43+ logging .Infof ("Initialized ModernBERT category classifier (fallback mode)" )
4044 return nil
4145}
4246
43- // createCategoryInitializer creates the appropriate category initializer based on configuration
44- func createCategoryInitializer (useModernBERT bool ) CategoryInitializer {
45- if useModernBERT {
46- return & ModernBertCategoryInitializer {}
47- }
48- return & LinearCategoryInitializer {}
47+ // createCategoryInitializer creates the category initializer (auto-detecting)
48+ func createCategoryInitializer () CategoryInitializer {
49+ return & CategoryInitializerImpl {}
4950}
5051
5152type CategoryInference interface {
5253 Classify (text string ) (candle_binding.ClassResult , error )
5354 ClassifyWithProbabilities (text string ) (candle_binding.ClassResultWithProbs , error )
5455}
5556
56- type LinearCategoryInference struct {}
57-
58- func (c * LinearCategoryInference ) Classify (text string ) (candle_binding.ClassResult , error ) {
59- return candle_binding .ClassifyText (text )
60- }
61-
62- func (c * LinearCategoryInference ) ClassifyWithProbabilities (text string ) (candle_binding.ClassResultWithProbs , error ) {
63- return candle_binding .ClassifyTextWithProbabilities (text )
64- }
65-
66- type ModernBertCategoryInference struct {}
57+ type CategoryInferenceImpl struct {}
6758
68- func (c * ModernBertCategoryInference ) Classify (text string ) (candle_binding.ClassResult , error ) {
69- return candle_binding .ClassifyModernBertText (text )
59+ func (c * CategoryInferenceImpl ) Classify (text string ) (candle_binding.ClassResult , error ) {
60+ // Auto-detecting inference - uses whichever classifier was initialized (LoRA or Traditional)
61+ return candle_binding .ClassifyCandleBertText (text )
7062}
7163
72- func (c * ModernBertCategoryInference ) ClassifyWithProbabilities (text string ) (candle_binding.ClassResultWithProbs , error ) {
64+ func (c * CategoryInferenceImpl ) ClassifyWithProbabilities (text string ) (candle_binding.ClassResultWithProbs , error ) {
65+ // Note: CandleBert doesn't have WithProbabilities yet, fall back to ModernBERT
66+ // This will work correctly if ModernBERT was initialized as fallback
7367 return candle_binding .ClassifyModernBertTextWithProbabilities (text )
7468}
7569
76- // createCategoryInference creates the appropriate category inference based on configuration
77- func createCategoryInference (useModernBERT bool ) CategoryInference {
78- if useModernBERT {
79- return & ModernBertCategoryInference {}
80- }
81- return & LinearCategoryInference {}
70+ // createCategoryInference creates the category inference (auto-detecting)
71+ func createCategoryInference () CategoryInference {
72+ return & CategoryInferenceImpl {}
8273}
8374
8475type JailbreakInitializer interface {
@@ -368,7 +359,7 @@ func NewClassifier(cfg *config.RouterConfig, categoryMapping *CategoryMapping, p
368359
369360 // Add in-tree classifier if configured
370361 if cfg .CategoryModel .ModelID != "" {
371- options = append (options , withCategory (categoryMapping , createCategoryInitializer (cfg . CategoryModel . UseModernBERT ), createCategoryInference (cfg . CategoryModel . UseModernBERT )))
362+ options = append (options , withCategory (categoryMapping , createCategoryInitializer (), createCategoryInference ()))
372363 }
373364
374365 // Add MCP classifier if configured
@@ -386,7 +377,7 @@ func NewClassifier(cfg *config.RouterConfig, categoryMapping *CategoryMapping, p
386377
387378// IsCategoryEnabled checks if category classification is properly configured
388379func (c * Classifier ) IsCategoryEnabled () bool {
389- return c .Config .CategoryModel .ModelID != "" && c .Config .CategoryMappingPath != "" && c .CategoryMapping != nil
380+ return c .Config .CategoryModel .ModelID != "" && c .Config .CategoryModel . CategoryMappingPath != "" && c .CategoryMapping != nil
390381}
391382
392383// initializeCategoryClassifier initializes the category classification model
0 commit comments