Skip to content

Commit ffe1b09

Browse files
committed
test: add more test cases and refactor SelectBestModelForCategory/SelectBestModelFromList for testability
Signed-off-by: Alex Wang <[email protected]>
1 parent 464ed6c commit ffe1b09

File tree

2 files changed

+660
-217
lines changed

2 files changed

+660
-217
lines changed

src/semantic-router/pkg/utils/classification/classifier.go

Lines changed: 56 additions & 66 deletions
Original file line numberDiff line numberDiff line change
@@ -446,14 +446,7 @@ func (c *Classifier) ClassifyAndSelectBestModel(query string) string {
446446

447447
// SelectBestModelForCategory selects the best model from a category based on score and TTFT
448448
func (c *Classifier) SelectBestModelForCategory(categoryName string) string {
449-
var cat *config.Category
450-
for i, category := range c.Config.Categories {
451-
if strings.EqualFold(category.Name, categoryName) {
452-
cat = &c.Config.Categories[i]
453-
break
454-
}
455-
}
456-
449+
cat := c.findCategory(categoryName)
457450
if cat == nil {
458451
log.Printf("Could not find matching category %s in config, using default model", categoryName)
459452
return c.Config.DefaultModel
@@ -462,30 +455,7 @@ func (c *Classifier) SelectBestModelForCategory(categoryName string) string {
462455
c.ModelLoadLock.Lock()
463456
defer c.ModelLoadLock.Unlock()
464457

465-
bestModel := ""
466-
bestScore := -1.0
467-
bestQuality := 0.0
468-
469-
if c.Config.Classifier.LoadAware {
470-
c.forEachModelScore(cat, func(modelScore config.ModelScore) {
471-
quality := modelScore.Score
472-
model := modelScore.Model
473-
baseTTFT := c.ModelTTFT[model]
474-
load := c.ModelLoad[model]
475-
estTTFT := baseTTFT * (1 + float64(load))
476-
if estTTFT == 0 {
477-
estTTFT = 1
478-
}
479-
score := quality / estTTFT
480-
c.updateBestModel(score, quality, model, &bestScore, &bestQuality, &bestModel)
481-
})
482-
} else {
483-
c.forEachModelScore(cat, func(modelScore config.ModelScore) {
484-
quality := modelScore.Score
485-
model := modelScore.Model
486-
c.updateBestModel(quality, quality, model, &bestScore, &bestQuality, &bestModel)
487-
})
488-
}
458+
bestModel, bestScore, bestQuality := c.selectBestModelInternal(cat, nil)
489459

490460
if bestModel == "" {
491461
log.Printf("No models found for category %s, using default model", categoryName)
@@ -497,6 +467,55 @@ func (c *Classifier) SelectBestModelForCategory(categoryName string) string {
497467
return bestModel
498468
}
499469

470+
// findCategory finds the category configuration by name (case-insensitive)
471+
func (c *Classifier) findCategory(categoryName string) *config.Category {
472+
for i, category := range c.Config.Categories {
473+
if strings.EqualFold(category.Name, categoryName) {
474+
return &c.Config.Categories[i]
475+
}
476+
}
477+
return nil
478+
}
479+
480+
// calculateModelScore calculates the combined score and quality for a model
481+
func (c *Classifier) calculateModelScore(modelScore config.ModelScore) (float64, float64) {
482+
quality := modelScore.Score
483+
model := modelScore.Model
484+
485+
if !c.Config.Classifier.LoadAware {
486+
return quality, quality
487+
}
488+
489+
baseTTFT := c.ModelTTFT[model]
490+
load := c.ModelLoad[model]
491+
estTTFT := baseTTFT * (1 + float64(load))
492+
if estTTFT == 0 {
493+
estTTFT = 1 // avoid div by zero
494+
}
495+
score := quality / estTTFT
496+
return score, quality
497+
}
498+
499+
// selectBestModelInternal performs the core model selection logic
500+
//
501+
// modelFilter is optional - if provided, only models passing the filter will be considered
502+
func (c *Classifier) selectBestModelInternal(cat *config.Category, modelFilter func(string) bool) (string, float64, float64) {
503+
bestModel := ""
504+
bestScore := -1.0
505+
bestQuality := 0.0
506+
507+
c.forEachModelScore(cat, func(modelScore config.ModelScore) {
508+
model := modelScore.Model
509+
if modelFilter != nil && !modelFilter(model) {
510+
return
511+
}
512+
score, quality := c.calculateModelScore(modelScore)
513+
c.updateBestModel(score, quality, model, &bestScore, &bestQuality, &bestModel)
514+
})
515+
516+
return bestModel, bestScore, bestQuality
517+
}
518+
500519
// forEachModelScore traverses the ModelScores document of the category and executes the callback for each element.
501520
func (c *Classifier) forEachModelScore(cat *config.Category, fn func(modelScore config.ModelScore)) {
502521
for _, modelScore := range cat.ModelScores {
@@ -510,15 +529,7 @@ func (c *Classifier) SelectBestModelFromList(candidateModels []string, categoryN
510529
return c.Config.DefaultModel
511530
}
512531

513-
// Find the category configuration
514-
var cat *config.Category
515-
for i, category := range c.Config.Categories {
516-
if strings.EqualFold(category.Name, categoryName) {
517-
cat = &c.Config.Categories[i]
518-
break
519-
}
520-
}
521-
532+
cat := c.findCategory(categoryName)
522533
if cat == nil {
523534
// Return first candidate if category not found
524535
return candidateModels[0]
@@ -527,31 +538,10 @@ func (c *Classifier) SelectBestModelFromList(candidateModels []string, categoryN
527538
c.ModelLoadLock.Lock()
528539
defer c.ModelLoadLock.Unlock()
529540

530-
bestModel := ""
531-
bestScore := -1.0
532-
bestQuality := 0.0
533-
534-
filteredFn := func(modelScore config.ModelScore) {
535-
model := modelScore.Model
536-
if !slices.Contains(candidateModels, model) {
537-
return
538-
}
539-
quality := modelScore.Score
540-
if c.Config.Classifier.LoadAware {
541-
baseTTFT := c.ModelTTFT[model]
542-
load := c.ModelLoad[model]
543-
estTTFT := baseTTFT * (1 + float64(load))
544-
if estTTFT == 0 {
545-
estTTFT = 1 // avoid div by zero
546-
}
547-
score := quality / estTTFT
548-
c.updateBestModel(score, quality, model, &bestScore, &bestQuality, &bestModel)
549-
} else {
550-
c.updateBestModel(quality, quality, model, &bestScore, &bestQuality, &bestModel)
551-
}
552-
}
553-
554-
c.forEachModelScore(cat, filteredFn)
541+
bestModel, bestScore, bestQuality := c.selectBestModelInternal(cat,
542+
func(model string) bool {
543+
return slices.Contains(candidateModels, model)
544+
})
555545

556546
if bestModel == "" {
557547
log.Printf("No suitable model found from candidates for category %s, using first candidate", categoryName)

0 commit comments

Comments
 (0)