From ea1a40d1d7533a264826973995a2cb3d9577edbc Mon Sep 17 00:00:00 2001 From: Alex Wang Date: Fri, 5 Sep 2025 00:12:47 -0700 Subject: [PATCH 1/5] feat: add test framework for classifier with dependency injection Signed-off-by: Alex Wang --- .../pkg/utils/classification/classifier.go | 24 +++- .../utils/classification/classifier_test.go | 115 ++++++++++++++++++ 2 files changed, 137 insertions(+), 2 deletions(-) create mode 100644 src/semantic-router/pkg/utils/classification/classifier_test.go diff --git a/src/semantic-router/pkg/utils/classification/classifier.go b/src/semantic-router/pkg/utils/classification/classifier.go index 8ec77430..f41f0ef7 100644 --- a/src/semantic-router/pkg/utils/classification/classifier.go +++ b/src/semantic-router/pkg/utils/classification/classifier.go @@ -12,6 +12,22 @@ import ( "github.com/vllm-project/semantic-router/semantic-router/pkg/metrics" ) +type ModelInference interface { + // Category classification + ClassifyText(text string) (candle_binding.ClassResult, error) + ClassifyModernBertText(text string) (candle_binding.ClassResult, error) +} + +type CandleModelInference struct{} + +func (c *CandleModelInference) ClassifyText(text string) (candle_binding.ClassResult, error) { + return candle_binding.ClassifyText(text) +} + +func (c *CandleModelInference) ClassifyModernBertText(text string) (candle_binding.ClassResult, error) { + return candle_binding.ClassifyModernBertText(text) +} + // JailbreakDetection represents the result of jailbreak analysis for a piece of content type JailbreakDetection struct { Content string `json:"content"` @@ -40,6 +56,9 @@ type PIIAnalysisResult struct { // Classifier handles text classification, model selection, and jailbreak detection functionality type Classifier struct { + // Dependencies + modelInference ModelInference + Config *config.RouterConfig CategoryMapping *CategoryMapping PIIMapping *PIIMapping @@ -55,6 +74,7 @@ type Classifier struct { // NewClassifier creates a new classifier with model selection and jailbreak detection capabilities func NewClassifier(cfg *config.RouterConfig, categoryMapping *CategoryMapping, piiMapping *PIIMapping, jailbreakMapping *JailbreakMapping, modelTTFT map[string]float64) *Classifier { return &Classifier{ + modelInference: &CandleModelInference{}, Config: cfg, CategoryMapping: categoryMapping, PIIMapping: piiMapping, @@ -202,10 +222,10 @@ func (c *Classifier) ClassifyCategory(text string) (string, float64, error) { start := time.Now() if c.Config.Classifier.CategoryModel.UseModernBERT { // Use ModernBERT classifier - result, err = candle_binding.ClassifyModernBertText(text) + result, err = c.modelInference.ClassifyModernBertText(text) } else { // Use linear classifier - result, err = candle_binding.ClassifyText(text) + result, err = c.modelInference.ClassifyText(text) } metrics.RecordClassifierLatency("category", time.Since(start).Seconds()) diff --git a/src/semantic-router/pkg/utils/classification/classifier_test.go b/src/semantic-router/pkg/utils/classification/classifier_test.go new file mode 100644 index 00000000..876f6542 --- /dev/null +++ b/src/semantic-router/pkg/utils/classification/classifier_test.go @@ -0,0 +1,115 @@ +package classification + +import ( + "errors" + "testing" + + . "github.com/onsi/ginkgo/v2" + . "github.com/onsi/gomega" + + candle_binding "github.com/vllm-project/semantic-router/candle-binding" + "github.com/vllm-project/semantic-router/semantic-router/pkg/config" +) + +func TestClassifier(t *testing.T) { + RegisterFailHandler(Fail) + RunSpecs(t, "Classifier Suite") +} + +// MockModelInference implements ModelInference interface for testing +type MockModelInference struct { + classifyTextResult candle_binding.ClassResult + classifyTextError error + classifyModernBertResult candle_binding.ClassResult + classifyModernBertError error +} + +func (m *MockModelInference) ClassifyText(text string) (candle_binding.ClassResult, error) { + return m.classifyTextResult, m.classifyTextError +} + +func (m *MockModelInference) ClassifyModernBertText(text string) (candle_binding.ClassResult, error) { + return m.classifyModernBertResult, m.classifyModernBertError +} + +var _ = Describe("Classifier", func() { + var ( + classifier *Classifier + mockModel *MockModelInference + ) + + BeforeEach(func() { + mockModel = &MockModelInference{} + cfg := &config.RouterConfig{} + cfg.Classifier.CategoryModel.Threshold = 0.5 // Set threshold for testing + + classifier = &Classifier{ + modelInference: mockModel, + Config: cfg, + CategoryMapping: &CategoryMapping{ + CategoryToIdx: map[string]int{"technology": 0, "sports": 1, "politics": 2}, + IdxToCategory: map[string]string{"0": "technology", "1": "sports", "2": "politics"}, + }, + } + }) + + Describe("ClassifyCategory", func() { + Context("when classification succeeds with high confidence", func() { + It("should return the correct category", func() { + mockModel.classifyTextResult = candle_binding.ClassResult{ + Class: 2, + Confidence: 0.95, + } + + category, score, err := classifier.ClassifyCategory("This is about politics") + + Expect(err).To(BeNil()) + Expect(category).To(Equal("politics")) + Expect(score).To(BeNumerically("~", 0.95, 0.001)) + }) + }) + + Context("when classification has low confidence below threshold", func() { + It("should return empty category", func() { + mockModel.classifyTextResult = candle_binding.ClassResult{ + Class: 0, + Confidence: 0.3, + } + + category, score, err := classifier.ClassifyCategory("Ambiguous text") + + Expect(err).To(BeNil()) + Expect(category).To(Equal("")) + Expect(score).To(BeNumerically("~", 0.3, 0.001)) + }) + }) + + Context("when BERT model returns error", func() { + It("should return unknown category with zero score", func() { + mockModel.classifyTextError = errors.New("model inference failed") + + category, score, err := classifier.ClassifyCategory("Some text") + + Expect(err).ToNot(BeNil()) + Expect(category).To(Equal("")) + Expect(score).To(BeNumerically("~", 0.0, 0.001)) + }) + }) + + Context("when input is empty or invalid", func() { + It("should handle empty text gracefully", func() { + mockModel.classifyTextResult = candle_binding.ClassResult{ + Class: 0, + Confidence: 0.8, + } + + category, score, err := classifier.ClassifyCategory("") + + // Should still attempt classification + Expect(err).To(BeNil()) + Expect(category).To(Equal("technology")) + Expect(score).To(BeNumerically("~", 0.8, 0.001)) + }) + }) + }) +}) From 5d0bbd1b8d63c8b6e2ecc75c2ce39904ee969db8 Mon Sep 17 00:00:00 2001 From: Alex Wang Date: Fri, 5 Sep 2025 09:12:25 -0700 Subject: [PATCH 2/5] update dependency injection Signed-off-by: Alex Wang --- .../pkg/utils/classification/classifier.go | 66 +++++---- .../utils/classification/classifier_test.go | 128 +++++++++--------- 2 files changed, 110 insertions(+), 84 deletions(-) diff --git a/src/semantic-router/pkg/utils/classification/classifier.go b/src/semantic-router/pkg/utils/classification/classifier.go index f41f0ef7..85783fbb 100644 --- a/src/semantic-router/pkg/utils/classification/classifier.go +++ b/src/semantic-router/pkg/utils/classification/classifier.go @@ -12,22 +12,38 @@ import ( "github.com/vllm-project/semantic-router/semantic-router/pkg/metrics" ) -type ModelInference interface { - // Category classification - ClassifyText(text string) (candle_binding.ClassResult, error) - ClassifyModernBertText(text string) (candle_binding.ClassResult, error) +type CategoryInference interface { + Classify(text string) (candle_binding.ClassResult, error) } -type CandleModelInference struct{} +type LinearCategoryInference struct{} -func (c *CandleModelInference) ClassifyText(text string) (candle_binding.ClassResult, error) { +func (c *LinearCategoryInference) Classify(text string) (candle_binding.ClassResult, error) { return candle_binding.ClassifyText(text) } -func (c *CandleModelInference) ClassifyModernBertText(text string) (candle_binding.ClassResult, error) { +type ModernBERTCategoryInference struct{} + +func (c *ModernBERTCategoryInference) Classify(text string) (candle_binding.ClassResult, error) { return candle_binding.ClassifyModernBertText(text) } +type JailbreakInference interface { + Classify(text string) (candle_binding.ClassResult, error) +} + +type LinearJailbreakInference struct{} + +func (c *LinearJailbreakInference) Classify(text string) (candle_binding.ClassResult, error) { + return candle_binding.ClassifyJailbreakText(text) +} + +type ModernBERTJailbreakInference struct{} + +func (c *ModernBERTJailbreakInference) Classify(text string) (candle_binding.ClassResult, error) { + return candle_binding.ClassifyModernBertJailbreakText(text) +} + // JailbreakDetection represents the result of jailbreak analysis for a piece of content type JailbreakDetection struct { Content string `json:"content"` @@ -57,7 +73,8 @@ type PIIAnalysisResult struct { // Classifier handles text classification, model selection, and jailbreak detection functionality type Classifier struct { // Dependencies - modelInference ModelInference + categoryInference CategoryInference + jailbreakInference JailbreakInference Config *config.RouterConfig CategoryMapping *CategoryMapping @@ -73,8 +90,23 @@ type Classifier struct { // NewClassifier creates a new classifier with model selection and jailbreak detection capabilities func NewClassifier(cfg *config.RouterConfig, categoryMapping *CategoryMapping, piiMapping *PIIMapping, jailbreakMapping *JailbreakMapping, modelTTFT map[string]float64) *Classifier { + var categoryInference CategoryInference + if cfg.Classifier.CategoryModel.UseModernBERT { + categoryInference = &ModernBERTCategoryInference{} + } else { + categoryInference = &LinearCategoryInference{} + } + + var jailbreakInference JailbreakInference + if cfg.PromptGuard.UseModernBERT { + jailbreakInference = &ModernBERTJailbreakInference{} + } else { + jailbreakInference = &LinearJailbreakInference{} + } + return &Classifier{ - modelInference: &CandleModelInference{}, + categoryInference: categoryInference, + jailbreakInference: jailbreakInference, Config: cfg, CategoryMapping: categoryMapping, PIIMapping: piiMapping, @@ -137,13 +169,7 @@ func (c *Classifier) CheckForJailbreak(text string) (bool, string, float32, erro var err error start := time.Now() - if c.Config.PromptGuard.UseModernBERT { - // Use ModernBERT jailbreak classifier - result, err = candle_binding.ClassifyModernBertJailbreakText(text) - } else { - // Use linear jailbreak classifier - result, err = candle_binding.ClassifyJailbreakText(text) - } + result, err = c.jailbreakInference.Classify(text) metrics.RecordClassifierLatency("jailbreak", time.Since(start).Seconds()) if err != nil { @@ -220,13 +246,7 @@ func (c *Classifier) ClassifyCategory(text string) (string, float64, error) { var err error start := time.Now() - if c.Config.Classifier.CategoryModel.UseModernBERT { - // Use ModernBERT classifier - result, err = c.modelInference.ClassifyModernBertText(text) - } else { - // Use linear classifier - result, err = c.modelInference.ClassifyText(text) - } + result, err = c.categoryInference.Classify(text) metrics.RecordClassifierLatency("category", time.Since(start).Seconds()) if err != nil { diff --git a/src/semantic-router/pkg/utils/classification/classifier_test.go b/src/semantic-router/pkg/utils/classification/classifier_test.go index 876f6542..4c293317 100644 --- a/src/semantic-router/pkg/utils/classification/classifier_test.go +++ b/src/semantic-router/pkg/utils/classification/classifier_test.go @@ -16,36 +16,29 @@ func TestClassifier(t *testing.T) { RunSpecs(t, "Classifier Suite") } -// MockModelInference implements ModelInference interface for testing -type MockModelInference struct { - classifyTextResult candle_binding.ClassResult - classifyTextError error - classifyModernBertResult candle_binding.ClassResult - classifyModernBertError error +type MockCategoryInference struct { + classifyResult candle_binding.ClassResult + classifyError error } -func (m *MockModelInference) ClassifyText(text string) (candle_binding.ClassResult, error) { - return m.classifyTextResult, m.classifyTextError +func (m *MockCategoryInference) Classify(text string) (candle_binding.ClassResult, error) { + return m.classifyResult, m.classifyError } -func (m *MockModelInference) ClassifyModernBertText(text string) (candle_binding.ClassResult, error) { - return m.classifyModernBertResult, m.classifyModernBertError -} - -var _ = Describe("Classifier", func() { +var _ = Describe("ClassifyCategory", func() { var ( - classifier *Classifier - mockModel *MockModelInference + classifier *Classifier + mockCategoryModel *MockCategoryInference ) BeforeEach(func() { - mockModel = &MockModelInference{} + mockCategoryModel = &MockCategoryInference{} cfg := &config.RouterConfig{} cfg.Classifier.CategoryModel.Threshold = 0.5 // Set threshold for testing classifier = &Classifier{ - modelInference: mockModel, - Config: cfg, + categoryInference: mockCategoryModel, + Config: cfg, CategoryMapping: &CategoryMapping{ CategoryToIdx: map[string]int{"technology": 0, "sports": 1, "politics": 2}, IdxToCategory: map[string]string{"0": "technology", "1": "sports", "2": "politics"}, @@ -53,63 +46,76 @@ var _ = Describe("Classifier", func() { } }) - Describe("ClassifyCategory", func() { - Context("when classification succeeds with high confidence", func() { - It("should return the correct category", func() { - mockModel.classifyTextResult = candle_binding.ClassResult{ - Class: 2, - Confidence: 0.95, - } + Context("when classification succeeds with high confidence", func() { + It("should return the correct category", func() { + mockCategoryModel.classifyResult = candle_binding.ClassResult{ + Class: 2, + Confidence: 0.95, + } + + category, score, err := classifier.ClassifyCategory("This is about politics") + + Expect(err).To(BeNil()) + Expect(category).To(Equal("politics")) + Expect(score).To(BeNumerically("~", 0.95, 0.001)) + }) + }) + + Context("when classification has low confidence below threshold", func() { + It("should return empty category", func() { + mockCategoryModel.classifyResult = candle_binding.ClassResult{ + Class: 0, + Confidence: 0.3, + } - category, score, err := classifier.ClassifyCategory("This is about politics") + category, score, err := classifier.ClassifyCategory("Ambiguous text") - Expect(err).To(BeNil()) - Expect(category).To(Equal("politics")) - Expect(score).To(BeNumerically("~", 0.95, 0.001)) - }) + Expect(err).To(BeNil()) + Expect(category).To(Equal("")) + Expect(score).To(BeNumerically("~", 0.3, 0.001)) }) + }) - Context("when classification has low confidence below threshold", func() { - It("should return empty category", func() { - mockModel.classifyTextResult = candle_binding.ClassResult{ - Class: 0, - Confidence: 0.3, - } + Context("when BERT model returns error", func() { + It("should return empty category with zero score", func() { + mockCategoryModel.classifyError = errors.New("model inference failed") - category, score, err := classifier.ClassifyCategory("Ambiguous text") + category, score, err := classifier.ClassifyCategory("Some text") - Expect(err).To(BeNil()) - Expect(category).To(Equal("")) - Expect(score).To(BeNumerically("~", 0.3, 0.001)) - }) + Expect(err).ToNot(BeNil()) + Expect(category).To(Equal("")) + Expect(score).To(BeNumerically("~", 0.0, 0.001)) }) + }) - Context("when BERT model returns error", func() { - It("should return unknown category with zero score", func() { - mockModel.classifyTextError = errors.New("model inference failed") + Context("when input is empty or invalid", func() { + It("should handle empty text gracefully", func() { + mockCategoryModel.classifyResult = candle_binding.ClassResult{ + Class: 0, + Confidence: 0.8, + } - category, score, err := classifier.ClassifyCategory("Some text") + category, score, err := classifier.ClassifyCategory("") - Expect(err).ToNot(BeNil()) - Expect(category).To(Equal("")) - Expect(score).To(BeNumerically("~", 0.0, 0.001)) - }) + // Should still attempt classification + Expect(err).To(BeNil()) + Expect(category).To(Equal("technology")) + Expect(score).To(BeNumerically("~", 0.8, 0.001)) }) + }) - Context("when input is empty or invalid", func() { - It("should handle empty text gracefully", func() { - mockModel.classifyTextResult = candle_binding.ClassResult{ - Class: 0, - Confidence: 0.8, - } + Context("when category mapping is invalid", func() { + It("should handle invalid category mapping gracefully", func() { + mockCategoryModel.classifyResult = candle_binding.ClassResult{ + Class: 9, + Confidence: 0.8, + } - category, score, err := classifier.ClassifyCategory("") + category, score, err := classifier.ClassifyCategory("Some text") - // Should still attempt classification - Expect(err).To(BeNil()) - Expect(category).To(Equal("technology")) - Expect(score).To(BeNumerically("~", 0.8, 0.001)) - }) + Expect(err).To(BeNil()) + Expect(category).To(Equal("")) + Expect(score).To(BeNumerically("~", 0.8, 0.001)) }) }) }) From db445945726b7130c8e1df6283a358b649efc0ea Mon Sep 17 00:00:00 2001 From: Alex Wang Date: Fri, 5 Sep 2025 09:34:10 -0700 Subject: [PATCH 3/5] finish jailbreak tests Signed-off-by: Alex Wang --- .../utils/classification/classifier_test.go | 119 +++++++++++++++++- 1 file changed, 116 insertions(+), 3 deletions(-) diff --git a/src/semantic-router/pkg/utils/classification/classifier_test.go b/src/semantic-router/pkg/utils/classification/classifier_test.go index 4c293317..04385951 100644 --- a/src/semantic-router/pkg/utils/classification/classifier_test.go +++ b/src/semantic-router/pkg/utils/classification/classifier_test.go @@ -61,7 +61,7 @@ var _ = Describe("ClassifyCategory", func() { }) }) - Context("when classification has low confidence below threshold", func() { + Context("when classification confidence is below threshold", func() { It("should return empty category", func() { mockCategoryModel.classifyResult = candle_binding.ClassResult{ Class: 0, @@ -76,7 +76,7 @@ var _ = Describe("ClassifyCategory", func() { }) }) - Context("when BERT model returns error", func() { + Context("when model inference fails", func() { It("should return empty category with zero score", func() { mockCategoryModel.classifyError = errors.New("model inference failed") @@ -104,7 +104,7 @@ var _ = Describe("ClassifyCategory", func() { }) }) - Context("when category mapping is invalid", func() { + Context("when class index is not found in category mapping", func() { It("should handle invalid category mapping gracefully", func() { mockCategoryModel.classifyResult = candle_binding.ClassResult{ Class: 9, @@ -119,3 +119,116 @@ var _ = Describe("ClassifyCategory", func() { }) }) }) + +type MockJailbreakInference struct { + classifyResult candle_binding.ClassResult + classifyError error +} + +func (m *MockJailbreakInference) Classify(text string) (candle_binding.ClassResult, error) { + return m.classifyResult, m.classifyError +} + +var _ = Describe("CheckForJailbreak", func() { + var ( + classifier *Classifier + mockJailbreakModel *MockJailbreakInference + ) + + BeforeEach(func() { + mockJailbreakModel = &MockJailbreakInference{} + cfg := &config.RouterConfig{} + cfg.PromptGuard.Enabled = true + cfg.PromptGuard.ModelID = "test-model" + cfg.PromptGuard.JailbreakMappingPath = "test-mapping" + cfg.PromptGuard.Threshold = 0.7 + + classifier = &Classifier{ + jailbreakInference: mockJailbreakModel, + Config: cfg, + JailbreakMapping: &JailbreakMapping{ + LabelToIdx: map[string]int{"jailbreak": 0, "benign": 1}, + IdxToLabel: map[string]string{"0": "jailbreak", "1": "benign"}, + }, + JailbreakInitialized: true, + } + }) + + Context("when jailbreak is detected with high confidence", func() { + It("should return true with jailbreak type", func() { + mockJailbreakModel.classifyResult = candle_binding.ClassResult{ + Class: 0, + Confidence: 0.9, + } + + isJailbreak, jailbreakType, confidence, err := classifier.CheckForJailbreak("This is a jailbreak attempt") + + Expect(err).To(BeNil()) + Expect(isJailbreak).To(BeTrue()) + Expect(jailbreakType).To(Equal("jailbreak")) + Expect(confidence).To(BeNumerically("~", 0.9, 0.001)) + }) + }) + + Context("when text is benign with high confidence", func() { + It("should return false with benign type", func() { + mockJailbreakModel.classifyResult = candle_binding.ClassResult{ + Class: 1, + Confidence: 0.9, + } + + isJailbreak, jailbreakType, confidence, err := classifier.CheckForJailbreak("This is a normal question") + + Expect(err).To(BeNil()) + Expect(isJailbreak).To(BeFalse()) + Expect(jailbreakType).To(Equal("benign")) + Expect(confidence).To(BeNumerically("~", 0.9, 0.001)) + }) + }) + + Context("when jailbreak confidence is below threshold", func() { + It("should return false even if classified as jailbreak", func() { + mockJailbreakModel.classifyResult = candle_binding.ClassResult{ + Class: 0, + Confidence: 0.5, + } + + isJailbreak, jailbreakType, confidence, err := classifier.CheckForJailbreak("Ambiguous text") + + Expect(err).To(BeNil()) + Expect(isJailbreak).To(BeFalse()) + Expect(jailbreakType).To(Equal("jailbreak")) + Expect(confidence).To(BeNumerically("~", 0.5, 0.001)) + }) + }) + + Context("when model inference fails", func() { + It("should return error", func() { + mockJailbreakModel.classifyError = errors.New("model inference failed") + + isJailbreak, jailbreakType, confidence, err := classifier.CheckForJailbreak("Some text") + + Expect(err).ToNot(BeNil()) + Expect(err.Error()).To(ContainSubstring("jailbreak classification failed")) + Expect(isJailbreak).To(BeFalse()) + Expect(jailbreakType).To(Equal("")) + Expect(confidence).To(BeNumerically("~", 0.0, 0.001)) + }) + }) + + Context("when class index is not found in jailbreak mapping", func() { + It("should return error for unknown class", func() { + mockJailbreakModel.classifyResult = candle_binding.ClassResult{ + Class: 9, + Confidence: 0.9, + } + + isJailbreak, jailbreakType, confidence, err := classifier.CheckForJailbreak("Some text") + + Expect(err).ToNot(BeNil()) + Expect(isJailbreak).To(BeFalse()) + Expect(jailbreakType).To(Equal("")) + Expect(confidence).To(BeNumerically("~", 0.0, 0.001)) + }) + }) +}) From 69e8e92d6fe84dc3d1359ecfaf7a1928f60d8c36 Mon Sep 17 00:00:00 2001 From: Alex Wang Date: Fri, 5 Sep 2025 11:25:49 -0700 Subject: [PATCH 4/5] finish pii tests Signed-off-by: Alex Wang --- .../pkg/utils/classification/classifier.go | 35 +++- .../utils/classification/classifier_test.go | 198 ++++++++++++++++-- 2 files changed, 211 insertions(+), 22 deletions(-) diff --git a/src/semantic-router/pkg/utils/classification/classifier.go b/src/semantic-router/pkg/utils/classification/classifier.go index 85783fbb..107c7e72 100644 --- a/src/semantic-router/pkg/utils/classification/classifier.go +++ b/src/semantic-router/pkg/utils/classification/classifier.go @@ -22,9 +22,9 @@ func (c *LinearCategoryInference) Classify(text string) (candle_binding.ClassRes return candle_binding.ClassifyText(text) } -type ModernBERTCategoryInference struct{} +type ModernBertCategoryInference struct{} -func (c *ModernBERTCategoryInference) Classify(text string) (candle_binding.ClassResult, error) { +func (c *ModernBertCategoryInference) Classify(text string) (candle_binding.ClassResult, error) { return candle_binding.ClassifyModernBertText(text) } @@ -38,12 +38,22 @@ func (c *LinearJailbreakInference) Classify(text string) (candle_binding.ClassRe return candle_binding.ClassifyJailbreakText(text) } -type ModernBERTJailbreakInference struct{} +type ModernBertJailbreakInference struct{} -func (c *ModernBERTJailbreakInference) Classify(text string) (candle_binding.ClassResult, error) { +func (c *ModernBertJailbreakInference) Classify(text string) (candle_binding.ClassResult, error) { return candle_binding.ClassifyModernBertJailbreakText(text) } +type PIIInference interface { + ClassifyTokens(text string, configPath string) (candle_binding.TokenClassificationResult, error) +} + +type ModernBertPIIInference struct{} + +func (c *ModernBertPIIInference) ClassifyTokens(text string, configPath string) (candle_binding.TokenClassificationResult, error) { + return candle_binding.ClassifyModernBertPIITokens(text, configPath) +} + // JailbreakDetection represents the result of jailbreak analysis for a piece of content type JailbreakDetection struct { Content string `json:"content"` @@ -75,6 +85,7 @@ type Classifier struct { // Dependencies categoryInference CategoryInference jailbreakInference JailbreakInference + piiInference PIIInference Config *config.RouterConfig CategoryMapping *CategoryMapping @@ -92,21 +103,25 @@ type Classifier struct { func NewClassifier(cfg *config.RouterConfig, categoryMapping *CategoryMapping, piiMapping *PIIMapping, jailbreakMapping *JailbreakMapping, modelTTFT map[string]float64) *Classifier { var categoryInference CategoryInference if cfg.Classifier.CategoryModel.UseModernBERT { - categoryInference = &ModernBERTCategoryInference{} + categoryInference = &ModernBertCategoryInference{} } else { categoryInference = &LinearCategoryInference{} } var jailbreakInference JailbreakInference if cfg.PromptGuard.UseModernBERT { - jailbreakInference = &ModernBERTJailbreakInference{} + jailbreakInference = &ModernBertJailbreakInference{} } else { jailbreakInference = &LinearJailbreakInference{} } + piiInference := &ModernBertPIIInference{} + return &Classifier{ - categoryInference: categoryInference, - jailbreakInference: jailbreakInference, + categoryInference: categoryInference, + jailbreakInference: jailbreakInference, + piiInference: piiInference, + Config: cfg, CategoryMapping: categoryMapping, PIIMapping: piiMapping, @@ -289,7 +304,7 @@ func (c *Classifier) ClassifyPII(text string) ([]string, error) { // Use ModernBERT PII token classifier for entity detection configPath := fmt.Sprintf("%s/config.json", c.Config.Classifier.PIIModel.ModelID) start := time.Now() - tokenResult, err := candle_binding.ClassifyModernBertPIITokens(text, configPath) + tokenResult, err := c.piiInference.ClassifyTokens(text, configPath) metrics.RecordClassifierLatency("pii", time.Since(start).Seconds()) if err != nil { return nil, fmt.Errorf("PII token classification error: %w", err) @@ -371,7 +386,7 @@ func (c *Classifier) AnalyzeContentForPII(contentList []string) (bool, []PIIAnal // Use ModernBERT PII token classifier for detailed analysis configPath := fmt.Sprintf("%s/config.json", c.Config.Classifier.PIIModel.ModelID) start := time.Now() - tokenResult, err := candle_binding.ClassifyModernBertPIITokens(content, configPath) + tokenResult, err := c.piiInference.ClassifyTokens(content, configPath) metrics.RecordClassifierLatency("pii", time.Since(start).Seconds()) if err != nil { log.Printf("Error analyzing content %d: %v", i, err) diff --git a/src/semantic-router/pkg/utils/classification/classifier_test.go b/src/semantic-router/pkg/utils/classification/classifier_test.go index 04385951..afa60092 100644 --- a/src/semantic-router/pkg/utils/classification/classifier_test.go +++ b/src/semantic-router/pkg/utils/classification/classifier_test.go @@ -34,7 +34,7 @@ var _ = Describe("ClassifyCategory", func() { BeforeEach(func() { mockCategoryModel = &MockCategoryInference{} cfg := &config.RouterConfig{} - cfg.Classifier.CategoryModel.Threshold = 0.5 // Set threshold for testing + cfg.Classifier.CategoryModel.Threshold = 0.5 classifier = &Classifier{ categoryInference: mockCategoryModel, @@ -55,7 +55,7 @@ var _ = Describe("ClassifyCategory", func() { category, score, err := classifier.ClassifyCategory("This is about politics") - Expect(err).To(BeNil()) + Expect(err).ToNot(HaveOccurred()) Expect(category).To(Equal("politics")) Expect(score).To(BeNumerically("~", 0.95, 0.001)) }) @@ -70,7 +70,7 @@ var _ = Describe("ClassifyCategory", func() { category, score, err := classifier.ClassifyCategory("Ambiguous text") - Expect(err).To(BeNil()) + Expect(err).ToNot(HaveOccurred()) Expect(category).To(Equal("")) Expect(score).To(BeNumerically("~", 0.3, 0.001)) }) @@ -82,7 +82,8 @@ var _ = Describe("ClassifyCategory", func() { category, score, err := classifier.ClassifyCategory("Some text") - Expect(err).ToNot(BeNil()) + Expect(err).To(HaveOccurred()) + Expect(err.Error()).To(ContainSubstring("classification error")) Expect(category).To(Equal("")) Expect(score).To(BeNumerically("~", 0.0, 0.001)) }) @@ -97,8 +98,7 @@ var _ = Describe("ClassifyCategory", func() { category, score, err := classifier.ClassifyCategory("") - // Should still attempt classification - Expect(err).To(BeNil()) + Expect(err).ToNot(HaveOccurred()) Expect(category).To(Equal("technology")) Expect(score).To(BeNumerically("~", 0.8, 0.001)) }) @@ -113,7 +113,7 @@ var _ = Describe("ClassifyCategory", func() { category, score, err := classifier.ClassifyCategory("Some text") - Expect(err).To(BeNil()) + Expect(err).ToNot(HaveOccurred()) Expect(category).To(Equal("")) Expect(score).To(BeNumerically("~", 0.8, 0.001)) }) @@ -163,7 +163,7 @@ var _ = Describe("CheckForJailbreak", func() { isJailbreak, jailbreakType, confidence, err := classifier.CheckForJailbreak("This is a jailbreak attempt") - Expect(err).To(BeNil()) + Expect(err).ToNot(HaveOccurred()) Expect(isJailbreak).To(BeTrue()) Expect(jailbreakType).To(Equal("jailbreak")) Expect(confidence).To(BeNumerically("~", 0.9, 0.001)) @@ -179,7 +179,7 @@ var _ = Describe("CheckForJailbreak", func() { isJailbreak, jailbreakType, confidence, err := classifier.CheckForJailbreak("This is a normal question") - Expect(err).To(BeNil()) + Expect(err).ToNot(HaveOccurred()) Expect(isJailbreak).To(BeFalse()) Expect(jailbreakType).To(Equal("benign")) Expect(confidence).To(BeNumerically("~", 0.9, 0.001)) @@ -195,7 +195,7 @@ var _ = Describe("CheckForJailbreak", func() { isJailbreak, jailbreakType, confidence, err := classifier.CheckForJailbreak("Ambiguous text") - Expect(err).To(BeNil()) + Expect(err).ToNot(HaveOccurred()) Expect(isJailbreak).To(BeFalse()) Expect(jailbreakType).To(Equal("jailbreak")) Expect(confidence).To(BeNumerically("~", 0.5, 0.001)) @@ -208,7 +208,7 @@ var _ = Describe("CheckForJailbreak", func() { isJailbreak, jailbreakType, confidence, err := classifier.CheckForJailbreak("Some text") - Expect(err).ToNot(BeNil()) + Expect(err).To(HaveOccurred()) Expect(err.Error()).To(ContainSubstring("jailbreak classification failed")) Expect(isJailbreak).To(BeFalse()) Expect(jailbreakType).To(Equal("")) @@ -225,10 +225,184 @@ var _ = Describe("CheckForJailbreak", func() { isJailbreak, jailbreakType, confidence, err := classifier.CheckForJailbreak("Some text") - Expect(err).ToNot(BeNil()) + Expect(err).To(HaveOccurred()) + Expect(err.Error()).To(ContainSubstring("unknown jailbreak class index")) Expect(isJailbreak).To(BeFalse()) Expect(jailbreakType).To(Equal("")) Expect(confidence).To(BeNumerically("~", 0.0, 0.001)) }) }) }) + +type PIIInferenceResponse struct { + classifyTokensResult candle_binding.TokenClassificationResult + classifyTokensError error +} + +type MockPIIInference struct { + PIIInferenceResponse + responseMap map[string]PIIInferenceResponse +} + +func (m *MockPIIInference) ClassifyTokens(text string, configPath string) (candle_binding.TokenClassificationResult, error) { + if response, exists := m.responseMap[text]; exists { + return response.classifyTokensResult, response.classifyTokensError + } + return m.classifyTokensResult, m.classifyTokensError +} + +var _ = Describe("PIIClassification", func() { + var ( + classifier *Classifier + mockPIIModel *MockPIIInference + ) + + BeforeEach(func() { + mockPIIModel = &MockPIIInference{} + cfg := &config.RouterConfig{} + cfg.Classifier.PIIModel.ModelID = "test-pii-model" + cfg.Classifier.PIIModel.Threshold = 0.7 + + classifier = &Classifier{ + piiInference: mockPIIModel, + Config: cfg, + PIIMapping: &PIIMapping{ + LabelToIdx: map[string]int{"PERSON": 0, "EMAIL": 1}, + IdxToLabel: map[string]string{"0": "PERSON", "1": "EMAIL"}, + }, + } + }) + + Describe("ClassifyPII", func() { + Context("when PII entities are detected above threshold", func() { + It("should return detected PII types", func() { + mockPIIModel.classifyTokensResult = candle_binding.TokenClassificationResult{ + Entities: []candle_binding.TokenEntity{ + { + EntityType: "PERSON", + Text: "John Doe", + Start: 0, + End: 8, + Confidence: 0.9, + }, + { + EntityType: "EMAIL", + Text: "john@example.com", + Start: 9, + End: 25, + Confidence: 0.8, + }, + }, + } + + piiTypes, err := classifier.ClassifyPII("John Doe john@example.com") + + Expect(err).ToNot(HaveOccurred()) + Expect(piiTypes).To(ConsistOf("PERSON", "EMAIL")) + }) + }) + + Context("when PII entities are detected below threshold", func() { + It("should filter out low confidence entities", func() { + mockPIIModel.classifyTokensResult = candle_binding.TokenClassificationResult{ + Entities: []candle_binding.TokenEntity{ + { + EntityType: "PERSON", + Text: "John Doe", + Start: 0, + End: 8, + Confidence: 0.9, + }, + { + EntityType: "EMAIL", + Text: "john@example.com", + Start: 9, + End: 25, + Confidence: 0.5, + }, + }, + } + + piiTypes, err := classifier.ClassifyPII("John Doe john@example.com") + + Expect(err).ToNot(HaveOccurred()) + Expect(piiTypes).To(ConsistOf("PERSON")) + }) + }) + + Context("when no PII is detected", func() { + It("should return empty list", func() { + mockPIIModel.classifyTokensResult = candle_binding.TokenClassificationResult{ + Entities: []candle_binding.TokenEntity{}, + } + + piiTypes, err := classifier.ClassifyPII("Some text") + + Expect(err).ToNot(HaveOccurred()) + Expect(piiTypes).To(BeEmpty()) + }) + }) + + Context("when model inference fails", func() { + It("should return error", func() { + mockPIIModel.classifyTokensError = errors.New("PII model inference failed") + + piiTypes, err := classifier.ClassifyPII("Some text") + + Expect(err).To(HaveOccurred()) + Expect(err.Error()).To(ContainSubstring("PII token classification error")) + Expect(piiTypes).To(BeNil()) + }) + }) + }) + + Describe("AnalyzeContentForPII", func() { + Context("when some texts contain PII", func() { + It("should return detailed analysis for each text", func() { + mockPIIModel.responseMap = make(map[string]PIIInferenceResponse) + mockPIIModel.responseMap["Alice Smith"] = PIIInferenceResponse{ + classifyTokensResult: candle_binding.TokenClassificationResult{ + Entities: []candle_binding.TokenEntity{ + { + EntityType: "PERSON", + Text: "Alice", + Start: 0, + End: 5, + Confidence: 0.9, + }, + }, + }, + classifyTokensError: nil, + } + + mockPIIModel.responseMap["No PII here"] = PIIInferenceResponse{} + + contentList := []string{"Alice Smith", "No PII here"} + hasPII, results, err := classifier.AnalyzeContentForPII(contentList) + + Expect(err).ToNot(HaveOccurred()) + Expect(hasPII).To(BeTrue()) + Expect(results).To(HaveLen(2)) + Expect(results[0].HasPII).To(BeTrue()) + Expect(results[0].Entities).To(HaveLen(1)) + Expect(results[0].Entities[0].EntityType).To(Equal("PERSON")) + Expect(results[0].Entities[0].Text).To(Equal("Alice")) + Expect(results[1].HasPII).To(BeFalse()) + Expect(results[1].Entities).To(BeEmpty()) + }) + }) + + Context("when model inference fails", func() { + It("should return false for hasPII and empty results", func() { + mockPIIModel.classifyTokensError = errors.New("model failed") + + contentList := []string{"Text 1", "Text 2"} + hasPII, results, err := classifier.AnalyzeContentForPII(contentList) + + Expect(err).ToNot(HaveOccurred()) + Expect(hasPII).To(BeFalse()) + Expect(results).To(BeEmpty()) + }) + }) + }) +}) From 39911f464ac8afcce2247c30be93705df8cc5c0f Mon Sep 17 00:00:00 2001 From: Alex Wang Date: Fri, 5 Sep 2025 12:43:37 -0700 Subject: [PATCH 5/5] refactor Signed-off-by: Alex Wang --- .../pkg/utils/classification/classifier.go | 43 +++++++++++-------- 1 file changed, 24 insertions(+), 19 deletions(-) diff --git a/src/semantic-router/pkg/utils/classification/classifier.go b/src/semantic-router/pkg/utils/classification/classifier.go index 107c7e72..97ed00e0 100644 --- a/src/semantic-router/pkg/utils/classification/classifier.go +++ b/src/semantic-router/pkg/utils/classification/classifier.go @@ -28,6 +28,14 @@ func (c *ModernBertCategoryInference) Classify(text string) (candle_binding.Clas return candle_binding.ClassifyModernBertText(text) } +// createCategoryInference creates the appropriate category inference based on configuration +func createCategoryInference(useModernBERT bool) CategoryInference { + if useModernBERT { + return &ModernBertCategoryInference{} + } + return &LinearCategoryInference{} +} + type JailbreakInference interface { Classify(text string) (candle_binding.ClassResult, error) } @@ -44,6 +52,14 @@ func (c *ModernBertJailbreakInference) Classify(text string) (candle_binding.Cla return candle_binding.ClassifyModernBertJailbreakText(text) } +// createJailbreakInference creates the appropriate jailbreak inference based on configuration +func createJailbreakInference(useModernBERT bool) JailbreakInference { + if useModernBERT { + return &ModernBertJailbreakInference{} + } + return &LinearJailbreakInference{} +} + type PIIInference interface { ClassifyTokens(text string, configPath string) (candle_binding.TokenClassificationResult, error) } @@ -54,6 +70,11 @@ func (c *ModernBertPIIInference) ClassifyTokens(text string, configPath string) return candle_binding.ClassifyModernBertPIITokens(text, configPath) } +// createPIIInference creates the appropriate PII inference (currently only ModernBERT) +func createPIIInference() PIIInference { + return &ModernBertPIIInference{} +} + // JailbreakDetection represents the result of jailbreak analysis for a piece of content type JailbreakDetection struct { Content string `json:"content"` @@ -101,26 +122,10 @@ type Classifier struct { // NewClassifier creates a new classifier with model selection and jailbreak detection capabilities func NewClassifier(cfg *config.RouterConfig, categoryMapping *CategoryMapping, piiMapping *PIIMapping, jailbreakMapping *JailbreakMapping, modelTTFT map[string]float64) *Classifier { - var categoryInference CategoryInference - if cfg.Classifier.CategoryModel.UseModernBERT { - categoryInference = &ModernBertCategoryInference{} - } else { - categoryInference = &LinearCategoryInference{} - } - - var jailbreakInference JailbreakInference - if cfg.PromptGuard.UseModernBERT { - jailbreakInference = &ModernBertJailbreakInference{} - } else { - jailbreakInference = &LinearJailbreakInference{} - } - - piiInference := &ModernBertPIIInference{} - return &Classifier{ - categoryInference: categoryInference, - jailbreakInference: jailbreakInference, - piiInference: piiInference, + categoryInference: createCategoryInference(cfg.Classifier.CategoryModel.UseModernBERT), + jailbreakInference: createJailbreakInference(cfg.PromptGuard.UseModernBERT), + piiInference: createPIIInference(), Config: cfg, CategoryMapping: categoryMapping,