Skip to content

Commit 6f9a8f0

Browse files
committed
feat: optimize classifier and add ttft unit test
Signed-off-by: yuluo-yx <[email protected]>
1 parent ea956c0 commit 6f9a8f0

File tree

2 files changed

+89
-58
lines changed

2 files changed

+89
-58
lines changed

src/semantic-router/pkg/utils/classification/classifier.go

Lines changed: 33 additions & 58 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@ package classification
33
import (
44
"fmt"
55
"log"
6+
"slices"
67
"strings"
78
"sync"
89
"time"
@@ -466,35 +467,24 @@ func (c *Classifier) SelectBestModelForCategory(categoryName string) string {
466467
bestQuality := 0.0
467468

468469
if c.Config.Classifier.LoadAware {
469-
// Load-aware: combine accuracy and TTFT
470-
for _, modelScore := range cat.ModelScores {
470+
c.forEachModelScore(cat, func(modelScore config.ModelScore) {
471471
quality := modelScore.Score
472472
model := modelScore.Model
473-
474473
baseTTFT := c.ModelTTFT[model]
475474
load := c.ModelLoad[model]
476475
estTTFT := baseTTFT * (1 + float64(load))
477476
if estTTFT == 0 {
478-
estTTFT = 1 // avoid div by zero
477+
estTTFT = 1
479478
}
480479
score := quality / estTTFT
481-
if score > bestScore {
482-
bestScore = score
483-
bestModel = model
484-
bestQuality = quality
485-
}
486-
}
480+
c.updateBestModel(score, quality, model, &bestScore, &bestQuality, &bestModel)
481+
})
487482
} else {
488-
// Not load-aware: pick the model with the highest accuracy only
489-
for _, modelScore := range cat.ModelScores {
483+
c.forEachModelScore(cat, func(modelScore config.ModelScore) {
490484
quality := modelScore.Score
491485
model := modelScore.Model
492-
if quality > bestScore {
493-
bestScore = quality
494-
bestModel = model
495-
bestQuality = quality
496-
}
497-
}
486+
c.updateBestModel(quality, quality, model, &bestScore, &bestQuality, &bestModel)
487+
})
498488
}
499489

500490
if bestModel == "" {
@@ -507,6 +497,13 @@ func (c *Classifier) SelectBestModelForCategory(categoryName string) string {
507497
return bestModel
508498
}
509499

500+
// forEachModelScore 遍历 category 的 ModelScores 并对每个元素执行回调
501+
func (c *Classifier) forEachModelScore(cat *config.Category, fn func(modelScore config.ModelScore)) {
502+
for _, modelScore := range cat.ModelScores {
503+
fn(modelScore)
504+
}
505+
}
506+
510507
// SelectBestModelFromList selects the best model from a list of candidate models for a given category
511508
func (c *Classifier) SelectBestModelFromList(candidateModels []string, categoryName string) string {
512509
if len(candidateModels) == 0 {
@@ -534,49 +531,28 @@ func (c *Classifier) SelectBestModelFromList(candidateModels []string, categoryN
534531
bestScore := -1.0
535532
bestQuality := 0.0
536533

537-
if c.Config.Classifier.LoadAware {
538-
// Load-aware: combine accuracy and TTFT
539-
for _, modelScore := range cat.ModelScores {
540-
model := modelScore.Model
541-
542-
// Check if this model is in the candidate list
543-
if !c.contains(candidateModels, model) {
544-
continue
545-
}
546-
547-
quality := modelScore.Score
534+
filteredFn := func(modelScore config.ModelScore) {
535+
model := modelScore.Model
536+
if !slices.Contains(candidateModels, model) {
537+
return
538+
}
539+
quality := modelScore.Score
540+
if c.Config.Classifier.LoadAware {
548541
baseTTFT := c.ModelTTFT[model]
549542
load := c.ModelLoad[model]
550543
estTTFT := baseTTFT * (1 + float64(load))
551544
if estTTFT == 0 {
552545
estTTFT = 1 // avoid div by zero
553546
}
554547
score := quality / estTTFT
555-
if score > bestScore {
556-
bestScore = score
557-
bestModel = model
558-
bestQuality = quality
559-
}
560-
}
561-
} else {
562-
// Not load-aware: pick the model with the highest accuracy only
563-
for _, modelScore := range cat.ModelScores {
564-
model := modelScore.Model
565-
566-
// Check if this model is in the candidate list
567-
if !c.contains(candidateModels, model) {
568-
continue
569-
}
570-
571-
quality := modelScore.Score
572-
if quality > bestScore {
573-
bestScore = quality
574-
bestModel = model
575-
bestQuality = quality
576-
}
548+
c.updateBestModel(score, quality, model, &bestScore, &bestQuality, &bestModel)
549+
} else {
550+
c.updateBestModel(quality, quality, model, &bestScore, &bestQuality, &bestModel)
577551
}
578552
}
579553

554+
c.forEachModelScore(cat, filteredFn)
555+
580556
if bestModel == "" {
581557
log.Printf("No suitable model found from candidates for category %s, using first candidate", categoryName)
582558
return candidateModels[0]
@@ -619,12 +595,11 @@ func (c *Classifier) DecrementModelLoad(model string) {
619595
}
620596
}
621597

622-
// contains checks if a slice contains a string
623-
func (c *Classifier) contains(slice []string, item string) bool {
624-
for _, s := range slice {
625-
if s == item {
626-
return true
627-
}
598+
// updateBestModel updates the best model, score, and quality if the new score is better.
599+
func (c *Classifier) updateBestModel(score, quality float64, model string, bestScore *float64, bestQuality *float64, bestModel *string) {
600+
if score > *bestScore {
601+
*bestScore = score
602+
*bestModel = model
603+
*bestQuality = quality
628604
}
629-
return false
630605
}
Lines changed: 56 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,56 @@
1+
package ttft
2+
3+
import (
4+
"testing"
5+
6+
"github.com/vllm-project/semantic-router/semantic-router/pkg/config"
7+
)
8+
9+
func TestComputeBaseTTFT(t *testing.T) {
10+
11+
gpuConfig := config.GPUConfig{
12+
FLOPS: 1e12, // 1 TFLOP
13+
HBM: 1e11, // 100 GB/s
14+
}
15+
calculator := NewCalculator(gpuConfig)
16+
17+
routerCfg := &config.RouterConfig{}
18+
// Mock config methods if needed, or set up fields so that
19+
// GetModelParamCount, GetModelBatchSize, GetModelContextSize return defaults
20+
21+
ttft := calculator.ComputeBaseTTFT("test-model", routerCfg)
22+
if ttft <= 0 {
23+
t.Errorf("Expected TTFT > 0, got %f", ttft)
24+
}
25+
}
26+
27+
func TestInitializeModelTTFT(t *testing.T) {
28+
gpuConfig := config.GPUConfig{
29+
FLOPS: 1e12,
30+
HBM: 1e11,
31+
}
32+
calculator := NewCalculator(gpuConfig)
33+
34+
// Minimal mock config with two categories and models
35+
routerCfg := &config.RouterConfig{
36+
Categories: []config.Category{
37+
{
38+
ModelScores: []config.ModelScore{
39+
{Model: "model-a", Score: 0.9},
40+
{Model: "model-b", Score: 0.8},
41+
},
42+
},
43+
},
44+
DefaultModel: "model-default",
45+
}
46+
47+
modelTTFT := calculator.InitializeModelTTFT(routerCfg)
48+
if len(modelTTFT) != 3 {
49+
t.Errorf("Expected 3 models in TTFT map, got %d", len(modelTTFT))
50+
}
51+
for model, ttft := range modelTTFT {
52+
if ttft <= 0 {
53+
t.Errorf("Model %s has non-positive TTFT: %f", model, ttft)
54+
}
55+
}
56+
}

0 commit comments

Comments
 (0)