Skip to content

Commit d78e0e7

Browse files
committed
refactor: remove quality
Signed-off-by: Alex Wang <[email protected]>
1 parent 76d48e7 commit d78e0e7

File tree

2 files changed

+19
-28
lines changed

2 files changed

+19
-28
lines changed

src/semantic-router/pkg/utils/classification/classifier.go

Lines changed: 9 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -468,15 +468,14 @@ func (c *Classifier) SelectBestModelForCategory(categoryName string) string {
468468
return c.Config.DefaultModel
469469
}
470470

471-
bestModel, bestScore, bestQuality := c.selectBestModelInternal(cat, nil)
471+
bestModel, bestScore := c.selectBestModelInternal(cat, nil)
472472

473473
if bestModel == "" {
474474
log.Printf("No models found for category %s, using default model", categoryName)
475475
return c.Config.DefaultModel
476476
}
477477

478-
log.Printf("Selected model %s for category %s with quality %.4f and combined score %.4e",
479-
bestModel, categoryName, bestQuality, bestScore)
478+
log.Printf("Selected model %s for category %s with score %.4f", bestModel, categoryName, bestScore)
480479
return bestModel
481480
}
482481

@@ -493,20 +492,19 @@ func (c *Classifier) findCategory(categoryName string) *config.Category {
493492
// selectBestModelInternal performs the core model selection logic
494493
//
495494
// modelFilter is optional - if provided, only models passing the filter will be considered
496-
func (c *Classifier) selectBestModelInternal(cat *config.Category, modelFilter func(string) bool) (string, float64, float64) {
495+
func (c *Classifier) selectBestModelInternal(cat *config.Category, modelFilter func(string) bool) (string, float64) {
497496
bestModel := ""
498497
bestScore := -1.0
499-
bestQuality := 0.0
500498

501499
c.forEachModelScore(cat, func(modelScore config.ModelScore) {
502500
model := modelScore.Model
503501
if modelFilter != nil && !modelFilter(model) {
504502
return
505503
}
506-
c.updateBestModel(modelScore.Score, modelScore.Score, model, &bestScore, &bestQuality, &bestModel)
504+
c.updateBestModel(modelScore.Score, model, &bestScore, &bestModel)
507505
})
508506

509-
return bestModel, bestScore, bestQuality
507+
return bestModel, bestScore
510508
}
511509

512510
// forEachModelScore traverses the ModelScores document of the category and executes the callback for each element.
@@ -528,7 +526,7 @@ func (c *Classifier) SelectBestModelFromList(candidateModels []string, categoryN
528526
return candidateModels[0]
529527
}
530528

531-
bestModel, bestScore, bestQuality := c.selectBestModelInternal(cat,
529+
bestModel, bestScore := c.selectBestModelInternal(cat,
532530
func(model string) bool {
533531
return slices.Contains(candidateModels, model)
534532
})
@@ -538,8 +536,7 @@ func (c *Classifier) SelectBestModelFromList(candidateModels []string, categoryN
538536
return candidateModels[0]
539537
}
540538

541-
log.Printf("Selected best model %s for category %s with quality %.4f and combined score %.4e",
542-
bestModel, categoryName, bestQuality, bestScore)
539+
log.Printf("Selected best model %s for category %s with score %.4f", bestModel, categoryName, bestScore)
543540
return bestModel
544541
}
545542

@@ -559,11 +556,10 @@ func (c *Classifier) GetModelsForCategory(categoryName string) []string {
559556
return models
560557
}
561558

562-
// updateBestModel updates the best model, score, and quality if the new score is better.
563-
func (c *Classifier) updateBestModel(score, quality float64, model string, bestScore *float64, bestQuality *float64, bestModel *string) {
559+
// updateBestModel updates the best model, score if the new score is better.
560+
func (c *Classifier) updateBestModel(score float64, model string, bestScore *float64, bestModel *string) {
564561
if score > *bestScore {
565562
*bestScore = score
566563
*bestModel = model
567-
*bestQuality = quality
568564
}
569565
}

src/semantic-router/pkg/utils/classification/classifier_test.go

Lines changed: 10 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -264,10 +264,9 @@ var _ = Describe("category classification and model selection", func() {
264264
},
265265
}
266266

267-
bestModel, score, quality := classifier.selectBestModelInternal(cat, nil)
267+
bestModel, score := classifier.selectBestModelInternal(cat, nil)
268268

269269
Expect(bestModel).To(Equal("model-a"))
270-
Expect(quality).To(BeNumerically("~", 0.9, 0.001))
271270
Expect(score).To(BeNumerically("~", 0.9, 0.001))
272271
})
273272

@@ -284,10 +283,9 @@ var _ = Describe("category classification and model selection", func() {
284283
return model == "model-b" || model == "model-c"
285284
}
286285

287-
bestModel, score, quality := classifier.selectBestModelInternal(cat, filter)
286+
bestModel, score := classifier.selectBestModelInternal(cat, filter)
288287

289288
Expect(bestModel).To(Equal("model-b"))
290-
Expect(quality).To(BeNumerically("~", 0.8, 0.001))
291289
Expect(score).To(BeNumerically("~", 0.8, 0.001))
292290
})
293291

