Skip to content
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
187 changes: 100 additions & 87 deletions src/semantic-router/pkg/utils/classification/classifier.go
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,40 @@ func createCategoryInference(useModernBERT bool) CategoryInference {
return &LinearCategoryInference{}
}

type JailbreakInitializer interface {
Init(modelID string, useCPU bool, numClasses ...int) error
}

type LinearJailbreakInitializer struct{}

func (c *LinearJailbreakInitializer) Init(modelID string, useCPU bool, numClasses ...int) error {
err := candle_binding.InitJailbreakClassifier(modelID, numClasses[0], useCPU)
if err != nil {
return fmt.Errorf("failed to initialize jailbreak classifier: %w", err)
}
log.Printf("Initialized linear jailbreak classifier with %d classes", numClasses[0])
return nil
}

type ModernBertJailbreakInitializer struct{}

func (c *ModernBertJailbreakInitializer) Init(modelID string, useCPU bool, numClasses ...int) error {
err := candle_binding.InitModernBertJailbreakClassifier(modelID, useCPU)
if err != nil {
return fmt.Errorf("failed to initialize ModernBERT jailbreak classifier: %w", err)
}
log.Printf("Initialized ModernBERT jailbreak classifier (classes auto-detected from model)")
return nil
}

// createJailbreakInitializer creates the appropriate jailbreak initializer based on configuration
func createJailbreakInitializer(useModernBERT bool) JailbreakInitializer {
if useModernBERT {
return &ModernBertJailbreakInitializer{}
}
return &LinearJailbreakInitializer{}
}

