@@ -446,14 +446,7 @@ func (c *Classifier) ClassifyAndSelectBestModel(query string) string {
446446
447447// SelectBestModelForCategory selects the best model from a category based on score and TTFT
448448func (c * Classifier ) SelectBestModelForCategory (categoryName string ) string {
449- var cat * config.Category
450- for i , category := range c .Config .Categories {
451- if strings .EqualFold (category .Name , categoryName ) {
452- cat = & c .Config .Categories [i ]
453- break
454- }
455- }
456-
449+ cat := c .findCategory (categoryName )
457450 if cat == nil {
458451 log .Printf ("Could not find matching category %s in config, using default model" , categoryName )
459452 return c .Config .DefaultModel
@@ -462,30 +455,7 @@ func (c *Classifier) SelectBestModelForCategory(categoryName string) string {
462455 c .ModelLoadLock .Lock ()
463456 defer c .ModelLoadLock .Unlock ()
464457
465- bestModel := ""
466- bestScore := - 1.0
467- bestQuality := 0.0
468-
469- if c .Config .Classifier .LoadAware {
470- c .forEachModelScore (cat , func (modelScore config.ModelScore ) {
471- quality := modelScore .Score
472- model := modelScore .Model
473- baseTTFT := c .ModelTTFT [model ]
474- load := c .ModelLoad [model ]
475- estTTFT := baseTTFT * (1 + float64 (load ))
476- if estTTFT == 0 {
477- estTTFT = 1
478- }
479- score := quality / estTTFT
480- c .updateBestModel (score , quality , model , & bestScore , & bestQuality , & bestModel )
481- })
482- } else {
483- c .forEachModelScore (cat , func (modelScore config.ModelScore ) {
484- quality := modelScore .Score
485- model := modelScore .Model
486- c .updateBestModel (quality , quality , model , & bestScore , & bestQuality , & bestModel )
487- })
488- }
458+ bestModel , bestScore , bestQuality := c .selectBestModelInternal (cat , nil )
489459
490460 if bestModel == "" {
491461 log .Printf ("No models found for category %s, using default model" , categoryName )
@@ -497,6 +467,55 @@ func (c *Classifier) SelectBestModelForCategory(categoryName string) string {
497467 return bestModel
498468}
499469
470+ // findCategory finds the category configuration by name (case-insensitive)
471+ func (c * Classifier ) findCategory (categoryName string ) * config.Category {
472+ for i , category := range c .Config .Categories {
473+ if strings .EqualFold (category .Name , categoryName ) {
474+ return & c .Config .Categories [i ]
475+ }
476+ }
477+ return nil
478+ }
479+
480+ // calculateModelScore calculates the combined score and quality for a model
481+ func (c * Classifier ) calculateModelScore (modelScore config.ModelScore ) (float64 , float64 ) {
482+ quality := modelScore .Score
483+ model := modelScore .Model
484+
485+ if ! c .Config .Classifier .LoadAware {
486+ return quality , quality
487+ }
488+
489+ baseTTFT := c .ModelTTFT [model ]
490+ load := c .ModelLoad [model ]
491+ estTTFT := baseTTFT * (1 + float64 (load ))
492+ if estTTFT == 0 {
493+ estTTFT = 1 // avoid div by zero
494+ }
495+ score := quality / estTTFT
496+ return score , quality
497+ }
498+
499+ // selectBestModelInternal performs the core model selection logic
500+ //
501+ // modelFilter is optional - if provided, only models passing the filter will be considered
502+ func (c * Classifier ) selectBestModelInternal (cat * config.Category , modelFilter func (string ) bool ) (string , float64 , float64 ) {
503+ bestModel := ""
504+ bestScore := - 1.0
505+ bestQuality := 0.0
506+
507+ c .forEachModelScore (cat , func (modelScore config.ModelScore ) {
508+ model := modelScore .Model
509+ if modelFilter != nil && ! modelFilter (model ) {
510+ return
511+ }
512+ score , quality := c .calculateModelScore (modelScore )
513+ c .updateBestModel (score , quality , model , & bestScore , & bestQuality , & bestModel )
514+ })
515+
516+ return bestModel , bestScore , bestQuality
517+ }
518+
500519// forEachModelScore traverses the ModelScores document of the category and executes the callback for each element.
501520func (c * Classifier ) forEachModelScore (cat * config.Category , fn func (modelScore config.ModelScore )) {
502521 for _ , modelScore := range cat .ModelScores {
@@ -510,15 +529,7 @@ func (c *Classifier) SelectBestModelFromList(candidateModels []string, categoryN
510529 return c .Config .DefaultModel
511530 }
512531
513- // Find the category configuration
514- var cat * config.Category
515- for i , category := range c .Config .Categories {
516- if strings .EqualFold (category .Name , categoryName ) {
517- cat = & c .Config .Categories [i ]
518- break
519- }
520- }
521-
532+ cat := c .findCategory (categoryName )
522533 if cat == nil {
523534 // Return first candidate if category not found
524535 return candidateModels [0 ]
@@ -527,31 +538,10 @@ func (c *Classifier) SelectBestModelFromList(candidateModels []string, categoryN
527538 c .ModelLoadLock .Lock ()
528539 defer c .ModelLoadLock .Unlock ()
529540
530- bestModel := ""
531- bestScore := - 1.0
532- bestQuality := 0.0
533-
534- filteredFn := func (modelScore config.ModelScore ) {
535- model := modelScore .Model
536- if ! slices .Contains (candidateModels , model ) {
537- return
538- }
539- quality := modelScore .Score
540- if c .Config .Classifier .LoadAware {
541- baseTTFT := c .ModelTTFT [model ]
542- load := c .ModelLoad [model ]
543- estTTFT := baseTTFT * (1 + float64 (load ))
544- if estTTFT == 0 {
545- estTTFT = 1 // avoid div by zero
546- }
547- score := quality / estTTFT
548- c .updateBestModel (score , quality , model , & bestScore , & bestQuality , & bestModel )
549- } else {
550- c .updateBestModel (quality , quality , model , & bestScore , & bestQuality , & bestModel )
551- }
552- }
553-
554- c .forEachModelScore (cat , filteredFn )
541+ bestModel , bestScore , bestQuality := c .selectBestModelInternal (cat ,
542+ func (model string ) bool {
543+ return slices .Contains (candidateModels , model )
544+ })
555545
556546 if bestModel == "" {
557547 log .Printf ("No suitable model found from candidates for category %s, using first candidate" , categoryName )
0 commit comments