@@ -3,6 +3,7 @@ package classification
33import (
44 "fmt"
55 "log"
6+ "slices"
67 "strings"
78 "sync"
89 "time"
@@ -466,35 +467,24 @@ func (c *Classifier) SelectBestModelForCategory(categoryName string) string {
466467 bestQuality := 0.0
467468
468469 if c .Config .Classifier .LoadAware {
469- // Load-aware: combine accuracy and TTFT
470- for _ , modelScore := range cat .ModelScores {
470+ c .forEachModelScore (cat , func (modelScore config.ModelScore ) {
471471 quality := modelScore .Score
472472 model := modelScore .Model
473-
474473 baseTTFT := c .ModelTTFT [model ]
475474 load := c .ModelLoad [model ]
476475 estTTFT := baseTTFT * (1 + float64 (load ))
477476 if estTTFT == 0 {
478- estTTFT = 1 // avoid div by zero
477+ estTTFT = 1
479478 }
480479 score := quality / estTTFT
481- if score > bestScore {
482- bestScore = score
483- bestModel = model
484- bestQuality = quality
485- }
486- }
480+ c .updateBestModel (score , quality , model , & bestScore , & bestQuality , & bestModel )
481+ })
487482 } else {
488- // Not load-aware: pick the model with the highest accuracy only
489- for _ , modelScore := range cat .ModelScores {
483+ c .forEachModelScore (cat , func (modelScore config.ModelScore ) {
490484 quality := modelScore .Score
491485 model := modelScore .Model
492- if quality > bestScore {
493- bestScore = quality
494- bestModel = model
495- bestQuality = quality
496- }
497- }
486+ c .updateBestModel (quality , quality , model , & bestScore , & bestQuality , & bestModel )
487+ })
498488 }
499489
500490 if bestModel == "" {
@@ -507,6 +497,13 @@ func (c *Classifier) SelectBestModelForCategory(categoryName string) string {
507497 return bestModel
508498}
509499
500+ // forEachModelScore 遍历 category 的 ModelScores 并对每个元素执行回调
501+ func (c * Classifier ) forEachModelScore (cat * config.Category , fn func (modelScore config.ModelScore )) {
502+ for _ , modelScore := range cat .ModelScores {
503+ fn (modelScore )
504+ }
505+ }
506+
510507// SelectBestModelFromList selects the best model from a list of candidate models for a given category
511508func (c * Classifier ) SelectBestModelFromList (candidateModels []string , categoryName string ) string {
512509 if len (candidateModels ) == 0 {
@@ -534,49 +531,28 @@ func (c *Classifier) SelectBestModelFromList(candidateModels []string, categoryN
534531 bestScore := - 1.0
535532 bestQuality := 0.0
536533
537- if c .Config .Classifier .LoadAware {
538- // Load-aware: combine accuracy and TTFT
539- for _ , modelScore := range cat .ModelScores {
540- model := modelScore .Model
541-
542- // Check if this model is in the candidate list
543- if ! c .contains (candidateModels , model ) {
544- continue
545- }
546-
547- quality := modelScore .Score
534+ filteredFn := func (modelScore config.ModelScore ) {
535+ model := modelScore .Model
536+ if ! slices .Contains (candidateModels , model ) {
537+ return
538+ }
539+ quality := modelScore .Score
540+ if c .Config .Classifier .LoadAware {
548541 baseTTFT := c .ModelTTFT [model ]
549542 load := c .ModelLoad [model ]
550543 estTTFT := baseTTFT * (1 + float64 (load ))
551544 if estTTFT == 0 {
552545 estTTFT = 1 // avoid div by zero
553546 }
554547 score := quality / estTTFT
555- if score > bestScore {
556- bestScore = score
557- bestModel = model
558- bestQuality = quality
559- }
560- }
561- } else {
562- // Not load-aware: pick the model with the highest accuracy only
563- for _ , modelScore := range cat .ModelScores {
564- model := modelScore .Model
565-
566- // Check if this model is in the candidate list
567- if ! c .contains (candidateModels , model ) {
568- continue
569- }
570-
571- quality := modelScore .Score
572- if quality > bestScore {
573- bestScore = quality
574- bestModel = model
575- bestQuality = quality
576- }
548+ c .updateBestModel (score , quality , model , & bestScore , & bestQuality , & bestModel )
549+ } else {
550+ c .updateBestModel (quality , quality , model , & bestScore , & bestQuality , & bestModel )
577551 }
578552 }
579553
554+ c .forEachModelScore (cat , filteredFn )
555+
580556 if bestModel == "" {
581557 log .Printf ("No suitable model found from candidates for category %s, using first candidate" , categoryName )
582558 return candidateModels [0 ]
@@ -619,12 +595,11 @@ func (c *Classifier) DecrementModelLoad(model string) {
619595 }
620596}
621597
622- // contains checks if a slice contains a string
623- func (c * Classifier ) contains ( slice [] string , item string ) bool {
624- for _ , s := range slice {
625- if s == item {
626- return true
627- }
598+ // updateBestModel updates the best model, score, and quality if the new score is better.
599+ func (c * Classifier ) updateBestModel ( score , quality float64 , model string , bestScore * float64 , bestQuality * float64 , bestModel * string ) {
600+ if score > * bestScore {
601+ * bestScore = score
602+ * bestModel = model
603+ * bestQuality = quality
628604 }
629- return false
630605}
0 commit comments