Skip to content

Commit ad93419

Browse files
committed
refactor: remove load aware algorithm
Signed-off-by: Alex Wang <[email protected]>
1 parent 4207c9a commit ad93419

File tree

12 files changed

+18
-155
lines changed

12 files changed

+18
-155
lines changed

config/config.yaml

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -90,7 +90,6 @@ classifier:
9090
threshold: 0.7
9191
use_cpu: true
9292
pii_mapping_path: "models/pii_classifier_modernbert-base_presidio_token_model/pii_type_mapping.json"
93-
load_aware: false
9493
categories:
9594
- name: business
9695
use_reasoning: false

deploy/kubernetes/config.yaml

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -78,7 +78,6 @@ classifier:
7878
threshold: 0.7
7979
use_cpu: true
8080
pii_mapping_path: "models/pii_classifier_modernbert-base_presidio_token_model/pii_type_mapping.json"
81-
load_aware: false
8281
categories:
8382
- name: business
8483
model_scores:

src/semantic-router/pkg/config/config.go

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,6 @@ type RouterConfig struct {
3333
UseCPU bool `yaml:"use_cpu"`
3434
PIIMappingPath string `yaml:"pii_mapping_path"`
3535
} `yaml:"pii_model"`
36-
LoadAware bool `yaml:"load_aware"`
3736
} `yaml:"classifier"`
3837

3938
// Categories for routing queries

src/semantic-router/pkg/config/config_test.go

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -60,7 +60,6 @@ classifier:
6060
use_cpu: true
6161
use_modernbert: false
6262
pii_mapping_path: "/path/to/pii.json"
63-
load_aware: true
6463
6564
categories:
6665
- name: "general"
@@ -138,7 +137,6 @@ tools:
138137
// Verify classifier config
139138
Expect(cfg.Classifier.CategoryModel.ModelID).To(Equal("test-category-model"))
140139
Expect(cfg.Classifier.CategoryModel.UseModernBERT).To(BeTrue())
141-
Expect(cfg.Classifier.LoadAware).To(BeTrue())
142140

143141
// Verify categories
144142
Expect(cfg.Categories).To(HaveLen(1))

