@@ -28,6 +28,14 @@ func (c *ModernBertCategoryInference) Classify(text string) (candle_binding.Clas
2828 return candle_binding .ClassifyModernBertText (text )
2929}
3030
31+ // createCategoryInference creates the appropriate category inference based on configuration
32+ func createCategoryInference (useModernBERT bool ) CategoryInference {
33+ if useModernBERT {
34+ return & ModernBertCategoryInference {}
35+ }
36+ return & LinearCategoryInference {}
37+ }
38+
3139type JailbreakInference interface {
3240 Classify (text string ) (candle_binding.ClassResult , error )
3341}
@@ -44,6 +52,14 @@ func (c *ModernBertJailbreakInference) Classify(text string) (candle_binding.Cla
4452 return candle_binding .ClassifyModernBertJailbreakText (text )
4553}
4654
55+ // createJailbreakInference creates the appropriate jailbreak inference based on configuration
56+ func createJailbreakInference (useModernBERT bool ) JailbreakInference {
57+ if useModernBERT {
58+ return & ModernBertJailbreakInference {}
59+ }
60+ return & LinearJailbreakInference {}
61+ }
62+
4763type PIIInference interface {
4864 ClassifyTokens (text string , configPath string ) (candle_binding.TokenClassificationResult , error )
4965}
@@ -54,6 +70,11 @@ func (c *ModernBertPIIInference) ClassifyTokens(text string, configPath string)
5470 return candle_binding .ClassifyModernBertPIITokens (text , configPath )
5571}
5672
73+ // createPIIInference creates the appropriate PII inference (currently only ModernBERT)
74+ func createPIIInference () PIIInference {
75+ return & ModernBertPIIInference {}
76+ }
77+
5778// JailbreakDetection represents the result of jailbreak analysis for a piece of content
5879type JailbreakDetection struct {
5980 Content string `json:"content"`
@@ -101,26 +122,10 @@ type Classifier struct {
101122
102123// NewClassifier creates a new classifier with model selection and jailbreak detection capabilities
103124func NewClassifier (cfg * config.RouterConfig , categoryMapping * CategoryMapping , piiMapping * PIIMapping , jailbreakMapping * JailbreakMapping , modelTTFT map [string ]float64 ) * Classifier {
104- var categoryInference CategoryInference
105- if cfg .Classifier .CategoryModel .UseModernBERT {
106- categoryInference = & ModernBertCategoryInference {}
107- } else {
108- categoryInference = & LinearCategoryInference {}
109- }
110-
111- var jailbreakInference JailbreakInference
112- if cfg .PromptGuard .UseModernBERT {
113- jailbreakInference = & ModernBertJailbreakInference {}
114- } else {
115- jailbreakInference = & LinearJailbreakInference {}
116- }
117-
118- piiInference := & ModernBertPIIInference {}
119-
120125 return & Classifier {
121- categoryInference : categoryInference ,
122- jailbreakInference : jailbreakInference ,
123- piiInference : piiInference ,
126+ categoryInference : createCategoryInference ( cfg . Classifier . CategoryModel . UseModernBERT ) ,
127+ jailbreakInference : createJailbreakInference ( cfg . PromptGuard . UseModernBERT ) ,
128+ piiInference : createPIIInference () ,
124129
125130 Config : cfg ,
126131 CategoryMapping : categoryMapping ,
0 commit comments