@@ -303,11 +301,10 @@ var _ = Describe("category classification and model selection", func() {
303301
return model == "non-existent-model"
304302
}
305303

306-
bestModel, score, quality := classifier.selectBestModelInternal(cat, filter)
304+
bestModel, score := classifier.selectBestModelInternal(cat, filter)
307305

308306
Expect(bestModel).To(Equal(""))
309307
Expect(score).To(BeNumerically("~", -1.0, 0.001))
310-
Expect(quality).To(BeNumerically("~", 0.0, 0.001))
311308
})
312309

313310
It("should return empty when category has no models", func() {
@@ -316,11 +313,10 @@ var _ = Describe("category classification and model selection", func() {
316313
ModelScores: []config.ModelScore{},
317314
}
318315

319-
bestModel, score, quality := classifier.selectBestModelInternal(cat, nil)
316+
bestModel, score := classifier.selectBestModelInternal(cat, nil)
320317

321318
Expect(bestModel).To(Equal(""))
322319
Expect(score).To(BeNumerically("~", -1.0, 0.001))
323-
Expect(quality).To(BeNumerically("~", 0.0, 0.001))
324320
})
325321
})
326322
})
@@ -869,17 +865,16 @@ func TestUpdateBestModel(t *testing.T) {
869865
classifier := &Classifier{}
870866

871867
bestScore := 0.5
872-
bestQuality := 0.5
873868
bestModel := "old-model"
874869

875-
classifier.updateBestModel(0.8, 0.9, "new-model", &bestScore, &bestQuality, &bestModel)
876-
if bestScore != 0.8 || bestQuality != 0.9 || bestModel != "new-model" {
877-
t.Errorf("update: got bestScore=%v, bestQuality=%v, bestModel=%v", bestScore, bestQuality, bestModel)
870+
classifier.updateBestModel(0.8, "new-model", &bestScore, &bestModel)
871+
if bestScore != 0.8 || bestModel != "new-model" {
872+
t.Errorf("update: got bestScore=%v, bestModel=%v", bestScore, bestModel)
878873
}
879874

880-
classifier.updateBestModel(0.7, 0.7, "another-model", &bestScore, &bestQuality, &bestModel)
881-
if bestScore != 0.8 || bestQuality != 0.9 || bestModel != "new-model" {
882-
t.Errorf("not update: got bestScore=%v, bestQuality=%v, bestModel=%v", bestScore, bestQuality, bestModel)
875+
classifier.updateBestModel(0.7, "another-model", &bestScore, &bestModel)
876+
if bestScore != 0.8 || bestModel != "new-model" {
877+
t.Errorf("not update: got bestScore=%v, bestModel=%v", bestScore, bestModel)
883878
}
884879
}
885880

0 commit comments

Comments
 (0)