Skip to content

Commit 5d0bbd1

Browse files
committed
update dependency injection
Signed-off-by: Alex Wang <[email protected]>
1 parent ea1a40d commit 5d0bbd1

File tree

2 files changed

+110
-84
lines changed

2 files changed

+110
-84
lines changed

src/semantic-router/pkg/utils/classification/classifier.go

Lines changed: 43 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -12,22 +12,38 @@ import (
1212
"github.com/vllm-project/semantic-router/semantic-router/pkg/metrics"
1313
)
1414

15-
type ModelInference interface {
16-
// Category classification
17-
ClassifyText(text string) (candle_binding.ClassResult, error)
18-
ClassifyModernBertText(text string) (candle_binding.ClassResult, error)
15+
type CategoryInference interface {
16+
Classify(text string) (candle_binding.ClassResult, error)
1917
}
2018

21-
type CandleModelInference struct{}
19+
type LinearCategoryInference struct{}
2220

23-
func (c *CandleModelInference) ClassifyText(text string) (candle_binding.ClassResult, error) {
21+
func (c *LinearCategoryInference) Classify(text string) (candle_binding.ClassResult, error) {
2422
return candle_binding.ClassifyText(text)
2523
}
2624

27-
func (c *CandleModelInference) ClassifyModernBertText(text string) (candle_binding.ClassResult, error) {
25+
type ModernBERTCategoryInference struct{}
26+
27+
func (c *ModernBERTCategoryInference) Classify(text string) (candle_binding.ClassResult, error) {
2828
return candle_binding.ClassifyModernBertText(text)
2929
}
3030

31+
type JailbreakInference interface {
32+
Classify(text string) (candle_binding.ClassResult, error)
33+
}
34+
35+
type LinearJailbreakInference struct{}
36+
37+
func (c *LinearJailbreakInference) Classify(text string) (candle_binding.ClassResult, error) {
38+
return candle_binding.ClassifyJailbreakText(text)
39+
}
40+
41+
type ModernBERTJailbreakInference struct{}
42+
43+
func (c *ModernBERTJailbreakInference) Classify(text string) (candle_binding.ClassResult, error) {
44+
return candle_binding.ClassifyModernBertJailbreakText(text)
45+
}
46+
3147
// JailbreakDetection represents the result of jailbreak analysis for a piece of content
3248
type JailbreakDetection struct {
3349
Content string `json:"content"`
@@ -57,7 +73,8 @@ type PIIAnalysisResult struct {
5773
// Classifier handles text classification, model selection, and jailbreak detection functionality
5874
type Classifier struct {
5975
// Dependencies
60-
modelInference ModelInference
76+
categoryInference CategoryInference
77+
jailbreakInference JailbreakInference
6178

6279
Config *config.RouterConfig
6380
CategoryMapping *CategoryMapping
@@ -73,8 +90,23 @@ type Classifier struct {
7390

7491
// NewClassifier creates a new classifier with model selection and jailbreak detection capabilities
7592
func NewClassifier(cfg *config.RouterConfig, categoryMapping *CategoryMapping, piiMapping *PIIMapping, jailbreakMapping *JailbreakMapping, modelTTFT map[string]float64) *Classifier {
93+
var categoryInference CategoryInference
94+
if cfg.Classifier.CategoryModel.UseModernBERT {
95+
categoryInference = &ModernBERTCategoryInference{}
96+
} else {
97+
categoryInference = &LinearCategoryInference{}
98+
}
99+
100+
var jailbreakInference JailbreakInference
101+
if cfg.PromptGuard.UseModernBERT {
102+
jailbreakInference = &ModernBERTJailbreakInference{}
103+
} else {
104+
jailbreakInference = &LinearJailbreakInference{}
105+
}
106+
76107
return &Classifier{
77-
modelInference: &CandleModelInference{},
108+
categoryInference: categoryInference,
109+
jailbreakInference: jailbreakInference,
78110
Config: cfg,
79111
CategoryMapping: categoryMapping,
80112
PIIMapping: piiMapping,
@@ -137,13 +169,7 @@ func (c *Classifier) CheckForJailbreak(text string) (bool, string, float32, erro
137169
var err error
138170

139171
start := time.Now()
140-
if c.Config.PromptGuard.UseModernBERT {
141-
// Use ModernBERT jailbreak classifier
142-
result, err = candle_binding.ClassifyModernBertJailbreakText(text)
143-
} else {
144-
// Use linear jailbreak classifier
145-
result, err = candle_binding.ClassifyJailbreakText(text)
146-
}
172+
result, err = c.jailbreakInference.Classify(text)
147173
metrics.RecordClassifierLatency("jailbreak", time.Since(start).Seconds())
148174

149175
if err != nil {
@@ -220,13 +246,7 @@ func (c *Classifier) ClassifyCategory(text string) (string, float64, error) {
220246
var err error
221247

222248
start := time.Now()
223-
if c.Config.Classifier.CategoryModel.UseModernBERT {
224-
// Use ModernBERT classifier
225-
result, err = c.modelInference.ClassifyModernBertText(text)
226-
} else {
227-
// Use linear classifier
228-
result, err = c.modelInference.ClassifyText(text)
229-
}
249+
result, err = c.categoryInference.Classify(text)
230250
metrics.RecordClassifierLatency("category", time.Since(start).Seconds())
231251

232252
if err != nil {

src/semantic-router/pkg/utils/classification/classifier_test.go

Lines changed: 67 additions & 61 deletions
Original file line numberDiff line numberDiff line change
@@ -16,100 +16,106 @@ func TestClassifier(t *testing.T) {
1616
RunSpecs(t, "Classifier Suite")
1717
}
1818

19-
// MockModelInference implements ModelInference interface for testing
20-
type MockModelInference struct {
21-
classifyTextResult candle_binding.ClassResult
22-
classifyTextError error
23-
classifyModernBertResult candle_binding.ClassResult
24-
classifyModernBertError error
19+
type MockCategoryInference struct {
20+
classifyResult candle_binding.ClassResult
21+
classifyError error
2522
}
2623

27-
func (m *MockModelInference) ClassifyText(text string) (candle_binding.ClassResult, error) {
28-
return m.classifyTextResult, m.classifyTextError
24+
func (m *MockCategoryInference) Classify(text string) (candle_binding.ClassResult, error) {
25+
return m.classifyResult, m.classifyError
2926
}
3027

31-
func (m *MockModelInference) ClassifyModernBertText(text string) (candle_binding.ClassResult, error) {
32-
return m.classifyModernBertResult, m.classifyModernBertError
33-
}
34-
35-
var _ = Describe("Classifier", func() {
28+
var _ = Describe("ClassifyCategory", func() {
3629
var (
37-
classifier *Classifier
38-
mockModel *MockModelInference
30+
classifier *Classifier
31+
mockCategoryModel *MockCategoryInference
3932
)
4033

4134
BeforeEach(func() {
42-
mockModel = &MockModelInference{}
35+
mockCategoryModel = &MockCategoryInference{}
4336
cfg := &config.RouterConfig{}
4437
cfg.Classifier.CategoryModel.Threshold = 0.5 // Set threshold for testing
4538

4639
classifier = &Classifier{
47-
modelInference: mockModel,
48-
Config: cfg,
40+
categoryInference: mockCategoryModel,
41+
Config: cfg,
4942
CategoryMapping: &CategoryMapping{
5043
CategoryToIdx: map[string]int{"technology": 0, "sports": 1, "politics": 2},
5144
IdxToCategory: map[string]string{"0": "technology", "1": "sports", "2": "politics"},
5245
},
5346
}
5447
})
5548

56-
Describe("ClassifyCategory", func() {
57-
Context("when classification succeeds with high confidence", func() {
58-
It("should return the correct category", func() {
59-
mockModel.classifyTextResult = candle_binding.ClassResult{
60-
Class: 2,
61-
Confidence: 0.95,
62-
}
49+
Context("when classification succeeds with high confidence", func() {
50+
It("should return the correct category", func() {
51+
mockCategoryModel.classifyResult = candle_binding.ClassResult{
52+
Class: 2,
53+
Confidence: 0.95,
54+
}
55+
56+
category, score, err := classifier.ClassifyCategory("This is about politics")
57+
58+
Expect(err).To(BeNil())
59+
Expect(category).To(Equal("politics"))
60+
Expect(score).To(BeNumerically("~", 0.95, 0.001))
61+
})
62+
})
63+
64+
Context("when classification has low confidence below threshold", func() {
65+
It("should return empty category", func() {
66+
mockCategoryModel.classifyResult = candle_binding.ClassResult{
67+
Class: 0,
68+
Confidence: 0.3,
69+
}
6370

64-
category, score, err := classifier.ClassifyCategory("This is about politics")
71+
category, score, err := classifier.ClassifyCategory("Ambiguous text")
6572

66-
Expect(err).To(BeNil())
67-
Expect(category).To(Equal("politics"))
68-
Expect(score).To(BeNumerically("~", 0.95, 0.001))
69-
})
73+
Expect(err).To(BeNil())
74+
Expect(category).To(Equal(""))
75+
Expect(score).To(BeNumerically("~", 0.3, 0.001))
7076
})
77+
})
7178

72-
Context("when classification has low confidence below threshold", func() {
73-
It("should return empty category", func() {
74-
mockModel.classifyTextResult = candle_binding.ClassResult{
75-
Class: 0,
76-
Confidence: 0.3,
77-
}
79+
Context("when BERT model returns error", func() {
80+
It("should return empty category with zero score", func() {
81+
mockCategoryModel.classifyError = errors.New("model inference failed")
7882

79-
category, score, err := classifier.ClassifyCategory("Ambiguous text")
83+
category, score, err := classifier.ClassifyCategory("Some text")
8084

81-
Expect(err).To(BeNil())
82-
Expect(category).To(Equal(""))
83-
Expect(score).To(BeNumerically("~", 0.3, 0.001))
84-
})
85+
Expect(err).ToNot(BeNil())
86+
Expect(category).To(Equal(""))
87+
Expect(score).To(BeNumerically("~", 0.0, 0.001))
8588
})
89+
})
8690

87-
Context("when BERT model returns error", func() {
88-
It("should return unknown category with zero score", func() {
89-
mockModel.classifyTextError = errors.New("model inference failed")
91+
Context("when input is empty or invalid", func() {
92+
It("should handle empty text gracefully", func() {
93+
mockCategoryModel.classifyResult = candle_binding.ClassResult{
94+
Class: 0,
95+
Confidence: 0.8,
96+
}
9097

91-
category, score, err := classifier.ClassifyCategory("Some text")
98+
category, score, err := classifier.ClassifyCategory("")
9299

93-
Expect(err).ToNot(BeNil())
94-
Expect(category).To(Equal(""))
95-
Expect(score).To(BeNumerically("~", 0.0, 0.001))
96-
})
100+
// Should still attempt classification
101+
Expect(err).To(BeNil())
102+
Expect(category).To(Equal("technology"))
103+
Expect(score).To(BeNumerically("~", 0.8, 0.001))
97104
})
105+
})
98106

99-
Context("when input is empty or invalid", func() {
100-
It("should handle empty text gracefully", func() {
101-
mockModel.classifyTextResult = candle_binding.ClassResult{
102-
Class: 0,
103-
Confidence: 0.8,
104-
}
107+
Context("when category mapping is invalid", func() {
108+
It("should handle invalid category mapping gracefully", func() {
109+
mockCategoryModel.classifyResult = candle_binding.ClassResult{
110+
Class: 9,
111+
Confidence: 0.8,
112+
}
105113

106-
category, score, err := classifier.ClassifyCategory("")
114+
category, score, err := classifier.ClassifyCategory("Some text")
107115

108-
// Should still attempt classification
109-
Expect(err).To(BeNil())
110-
Expect(category).To(Equal("technology"))
111-
Expect(score).To(BeNumerically("~", 0.8, 0.001))
112-
})
116+
Expect(err).To(BeNil())
117+
Expect(category).To(Equal(""))
118+
Expect(score).To(BeNumerically("~", 0.8, 0.001))
113119
})
114120
})
115121
})

0 commit comments

Comments
 (0)