Skip to content

Commit ea956c0

Browse files
authored
Merge pull request #57 from aeft/feature/classifier-test-framework
feat: add test framework for classifier with dependency injection
2 parents 80f062b + 15a7437 commit ea956c0

File tree

2 files changed

+484
-16
lines changed

2 files changed

+484
-16
lines changed

src/semantic-router/pkg/utils/classification/classifier.go

Lines changed: 76 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,69 @@ import (
1212
"github.com/vllm-project/semantic-router/semantic-router/pkg/metrics"
1313
)
1414

15+
type CategoryInference interface {
16+
Classify(text string) (candle_binding.ClassResult, error)
17+
}
18+
19+
type LinearCategoryInference struct{}
20+
21+
func (c *LinearCategoryInference) Classify(text string) (candle_binding.ClassResult, error) {
22+
return candle_binding.ClassifyText(text)
23+
}
24+
25+
type ModernBertCategoryInference struct{}
26+
27+
func (c *ModernBertCategoryInference) Classify(text string) (candle_binding.ClassResult, error) {
28+
return candle_binding.ClassifyModernBertText(text)
29+
}
30+
31+
// createCategoryInference creates the appropriate category inference based on configuration
32+
func createCategoryInference(useModernBERT bool) CategoryInference {
33+
if useModernBERT {
34+
return &ModernBertCategoryInference{}
35+
}
36+
return &LinearCategoryInference{}
37+
}
38+
39+
type JailbreakInference interface {
40+
Classify(text string) (candle_binding.ClassResult, error)
41+
}
42+
43+
type LinearJailbreakInference struct{}
44+
45+
func (c *LinearJailbreakInference) Classify(text string) (candle_binding.ClassResult, error) {
46+
return candle_binding.ClassifyJailbreakText(text)
47+
}
48+
49+
type ModernBertJailbreakInference struct{}
50+
51+
func (c *ModernBertJailbreakInference) Classify(text string) (candle_binding.ClassResult, error) {
52+
return candle_binding.ClassifyModernBertJailbreakText(text)
53+
}
54+
55+
// createJailbreakInference creates the appropriate jailbreak inference based on configuration
56+
func createJailbreakInference(useModernBERT bool) JailbreakInference {
57+
if useModernBERT {
58+
return &ModernBertJailbreakInference{}
59+
}
60+
return &LinearJailbreakInference{}
61+
}
62+
63+
type PIIInference interface {
64+
ClassifyTokens(text string, configPath string) (candle_binding.TokenClassificationResult, error)
65+
}
66+
67+
type ModernBertPIIInference struct{}
68+
69+
func (c *ModernBertPIIInference) ClassifyTokens(text string, configPath string) (candle_binding.TokenClassificationResult, error) {
70+
return candle_binding.ClassifyModernBertPIITokens(text, configPath)
71+
}
72+
73+
// createPIIInference creates the appropriate PII inference (currently only ModernBERT)
74+
func createPIIInference() PIIInference {
75+
return &ModernBertPIIInference{}
76+
}
77+
1578
// JailbreakDetection represents the result of jailbreak analysis for a piece of content
1679
type JailbreakDetection struct {
1780
Content string `json:"content"`
@@ -40,6 +103,11 @@ type PIIAnalysisResult struct {
40103

41104
// Classifier handles text classification, model selection, and jailbreak detection functionality
42105
type Classifier struct {
106+
// Dependencies
107+
categoryInference CategoryInference
108+
jailbreakInference JailbreakInference
109+
piiInference PIIInference
110+
43111
Config *config.RouterConfig
44112
CategoryMapping *CategoryMapping
45113
PIIMapping *PIIMapping
@@ -55,6 +123,10 @@ type Classifier struct {
55123
// NewClassifier creates a new classifier with model selection and jailbreak detection capabilities
56124
func NewClassifier(cfg *config.RouterConfig, categoryMapping *CategoryMapping, piiMapping *PIIMapping, jailbreakMapping *JailbreakMapping, modelTTFT map[string]float64) *Classifier {
57125
return &Classifier{
126+
categoryInference: createCategoryInference(cfg.Classifier.CategoryModel.UseModernBERT),
127+
jailbreakInference: createJailbreakInference(cfg.PromptGuard.UseModernBERT),
128+
piiInference: createPIIInference(),
129+
58130
Config: cfg,
59131
CategoryMapping: categoryMapping,
60132
PIIMapping: piiMapping,
@@ -117,13 +189,7 @@ func (c *Classifier) CheckForJailbreak(text string) (bool, string, float32, erro
117189
var err error
118190

119191
start := time.Now()
120-
if c.Config.PromptGuard.UseModernBERT {
121-
// Use ModernBERT jailbreak classifier
122-
result, err = candle_binding.ClassifyModernBertJailbreakText(text)
123-
} else {
124-
// Use linear jailbreak classifier
125-
result, err = candle_binding.ClassifyJailbreakText(text)
126-
}
192+
result, err = c.jailbreakInference.Classify(text)
127193
metrics.RecordClassifierLatency("jailbreak", time.Since(start).Seconds())
128194

129195
if err != nil {
@@ -200,13 +266,7 @@ func (c *Classifier) ClassifyCategory(text string) (string, float64, error) {
200266
var err error
201267

202268
start := time.Now()
203-
if c.Config.Classifier.CategoryModel.UseModernBERT {
204-
// Use ModernBERT classifier
205-
result, err = candle_binding.ClassifyModernBertText(text)
206-
} else {
207-
// Use linear classifier
208-
result, err = candle_binding.ClassifyText(text)
209-
}
269+
result, err = c.categoryInference.Classify(text)
210270
metrics.RecordClassifierLatency("category", time.Since(start).Seconds())
211271

212272
if err != nil {
@@ -249,7 +309,7 @@ func (c *Classifier) ClassifyPII(text string) ([]string, error) {
249309
// Use ModernBERT PII token classifier for entity detection
250310
configPath := fmt.Sprintf("%s/config.json", c.Config.Classifier.PIIModel.ModelID)
251311
start := time.Now()
252-
tokenResult, err := candle_binding.ClassifyModernBertPIITokens(text, configPath)
312+
tokenResult, err := c.piiInference.ClassifyTokens(text, configPath)
253313
metrics.RecordClassifierLatency("pii", time.Since(start).Seconds())
254314
if err != nil {
255315
return nil, fmt.Errorf("PII token classification error: %w", err)
@@ -331,7 +391,7 @@ func (c *Classifier) AnalyzeContentForPII(contentList []string) (bool, []PIIAnal
331391
// Use ModernBERT PII token classifier for detailed analysis
332392
configPath := fmt.Sprintf("%s/config.json", c.Config.Classifier.PIIModel.ModelID)
333393
start := time.Now()
334-
tokenResult, err := candle_binding.ClassifyModernBertPIITokens(content, configPath)
394+
tokenResult, err := c.piiInference.ClassifyTokens(content, configPath)
335395
metrics.RecordClassifierLatency("pii", time.Since(start).Seconds())
336396
if err != nil {
337397
log.Printf("Error analyzing content %d: %v", i, err)

0 commit comments

Comments
 (0)