55 "log"
66 "slices"
77 "strings"
8- "sync"
98 "time"
109
1110 candle_binding "github.com/vllm-project/semantic-router/candle-binding"
@@ -148,16 +147,12 @@ type Classifier struct {
148147 CategoryMapping * CategoryMapping
149148 PIIMapping * PIIMapping
150149 JailbreakMapping * JailbreakMapping
151- // Model selection fields
152- ModelLoad map [string ]int
153- ModelLoadLock sync.Mutex
154- ModelTTFT map [string ]float64
155150 // Jailbreak detection state
156151 JailbreakInitialized bool
157152}
158153
159154// NewClassifier creates a new classifier with model selection and jailbreak detection capabilities
160- func NewClassifier (cfg * config.RouterConfig , categoryMapping * CategoryMapping , piiMapping * PIIMapping , jailbreakMapping * JailbreakMapping , modelTTFT map [ string ] float64 ) * Classifier {
155+ func NewClassifier (cfg * config.RouterConfig , categoryMapping * CategoryMapping , piiMapping * PIIMapping , jailbreakMapping * JailbreakMapping ) * Classifier {
161156 return & Classifier {
162157 categoryInference : createCategoryInference (cfg .Classifier .CategoryModel .UseModernBERT ),
163158 jailbreakInitializer : createJailbreakInitializer (cfg .PromptGuard .UseModernBERT ),
@@ -168,8 +163,6 @@ func NewClassifier(cfg *config.RouterConfig, categoryMapping *CategoryMapping, p
168163 CategoryMapping : categoryMapping ,
169164 PIIMapping : piiMapping ,
170165 JailbreakMapping : jailbreakMapping ,
171- ModelLoad : make (map [string ]int ),
172- ModelTTFT : modelTTFT ,
173166 JailbreakInitialized : false ,
174167 }
175168}
@@ -475,9 +468,6 @@ func (c *Classifier) SelectBestModelForCategory(categoryName string) string {
475468 return c .Config .DefaultModel
476469 }
477470
478- c .ModelLoadLock .Lock ()
479- defer c .ModelLoadLock .Unlock ()
480-
481471 bestModel , bestScore , bestQuality := c .selectBestModelInternal (cat , nil )
482472
483473 if bestModel == "" {
@@ -500,25 +490,6 @@ func (c *Classifier) findCategory(categoryName string) *config.Category {
500490 return nil
501491}
502492
503- // calculateModelScore calculates the combined score and quality for a model
504- func (c * Classifier ) calculateModelScore (modelScore config.ModelScore ) (float64 , float64 ) {
505- quality := modelScore .Score
506- model := modelScore .Model
507-
508- if ! c .Config .Classifier .LoadAware {
509- return quality , quality
510- }
511-
512- baseTTFT := c .ModelTTFT [model ]
513- load := c .ModelLoad [model ]
514- estTTFT := baseTTFT * (1 + float64 (load ))
515- if estTTFT == 0 {
516- estTTFT = 1 // avoid div by zero
517- }
518- score := quality / estTTFT
519- return score , quality
520- }
521-
522493// selectBestModelInternal performs the core model selection logic
523494//
524495// modelFilter is optional - if provided, only models passing the filter will be considered
@@ -532,8 +503,7 @@ func (c *Classifier) selectBestModelInternal(cat *config.Category, modelFilter f
532503 if modelFilter != nil && ! modelFilter (model ) {
533504 return
534505 }
535- score , quality := c .calculateModelScore (modelScore )
536- c .updateBestModel (score , quality , model , & bestScore , & bestQuality , & bestModel )
506+ c .updateBestModel (modelScore .Score , modelScore .Score , model , & bestScore , & bestQuality , & bestModel )
537507 })
538508
539509 return bestModel , bestScore , bestQuality
@@ -558,9 +528,6 @@ func (c *Classifier) SelectBestModelFromList(candidateModels []string, categoryN
558528 return candidateModels [0 ]
559529 }
560530
561- c .ModelLoadLock .Lock ()
562- defer c .ModelLoadLock .Unlock ()
563-
564531 bestModel , bestScore , bestQuality := c .selectBestModelInternal (cat ,
565532 func (model string ) bool {
566533 return slices .Contains (candidateModels , model )
@@ -592,22 +559,6 @@ func (c *Classifier) GetModelsForCategory(categoryName string) []string {
592559 return models
593560}
594561
595- // IncrementModelLoad increments the load counter for a model
596- func (c * Classifier ) IncrementModelLoad (model string ) {
597- c .ModelLoadLock .Lock ()
598- defer c .ModelLoadLock .Unlock ()
599- c .ModelLoad [model ]++
600- }
601-
602- // DecrementModelLoad decrements the load counter for a model
603- func (c * Classifier ) DecrementModelLoad (model string ) {
604- c .ModelLoadLock .Lock ()
605- defer c .ModelLoadLock .Unlock ()
606- if c .ModelLoad [model ] > 0 {
607- c .ModelLoad [model ]--
608- }
609- }
610-
611562// updateBestModel updates the best model, score, and quality if the new score is better.
612563func (c * Classifier ) updateBestModel (score , quality float64 , model string , bestScore * float64 , bestQuality * float64 , bestModel * string ) {
613564 if score > * bestScore {
0 commit comments