Skip to content

Commit ac3ef2c

Browse files
committed
test: refactor InitializeJailbreakClassifier and add test cases
Signed-off-by: Alex Wang <[email protected]>
1 parent ffe1b09 commit ac3ef2c

File tree

2 files changed

+119
-21
lines changed

2 files changed

+119
-21
lines changed

src/semantic-router/pkg/utils/classification/classifier.go

Lines changed: 44 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,40 @@ func createCategoryInference(useModernBERT bool) CategoryInference {
3737
return &LinearCategoryInference{}
3838
}
3939

40+
type JailbreakInitializer interface {
41+
Init(modelID string, useCPU bool, numClasses ...int) error
42+
}
43+
44+
type LinearJailbreakInitializer struct{}
45+
46+
func (c *LinearJailbreakInitializer) Init(modelID string, useCPU bool, numClasses ...int) error {
47+
err := candle_binding.InitJailbreakClassifier(modelID, numClasses[0], useCPU)
48+
if err != nil {
49+
return fmt.Errorf("failed to initialize jailbreak classifier: %w", err)
50+
}
51+
log.Printf("Initialized linear jailbreak classifier with %d classes", numClasses[0])
52+
return nil
53+
}
54+
55+
type ModernBertJailbreakInitializer struct{}
56+
57+
func (c *ModernBertJailbreakInitializer) Init(modelID string, useCPU bool, numClasses ...int) error {
58+
err := candle_binding.InitModernBertJailbreakClassifier(modelID, useCPU)
59+
if err != nil {
60+
return fmt.Errorf("failed to initialize ModernBERT jailbreak classifier: %w", err)
61+
}
62+
log.Printf("Initialized ModernBERT jailbreak classifier (classes auto-detected from model)")
63+
return nil
64+
}
65+
66+
// createJailbreakInitializer creates the appropriate jailbreak initializer based on configuration
67+
func createJailbreakInitializer(useModernBERT bool) JailbreakInitializer {
68+
if useModernBERT {
69+
return &ModernBertJailbreakInitializer{}
70+
}
71+
return &LinearJailbreakInitializer{}
72+
}
73+
4074
type JailbreakInference interface {
4175
Classify(text string) (candle_binding.ClassResult, error)
4276
}
@@ -105,9 +139,10 @@ type PIIAnalysisResult struct {
105139
// Classifier handles text classification, model selection, and jailbreak detection functionality
106140
type Classifier struct {
107141
// Dependencies
108-
categoryInference CategoryInference
109-
jailbreakInference JailbreakInference
110-
piiInference PIIInference
142+
categoryInference CategoryInference
143+
jailbreakInitializer JailbreakInitializer
144+
jailbreakInference JailbreakInference
145+
piiInference PIIInference
111146

112147
Config *config.RouterConfig
113148
CategoryMapping *CategoryMapping
@@ -124,9 +159,10 @@ type Classifier struct {
124159
// NewClassifier creates a new classifier with model selection and jailbreak detection capabilities
125160
func NewClassifier(cfg *config.RouterConfig, categoryMapping *CategoryMapping, piiMapping *PIIMapping, jailbreakMapping *JailbreakMapping, modelTTFT map[string]float64) *Classifier {
126161
return &Classifier{
127-
categoryInference: createCategoryInference(cfg.Classifier.CategoryModel.UseModernBERT),
128-
jailbreakInference: createJailbreakInference(cfg.PromptGuard.UseModernBERT),
129-
piiInference: createPIIInference(),
162+
categoryInference: createCategoryInference(cfg.Classifier.CategoryModel.UseModernBERT),
163+
jailbreakInitializer: createJailbreakInitializer(cfg.PromptGuard.UseModernBERT),
164+
jailbreakInference: createJailbreakInference(cfg.PromptGuard.UseModernBERT),
165+
piiInference: createPIIInference(),
130166

131167
Config: cfg,
132168
CategoryMapping: categoryMapping,
@@ -149,21 +185,8 @@ func (c *Classifier) InitializeJailbreakClassifier() error {
149185
return fmt.Errorf("not enough jailbreak types for classification, need at least 2, got %d", numClasses)
150186
}
151187

152-
var err error
153-
if c.Config.PromptGuard.UseModernBERT {
154-
// Initialize ModernBERT jailbreak classifier
155-
err = candle_binding.InitModernBertJailbreakClassifier(c.Config.PromptGuard.ModelID, c.Config.PromptGuard.UseCPU)
156-
if err != nil {
157-
return fmt.Errorf("failed to initialize ModernBERT jailbreak classifier: %w", err)
158-
}
159-
log.Printf("Initialized ModernBERT jailbreak classifier (classes auto-detected from model)")
160-
} else {
161-
// Initialize linear jailbreak classifier
162-
err = candle_binding.InitJailbreakClassifier(c.Config.PromptGuard.ModelID, numClasses, c.Config.PromptGuard.UseCPU)
163-
if err != nil {
164-
return fmt.Errorf("failed to initialize jailbreak classifier: %w", err)
165-
}
166-
log.Printf("Initialized linear jailbreak classifier with %d classes", numClasses)
188+
if err := c.jailbreakInitializer.Init(c.Config.PromptGuard.ModelID, c.Config.PromptGuard.UseCPU, numClasses); err != nil {
189+
return err
167190
}
168191

169192
c.JailbreakInitialized = true

src/semantic-router/pkg/utils/classification/classifier_test.go

Lines changed: 75 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -424,6 +424,81 @@ func (m *MockJailbreakInference) Classify(text string) (candle_binding.ClassResu
424424
return m.classifyResult, m.classifyError
425425
}
426426

427+
type MockJailbreakInitializer struct {
428+
InitError error
429+
}
430+
431+
func (m *MockJailbreakInitializer) Init(modelID string, useCPU bool, numClasses ...int) error {
432+
return m.InitError
433+
}
434+
435+
var _ = Describe("initialize jailbreak classifier", func() {
436+
var (
437+
classifier *Classifier
438+
mockJailbreakInitializer *MockJailbreakInitializer
439+
)
440+
441+
BeforeEach(func() {
442+
mockJailbreakInitializer = &MockJailbreakInitializer{InitError: nil}
443+
cfg := &config.RouterConfig{}
444+
cfg.PromptGuard.Enabled = true
445+
cfg.PromptGuard.ModelID = "test-model"
446+
cfg.PromptGuard.JailbreakMappingPath = "test-mapping"
447+
cfg.PromptGuard.Threshold = 0.7
448+
classifier = &Classifier{
449+
jailbreakInitializer: mockJailbreakInitializer,
450+
Config: cfg,
451+
JailbreakMapping: &JailbreakMapping{
452+
LabelToIdx: map[string]int{"jailbreak": 0, "benign": 1},
453+
IdxToLabel: map[string]string{"0": "jailbreak", "1": "benign"},
454+
},
455+
JailbreakInitialized: false,
456+
}
457+
})
458+
459+
It("should initialize jailbreak classifier", func() {
460+
err := classifier.InitializeJailbreakClassifier()
461+
Expect(err).ToNot(HaveOccurred())
462+
Expect(classifier.JailbreakInitialized).To(BeTrue())
463+
})
464+
465+
Context("when jailbreak mapping is not initialized", func() {
466+
It("should return nil", func() {
467+
classifier.JailbreakMapping = nil
468+
err := classifier.InitializeJailbreakClassifier()
469+
Expect(err).ToNot(HaveOccurred())
470+
Expect(classifier.JailbreakInitialized).To(BeFalse())
471+
})
472+
})
473+
474+
Context("when not enough jailbreak types", func() {
475+
It("should return error", func() {
476+
classifier.JailbreakMapping = &JailbreakMapping{
477+
LabelToIdx: map[string]int{"jailbreak": 0},
478+
IdxToLabel: map[string]string{"0": "jailbreak"},
479+
}
480+
481+
err := classifier.InitializeJailbreakClassifier()
482+
483+
Expect(err).To(HaveOccurred())
484+
Expect(err.Error()).To(ContainSubstring("not enough jailbreak types for classification"))
485+
Expect(classifier.JailbreakInitialized).To(BeFalse())
486+
})
487+
})
488+
489+
Context("when initialize jailbreak classifier fails", func() {
490+
It("should return error", func() {
491+
mockJailbreakInitializer.InitError = errors.New("initialize jailbreak classifier failed")
492+
493+
err := classifier.InitializeJailbreakClassifier()
494+
495+
Expect(err).To(HaveOccurred())
496+
Expect(err.Error()).To(ContainSubstring("initialize jailbreak classifier failed"))
497+
Expect(classifier.JailbreakInitialized).To(BeFalse())
498+
})
499+
})
500+
})
501+
427502
var _ = Describe("jailbreak detection", func() {
428503
var (
429504
classifier *Classifier

0 commit comments

Comments
 (0)