Skip to content

Commit a160bfa

Browse files
authored
perf: optimize FindSimilarTools by early pruning (#248)
Signed-off-by: cryo <[email protected]>
1 parent e1ec60c commit a160bfa

File tree

1 file changed

+14
-12
lines changed

1 file changed

+14
-12
lines changed

src/semantic-router/pkg/tools/tools.go

Lines changed: 14 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -143,14 +143,17 @@ func (db *ToolsDatabase) FindSimilarTools(query string, topK int) ([]openai.Chat
143143
dotProduct += queryEmbedding[i] * entry.Embedding[i]
144144
}
145145

146-
results = append(results, SimilarityResult{
147-
Entry: entry,
148-
Similarity: dotProduct,
149-
})
150-
151146
// Debug logging to see similarity scores
152147
observability.Debugf("Tool '%s' similarity score: %.4f (threshold: %.4f)",
153148
entry.Tool.Function.Name, dotProduct, db.similarityThreshold)
149+
150+
// Only consider if above threshold
151+
if dotProduct >= db.similarityThreshold {
152+
results = append(results, SimilarityResult{
153+
Entry: entry,
154+
Similarity: dotProduct,
155+
})
156+
}
154157
}
155158

156159
// No results found
@@ -164,13 +167,12 @@ func (db *ToolsDatabase) FindSimilarTools(query string, topK int) ([]openai.Chat
164167
})
165168

166169
// Select top-k tools that meet the threshold
167-
var selectedTools []openai.ChatCompletionToolParam
168-
for i := 0; i < len(results) && i < topK; i++ {
169-
if results[i].Similarity >= db.similarityThreshold {
170-
selectedTools = append(selectedTools, results[i].Entry.Tool)
171-
observability.Infof("Selected tool: %s (similarity=%.4f)",
172-
results[i].Entry.Tool.Function.Name, results[i].Similarity)
173-
}
170+
limit := min(topK, len(results))
171+
selectedTools := make([]openai.ChatCompletionToolParam, 0, limit)
172+
for i := range limit {
173+
selectedTools = append(selectedTools, results[i].Entry.Tool)
174+
observability.Infof("Selected tool: %s (similarity=%.4f)",
175+
results[i].Entry.Tool.Function.Name, results[i].Similarity)
174176
}
175177

176178
observability.Infof("Found %d similar tools for query: %s", len(selectedTools), query)

0 commit comments

Comments
 (0)