type JailbreakInference interface {
Classify(text string) (candle_binding.ClassResult, error)
}
Expand Down Expand Up @@ -105,9 +139,10 @@ type PIIAnalysisResult struct {
// Classifier handles text classification, model selection, and jailbreak detection functionality
type Classifier struct {
// Dependencies
categoryInference CategoryInference
jailbreakInference JailbreakInference
piiInference PIIInference
categoryInference CategoryInference
jailbreakInitializer JailbreakInitializer
jailbreakInference JailbreakInference
piiInference PIIInference

Config *config.RouterConfig
CategoryMapping *CategoryMapping
Expand All @@ -124,9 +159,10 @@ type Classifier struct {
// NewClassifier creates a new classifier with model selection and jailbreak detection capabilities
func NewClassifier(cfg *config.RouterConfig, categoryMapping *CategoryMapping, piiMapping *PIIMapping, jailbreakMapping *JailbreakMapping, modelTTFT map[string]float64) *Classifier {
return &Classifier{
categoryInference: createCategoryInference(cfg.Classifier.CategoryModel.UseModernBERT),
jailbreakInference: createJailbreakInference(cfg.PromptGuard.UseModernBERT),
piiInference: createPIIInference(),
categoryInference: createCategoryInference(cfg.Classifier.CategoryModel.UseModernBERT),
jailbreakInitializer: createJailbreakInitializer(cfg.PromptGuard.UseModernBERT),
jailbreakInference: createJailbreakInference(cfg.PromptGuard.UseModernBERT),
piiInference: createPIIInference(),

Config: cfg,
CategoryMapping: categoryMapping,
Expand All @@ -149,21 +185,8 @@ func (c *Classifier) InitializeJailbreakClassifier() error {
return fmt.Errorf("not enough jailbreak types for classification, need at least 2, got %d", numClasses)
}

var err error
if c.Config.PromptGuard.UseModernBERT {
// Initialize ModernBERT jailbreak classifier
err = candle_binding.InitModernBertJailbreakClassifier(c.Config.PromptGuard.ModelID, c.Config.PromptGuard.UseCPU)
if err != nil {
return fmt.Errorf("failed to initialize ModernBERT jailbreak classifier: %w", err)
}
log.Printf("Initialized ModernBERT jailbreak classifier (classes auto-detected from model)")
} else {
// Initialize linear jailbreak classifier
err = candle_binding.InitJailbreakClassifier(c.Config.PromptGuard.ModelID, numClasses, c.Config.PromptGuard.UseCPU)
if err != nil {
return fmt.Errorf("failed to initialize jailbreak classifier: %w", err)
}
log.Printf("Initialized linear jailbreak classifier with %d classes", numClasses)
if err := c.jailbreakInitializer.Init(c.Config.PromptGuard.ModelID, c.Config.PromptGuard.UseCPU, numClasses); err != nil {
return err
}

c.JailbreakInitialized = true
Expand Down Expand Up @@ -446,14 +469,7 @@ func (c *Classifier) ClassifyAndSelectBestModel(query string) string {

// SelectBestModelForCategory selects the best model from a category based on score and TTFT
func (c *Classifier) SelectBestModelForCategory(categoryName string) string {
var cat *config.Category
for i, category := range c.Config.Categories {
if strings.EqualFold(category.Name, categoryName) {
cat = &c.Config.Categories[i]
break
}
}

cat := c.findCategory(categoryName)
if cat == nil {
log.Printf("Could not find matching category %s in config, using default model", categoryName)
return c.Config.DefaultModel
Expand All @@ -462,30 +478,7 @@ func (c *Classifier) SelectBestModelForCategory(categoryName string) string {
c.ModelLoadLock.Lock()
defer c.ModelLoadLock.Unlock()

bestModel := ""
bestScore := -1.0
bestQuality := 0.0

if c.Config.Classifier.LoadAware {
c.forEachModelScore(cat, func(modelScore config.ModelScore) {
quality := modelScore.Score
model := modelScore.Model
baseTTFT := c.ModelTTFT[model]
load := c.ModelLoad[model]
estTTFT := baseTTFT * (1 + float64(load))
if estTTFT == 0 {
estTTFT = 1
}
score := quality / estTTFT
c.updateBestModel(score, quality, model, &bestScore, &bestQuality, &bestModel)
})
} else {
c.forEachModelScore(cat, func(modelScore config.ModelScore) {
quality := modelScore.Score
model := modelScore.Model
c.updateBestModel(quality, quality, model, &bestScore, &bestQuality, &bestModel)
})
}
bestModel, bestScore, bestQuality := c.selectBestModelInternal(cat, nil)

if bestModel == "" {
log.Printf("No models found for category %s, using default model", categoryName)
Expand All @@ -497,6 +490,55 @@ func (c *Classifier) SelectBestModelForCategory(categoryName string) string {
return bestModel
}

// findCategory finds the category configuration by name (case-insensitive)
func (c *Classifier) findCategory(categoryName string) *config.Category {
for i, category := range c.Config.Categories {
if strings.EqualFold(category.Name, categoryName) {
return &c.Config.Categories[i]
}
}
return nil
}

// calculateModelScore calculates the combined score and quality for a model
func (c *Classifier) calculateModelScore(modelScore config.ModelScore) (float64, float64) {
quality := modelScore.Score
model := modelScore.Model

if !c.Config.Classifier.LoadAware {
return quality, quality
}

baseTTFT := c.ModelTTFT[model]
load := c.ModelLoad[model]
estTTFT := baseTTFT * (1 + float64(load))
if estTTFT == 0 {
estTTFT = 1 // avoid div by zero
}
score := quality / estTTFT
return score, quality
}

// selectBestModelInternal performs the core model selection logic
//
// modelFilter is optional - if provided, only models passing the filter will be considered
func (c *Classifier) selectBestModelInternal(cat *config.Category, modelFilter func(string) bool) (string, float64, float64) {
bestModel := ""
bestScore := -1.0
bestQuality := 0.0

c.forEachModelScore(cat, func(modelScore config.ModelScore) {
model := modelScore.Model
if modelFilter != nil && !modelFilter(model) {
return
}
score, quality := c.calculateModelScore(modelScore)
c.updateBestModel(score, quality, model, &bestScore, &bestQuality, &bestModel)
})

return bestModel, bestScore, bestQuality
}

// forEachModelScore traverses the ModelScores document of the category and executes the callback for each element.
func (c *Classifier) forEachModelScore(cat *config.Category, fn func(modelScore config.ModelScore)) {
for _, modelScore := range cat.ModelScores {
Expand All @@ -510,15 +552,7 @@ func (c *Classifier) SelectBestModelFromList(candidateModels []string, categoryN
return c.Config.DefaultModel
}

// Find the category configuration
var cat *config.Category
for i, category := range c.Config.Categories {
if strings.EqualFold(category.Name, categoryName) {
cat = &c.Config.Categories[i]
break
}
}

cat := c.findCategory(categoryName)
if cat == nil {
// Return first candidate if category not found
return candidateModels[0]
Expand All @@ -527,31 +561,10 @@ func (c *Classifier) SelectBestModelFromList(candidateModels []string, categoryN
c.ModelLoadLock.Lock()
defer c.ModelLoadLock.Unlock()

bestModel := ""
bestScore := -1.0
bestQuality := 0.0

filteredFn := func(modelScore config.ModelScore) {
model := modelScore.Model
if !slices.Contains(candidateModels, model) {
return
}
quality := modelScore.Score
if c.Config.Classifier.LoadAware {
baseTTFT := c.ModelTTFT[model]
load := c.ModelLoad[model]
estTTFT := baseTTFT * (1 + float64(load))
if estTTFT == 0 {
estTTFT = 1 // avoid div by zero
}
score := quality / estTTFT
c.updateBestModel(score, quality, model, &bestScore, &bestQuality, &bestModel)
} else {
c.updateBestModel(quality, quality, model, &bestScore, &bestQuality, &bestModel)
}
}

c.forEachModelScore(cat, filteredFn)
bestModel, bestScore, bestQuality := c.selectBestModelInternal(cat,
func(model string) bool {
return slices.Contains(candidateModels, model)
})

if bestModel == "" {
log.Printf("No suitable model found from candidates for category %s, using first candidate", categoryName)
Expand Down
Loading
Loading