Skip to content

Commit 39911f4

Browse files
committed
refactor
Signed-off-by: Alex Wang <[email protected]>
1 parent 69e8e92 commit 39911f4

File tree

1 file changed

+24
-19
lines changed

1 file changed

+24
-19
lines changed

src/semantic-router/pkg/utils/classification/classifier.go

Lines changed: 24 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,14 @@ func (c *ModernBertCategoryInference) Classify(text string) (candle_binding.Clas
2828
return candle_binding.ClassifyModernBertText(text)
2929
}
3030

31+
// createCategoryInference creates the appropriate category inference based on configuration
32+
func createCategoryInference(useModernBERT bool) CategoryInference {
33+
if useModernBERT {
34+
return &ModernBertCategoryInference{}
35+
}
36+
return &LinearCategoryInference{}
37+
}
38+
3139
type JailbreakInference interface {
3240
Classify(text string) (candle_binding.ClassResult, error)
3341
}
@@ -44,6 +52,14 @@ func (c *ModernBertJailbreakInference) Classify(text string) (candle_binding.Cla
4452
return candle_binding.ClassifyModernBertJailbreakText(text)
4553
}
4654

55+
// createJailbreakInference creates the appropriate jailbreak inference based on configuration
56+
func createJailbreakInference(useModernBERT bool) JailbreakInference {
57+
if useModernBERT {
58+
return &ModernBertJailbreakInference{}
59+
}
60+
return &LinearJailbreakInference{}
61+
}
62+
4763
type PIIInference interface {
4864
ClassifyTokens(text string, configPath string) (candle_binding.TokenClassificationResult, error)
4965
}
@@ -54,6 +70,11 @@ func (c *ModernBertPIIInference) ClassifyTokens(text string, configPath string)
5470
return candle_binding.ClassifyModernBertPIITokens(text, configPath)
5571
}
5672

73+
// createPIIInference creates the appropriate PII inference (currently only ModernBERT)
74+
func createPIIInference() PIIInference {
75+
return &ModernBertPIIInference{}
76+
}
77+
5778
// JailbreakDetection represents the result of jailbreak analysis for a piece of content
5879
type JailbreakDetection struct {
5980
Content string `json:"content"`
@@ -101,26 +122,10 @@ type Classifier struct {
101122

102123
// NewClassifier creates a new classifier with model selection and jailbreak detection capabilities
103124
func NewClassifier(cfg *config.RouterConfig, categoryMapping *CategoryMapping, piiMapping *PIIMapping, jailbreakMapping *JailbreakMapping, modelTTFT map[string]float64) *Classifier {
104-
var categoryInference CategoryInference
105-
if cfg.Classifier.CategoryModel.UseModernBERT {
106-
categoryInference = &ModernBertCategoryInference{}
107-
} else {
108-
categoryInference = &LinearCategoryInference{}
109-
}
110-
111-
var jailbreakInference JailbreakInference
112-
if cfg.PromptGuard.UseModernBERT {
113-
jailbreakInference = &ModernBertJailbreakInference{}
114-
} else {
115-
jailbreakInference = &LinearJailbreakInference{}
116-
}
117-
118-
piiInference := &ModernBertPIIInference{}
119-
120125
return &Classifier{
121-
categoryInference: categoryInference,
122-
jailbreakInference: jailbreakInference,
123-
piiInference: piiInference,
126+
categoryInference: createCategoryInference(cfg.Classifier.CategoryModel.UseModernBERT),
127+
jailbreakInference: createJailbreakInference(cfg.PromptGuard.UseModernBERT),
128+
piiInference: createPIIInference(),
124129

125130
Config: cfg,
126131
CategoryMapping: categoryMapping,

0 commit comments

Comments
 (0)