@@ -37,6 +37,40 @@ func createCategoryInference(useModernBERT bool) CategoryInference {
3737 return & LinearCategoryInference {}
3838}
3939
40+ type JailbreakInitializer interface {
41+ Init (modelID string , useCPU bool , numClasses ... int ) error
42+ }
43+
44+ type LinearJailbreakInitializer struct {}
45+
46+ func (c * LinearJailbreakInitializer ) Init (modelID string , useCPU bool , numClasses ... int ) error {
47+ err := candle_binding .InitJailbreakClassifier (modelID , numClasses [0 ], useCPU )
48+ if err != nil {
49+ return fmt .Errorf ("failed to initialize jailbreak classifier: %w" , err )
50+ }
51+ log .Printf ("Initialized linear jailbreak classifier with %d classes" , numClasses [0 ])
52+ return nil
53+ }
54+
55+ type ModernBertJailbreakInitializer struct {}
56+
57+ func (c * ModernBertJailbreakInitializer ) Init (modelID string , useCPU bool , numClasses ... int ) error {
58+ err := candle_binding .InitModernBertJailbreakClassifier (modelID , useCPU )
59+ if err != nil {
60+ return fmt .Errorf ("failed to initialize ModernBERT jailbreak classifier: %w" , err )
61+ }
62+ log .Printf ("Initialized ModernBERT jailbreak classifier (classes auto-detected from model)" )
63+ return nil
64+ }
65+
66+ // createJailbreakInitializer creates the appropriate jailbreak initializer based on configuration
67+ func createJailbreakInitializer (useModernBERT bool ) JailbreakInitializer {
68+ if useModernBERT {
69+ return & ModernBertJailbreakInitializer {}
70+ }
71+ return & LinearJailbreakInitializer {}
72+ }
73+
4074type JailbreakInference interface {
4175 Classify (text string ) (candle_binding.ClassResult , error )
4276}
@@ -105,9 +139,10 @@ type PIIAnalysisResult struct {
105139// Classifier handles text classification, model selection, and jailbreak detection functionality
106140type Classifier struct {
107141 // Dependencies
108- categoryInference CategoryInference
109- jailbreakInference JailbreakInference
110- piiInference PIIInference
142+ categoryInference CategoryInference
143+ jailbreakInitializer JailbreakInitializer
144+ jailbreakInference JailbreakInference
145+ piiInference PIIInference
111146
112147 Config * config.RouterConfig
113148 CategoryMapping * CategoryMapping
@@ -124,9 +159,10 @@ type Classifier struct {
124159// NewClassifier creates a new classifier with model selection and jailbreak detection capabilities
125160func NewClassifier (cfg * config.RouterConfig , categoryMapping * CategoryMapping , piiMapping * PIIMapping , jailbreakMapping * JailbreakMapping , modelTTFT map [string ]float64 ) * Classifier {
126161 return & Classifier {
127- categoryInference : createCategoryInference (cfg .Classifier .CategoryModel .UseModernBERT ),
128- jailbreakInference : createJailbreakInference (cfg .PromptGuard .UseModernBERT ),
129- piiInference : createPIIInference (),
162+ categoryInference : createCategoryInference (cfg .Classifier .CategoryModel .UseModernBERT ),
163+ jailbreakInitializer : createJailbreakInitializer (cfg .PromptGuard .UseModernBERT ),
164+ jailbreakInference : createJailbreakInference (cfg .PromptGuard .UseModernBERT ),
165+ piiInference : createPIIInference (),
130166
131167 Config : cfg ,
132168 CategoryMapping : categoryMapping ,
@@ -149,21 +185,8 @@ func (c *Classifier) InitializeJailbreakClassifier() error {
149185 return fmt .Errorf ("not enough jailbreak types for classification, need at least 2, got %d" , numClasses )
150186 }
151187
152- var err error
153- if c .Config .PromptGuard .UseModernBERT {
154- // Initialize ModernBERT jailbreak classifier
155- err = candle_binding .InitModernBertJailbreakClassifier (c .Config .PromptGuard .ModelID , c .Config .PromptGuard .UseCPU )
156- if err != nil {
157- return fmt .Errorf ("failed to initialize ModernBERT jailbreak classifier: %w" , err )
158- }
159- log .Printf ("Initialized ModernBERT jailbreak classifier (classes auto-detected from model)" )
160- } else {
161- // Initialize linear jailbreak classifier
162- err = candle_binding .InitJailbreakClassifier (c .Config .PromptGuard .ModelID , numClasses , c .Config .PromptGuard .UseCPU )
163- if err != nil {
164- return fmt .Errorf ("failed to initialize jailbreak classifier: %w" , err )
165- }
166- log .Printf ("Initialized linear jailbreak classifier with %d classes" , numClasses )
188+ if err := c .jailbreakInitializer .Init (c .Config .PromptGuard .ModelID , c .Config .PromptGuard .UseCPU , numClasses ); err != nil {
189+ return err
167190 }
168191
169192 c .JailbreakInitialized = true
0 commit comments