diff --git a/src/semantic-router/pkg/utils/classification/classifier.go b/src/semantic-router/pkg/utils/classification/classifier.go index 8ec77430..97ed00e0 100644 --- a/src/semantic-router/pkg/utils/classification/classifier.go +++ b/src/semantic-router/pkg/utils/classification/classifier.go @@ -12,6 +12,69 @@ import ( "github.com/vllm-project/semantic-router/semantic-router/pkg/metrics" ) +type CategoryInference interface { + Classify(text string) (candle_binding.ClassResult, error) +} + +type LinearCategoryInference struct{} + +func (c *LinearCategoryInference) Classify(text string) (candle_binding.ClassResult, error) { + return candle_binding.ClassifyText(text) +} + +type ModernBertCategoryInference struct{} + +func (c *ModernBertCategoryInference) Classify(text string) (candle_binding.ClassResult, error) { + return candle_binding.ClassifyModernBertText(text) +} + +// createCategoryInference creates the appropriate category inference based on configuration +func createCategoryInference(useModernBERT bool) CategoryInference { + if useModernBERT { + return &ModernBertCategoryInference{} + } + return &LinearCategoryInference{} +} + +type JailbreakInference interface { + Classify(text string) (candle_binding.ClassResult, error) +} + +type LinearJailbreakInference struct{} + +func (c *LinearJailbreakInference) Classify(text string) (candle_binding.ClassResult, error) { + return candle_binding.ClassifyJailbreakText(text) +} + +type ModernBertJailbreakInference struct{} + +func (c *ModernBertJailbreakInference) Classify(text string) (candle_binding.ClassResult, error) { + return candle_binding.ClassifyModernBertJailbreakText(text) +} + +// createJailbreakInference creates the appropriate jailbreak inference based on configuration +func createJailbreakInference(useModernBERT bool) JailbreakInference { + if useModernBERT { + return &ModernBertJailbreakInference{} + } + return &LinearJailbreakInference{} +} + +type PIIInference interface { + ClassifyTokens(text string, configPath string) (candle_binding.TokenClassificationResult, error) +} + +type ModernBertPIIInference struct{} + +func (c *ModernBertPIIInference) ClassifyTokens(text string, configPath string) (candle_binding.TokenClassificationResult, error) { + return candle_binding.ClassifyModernBertPIITokens(text, configPath) +} + +// createPIIInference creates the appropriate PII inference (currently only ModernBERT) +func createPIIInference() PIIInference { + return &ModernBertPIIInference{} +} + // JailbreakDetection represents the result of jailbreak analysis for a piece of content type JailbreakDetection struct { Content string `json:"content"` @@ -40,6 +103,11 @@ type PIIAnalysisResult struct { // Classifier handles text classification, model selection, and jailbreak detection functionality type Classifier struct { + // Dependencies + categoryInference CategoryInference + jailbreakInference JailbreakInference + piiInference PIIInference + Config *config.RouterConfig CategoryMapping *CategoryMapping PIIMapping *PIIMapping @@ -55,6 +123,10 @@ type Classifier struct { // NewClassifier creates a new classifier with model selection and jailbreak detection capabilities func NewClassifier(cfg *config.RouterConfig, categoryMapping *CategoryMapping, piiMapping *PIIMapping, jailbreakMapping *JailbreakMapping, modelTTFT map[string]float64) *Classifier { return &Classifier{ + categoryInference: createCategoryInference(cfg.Classifier.CategoryModel.UseModernBERT), + jailbreakInference: createJailbreakInference(cfg.PromptGuard.UseModernBERT), + piiInference: createPIIInference(), + Config: cfg, CategoryMapping: categoryMapping, PIIMapping: piiMapping, @@ -117,13 +189,7 @@ func (c *Classifier) CheckForJailbreak(text string) (bool, string, float32, erro var err error start := time.Now() - if c.Config.PromptGuard.UseModernBERT { - // Use ModernBERT jailbreak classifier - result, err = candle_binding.ClassifyModernBertJailbreakText(text) - } else { - // Use linear jailbreak classifier - result, err = candle_binding.ClassifyJailbreakText(text) - } + result, err = c.jailbreakInference.Classify(text) metrics.RecordClassifierLatency("jailbreak", time.Since(start).Seconds()) if err != nil { @@ -200,13 +266,7 @@ func (c *Classifier) ClassifyCategory(text string) (string, float64, error) { var err error start := time.Now() - if c.Config.Classifier.CategoryModel.UseModernBERT { - // Use ModernBERT classifier - result, err = candle_binding.ClassifyModernBertText(text) - } else { - // Use linear classifier - result, err = candle_binding.ClassifyText(text) - } + result, err = c.categoryInference.Classify(text) metrics.RecordClassifierLatency("category", time.Since(start).Seconds()) if err != nil { @@ -249,7 +309,7 @@ func (c *Classifier) ClassifyPII(text string) ([]string, error) { // Use ModernBERT PII token classifier for entity detection configPath := fmt.Sprintf("%s/config.json", c.Config.Classifier.PIIModel.ModelID) start := time.Now() - tokenResult, err := candle_binding.ClassifyModernBertPIITokens(text, configPath) + tokenResult, err := c.piiInference.ClassifyTokens(text, configPath) metrics.RecordClassifierLatency("pii", time.Since(start).Seconds()) if err != nil { return nil, fmt.Errorf("PII token classification error: %w", err) @@ -331,7 +391,7 @@ func (c *Classifier) AnalyzeContentForPII(contentList []string) (bool, []PIIAnal // Use ModernBERT PII token classifier for detailed analysis configPath := fmt.Sprintf("%s/config.json", c.Config.Classifier.PIIModel.ModelID) start := time.Now() - tokenResult, err := candle_binding.ClassifyModernBertPIITokens(content, configPath) + tokenResult, err := c.piiInference.ClassifyTokens(content, configPath) metrics.RecordClassifierLatency("pii", time.Since(start).Seconds()) if err != nil { log.Printf("Error analyzing content %d: %v", i, err) diff --git a/src/semantic-router/pkg/utils/classification/classifier_test.go b/src/semantic-router/pkg/utils/classification/classifier_test.go new file mode 100644 index 00000000..afa60092 --- /dev/null +++ b/src/semantic-router/pkg/utils/classification/classifier_test.go @@ -0,0 +1,408 @@ +package classification + +import ( + "errors" + "testing" + + . "github.com/onsi/ginkgo/v2" + . "github.com/onsi/gomega" + + candle_binding "github.com/vllm-project/semantic-router/candle-binding" + "github.com/vllm-project/semantic-router/semantic-router/pkg/config" +) + +func TestClassifier(t *testing.T) { + RegisterFailHandler(Fail) + RunSpecs(t, "Classifier Suite") +} + +type MockCategoryInference struct { + classifyResult candle_binding.ClassResult + classifyError error +} + +func (m *MockCategoryInference) Classify(text string) (candle_binding.ClassResult, error) { + return m.classifyResult, m.classifyError +} + +var _ = Describe("ClassifyCategory", func() { + var ( + classifier *Classifier + mockCategoryModel *MockCategoryInference + ) + + BeforeEach(func() { + mockCategoryModel = &MockCategoryInference{} + cfg := &config.RouterConfig{} + cfg.Classifier.CategoryModel.Threshold = 0.5 + + classifier = &Classifier{ + categoryInference: mockCategoryModel, + Config: cfg, + CategoryMapping: &CategoryMapping{ + CategoryToIdx: map[string]int{"technology": 0, "sports": 1, "politics": 2}, + IdxToCategory: map[string]string{"0": "technology", "1": "sports", "2": "politics"}, + }, + } + }) + + Context("when classification succeeds with high confidence", func() { + It("should return the correct category", func() { + mockCategoryModel.classifyResult = candle_binding.ClassResult{ + Class: 2, + Confidence: 0.95, + } + + category, score, err := classifier.ClassifyCategory("This is about politics") + + Expect(err).ToNot(HaveOccurred()) + Expect(category).To(Equal("politics")) + Expect(score).To(BeNumerically("~", 0.95, 0.001)) + }) + }) + + Context("when classification confidence is below threshold", func() { + It("should return empty category", func() { + mockCategoryModel.classifyResult = candle_binding.ClassResult{ + Class: 0, + Confidence: 0.3, + } + + category, score, err := classifier.ClassifyCategory("Ambiguous text") + + Expect(err).ToNot(HaveOccurred()) + Expect(category).To(Equal("")) + Expect(score).To(BeNumerically("~", 0.3, 0.001)) + }) + }) + + Context("when model inference fails", func() { + It("should return empty category with zero score", func() { + mockCategoryModel.classifyError = errors.New("model inference failed") + + category, score, err := classifier.ClassifyCategory("Some text") + + Expect(err).To(HaveOccurred()) + Expect(err.Error()).To(ContainSubstring("classification error")) + Expect(category).To(Equal("")) + Expect(score).To(BeNumerically("~", 0.0, 0.001)) + }) + }) + + Context("when input is empty or invalid", func() { + It("should handle empty text gracefully", func() { + mockCategoryModel.classifyResult = candle_binding.ClassResult{ + Class: 0, + Confidence: 0.8, + } + + category, score, err := classifier.ClassifyCategory("") + + Expect(err).ToNot(HaveOccurred()) + Expect(category).To(Equal("technology")) + Expect(score).To(BeNumerically("~", 0.8, 0.001)) + }) + }) + + Context("when class index is not found in category mapping", func() { + It("should handle invalid category mapping gracefully", func() { + mockCategoryModel.classifyResult = candle_binding.ClassResult{ + Class: 9, + Confidence: 0.8, + } + + category, score, err := classifier.ClassifyCategory("Some text") + + Expect(err).ToNot(HaveOccurred()) + Expect(category).To(Equal("")) + Expect(score).To(BeNumerically("~", 0.8, 0.001)) + }) + }) +}) + +type MockJailbreakInference struct { + classifyResult candle_binding.ClassResult + classifyError error +} + +func (m *MockJailbreakInference) Classify(text string) (candle_binding.ClassResult, error) { + return m.classifyResult, m.classifyError +} + +var _ = Describe("CheckForJailbreak", func() { + var ( + classifier *Classifier + mockJailbreakModel *MockJailbreakInference + ) + + BeforeEach(func() { + mockJailbreakModel = &MockJailbreakInference{} + cfg := &config.RouterConfig{} + cfg.PromptGuard.Enabled = true + cfg.PromptGuard.ModelID = "test-model" + cfg.PromptGuard.JailbreakMappingPath = "test-mapping" + cfg.PromptGuard.Threshold = 0.7 + + classifier = &Classifier{ + jailbreakInference: mockJailbreakModel, + Config: cfg, + JailbreakMapping: &JailbreakMapping{ + LabelToIdx: map[string]int{"jailbreak": 0, "benign": 1}, + IdxToLabel: map[string]string{"0": "jailbreak", "1": "benign"}, + }, + JailbreakInitialized: true, + } + }) + + Context("when jailbreak is detected with high confidence", func() { + It("should return true with jailbreak type", func() { + mockJailbreakModel.classifyResult = candle_binding.ClassResult{ + Class: 0, + Confidence: 0.9, + } + + isJailbreak, jailbreakType, confidence, err := classifier.CheckForJailbreak("This is a jailbreak attempt") + + Expect(err).ToNot(HaveOccurred()) + Expect(isJailbreak).To(BeTrue()) + Expect(jailbreakType).To(Equal("jailbreak")) + Expect(confidence).To(BeNumerically("~", 0.9, 0.001)) + }) + }) + + Context("when text is benign with high confidence", func() { + It("should return false with benign type", func() { + mockJailbreakModel.classifyResult = candle_binding.ClassResult{ + Class: 1, + Confidence: 0.9, + } + + isJailbreak, jailbreakType, confidence, err := classifier.CheckForJailbreak("This is a normal question") + + Expect(err).ToNot(HaveOccurred()) + Expect(isJailbreak).To(BeFalse()) + Expect(jailbreakType).To(Equal("benign")) + Expect(confidence).To(BeNumerically("~", 0.9, 0.001)) + }) + }) + + Context("when jailbreak confidence is below threshold", func() { + It("should return false even if classified as jailbreak", func() { + mockJailbreakModel.classifyResult = candle_binding.ClassResult{ + Class: 0, + Confidence: 0.5, + } + + isJailbreak, jailbreakType, confidence, err := classifier.CheckForJailbreak("Ambiguous text") + + Expect(err).ToNot(HaveOccurred()) + Expect(isJailbreak).To(BeFalse()) + Expect(jailbreakType).To(Equal("jailbreak")) + Expect(confidence).To(BeNumerically("~", 0.5, 0.001)) + }) + }) + + Context("when model inference fails", func() { + It("should return error", func() { + mockJailbreakModel.classifyError = errors.New("model inference failed") + + isJailbreak, jailbreakType, confidence, err := classifier.CheckForJailbreak("Some text") + + Expect(err).To(HaveOccurred()) + Expect(err.Error()).To(ContainSubstring("jailbreak classification failed")) + Expect(isJailbreak).To(BeFalse()) + Expect(jailbreakType).To(Equal("")) + Expect(confidence).To(BeNumerically("~", 0.0, 0.001)) + }) + }) + + Context("when class index is not found in jailbreak mapping", func() { + It("should return error for unknown class", func() { + mockJailbreakModel.classifyResult = candle_binding.ClassResult{ + Class: 9, + Confidence: 0.9, + } + + isJailbreak, jailbreakType, confidence, err := classifier.CheckForJailbreak("Some text") + + Expect(err).To(HaveOccurred()) + Expect(err.Error()).To(ContainSubstring("unknown jailbreak class index")) + Expect(isJailbreak).To(BeFalse()) + Expect(jailbreakType).To(Equal("")) + Expect(confidence).To(BeNumerically("~", 0.0, 0.001)) + }) + }) +}) + +type PIIInferenceResponse struct { + classifyTokensResult candle_binding.TokenClassificationResult + classifyTokensError error +} + +type MockPIIInference struct { + PIIInferenceResponse + responseMap map[string]PIIInferenceResponse +} + +func (m *MockPIIInference) ClassifyTokens(text string, configPath string) (candle_binding.TokenClassificationResult, error) { + if response, exists := m.responseMap[text]; exists { + return response.classifyTokensResult, response.classifyTokensError + } + return m.classifyTokensResult, m.classifyTokensError +} + +var _ = Describe("PIIClassification", func() { + var ( + classifier *Classifier + mockPIIModel *MockPIIInference + ) + + BeforeEach(func() { + mockPIIModel = &MockPIIInference{} + cfg := &config.RouterConfig{} + cfg.Classifier.PIIModel.ModelID = "test-pii-model" + cfg.Classifier.PIIModel.Threshold = 0.7 + + classifier = &Classifier{ + piiInference: mockPIIModel, + Config: cfg, + PIIMapping: &PIIMapping{ + LabelToIdx: map[string]int{"PERSON": 0, "EMAIL": 1}, + IdxToLabel: map[string]string{"0": "PERSON", "1": "EMAIL"}, + }, + } + }) + + Describe("ClassifyPII", func() { + Context("when PII entities are detected above threshold", func() { + It("should return detected PII types", func() { + mockPIIModel.classifyTokensResult = candle_binding.TokenClassificationResult{ + Entities: []candle_binding.TokenEntity{ + { + EntityType: "PERSON", + Text: "John Doe", + Start: 0, + End: 8, + Confidence: 0.9, + }, + { + EntityType: "EMAIL", + Text: "john@example.com", + Start: 9, + End: 25, + Confidence: 0.8, + }, + }, + } + + piiTypes, err := classifier.ClassifyPII("John Doe john@example.com") + + Expect(err).ToNot(HaveOccurred()) + Expect(piiTypes).To(ConsistOf("PERSON", "EMAIL")) + }) + }) + + Context("when PII entities are detected below threshold", func() { + It("should filter out low confidence entities", func() { + mockPIIModel.classifyTokensResult = candle_binding.TokenClassificationResult{ + Entities: []candle_binding.TokenEntity{ + { + EntityType: "PERSON", + Text: "John Doe", + Start: 0, + End: 8, + Confidence: 0.9, + }, + { + EntityType: "EMAIL", + Text: "john@example.com", + Start: 9, + End: 25, + Confidence: 0.5, + }, + }, + } + + piiTypes, err := classifier.ClassifyPII("John Doe john@example.com") + + Expect(err).ToNot(HaveOccurred()) + Expect(piiTypes).To(ConsistOf("PERSON")) + }) + }) + + Context("when no PII is detected", func() { + It("should return empty list", func() { + mockPIIModel.classifyTokensResult = candle_binding.TokenClassificationResult{ + Entities: []candle_binding.TokenEntity{}, + } + + piiTypes, err := classifier.ClassifyPII("Some text") + + Expect(err).ToNot(HaveOccurred()) + Expect(piiTypes).To(BeEmpty()) + }) + }) + + Context("when model inference fails", func() { + It("should return error", func() { + mockPIIModel.classifyTokensError = errors.New("PII model inference failed") + + piiTypes, err := classifier.ClassifyPII("Some text") + + Expect(err).To(HaveOccurred()) + Expect(err.Error()).To(ContainSubstring("PII token classification error")) + Expect(piiTypes).To(BeNil()) + }) + }) + }) + + Describe("AnalyzeContentForPII", func() { + Context("when some texts contain PII", func() { + It("should return detailed analysis for each text", func() { + mockPIIModel.responseMap = make(map[string]PIIInferenceResponse) + mockPIIModel.responseMap["Alice Smith"] = PIIInferenceResponse{ + classifyTokensResult: candle_binding.TokenClassificationResult{ + Entities: []candle_binding.TokenEntity{ + { + EntityType: "PERSON", + Text: "Alice", + Start: 0, + End: 5, + Confidence: 0.9, + }, + }, + }, + classifyTokensError: nil, + } + + mockPIIModel.responseMap["No PII here"] = PIIInferenceResponse{} + + contentList := []string{"Alice Smith", "No PII here"} + hasPII, results, err := classifier.AnalyzeContentForPII(contentList) + + Expect(err).ToNot(HaveOccurred()) + Expect(hasPII).To(BeTrue()) + Expect(results).To(HaveLen(2)) + Expect(results[0].HasPII).To(BeTrue()) + Expect(results[0].Entities).To(HaveLen(1)) + Expect(results[0].Entities[0].EntityType).To(Equal("PERSON")) + Expect(results[0].Entities[0].Text).To(Equal("Alice")) + Expect(results[1].HasPII).To(BeFalse()) + Expect(results[1].Entities).To(BeEmpty()) + }) + }) + + Context("when model inference fails", func() { + It("should return false for hasPII and empty results", func() { + mockPIIModel.classifyTokensError = errors.New("model failed") + + contentList := []string{"Text 1", "Text 2"} + hasPII, results, err := classifier.AnalyzeContentForPII(contentList) + + Expect(err).ToNot(HaveOccurred()) + Expect(hasPII).To(BeFalse()) + Expect(results).To(BeEmpty()) + }) + }) + }) +})