diff --git a/config/config.yaml b/config/config.yaml index fe41998b..2f722822 100644 --- a/config/config.yaml +++ b/config/config.yaml @@ -112,7 +112,6 @@ classifier: threshold: 0.7 use_cpu: true pii_mapping_path: "models/pii_classifier_modernbert-base_presidio_token_model/pii_type_mapping.json" - load_aware: false categories: - name: business use_reasoning: false diff --git a/deploy/kubernetes/config.yaml b/deploy/kubernetes/config.yaml index 9470b1ce..358fb0cd 100644 --- a/deploy/kubernetes/config.yaml +++ b/deploy/kubernetes/config.yaml @@ -78,7 +78,6 @@ classifier: threshold: 0.7 use_cpu: true pii_mapping_path: "models/pii_classifier_modernbert-base_presidio_token_model/pii_type_mapping.json" - load_aware: false categories: - name: business model_scores: diff --git a/src/semantic-router/pkg/config/config.go b/src/semantic-router/pkg/config/config.go index 3053ff00..fe1318e1 100644 --- a/src/semantic-router/pkg/config/config.go +++ b/src/semantic-router/pkg/config/config.go @@ -34,7 +34,6 @@ type RouterConfig struct { UseCPU bool `yaml:"use_cpu"` PIIMappingPath string `yaml:"pii_mapping_path"` } `yaml:"pii_model"` - LoadAware bool `yaml:"load_aware"` } `yaml:"classifier"` // Categories for routing queries diff --git a/src/semantic-router/pkg/config/config_test.go b/src/semantic-router/pkg/config/config_test.go index 106f9a6f..fc01cee8 100644 --- a/src/semantic-router/pkg/config/config_test.go +++ b/src/semantic-router/pkg/config/config_test.go @@ -60,7 +60,6 @@ classifier: use_cpu: true use_modernbert: false pii_mapping_path: "/path/to/pii.json" - load_aware: true categories: - name: "general" @@ -138,7 +137,6 @@ tools: // Verify classifier config Expect(cfg.Classifier.CategoryModel.ModelID).To(Equal("test-category-model")) Expect(cfg.Classifier.CategoryModel.UseModernBERT).To(BeTrue()) - Expect(cfg.Classifier.LoadAware).To(BeTrue()) // Verify categories Expect(cfg.Categories).To(HaveLen(1)) diff --git a/src/semantic-router/pkg/extproc/request_handler.go b/src/semantic-router/pkg/extproc/request_handler.go index 0bfa4431..a9323e45 100644 --- a/src/semantic-router/pkg/extproc/request_handler.go +++ b/src/semantic-router/pkg/extproc/request_handler.go @@ -370,9 +370,6 @@ func (r *OpenAIRouter) handleModelRouting(openAIRequest *openai.ChatCompletionNe effortForMetrics := r.getReasoningEffort(categoryName) metrics.RecordReasoningDecision(categoryName, matchedModel, useReasoning, effortForMetrics) - // Track the model load for the selected model - r.Classifier.IncrementModelLoad(matchedModel) - // Track the model routing change metrics.RecordModelRouting(originalModel, matchedModel) diff --git a/src/semantic-router/pkg/extproc/response_handler.go b/src/semantic-router/pkg/extproc/response_handler.go index a4cd3290..c39ad73a 100644 --- a/src/semantic-router/pkg/extproc/response_handler.go +++ b/src/semantic-router/pkg/extproc/response_handler.go @@ -52,7 +52,6 @@ func (r *OpenAIRouter) handleResponseBody(v *ext_proc.ProcessingRequest_Response float64(completionTokens), ) metrics.RecordModelCompletionLatency(ctx.RequestModel, completionLatency.Seconds()) - r.Classifier.DecrementModelLoad(ctx.RequestModel) // Compute and record cost if pricing is configured if r.Config != nil { diff --git a/src/semantic-router/pkg/extproc/router.go b/src/semantic-router/pkg/extproc/router.go index fd397a4d..152d1a06 100644 --- a/src/semantic-router/pkg/extproc/router.go +++ b/src/semantic-router/pkg/extproc/router.go @@ -131,8 +131,7 @@ func NewOpenAIRouter(configPath string) (*OpenAIRouter, error) { // Create utility components piiChecker := pii.NewPolicyChecker(cfg, cfg.ModelConfig) - modelTTFT := make(map[string]float64) // Empty TTFT map since load balancing is disabled - classifier := classification.NewClassifier(cfg, categoryMapping, piiMapping, jailbreakMapping, modelTTFT) + classifier := classification.NewClassifier(cfg, categoryMapping, piiMapping, jailbreakMapping) // Create global classification service for API access services.NewClassificationService(classifier, cfg) diff --git a/src/semantic-router/pkg/extproc/security_test.go b/src/semantic-router/pkg/extproc/security_test.go index 0ad441aa..d2812846 100644 --- a/src/semantic-router/pkg/extproc/security_test.go +++ b/src/semantic-router/pkg/extproc/security_test.go @@ -52,7 +52,7 @@ var _ = Describe("Security Checks", func() { }, } router.PIIChecker = pii.NewPolicyChecker(cfg, cfg.ModelConfig) - router.Classifier = classification.NewClassifier(cfg, router.Classifier.CategoryMapping, router.Classifier.PIIMapping, nil, router.Classifier.ModelTTFT) + router.Classifier = classification.NewClassifier(cfg, router.Classifier.CategoryMapping, router.Classifier.PIIMapping, nil) }) It("should allow requests with no PII", func() { @@ -97,7 +97,7 @@ var _ = Describe("Security Checks", func() { piiMapping, err := classification.LoadPIIMapping(cfg.Classifier.PIIModel.PIIMappingPath) Expect(err).NotTo(HaveOccurred()) - router.Classifier = classification.NewClassifier(cfg, router.Classifier.CategoryMapping, piiMapping, nil, router.Classifier.ModelTTFT) + router.Classifier = classification.NewClassifier(cfg, router.Classifier.CategoryMapping, piiMapping, nil) }) Describe("ClassifyPII method", func() { @@ -339,7 +339,7 @@ var _ = Describe("Security Checks", func() { piiMapping, err := classification.LoadPIIMapping(cfg.Classifier.PIIModel.PIIMappingPath) Expect(err).NotTo(HaveOccurred()) - router.Classifier = classification.NewClassifier(cfg, router.Classifier.CategoryMapping, piiMapping, nil, router.Classifier.ModelTTFT) + router.Classifier = classification.NewClassifier(cfg, router.Classifier.CategoryMapping, piiMapping, nil) }) Describe("Error handling and edge cases", func() { @@ -524,7 +524,7 @@ var _ = Describe("Security Checks", func() { IdxToLabel: map[string]string{"0": "benign", "1": "jailbreak"}, } - router.Classifier = classification.NewClassifier(cfg, router.Classifier.CategoryMapping, router.Classifier.PIIMapping, jailbreakMapping, router.Classifier.ModelTTFT) + router.Classifier = classification.NewClassifier(cfg, router.Classifier.CategoryMapping, router.Classifier.PIIMapping, jailbreakMapping) }) It("should process potential jailbreak attempts", func() { diff --git a/src/semantic-router/pkg/extproc/test_utils_test.go b/src/semantic-router/pkg/extproc/test_utils_test.go index c65e68a7..80d1b9ad 100644 --- a/src/semantic-router/pkg/extproc/test_utils_test.go +++ b/src/semantic-router/pkg/extproc/test_utils_test.go @@ -95,7 +95,6 @@ func CreateTestConfig() *config.RouterConfig { UseCPU bool `yaml:"use_cpu"` PIIMappingPath string `yaml:"pii_mapping_path"` } `yaml:"pii_model"` - LoadAware bool `yaml:"load_aware"` }{ CategoryModel: struct { ModelID string `yaml:"model_id"` @@ -119,7 +118,6 @@ func CreateTestConfig() *config.RouterConfig { UseCPU: true, PIIMappingPath: "../../../../models/pii_classifier_modernbert-base_presidio_token_model/pii_type_mapping.json", }, - LoadAware: true, }, Categories: []config.Category{ { @@ -220,11 +218,7 @@ func CreateTestRouter(cfg *config.RouterConfig) (*extproc.OpenAIRouter, error) { toolsDatabase := tools.NewToolsDatabase(toolsOptions) // Create classifier - modelTTFT := map[string]float64{ - "model-a": 2.5, - "model-b": 1.8, - } - classifier := classification.NewClassifier(cfg, categoryMapping, piiMapping, nil, modelTTFT) + classifier := classification.NewClassifier(cfg, categoryMapping, piiMapping, nil) // Create PII checker piiChecker := pii.NewPolicyChecker(cfg, cfg.ModelConfig) diff --git a/src/semantic-router/pkg/utils/classification/classifier.go b/src/semantic-router/pkg/utils/classification/classifier.go index c0eb7d0e..40a3ad01 100644 --- a/src/semantic-router/pkg/utils/classification/classifier.go +++ b/src/semantic-router/pkg/utils/classification/classifier.go @@ -5,7 +5,6 @@ import ( "log" "slices" "strings" - "sync" "time" candle_binding "github.com/vllm-project/semantic-router/candle-binding" @@ -37,6 +36,40 @@ func createCategoryInference(useModernBERT bool) CategoryInference { return &LinearCategoryInference{} } +type JailbreakInitializer interface { + Init(modelID string, useCPU bool, numClasses ...int) error +} + +type LinearJailbreakInitializer struct{} + +func (c *LinearJailbreakInitializer) Init(modelID string, useCPU bool, numClasses ...int) error { + err := candle_binding.InitJailbreakClassifier(modelID, numClasses[0], useCPU) + if err != nil { + return fmt.Errorf("failed to initialize jailbreak classifier: %w", err) + } + log.Printf("Initialized linear jailbreak classifier with %d classes", numClasses[0]) + return nil +} + +type ModernBertJailbreakInitializer struct{} + +func (c *ModernBertJailbreakInitializer) Init(modelID string, useCPU bool, numClasses ...int) error { + err := candle_binding.InitModernBertJailbreakClassifier(modelID, useCPU) + if err != nil { + return fmt.Errorf("failed to initialize ModernBERT jailbreak classifier: %w", err) + } + log.Printf("Initialized ModernBERT jailbreak classifier (classes auto-detected from model)") + return nil +} + +// createJailbreakInitializer creates the appropriate jailbreak initializer based on configuration +func createJailbreakInitializer(useModernBERT bool) JailbreakInitializer { + if useModernBERT { + return &ModernBertJailbreakInitializer{} + } + return &LinearJailbreakInitializer{} +} + type JailbreakInference interface { Classify(text string) (candle_binding.ClassResult, error) } @@ -105,35 +138,31 @@ type PIIAnalysisResult struct { // Classifier handles text classification, model selection, and jailbreak detection functionality type Classifier struct { // Dependencies - categoryInference CategoryInference - jailbreakInference JailbreakInference - piiInference PIIInference + categoryInference CategoryInference + jailbreakInitializer JailbreakInitializer + jailbreakInference JailbreakInference + piiInference PIIInference Config *config.RouterConfig CategoryMapping *CategoryMapping PIIMapping *PIIMapping JailbreakMapping *JailbreakMapping - // Model selection fields - ModelLoad map[string]int - ModelLoadLock sync.Mutex - ModelTTFT map[string]float64 // Jailbreak detection state JailbreakInitialized bool } // NewClassifier creates a new classifier with model selection and jailbreak detection capabilities -func NewClassifier(cfg *config.RouterConfig, categoryMapping *CategoryMapping, piiMapping *PIIMapping, jailbreakMapping *JailbreakMapping, modelTTFT map[string]float64) *Classifier { +func NewClassifier(cfg *config.RouterConfig, categoryMapping *CategoryMapping, piiMapping *PIIMapping, jailbreakMapping *JailbreakMapping) *Classifier { return &Classifier{ - categoryInference: createCategoryInference(cfg.Classifier.CategoryModel.UseModernBERT), - jailbreakInference: createJailbreakInference(cfg.PromptGuard.UseModernBERT), - piiInference: createPIIInference(), + categoryInference: createCategoryInference(cfg.Classifier.CategoryModel.UseModernBERT), + jailbreakInitializer: createJailbreakInitializer(cfg.PromptGuard.UseModernBERT), + jailbreakInference: createJailbreakInference(cfg.PromptGuard.UseModernBERT), + piiInference: createPIIInference(), Config: cfg, CategoryMapping: categoryMapping, PIIMapping: piiMapping, JailbreakMapping: jailbreakMapping, - ModelLoad: make(map[string]int), - ModelTTFT: modelTTFT, JailbreakInitialized: false, } } @@ -149,21 +178,8 @@ func (c *Classifier) InitializeJailbreakClassifier() error { return fmt.Errorf("not enough jailbreak types for classification, need at least 2, got %d", numClasses) } - var err error - if c.Config.PromptGuard.UseModernBERT { - // Initialize ModernBERT jailbreak classifier - err = candle_binding.InitModernBertJailbreakClassifier(c.Config.PromptGuard.ModelID, c.Config.PromptGuard.UseCPU) - if err != nil { - return fmt.Errorf("failed to initialize ModernBERT jailbreak classifier: %w", err) - } - log.Printf("Initialized ModernBERT jailbreak classifier (classes auto-detected from model)") - } else { - // Initialize linear jailbreak classifier - err = candle_binding.InitJailbreakClassifier(c.Config.PromptGuard.ModelID, numClasses, c.Config.PromptGuard.UseCPU) - if err != nil { - return fmt.Errorf("failed to initialize jailbreak classifier: %w", err) - } - log.Printf("Initialized linear jailbreak classifier with %d classes", numClasses) + if err := c.jailbreakInitializer.Init(c.Config.PromptGuard.ModelID, c.Config.PromptGuard.UseCPU, numClasses); err != nil { + return err } c.JailbreakInitialized = true @@ -446,57 +462,51 @@ func (c *Classifier) ClassifyAndSelectBestModel(query string) string { // SelectBestModelForCategory selects the best model from a category based on score and TTFT func (c *Classifier) SelectBestModelForCategory(categoryName string) string { - var cat *config.Category - for i, category := range c.Config.Categories { - if strings.EqualFold(category.Name, categoryName) { - cat = &c.Config.Categories[i] - break - } - } - + cat := c.findCategory(categoryName) if cat == nil { log.Printf("Could not find matching category %s in config, using default model", categoryName) return c.Config.DefaultModel } - c.ModelLoadLock.Lock() - defer c.ModelLoadLock.Unlock() - - bestModel := "" - bestScore := -1.0 - bestQuality := 0.0 - - if c.Config.Classifier.LoadAware { - c.forEachModelScore(cat, func(modelScore config.ModelScore) { - quality := modelScore.Score - model := modelScore.Model - baseTTFT := c.ModelTTFT[model] - load := c.ModelLoad[model] - estTTFT := baseTTFT * (1 + float64(load)) - if estTTFT == 0 { - estTTFT = 1 - } - score := quality / estTTFT - c.updateBestModel(score, quality, model, &bestScore, &bestQuality, &bestModel) - }) - } else { - c.forEachModelScore(cat, func(modelScore config.ModelScore) { - quality := modelScore.Score - model := modelScore.Model - c.updateBestModel(quality, quality, model, &bestScore, &bestQuality, &bestModel) - }) - } + bestModel, bestScore := c.selectBestModelInternal(cat, nil) if bestModel == "" { log.Printf("No models found for category %s, using default model", categoryName) return c.Config.DefaultModel } - log.Printf("Selected model %s for category %s with quality %.4f and combined score %.4e", - bestModel, categoryName, bestQuality, bestScore) + log.Printf("Selected model %s for category %s with score %.4f", bestModel, categoryName, bestScore) return bestModel } +// findCategory finds the category configuration by name (case-insensitive) +func (c *Classifier) findCategory(categoryName string) *config.Category { + for i, category := range c.Config.Categories { + if strings.EqualFold(category.Name, categoryName) { + return &c.Config.Categories[i] + } + } + return nil +} + +// selectBestModelInternal performs the core model selection logic +// +// modelFilter is optional - if provided, only models passing the filter will be considered +func (c *Classifier) selectBestModelInternal(cat *config.Category, modelFilter func(string) bool) (string, float64) { + bestModel := "" + bestScore := -1.0 + + c.forEachModelScore(cat, func(modelScore config.ModelScore) { + model := modelScore.Model + if modelFilter != nil && !modelFilter(model) { + return + } + c.updateBestModel(modelScore.Score, model, &bestScore, &bestModel) + }) + + return bestModel, bestScore +} + // forEachModelScore traverses the ModelScores document of the category and executes the callback for each element. func (c *Classifier) forEachModelScore(cat *config.Category, fn func(modelScore config.ModelScore)) { for _, modelScore := range cat.ModelScores { @@ -510,56 +520,23 @@ func (c *Classifier) SelectBestModelFromList(candidateModels []string, categoryN return c.Config.DefaultModel } - // Find the category configuration - var cat *config.Category - for i, category := range c.Config.Categories { - if strings.EqualFold(category.Name, categoryName) { - cat = &c.Config.Categories[i] - break - } - } - + cat := c.findCategory(categoryName) if cat == nil { // Return first candidate if category not found return candidateModels[0] } - c.ModelLoadLock.Lock() - defer c.ModelLoadLock.Unlock() - - bestModel := "" - bestScore := -1.0 - bestQuality := 0.0 - - filteredFn := func(modelScore config.ModelScore) { - model := modelScore.Model - if !slices.Contains(candidateModels, model) { - return - } - quality := modelScore.Score - if c.Config.Classifier.LoadAware { - baseTTFT := c.ModelTTFT[model] - load := c.ModelLoad[model] - estTTFT := baseTTFT * (1 + float64(load)) - if estTTFT == 0 { - estTTFT = 1 // avoid div by zero - } - score := quality / estTTFT - c.updateBestModel(score, quality, model, &bestScore, &bestQuality, &bestModel) - } else { - c.updateBestModel(quality, quality, model, &bestScore, &bestQuality, &bestModel) - } - } - - c.forEachModelScore(cat, filteredFn) + bestModel, bestScore := c.selectBestModelInternal(cat, + func(model string) bool { + return slices.Contains(candidateModels, model) + }) if bestModel == "" { log.Printf("No suitable model found from candidates for category %s, using first candidate", categoryName) return candidateModels[0] } - log.Printf("Selected best model %s for category %s with quality %.4f and combined score %.4e", - bestModel, categoryName, bestQuality, bestScore) + log.Printf("Selected best model %s for category %s with score %.4f", bestModel, categoryName, bestScore) return bestModel } @@ -579,27 +556,10 @@ func (c *Classifier) GetModelsForCategory(categoryName string) []string { return models } -// IncrementModelLoad increments the load counter for a model -func (c *Classifier) IncrementModelLoad(model string) { - c.ModelLoadLock.Lock() - defer c.ModelLoadLock.Unlock() - c.ModelLoad[model]++ -} - -// DecrementModelLoad decrements the load counter for a model -func (c *Classifier) DecrementModelLoad(model string) { - c.ModelLoadLock.Lock() - defer c.ModelLoadLock.Unlock() - if c.ModelLoad[model] > 0 { - c.ModelLoad[model]-- - } -} - -// updateBestModel updates the best model, score, and quality if the new score is better. -func (c *Classifier) updateBestModel(score, quality float64, model string, bestScore *float64, bestQuality *float64, bestModel *string) { +// updateBestModel updates the best model, score if the new score is better. +func (c *Classifier) updateBestModel(score float64, model string, bestScore *float64, bestModel *string) { if score > *bestScore { *bestScore = score *bestModel = model - *bestQuality = quality } } diff --git a/src/semantic-router/pkg/utils/classification/classifier_test.go b/src/semantic-router/pkg/utils/classification/classifier_test.go index 32682d28..8a9a5a17 100644 --- a/src/semantic-router/pkg/utils/classification/classifier_test.go +++ b/src/semantic-router/pkg/utils/classification/classifier_test.go @@ -25,7 +25,7 @@ func (m *MockCategoryInference) Classify(text string) (candle_binding.ClassResul return m.classifyResult, m.classifyError } -var _ = Describe("ClassifyCategory", func() { +var _ = Describe("category classification and model selection", func() { var ( classifier *Classifier mockCategoryModel *MockCategoryInference @@ -46,202 +46,552 @@ var _ = Describe("ClassifyCategory", func() { } }) - Context("when classification succeeds with high confidence", func() { - It("should return the correct category", func() { - mockCategoryModel.classifyResult = candle_binding.ClassResult{ - Class: 2, - Confidence: 0.95, - } + Describe("classify category", func() { + Context("when category mapping is not initialized", func() { + It("should return error", func() { + classifier.CategoryMapping = nil + _, _, err := classifier.ClassifyCategory("Some text") + Expect(err).To(HaveOccurred()) + Expect(err.Error()).To(ContainSubstring("category mapping not initialized")) + }) + }) - category, score, err := classifier.ClassifyCategory("This is about politics") + Context("when classification succeeds with high confidence", func() { + It("should return the correct category", func() { + mockCategoryModel.classifyResult = candle_binding.ClassResult{ + Class: 2, + Confidence: 0.95, + } - Expect(err).ToNot(HaveOccurred()) - Expect(category).To(Equal("politics")) - Expect(score).To(BeNumerically("~", 0.95, 0.001)) + category, score, err := classifier.ClassifyCategory("This is about politics") + + Expect(err).ToNot(HaveOccurred()) + Expect(category).To(Equal("politics")) + Expect(score).To(BeNumerically("~", 0.95, 0.001)) + }) + }) + + Context("when classification confidence is below threshold", func() { + It("should return empty category", func() { + mockCategoryModel.classifyResult = candle_binding.ClassResult{ + Class: 0, + Confidence: 0.3, + } + + category, score, err := classifier.ClassifyCategory("Ambiguous text") + + Expect(err).ToNot(HaveOccurred()) + Expect(category).To(Equal("")) + Expect(score).To(BeNumerically("~", 0.3, 0.001)) + }) + }) + + Context("when model inference fails", func() { + It("should return empty category with zero score", func() { + mockCategoryModel.classifyError = errors.New("model inference failed") + + category, score, err := classifier.ClassifyCategory("Some text") + + Expect(err).To(HaveOccurred()) + Expect(err.Error()).To(ContainSubstring("classification error")) + Expect(category).To(Equal("")) + Expect(score).To(BeNumerically("~", 0.0, 0.001)) + }) + }) + + Context("when input is empty or invalid", func() { + It("should handle empty text gracefully", func() { + mockCategoryModel.classifyResult = candle_binding.ClassResult{ + Class: 0, + Confidence: 0.8, + } + + category, score, err := classifier.ClassifyCategory("") + + Expect(err).ToNot(HaveOccurred()) + Expect(category).To(Equal("technology")) + Expect(score).To(BeNumerically("~", 0.8, 0.001)) + }) + }) + + Context("when class index is not found in category mapping", func() { + It("should handle invalid category mapping gracefully", func() { + mockCategoryModel.classifyResult = candle_binding.ClassResult{ + Class: 9, + Confidence: 0.8, + } + + category, score, err := classifier.ClassifyCategory("Some text") + + Expect(err).ToNot(HaveOccurred()) + Expect(category).To(Equal("")) + Expect(score).To(BeNumerically("~", 0.8, 0.001)) + }) }) }) - Context("when classification confidence is below threshold", func() { - It("should return empty category", func() { - mockCategoryModel.classifyResult = candle_binding.ClassResult{ - Class: 0, - Confidence: 0.3, - } + BeforeEach(func() { + classifier.Config.Categories = []config.Category{ + { + Name: "technology", + ModelScores: []config.ModelScore{ + {Model: "model-a", Score: 0.9}, + {Model: "model-b", Score: 0.8}, + }, + }, + { + Name: "sports", + ModelScores: []config.ModelScore{}, + }, + } + classifier.Config.DefaultModel = "default-model" + }) + + Describe("select best model for category", func() { + It("should return the best model", func() { + model := classifier.SelectBestModelForCategory("technology") + Expect(model).To(Equal("model-a")) + }) - category, score, err := classifier.ClassifyCategory("Ambiguous text") + Context("when category is not found", func() { + It("should return the default model", func() { + model := classifier.SelectBestModelForCategory("non-existent-category") + Expect(model).To(Equal("default-model")) + }) + }) - Expect(err).ToNot(HaveOccurred()) - Expect(category).To(Equal("")) - Expect(score).To(BeNumerically("~", 0.3, 0.001)) + Context("when no best model is found", func() { + It("should return the default model", func() { + model := classifier.SelectBestModelForCategory("sports") + Expect(model).To(Equal("default-model")) + }) }) }) - Context("when model inference fails", func() { - It("should return empty category with zero score", func() { - mockCategoryModel.classifyError = errors.New("model inference failed") + Describe("select best model from list", func() { + It("should return the best model", func() { + model := classifier.SelectBestModelFromList([]string{"model-a"}, "technology") + Expect(model).To(Equal("model-a")) + }) + + Context("when candidate models are empty", func() { + It("should return the default model", func() { + model := classifier.SelectBestModelFromList([]string{}, "technology") + Expect(model).To(Equal("default-model")) + }) + }) - category, score, err := classifier.ClassifyCategory("Some text") + Context("when category is not found", func() { + It("should return the first candidate model", func() { + model := classifier.SelectBestModelFromList([]string{"model-a"}, "non-existent-category") + Expect(model).To(Equal("model-a")) + }) + }) - Expect(err).To(HaveOccurred()) - Expect(err.Error()).To(ContainSubstring("classification error")) - Expect(category).To(Equal("")) - Expect(score).To(BeNumerically("~", 0.0, 0.001)) + Context("when the model is not in the candidate models", func() { + It("should return the first candidate model", func() { + model := classifier.SelectBestModelFromList([]string{"model-c"}, "technology") + Expect(model).To(Equal("model-c")) + }) }) }) - Context("when input is empty or invalid", func() { - It("should handle empty text gracefully", func() { + Describe("classify and select best model", func() { + It("should return the best model", func() { mockCategoryModel.classifyResult = candle_binding.ClassResult{ Class: 0, - Confidence: 0.8, + Confidence: 0.9, } + model := classifier.ClassifyAndSelectBestModel("Some text") + Expect(model).To(Equal("model-a")) + }) - category, score, err := classifier.ClassifyCategory("") + Context("when the categories are empty", func() { + It("should return the default model", func() { + classifier.Config.Categories = nil + model := classifier.ClassifyAndSelectBestModel("Some text") + Expect(model).To(Equal("default-model")) + }) + }) - Expect(err).ToNot(HaveOccurred()) - Expect(category).To(Equal("technology")) - Expect(score).To(BeNumerically("~", 0.8, 0.001)) + Context("when the classification fails", func() { + It("should return the default model", func() { + mockCategoryModel.classifyError = errors.New("classification failed") + model := classifier.ClassifyAndSelectBestModel("Some text") + Expect(model).To(Equal("default-model")) + }) + }) + + Context("when the category name is empty", func() { + It("should return the default model", func() { + mockCategoryModel.classifyResult = candle_binding.ClassResult{ + Class: 9, + Confidence: 0.9, + } + model := classifier.ClassifyAndSelectBestModel("Some text") + Expect(model).To(Equal("default-model")) + }) }) }) - Context("when class index is not found in category mapping", func() { - It("should handle invalid category mapping gracefully", func() { - mockCategoryModel.classifyResult = candle_binding.ClassResult{ - Class: 9, - Confidence: 0.8, - } + Describe("internal helper methods", func() { + type row struct { + query string + want *config.Category + } + + DescribeTable("find category", + func(r row) { + cat := classifier.findCategory(r.query) + if r.want == nil { + Expect(cat).To(BeNil()) + } else { + Expect(cat.Name).To(Equal(r.want.Name)) + } + }, + Entry("should find category case-insensitively", row{query: "TECHNOLOGY", want: &config.Category{Name: "technology"}}), + Entry("should return nil for non-existent category", row{query: "non-existent", want: nil}), + ) + + Describe("select best model internal", func() { + + It("should select best model without filter", func() { + cat := &config.Category{ + Name: "test", + ModelScores: []config.ModelScore{ + {Model: "model-a", Score: 0.9}, + {Model: "model-b", Score: 0.8}, + }, + } - category, score, err := classifier.ClassifyCategory("Some text") + bestModel, score := classifier.selectBestModelInternal(cat, nil) - Expect(err).ToNot(HaveOccurred()) - Expect(category).To(Equal("")) - Expect(score).To(BeNumerically("~", 0.8, 0.001)) + Expect(bestModel).To(Equal("model-a")) + Expect(score).To(BeNumerically("~", 0.9, 0.001)) + }) + + It("should select best model with filter", func() { + cat := &config.Category{ + Name: "test", + ModelScores: []config.ModelScore{ + {Model: "model-a", Score: 0.9}, + {Model: "model-b", Score: 0.8}, + {Model: "model-c", Score: 0.7}, + }, + } + filter := func(model string) bool { + return model == "model-b" || model == "model-c" + } + + bestModel, score := classifier.selectBestModelInternal(cat, filter) + + Expect(bestModel).To(Equal("model-b")) + Expect(score).To(BeNumerically("~", 0.8, 0.001)) + }) + + It("should return empty when no models match filter", func() { + cat := &config.Category{ + Name: "test", + ModelScores: []config.ModelScore{ + {Model: "model-a", Score: 0.9}, + {Model: "model-b", Score: 0.8}, + }, + } + filter := func(model string) bool { + return model == "non-existent-model" + } + + bestModel, score := classifier.selectBestModelInternal(cat, filter) + + Expect(bestModel).To(Equal("")) + Expect(score).To(BeNumerically("~", -1.0, 0.001)) + }) + + It("should return empty when category has no models", func() { + cat := &config.Category{ + Name: "test", + ModelScores: []config.ModelScore{}, + } + + bestModel, score := classifier.selectBestModelInternal(cat, nil) + + Expect(bestModel).To(Equal("")) + Expect(score).To(BeNumerically("~", -1.0, 0.001)) + }) }) }) }) -type MockJailbreakInference struct { +type MockJailbreakInferenceResponse struct { classifyResult candle_binding.ClassResult classifyError error } +type MockJailbreakInference struct { + MockJailbreakInferenceResponse + responseMap map[string]MockJailbreakInferenceResponse +} + +func (m *MockJailbreakInference) setMockResponse(text string, class int, confidence float32, err error) { + m.responseMap[text] = MockJailbreakInferenceResponse{ + classifyResult: candle_binding.ClassResult{ + Class: class, + Confidence: confidence, + }, + classifyError: err, + } +} + func (m *MockJailbreakInference) Classify(text string) (candle_binding.ClassResult, error) { + if response, exists := m.responseMap[text]; exists { + return response.classifyResult, response.classifyError + } return m.classifyResult, m.classifyError } -var _ = Describe("CheckForJailbreak", func() { +type MockJailbreakInitializer struct { + InitError error +} + +func (m *MockJailbreakInitializer) Init(modelID string, useCPU bool, numClasses ...int) error { + return m.InitError +} + +var _ = Describe("initialize jailbreak classifier", func() { var ( - classifier *Classifier - mockJailbreakModel *MockJailbreakInference + classifier *Classifier + mockJailbreakInitializer *MockJailbreakInitializer ) BeforeEach(func() { - mockJailbreakModel = &MockJailbreakInference{} + mockJailbreakInitializer = &MockJailbreakInitializer{InitError: nil} cfg := &config.RouterConfig{} cfg.PromptGuard.Enabled = true cfg.PromptGuard.ModelID = "test-model" cfg.PromptGuard.JailbreakMappingPath = "test-mapping" cfg.PromptGuard.Threshold = 0.7 - classifier = &Classifier{ - jailbreakInference: mockJailbreakModel, - Config: cfg, + jailbreakInitializer: mockJailbreakInitializer, + Config: cfg, JailbreakMapping: &JailbreakMapping{ LabelToIdx: map[string]int{"jailbreak": 0, "benign": 1}, IdxToLabel: map[string]string{"0": "jailbreak", "1": "benign"}, }, - JailbreakInitialized: true, + JailbreakInitialized: false, } }) - Context("when jailbreak is detected with high confidence", func() { - It("should return true with jailbreak type", func() { - mockJailbreakModel.classifyResult = candle_binding.ClassResult{ - Class: 0, - Confidence: 0.9, - } - - isJailbreak, jailbreakType, confidence, err := classifier.CheckForJailbreak("This is a jailbreak attempt") + It("should initialize jailbreak classifier", func() { + err := classifier.InitializeJailbreakClassifier() + Expect(err).ToNot(HaveOccurred()) + Expect(classifier.JailbreakInitialized).To(BeTrue()) + }) + Context("when jailbreak mapping is not initialized", func() { + It("should return nil", func() { + classifier.JailbreakMapping = nil + err := classifier.InitializeJailbreakClassifier() Expect(err).ToNot(HaveOccurred()) - Expect(isJailbreak).To(BeTrue()) - Expect(jailbreakType).To(Equal("jailbreak")) - Expect(confidence).To(BeNumerically("~", 0.9, 0.001)) + Expect(classifier.JailbreakInitialized).To(BeFalse()) }) }) - Context("when text is benign with high confidence", func() { - It("should return false with benign type", func() { - mockJailbreakModel.classifyResult = candle_binding.ClassResult{ - Class: 1, - Confidence: 0.9, + Context("when not enough jailbreak types", func() { + It("should return error", func() { + classifier.JailbreakMapping = &JailbreakMapping{ + LabelToIdx: map[string]int{"jailbreak": 0}, + IdxToLabel: map[string]string{"0": "jailbreak"}, } - isJailbreak, jailbreakType, confidence, err := classifier.CheckForJailbreak("This is a normal question") + err := classifier.InitializeJailbreakClassifier() - Expect(err).ToNot(HaveOccurred()) - Expect(isJailbreak).To(BeFalse()) - Expect(jailbreakType).To(Equal("benign")) - Expect(confidence).To(BeNumerically("~", 0.9, 0.001)) + Expect(err).To(HaveOccurred()) + Expect(err.Error()).To(ContainSubstring("not enough jailbreak types for classification")) + Expect(classifier.JailbreakInitialized).To(BeFalse()) }) }) - Context("when jailbreak confidence is below threshold", func() { - It("should return false even if classified as jailbreak", func() { - mockJailbreakModel.classifyResult = candle_binding.ClassResult{ - Class: 0, - Confidence: 0.5, - } + Context("when initialize jailbreak classifier fails", func() { + It("should return error", func() { + mockJailbreakInitializer.InitError = errors.New("initialize jailbreak classifier failed") - isJailbreak, jailbreakType, confidence, err := classifier.CheckForJailbreak("Ambiguous text") + err := classifier.InitializeJailbreakClassifier() - Expect(err).ToNot(HaveOccurred()) - Expect(isJailbreak).To(BeFalse()) - Expect(jailbreakType).To(Equal("jailbreak")) - Expect(confidence).To(BeNumerically("~", 0.5, 0.001)) + Expect(err).To(HaveOccurred()) + Expect(err.Error()).To(ContainSubstring("initialize jailbreak classifier failed")) + Expect(classifier.JailbreakInitialized).To(BeFalse()) }) }) +}) - Context("when model inference fails", func() { - It("should return error", func() { - mockJailbreakModel.classifyError = errors.New("model inference failed") +var _ = Describe("jailbreak detection", func() { + var ( + classifier *Classifier + mockJailbreakModel *MockJailbreakInference + ) - isJailbreak, jailbreakType, confidence, err := classifier.CheckForJailbreak("Some text") + BeforeEach(func() { + mockJailbreakModel = &MockJailbreakInference{} + mockJailbreakModel.responseMap = make(map[string]MockJailbreakInferenceResponse) + cfg := &config.RouterConfig{} + cfg.PromptGuard.Enabled = true + cfg.PromptGuard.ModelID = "test-model" + cfg.PromptGuard.JailbreakMappingPath = "test-mapping" + cfg.PromptGuard.Threshold = 0.7 - Expect(err).To(HaveOccurred()) - Expect(err.Error()).To(ContainSubstring("jailbreak classification failed")) - Expect(isJailbreak).To(BeFalse()) - Expect(jailbreakType).To(Equal("")) - Expect(confidence).To(BeNumerically("~", 0.0, 0.001)) - }) + classifier = &Classifier{ + jailbreakInference: mockJailbreakModel, + Config: cfg, + JailbreakMapping: &JailbreakMapping{ + LabelToIdx: map[string]int{"jailbreak": 0, "benign": 1}, + IdxToLabel: map[string]string{"0": "jailbreak", "1": "benign"}, + }, + JailbreakInitialized: true, + } }) - Context("when class index is not found in jailbreak mapping", func() { - It("should return error for unknown class", func() { - mockJailbreakModel.classifyResult = candle_binding.ClassResult{ - Class: 9, - Confidence: 0.9, - } + Describe("check for jailbreak", func() { + Context("when jailbreak mapping is not initialized", func() { + It("should return false", func() { + classifier.JailbreakMapping = nil + isJailbreak, _, _, err := classifier.CheckForJailbreak("Some text") + Expect(err).ToNot(HaveOccurred()) + Expect(isJailbreak).To(BeFalse()) + }) + }) + + Context("when text is empty", func() { + It("should return false", func() { + isJailbreak, _, _, err := classifier.CheckForJailbreak("") + Expect(err).ToNot(HaveOccurred()) + Expect(isJailbreak).To(BeFalse()) + }) + }) - isJailbreak, jailbreakType, confidence, err := classifier.CheckForJailbreak("Some text") + Context("when jailbreak is detected with high confidence", func() { + It("should return true with jailbreak type", func() { + mockJailbreakModel.classifyResult = candle_binding.ClassResult{ + Class: 0, + Confidence: 0.9, + } + isJailbreak, jailbreakType, confidence, err := classifier.CheckForJailbreak("This is a jailbreak attempt") + Expect(err).ToNot(HaveOccurred()) + Expect(isJailbreak).To(BeTrue()) + Expect(jailbreakType).To(Equal("jailbreak")) + Expect(confidence).To(BeNumerically("~", 0.9, 0.001)) + }) + }) - Expect(err).To(HaveOccurred()) - Expect(err.Error()).To(ContainSubstring("unknown jailbreak class index")) - Expect(isJailbreak).To(BeFalse()) - Expect(jailbreakType).To(Equal("")) - Expect(confidence).To(BeNumerically("~", 0.0, 0.001)) + Context("when text is benign with high confidence", func() { + It("should return false with benign type", func() { + mockJailbreakModel.classifyResult = candle_binding.ClassResult{ + Class: 1, + Confidence: 0.9, + } + isJailbreak, jailbreakType, confidence, err := classifier.CheckForJailbreak("This is a normal question") + Expect(err).ToNot(HaveOccurred()) + Expect(isJailbreak).To(BeFalse()) + Expect(jailbreakType).To(Equal("benign")) + Expect(confidence).To(BeNumerically("~", 0.9, 0.001)) + }) + }) + + Context("when jailbreak confidence is below threshold", func() { + It("should return false even if classified as jailbreak", func() { + mockJailbreakModel.classifyResult = candle_binding.ClassResult{ + Class: 0, + Confidence: 0.5, + } + isJailbreak, jailbreakType, confidence, err := classifier.CheckForJailbreak("Ambiguous text") + Expect(err).ToNot(HaveOccurred()) + Expect(isJailbreak).To(BeFalse()) + Expect(jailbreakType).To(Equal("jailbreak")) + Expect(confidence).To(BeNumerically("~", 0.5, 0.001)) + }) + }) + + Context("when model inference fails", func() { + It("should return error", func() { + mockJailbreakModel.classifyError = errors.New("model inference failed") + isJailbreak, jailbreakType, confidence, err := classifier.CheckForJailbreak("Some text") + Expect(err).To(HaveOccurred()) + Expect(err.Error()).To(ContainSubstring("jailbreak classification failed")) + Expect(isJailbreak).To(BeFalse()) + Expect(jailbreakType).To(Equal("")) + Expect(confidence).To(BeNumerically("~", 0.0, 0.001)) + }) + }) + + Context("when class index is not found in jailbreak mapping", func() { + It("should return error for unknown class", func() { + mockJailbreakModel.classifyResult = candle_binding.ClassResult{ + Class: 9, + Confidence: 0.9, + } + isJailbreak, jailbreakType, confidence, err := classifier.CheckForJailbreak("Some text") + Expect(err).To(HaveOccurred()) + Expect(err.Error()).To(ContainSubstring("unknown jailbreak class index")) + Expect(isJailbreak).To(BeFalse()) + Expect(jailbreakType).To(Equal("")) + Expect(confidence).To(BeNumerically("~", 0.0, 0.001)) + }) + }) + }) + + Describe("analyze content for jailbreak", func() { + Context("when jailbreak mapping is not initialized", func() { + It("should return empty list", func() { + classifier.JailbreakMapping = nil + hasJailbreak, _, err := classifier.AnalyzeContentForJailbreak([]string{"Some text"}) + Expect(err).ToNot(HaveOccurred()) + Expect(hasJailbreak).To(BeFalse()) + }) + }) + + Context("when 5 texts in total, 1 has jailbreak, 1 has empty text, 1 has model inference failure", func() { + It("should return 3 results with correct analysis", func() { + mockJailbreakModel.setMockResponse("text0", 0, 0.9, errors.New("model inference failed")) + mockJailbreakModel.setMockResponse("text1", 0, 0.3, nil) + mockJailbreakModel.setMockResponse("text2", 1, 0.9, nil) + mockJailbreakModel.setMockResponse("text3", 0, 0.9, nil) + mockJailbreakModel.setMockResponse("", 0, 0.9, nil) + contentList := []string{"text0", "text1", "text2", "text3", ""} + hasJailbreak, results, err := classifier.AnalyzeContentForJailbreak(contentList) + Expect(err).ToNot(HaveOccurred()) + Expect(hasJailbreak).To(BeTrue()) + // only 3 results because the first and the last are skipped because of model inference failure and empty text + Expect(results).To(HaveLen(3)) + Expect(results[0].IsJailbreak).To(BeFalse()) + Expect(results[0].JailbreakType).To(Equal("jailbreak")) + Expect(results[0].Confidence).To(BeNumerically("~", 0.3, 0.001)) + Expect(results[1].IsJailbreak).To(BeFalse()) + Expect(results[1].JailbreakType).To(Equal("benign")) + Expect(results[1].Confidence).To(BeNumerically("~", 0.9, 0.001)) + Expect(results[2].IsJailbreak).To(BeTrue()) + Expect(results[2].JailbreakType).To(Equal("jailbreak")) + Expect(results[2].Confidence).To(BeNumerically("~", 0.9, 0.001)) + }) }) }) }) -type PIIInferenceResponse struct { +type MockPIIInferenceResponse struct { classifyTokensResult candle_binding.TokenClassificationResult classifyTokensError error } type MockPIIInference struct { - PIIInferenceResponse - responseMap map[string]PIIInferenceResponse + MockPIIInferenceResponse + responseMap map[string]MockPIIInferenceResponse +} + +func (m *MockPIIInference) setMockResponse(text string, entities []candle_binding.TokenEntity, err error) { + m.responseMap[text] = MockPIIInferenceResponse{ + classifyTokensResult: candle_binding.TokenClassificationResult{ + Entities: entities, + }, + classifyTokensError: err, + } } func (m *MockPIIInference) ClassifyTokens(text string, configPath string) (candle_binding.TokenClassificationResult, error) { @@ -251,7 +601,7 @@ func (m *MockPIIInference) ClassifyTokens(text string, configPath string) (candl return m.classifyTokensResult, m.classifyTokensError } -var _ = Describe("PIIClassification", func() { +var _ = Describe("PII detection", func() { var ( classifier *Classifier mockPIIModel *MockPIIInference @@ -259,6 +609,7 @@ var _ = Describe("PIIClassification", func() { BeforeEach(func() { mockPIIModel = &MockPIIInference{} + mockPIIModel.responseMap = make(map[string]MockPIIInferenceResponse) cfg := &config.RouterConfig{} cfg.Classifier.PIIModel.ModelID = "test-pii-model" cfg.Classifier.PIIModel.Threshold = 0.7 @@ -273,7 +624,24 @@ var _ = Describe("PIIClassification", func() { } }) - Describe("ClassifyPII", func() { + Describe("classify PII", func() { + Context("when PII mapping is not initialized", func() { + It("should return empty list", func() { + classifier.PIIMapping = nil + piiTypes, err := classifier.ClassifyPII("Some text") + Expect(err).ToNot(HaveOccurred()) + Expect(piiTypes).To(BeEmpty()) + }) + }) + + Context("when text is empty", func() { + It("should return empty list", func() { + piiTypes, err := classifier.ClassifyPII("") + Expect(err).ToNot(HaveOccurred()) + Expect(piiTypes).To(BeEmpty()) + }) + }) + Context("when PII entities are detected above threshold", func() { It("should return detected PII types", func() { mockPIIModel.classifyTokensResult = candle_binding.TokenClassificationResult{ @@ -356,73 +724,157 @@ var _ = Describe("PIIClassification", func() { }) }) - Describe("AnalyzeContentForPII", func() { - Context("when some texts contain PII", func() { - It("should return detailed analysis for each text", func() { - mockPIIModel.responseMap = make(map[string]PIIInferenceResponse) - mockPIIModel.responseMap["Alice Smith"] = PIIInferenceResponse{ - classifyTokensResult: candle_binding.TokenClassificationResult{ - Entities: []candle_binding.TokenEntity{ - { - EntityType: "PERSON", - Text: "Alice", - Start: 0, - End: 5, - Confidence: 0.9, - }, - }, - }, - classifyTokensError: nil, - } + Describe("analyze content for PII", func() { + Context("when PII mapping is not initialized", func() { + It("should return empty list", func() { + classifier.PIIMapping = nil + hasPII, _, err := classifier.AnalyzeContentForPII([]string{"Some text"}) + Expect(err).ToNot(HaveOccurred()) + Expect(hasPII).To(BeFalse()) + }) + }) - mockPIIModel.responseMap["No PII here"] = PIIInferenceResponse{} + Context("when 5 texts in total, 1 has PII, 1 has empty text, 1 has model inference failure", func() { + It("should return 3 results with correct analysis", func() { + mockPIIModel.setMockResponse("Bob", []candle_binding.TokenEntity{}, errors.New("model inference failed")) + mockPIIModel.setMockResponse("Lisa Smith", []candle_binding.TokenEntity{ + { + EntityType: "PERSON", + Text: "Lisa", + Start: 0, + End: 4, + Confidence: 0.3, + }, + }, nil) + mockPIIModel.setMockResponse("Alice Smith", []candle_binding.TokenEntity{ + { + EntityType: "PERSON", + Text: "Alice", + Start: 0, + End: 5, + Confidence: 0.9, + }, + }, nil) + mockPIIModel.setMockResponse("No PII here", []candle_binding.TokenEntity{}, nil) + mockPIIModel.setMockResponse("", []candle_binding.TokenEntity{}, nil) + contentList := []string{"Bob", "Lisa Smith", "Alice Smith", "No PII here", ""} - contentList := []string{"Alice Smith", "No PII here"} hasPII, results, err := classifier.AnalyzeContentForPII(contentList) Expect(err).ToNot(HaveOccurred()) Expect(hasPII).To(BeTrue()) - Expect(results).To(HaveLen(2)) - Expect(results[0].HasPII).To(BeTrue()) - Expect(results[0].Entities).To(HaveLen(1)) - Expect(results[0].Entities[0].EntityType).To(Equal("PERSON")) - Expect(results[0].Entities[0].Text).To(Equal("Alice")) - Expect(results[1].HasPII).To(BeFalse()) - Expect(results[1].Entities).To(BeEmpty()) + // only 3 results because the first and the last are skipped because of model inference failure and empty text + Expect(results).To(HaveLen(3)) + Expect(results[0].HasPII).To(BeFalse()) + Expect(results[0].Entities).To(BeEmpty()) + Expect(results[1].HasPII).To(BeTrue()) + Expect(results[1].Entities).To(HaveLen(1)) + Expect(results[1].Entities[0].EntityType).To(Equal("PERSON")) + Expect(results[1].Entities[0].Text).To(Equal("Alice")) + Expect(results[2].HasPII).To(BeFalse()) + Expect(results[2].Entities).To(BeEmpty()) }) }) + }) - Context("when model inference fails", func() { - It("should return false for hasPII and empty results", func() { - mockPIIModel.classifyTokensError = errors.New("model failed") + Describe("detect PII in content", func() { + Context("when 5 texts in total, 2 has PII, 1 has empty text, 1 has model inference failure", func() { + It("should return 2 detected PII types", func() { + mockPIIModel.setMockResponse("Bob", []candle_binding.TokenEntity{}, errors.New("model inference failed")) + mockPIIModel.setMockResponse("Lisa Smith", []candle_binding.TokenEntity{ + { + EntityType: "PERSON", + Text: "Lisa", + Start: 0, + End: 4, + Confidence: 0.8, + }, + }, nil) + mockPIIModel.setMockResponse("Alice Smith alice@example.com", []candle_binding.TokenEntity{ + { + EntityType: "PERSON", + Text: "Alice", + Start: 0, + End: 5, + Confidence: 0.9, + }, { + EntityType: "EMAIL", + Text: "alice@example.com", + Start: 12, + End: 29, + Confidence: 0.9, + }, + }, nil) + mockPIIModel.setMockResponse("No PII here", []candle_binding.TokenEntity{}, nil) + mockPIIModel.setMockResponse("", []candle_binding.TokenEntity{}, nil) + contentList := []string{"Bob", "Lisa Smith", "Alice Smith alice@example.com", "No PII here", ""} - contentList := []string{"Text 1", "Text 2"} - hasPII, results, err := classifier.AnalyzeContentForPII(contentList) + detectedPII := classifier.DetectPIIInContent(contentList) - Expect(err).ToNot(HaveOccurred()) - Expect(hasPII).To(BeFalse()) - Expect(results).To(BeEmpty()) + Expect(detectedPII).To(ConsistOf("PERSON", "EMAIL")) }) }) }) }) +var _ = Describe("get models for category", func() { + var c *Classifier + + BeforeEach(func() { + c = &Classifier{ + Config: &config.RouterConfig{ + Categories: []config.Category{ + { + Name: "Toxicity", + ModelScores: []config.ModelScore{ + {Model: "m1"}, {Model: "m2"}, + }, + }, + { + Name: "Toxicity", // duplicate name, should be ignored by "first wins" + ModelScores: []config.ModelScore{{Model: "mX"}}, + }, + { + Name: "Jailbreak", + ModelScores: []config.ModelScore{{Model: "jb1"}}, + }, + }, + }, + } + }) + + type row struct { + query string + want []string + } + + DescribeTable("lookup behavior", + func(r row) { + got := c.GetModelsForCategory(r.query) + Expect(got).To(Equal(r.want)) + }, + + Entry("case-insensitive match", row{query: "toxicity", want: []string{"m1", "m2"}}), + Entry("no match returns nil slice", row{query: "NotExist", want: nil}), + Entry("another category", row{query: "JAILBREAK", want: []string{"jb1"}}), + ) +}) + func TestUpdateBestModel(t *testing.T) { classifier := &Classifier{} bestScore := 0.5 - bestQuality := 0.5 bestModel := "old-model" - classifier.updateBestModel(0.8, 0.9, "new-model", &bestScore, &bestQuality, &bestModel) - if bestScore != 0.8 || bestQuality != 0.9 || bestModel != "new-model" { - t.Errorf("update: got bestScore=%v, bestQuality=%v, bestModel=%v", bestScore, bestQuality, bestModel) + classifier.updateBestModel(0.8, "new-model", &bestScore, &bestModel) + if bestScore != 0.8 || bestModel != "new-model" { + t.Errorf("update: got bestScore=%v, bestModel=%v", bestScore, bestModel) } - classifier.updateBestModel(0.7, 0.7, "another-model", &bestScore, &bestQuality, &bestModel) - if bestScore != 0.8 || bestQuality != 0.9 || bestModel != "new-model" { - t.Errorf("not update: got bestScore=%v, bestQuality=%v, bestModel=%v", bestScore, bestQuality, bestModel) + classifier.updateBestModel(0.7, "another-model", &bestScore, &bestModel) + if bestScore != 0.8 || bestModel != "new-model" { + t.Errorf("not update: got bestScore=%v, bestModel=%v", bestScore, bestModel) } } diff --git a/src/training/model_eval/result_to_config.py b/src/training/model_eval/result_to_config.py index 1267bada..e5c5c2c3 100644 --- a/src/training/model_eval/result_to_config.py +++ b/src/training/model_eval/result_to_config.py @@ -117,7 +117,6 @@ def generate_config_yaml(category_accuracies, similarity_threshold): "use_cpu": True, "pii_mapping_path": "models/pii_classifier_modernbert-base_presidio_token_model/pii_type_mapping.json", }, - "load_aware": False, }, "categories": [], "default_reasoning_effort": "medium", # Default reasoning effort level (low, medium, high)