src/semantic-router/pkg/extproc/request_handler.go

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -370,9 +370,6 @@ func (r *OpenAIRouter) handleModelRouting(openAIRequest *openai.ChatCompletionNe
370370
effortForMetrics := r.getReasoningEffort(categoryName)
371371
metrics.RecordReasoningDecision(categoryName, matchedModel, useReasoning, effortForMetrics)
372372

373-
// Track the model load for the selected model
374-
r.Classifier.IncrementModelLoad(matchedModel)
375-
376373
// Track the model routing change
377374
metrics.RecordModelRouting(originalModel, matchedModel)
378375

src/semantic-router/pkg/extproc/response_handler.go

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -52,7 +52,6 @@ func (r *OpenAIRouter) handleResponseBody(v *ext_proc.ProcessingRequest_Response
5252
float64(completionTokens),
5353
)
5454
metrics.RecordModelCompletionLatency(ctx.RequestModel, completionLatency.Seconds())
55-
r.Classifier.DecrementModelLoad(ctx.RequestModel)
5655

5756
// Compute and record cost if pricing is configured
5857
if r.Config != nil {

src/semantic-router/pkg/extproc/router.go

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -131,8 +131,7 @@ func NewOpenAIRouter(configPath string) (*OpenAIRouter, error) {
131131

132132
// Create utility components
133133
piiChecker := pii.NewPolicyChecker(cfg, cfg.ModelConfig)
134-
modelTTFT := make(map[string]float64) // Empty TTFT map since load balancing is disabled
135-
classifier := classification.NewClassifier(cfg, categoryMapping, piiMapping, jailbreakMapping, modelTTFT)
134+
classifier := classification.NewClassifier(cfg, categoryMapping, piiMapping, jailbreakMapping)
136135

137136
// Create global classification service for API access
138137
services.NewClassificationService(classifier, cfg)

src/semantic-router/pkg/extproc/security_test.go

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -52,7 +52,7 @@ var _ = Describe("Security Checks", func() {
5252
},
5353
}
5454
router.PIIChecker = pii.NewPolicyChecker(cfg, cfg.ModelConfig)
55-
router.Classifier = classification.NewClassifier(cfg, router.Classifier.CategoryMapping, router.Classifier.PIIMapping, nil, router.Classifier.ModelTTFT)
55+
router.Classifier = classification.NewClassifier(cfg, router.Classifier.CategoryMapping, router.Classifier.PIIMapping, nil)
5656
})
5757

5858
It("should allow requests with no PII", func() {
@@ -97,7 +97,7 @@ var _ = Describe("Security Checks", func() {
9797
piiMapping, err := classification.LoadPIIMapping(cfg.Classifier.PIIModel.PIIMappingPath)
9898
Expect(err).NotTo(HaveOccurred())
9999

100-
router.Classifier = classification.NewClassifier(cfg, router.Classifier.CategoryMapping, piiMapping, nil, router.Classifier.ModelTTFT)
100+
router.Classifier = classification.NewClassifier(cfg, router.Classifier.CategoryMapping, piiMapping, nil)
101101
})
102102

103103
Describe("ClassifyPII method", func() {
@@ -339,7 +339,7 @@ var _ = Describe("Security Checks", func() {
339339
piiMapping, err := classification.LoadPIIMapping(cfg.Classifier.PIIModel.PIIMappingPath)
340340
Expect(err).NotTo(HaveOccurred())
341341

342-
router.Classifier = classification.NewClassifier(cfg, router.Classifier.CategoryMapping, piiMapping, nil, router.Classifier.ModelTTFT)
342+
router.Classifier = classification.NewClassifier(cfg, router.Classifier.CategoryMapping, piiMapping, nil)
343343
})
344344

345345
Describe("Error handling and edge cases", func() {
@@ -524,7 +524,7 @@ var _ = Describe("Security Checks", func() {
524524
IdxToLabel: map[string]string{"0": "benign", "1": "jailbreak"},
525525
}
526526

527-
router.Classifier = classification.NewClassifier(cfg, router.Classifier.CategoryMapping, router.Classifier.PIIMapping, jailbreakMapping, router.Classifier.ModelTTFT)
527+
router.Classifier = classification.NewClassifier(cfg, router.Classifier.CategoryMapping, router.Classifier.PIIMapping, jailbreakMapping)
528528
})
529529

530530
It("should process potential jailbreak attempts", func() {

src/semantic-router/pkg/extproc/test_utils_test.go

Lines changed: 1 addition & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -95,7 +95,6 @@ func CreateTestConfig() *config.RouterConfig {
9595
UseCPU bool `yaml:"use_cpu"`
9696
PIIMappingPath string `yaml:"pii_mapping_path"`
9797
} `yaml:"pii_model"`
98-
LoadAware bool `yaml:"load_aware"`
9998
}{
10099
CategoryModel: struct {
101100
ModelID string `yaml:"model_id"`
@@ -119,7 +118,6 @@ func CreateTestConfig() *config.RouterConfig {
119118
UseCPU: true,
120119
PIIMappingPath: "../../../../models/pii_classifier_modernbert-base_presidio_token_model/pii_type_mapping.json",
121120
},
122-
LoadAware: true,
123121
},
124122
Categories: []config.Category{
125123
{
@@ -220,11 +218,7 @@ func CreateTestRouter(cfg *config.RouterConfig) (*extproc.OpenAIRouter, error) {
220218
toolsDatabase := tools.NewToolsDatabase(toolsOptions)
221219

222220
// Create classifier
223-
modelTTFT := map[string]float64{
224-
"model-a": 2.5,
225-
"model-b": 1.8,
226-
}
227-
classifier := classification.NewClassifier(cfg, categoryMapping, piiMapping, nil, modelTTFT)
221+
classifier := classification.NewClassifier(cfg, categoryMapping, piiMapping, nil)
228222

229223
// Create PII checker
230224
piiChecker := pii.NewPolicyChecker(cfg, cfg.ModelConfig)

src/semantic-router/pkg/utils/classification/classifier.go

Lines changed: 2 additions & 51 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,6 @@ import (
55
"log"
66
"slices"
77
"strings"
8-
"sync"
98
"time"
109

1110
candle_binding "github.com/vllm-project/semantic-router/candle-binding"
@@ -148,16 +147,12 @@ type Classifier struct {
148147
CategoryMapping *CategoryMapping
149148
PIIMapping *PIIMapping
150149
JailbreakMapping *JailbreakMapping
151-
// Model selection fields
152-
ModelLoad map[string]int
153-
ModelLoadLock sync.Mutex
154-
ModelTTFT map[string]float64
155150
// Jailbreak detection state
156151
JailbreakInitialized bool
157152
}
158153

159154
// NewClassifier creates a new classifier with model selection and jailbreak detection capabilities
160-
func NewClassifier(cfg *config.RouterConfig, categoryMapping *CategoryMapping, piiMapping *PIIMapping, jailbreakMapping *JailbreakMapping, modelTTFT map[string]float64) *Classifier {
155+
func NewClassifier(cfg *config.RouterConfig, categoryMapping *CategoryMapping, piiMapping *PIIMapping, jailbreakMapping *JailbreakMapping) *Classifier {
161156
return &Classifier{
162157
categoryInference: createCategoryInference(cfg.Classifier.CategoryModel.UseModernBERT),
163158
jailbreakInitializer: createJailbreakInitializer(cfg.PromptGuard.UseModernBERT),
@@ -168,8 +163,6 @@ func NewClassifier(cfg *config.RouterConfig, categoryMapping *CategoryMapping, p
168163
CategoryMapping: categoryMapping,
169164
PIIMapping: piiMapping,
170165
JailbreakMapping: jailbreakMapping,
171-
ModelLoad: make(map[string]int),
172-
ModelTTFT: modelTTFT,
173166
JailbreakInitialized: false,
174167
}
175168
}
@@ -475,9 +468,6 @@ func (c *Classifier) SelectBestModelForCategory(categoryName string) string {
475468
return c.Config.DefaultModel
476469
}
477470

478-
c.ModelLoadLock.Lock()
479-
defer c.ModelLoadLock.Unlock()
480-
481471
bestModel, bestScore, bestQuality := c.selectBestModelInternal(cat, nil)
482472

483473
if bestModel == "" {
@@ -500,25 +490,6 @@ func (c *Classifier) findCategory(categoryName string) *config.Category {
500490
return nil
501491
}
502492

503-
// calculateModelScore calculates the combined score and quality for a model
504-
func (c *Classifier) calculateModelScore(modelScore config.ModelScore) (float64, float64) {
505-
quality := modelScore.Score
506-
model := modelScore.Model
507-
508-
if !c.Config.Classifier.LoadAware {
509-
return quality, quality
510-
}
511-
512-
baseTTFT := c.ModelTTFT[model]
513-
load := c.ModelLoad[model]
514-
estTTFT := baseTTFT * (1 + float64(load))
515-
if estTTFT == 0 {
516-
estTTFT = 1 // avoid div by zero
517-
}
518-
score := quality / estTTFT
519-
return score, quality
520-
}
521-
522493
// selectBestModelInternal performs the core model selection logic
523494
//
524495
// modelFilter is optional - if provided, only models passing the filter will be considered
@@ -532,8 +503,7 @@ func (c *Classifier) selectBestModelInternal(cat *config.Category, modelFilter f
532503
if modelFilter != nil && !modelFilter(model) {
533504
return
534505
}
535-
score, quality := c.calculateModelScore(modelScore)
536-
c.updateBestModel(score, quality, model, &bestScore, &bestQuality, &bestModel)
506+
c.updateBestModel(modelScore.Score, modelScore.Score, model, &bestScore, &bestQuality, &bestModel)
537507
})
538508

539509
return bestModel, bestScore, bestQuality
@@ -558,9 +528,6 @@ func (c *Classifier) SelectBestModelFromList(candidateModels []string, categoryN
558528
return candidateModels[0]
559529
}
560530

561-
c.ModelLoadLock.Lock()
562-
defer c.ModelLoadLock.Unlock()
563-
564531
bestModel, bestScore, bestQuality := c.selectBestModelInternal(cat,
565532
func(model string) bool {
566533
return slices.Contains(candidateModels, model)
@@ -592,22 +559,6 @@ func (c *Classifier) GetModelsForCategory(categoryName string) []string {
592559
return models
593560
}
594561

595-
// IncrementModelLoad increments the load counter for a model
596-
func (c *Classifier) IncrementModelLoad(model string) {
597-
c.ModelLoadLock.Lock()
598-
defer c.ModelLoadLock.Unlock()
599-
c.ModelLoad[model]++
600-
}
601-
602-
// DecrementModelLoad decrements the load counter for a model
603-
func (c *Classifier) DecrementModelLoad(model string) {
604-
c.ModelLoadLock.Lock()
605-
defer c.ModelLoadLock.Unlock()
606-
if c.ModelLoad[model] > 0 {
607-
c.ModelLoad[model]--
608-
}
609-
}
610-
611562
// updateBestModel updates the best model, score, and quality if the new score is better.
612563
func (c *Classifier) updateBestModel(score, quality float64, model string, bestScore *float64, bestQuality *float64, bestModel *string) {
613564
if score > *bestScore {

0 commit comments

Comments
 (0)