@@ -28,6 +28,14 @@ func (c *ModernBertCategoryInference) Classify(text string) (candle_binding.Clas
2828 return candle_binding.ClassifyModernBertText(text)
2929}
3030
31+ // createCategoryInference creates the appropriate category inference based on configuration
32+ func createCategoryInference(useModernBERT bool) CategoryInference {
33+ if useModernBERT {
34+ return &ModernBertCategoryInference{}
35+ }
36+ return &LinearCategoryInference{}
37+ }
38+
3139type JailbreakInference interface {
3240 Classify(text string) (candle_binding.ClassResult, error)
3341}
@@ -44,6 +52,14 @@ func (c *ModernBertJailbreakInference) Classify(text string) (candle_binding.Cla
4452 return candle_binding.ClassifyModernBertJailbreakText(text)
4553}
4654
55+ // createJailbreakInference creates the appropriate jailbreak inference based on configuration
56+ func createJailbreakInference(useModernBERT bool) JailbreakInference {
57+ if useModernBERT {
58+ return &ModernBertJailbreakInference{}
59+ }
60+ return &LinearJailbreakInference{}
61+ }
62+
4763type PIIInference interface {
4864 ClassifyTokens(text string, configPath string) (candle_binding.TokenClassificationResult, error)
4965}
@@ -54,6 +70,11 @@ func (c *ModernBertPIIInference) ClassifyTokens(text string, configPath string)
5470 return candle_binding.ClassifyModernBertPIITokens(text, configPath)
5571}
5672
73+ // createPIIInference creates the appropriate PII inference (currently only ModernBERT)
74+ func createPIIInference() PIIInference {
75+ return &ModernBertPIIInference{}
76+ }
77+
5778// JailbreakDetection represents the result of jailbreak analysis for a piece of content
5879type JailbreakDetection struct {
5980 Content string `json:"content"`
@@ -101,26 +122,10 @@ type Classifier struct {
101122
102123// NewClassifier creates a new classifier with model selection and jailbreak detection capabilities
103124func NewClassifier(cfg *config.RouterConfig, categoryMapping *CategoryMapping, piiMapping *PIIMapping, jailbreakMapping *JailbreakMapping, modelTTFT map[string]float64) *Classifier {
104- var categoryInference CategoryInference
105- if cfg.Classifier.CategoryModel.UseModernBERT {
106- categoryInference = &ModernBertCategoryInference{}
107- } else {
108- categoryInference = &LinearCategoryInference{}
109- }
110-
111- var jailbreakInference JailbreakInference
112- if cfg.PromptGuard.UseModernBERT {
113- jailbreakInference = &ModernBertJailbreakInference{}
114- } else {
115- jailbreakInference = &LinearJailbreakInference{}
116- }
117-
118- piiInference := &ModernBertPIIInference{}
119-
120125 return &Classifier{
121- categoryInference: categoryInference ,
122- jailbreakInference: jailbreakInference ,
123- piiInference: piiInference ,
126+ categoryInference: createCategoryInference(cfg.Classifier.CategoryModel.UseModernBERT) ,
127+ jailbreakInference: createJailbreakInference(cfg.PromptGuard.UseModernBERT) ,
128+ piiInference: createPIIInference() ,
124129
125130 Config: cfg,
126131 CategoryMapping: categoryMapping,
0 commit comments