Skip to content

Commit 38ffba0

Browse files
committed
Refactor MilvusCache to use Milvus Search API
Replaces manual similarity calculation and query-based retrieval in FindSimilar with Milvus's Search API for more efficient and accurate similarity search. Updates index creation to use the new HNSW index API. Improves cache hit/miss logic and error handling. Signed-off-by: Srinivas A <[email protected]>
1 parent b7888be commit 38ffba0

File tree

1 file changed

+61
-94
lines changed

1 file changed

+61
-94
lines changed

src/semantic-router/pkg/cache/milvus_cache.go

Lines changed: 61 additions & 94 deletions
Original file line numberDiff line numberDiff line change
@@ -307,19 +307,11 @@ func (c *MilvusCache) createCollection() error {
307307
return err
308308
}
309309

310-
// Create index
311-
indexParams := map[string]string{
312-
"index_type": c.config.Collection.Index.Type,
313-
"metric_type": c.config.Collection.VectorField.MetricType,
314-
"params": fmt.Sprintf(`{"M": %d, "efConstruction": %d}`,
315-
c.config.Collection.Index.Params.M,
316-
c.config.Collection.Index.Params.EfConstruction),
317-
}
318-
319-
observability.Debugf("MilvusCache.createCollection: creating index for %d-dimensional vectors", actualDimension)
320-
321310
// Create index with updated API
322-
index := entity.NewGenericIndex(c.config.Collection.VectorField.Name, entity.IndexType(c.config.Collection.Index.Type), indexParams)
311+
index, err := entity.NewIndexHNSW(entity.MetricType(c.config.Collection.VectorField.MetricType), c.config.Collection.Index.Params.EfConstruction, c.config.Collection.Index.Params.M)
312+
if err != nil {
313+
return fmt.Errorf("failed to create HNSW index: %w", err)
314+
}
323315
if err := c.client.CreateIndex(ctx, c.collectionName, c.config.Collection.VectorField.Name, index, false); err != nil {
324316
return err
325317
}
@@ -517,112 +509,87 @@ func (c *MilvusCache) FindSimilar(model string, query string) ([]byte, bool, err
517509

518510
ctx := context.Background()
519511

520-
// Query for completed entries with the same model
521-
// Using Query approach for comprehensive similarity search
522-
queryExpr := fmt.Sprintf("model == \"%s\" && response_body != \"\"", model)
523-
observability.Debugf("MilvusCache.FindSimilar: querying with expr: %s (embedding_dim: %d)",
524-
queryExpr, len(queryEmbedding))
525-
526-
// Use Query to get all matching entries, then compute similarity manually
527-
results, err := c.client.Query(ctx, c.collectionName, []string{}, queryExpr,
528-
[]string{"query", "response_body", c.config.Collection.VectorField.Name})
512+
// Define search parameters
513+
searchParam, err := entity.NewIndexHNSWSearchParam(c.config.Search.Params.Ef)
514+
if err != nil {
515+
return nil, false, fmt.Errorf("failed to create search parameters: %w", err)
516+
}
517+
518+
// Use Milvus Search for efficient similarity search
519+
searchResult, err := c.client.Search(
520+
ctx,
521+
c.collectionName,
522+
[]string{},
523+
fmt.Sprintf("model == \"%s\" && response_body != \"\"", model),
524+
[]string{"response_body"},
525+
[]entity.Vector{entity.FloatVector(queryEmbedding)},
526+
c.config.Collection.VectorField.Name,
527+
entity.MetricType(c.config.Collection.VectorField.MetricType),
528+
c.config.Search.TopK,
529+
searchParam,
530+
)
529531

530532
if err != nil {
531-
observability.Debugf("MilvusCache.FindSimilar: query failed: %v", err)
533+
observability.Debugf("MilvusCache.FindSimilar: search failed: %v", err)
532534
atomic.AddInt64(&c.missCount, 1)
533535
metrics.RecordCacheOperation("milvus", "find_similar", "error", time.Since(start).Seconds())
534536
metrics.RecordCacheMiss()
535537
return nil, false, nil
536538
}
537539

538-
if len(results) == 0 {
540+
if len(searchResult) == 0 || searchResult[0].ResultCount == 0 {
539541
atomic.AddInt64(&c.missCount, 1)
540-
observability.Debugf("MilvusCache.FindSimilar: no entries found with responses")
542+
observability.Debugf("MilvusCache.FindSimilar: no entries found")
541543
metrics.RecordCacheOperation("milvus", "find_similar", "miss", time.Since(start).Seconds())
542544
metrics.RecordCacheMiss()
543545
return nil, false, nil
544546
}
545547

546-
// Calculate semantic similarity for each candidate
547-
bestSimilarity := float32(-1.0)
548-
var bestResponse string
549-
550-
// Find columns by type instead of assuming order
551-
var queryColumn *entity.ColumnVarChar
552-
var responseColumn *entity.ColumnVarChar
553-
var embeddingColumn *entity.ColumnFloatVector
548+
bestScore := searchResult[0].Scores[0]
549+
if bestScore < c.similarityThreshold {
550+
atomic.AddInt64(&c.missCount, 1)
551+
observability.Debugf("MilvusCache.FindSimilar: CACHE MISS - best_similarity=%.4f < threshold=%.4f",
552+
bestScore, c.similarityThreshold)
553+
observability.LogEvent("cache_miss", map[string]interface{}{
554+
"backend": "milvus",
555+
"best_similarity": bestScore,
556+
"threshold": c.similarityThreshold,
557+
"model": model,
558+
"collection": c.collectionName,
559+
})
560+
metrics.RecordCacheOperation("milvus", "find_similar", "miss", time.Since(start).Seconds())
561+
metrics.RecordCacheMiss()
562+
return nil, false, nil
563+
}
554564

555-
for _, col := range results {
556-
switch typedCol := col.(type) {
557-
case *entity.ColumnVarChar:
558-
if typedCol.Name() == "query" {
559-
queryColumn = typedCol
560-
} else if typedCol.Name() == "response_body" {
561-
responseColumn = typedCol
562-
}
563-
case *entity.ColumnFloatVector:
564-
if typedCol.Name() == c.config.Collection.VectorField.Name {
565-
embeddingColumn = typedCol
566-
}
567-
}
565+
// Cache Hit
566+
var responseBody []byte
567+
responseBodyColumn, ok := searchResult[0].Fields[0].(*entity.ColumnVarChar)
568+
if ok && responseBodyColumn.Len() > 0 {
569+
responseBody = []byte(responseBodyColumn.Data()[0])
568570
}
569571

570-
if queryColumn == nil || responseColumn == nil || embeddingColumn == nil {
571-
observability.Debugf("MilvusCache.FindSimilar: missing required columns in results")
572+
if responseBody == nil {
573+
observability.Debugf("MilvusCache.FindSimilar: cache hit but response_body is missing or not a string")
572574
atomic.AddInt64(&c.missCount, 1)
573575
metrics.RecordCacheOperation("milvus", "find_similar", "error", time.Since(start).Seconds())
574576
metrics.RecordCacheMiss()
575577
return nil, false, nil
576578
}
577579

578-
for i := 0; i < queryColumn.Len(); i++ {
579-
storedEmbedding := embeddingColumn.Data()[i]
580-
581-
// Calculate dot product similarity score
582-
var similarity float32
583-
for j := 0; j < len(queryEmbedding) && j < len(storedEmbedding); j++ {
584-
similarity += queryEmbedding[j] * storedEmbedding[j]
585-
}
586-
587-
if similarity > bestSimilarity {
588-
bestSimilarity = similarity
589-
bestResponse = responseColumn.Data()[i]
590-
}
591-
}
592-
593-
observability.Debugf("MilvusCache.FindSimilar: best similarity=%.4f, threshold=%.4f (checked %d entries)",
594-
bestSimilarity, c.similarityThreshold, queryColumn.Len())
595-
596-
if bestSimilarity >= c.similarityThreshold {
597-
atomic.AddInt64(&c.hitCount, 1)
598-
observability.Debugf("MilvusCache.FindSimilar: CACHE HIT - similarity=%.4f >= threshold=%.4f, response_size=%d bytes",
599-
bestSimilarity, c.similarityThreshold, len(bestResponse))
600-
observability.LogEvent("cache_hit", map[string]interface{}{
601-
"backend": "milvus",
602-
"similarity": bestSimilarity,
603-
"threshold": c.similarityThreshold,
604-
"model": model,
605-
"collection": c.collectionName,
606-
})
607-
metrics.RecordCacheOperation("milvus", "find_similar", "hit", time.Since(start).Seconds())
608-
metrics.RecordCacheHit()
609-
return []byte(bestResponse), true, nil
610-
}
611-
612-
atomic.AddInt64(&c.missCount, 1)
613-
observability.Debugf("MilvusCache.FindSimilar: CACHE MISS - best_similarity=%.4f < threshold=%.4f",
614-
bestSimilarity, c.similarityThreshold)
615-
observability.LogEvent("cache_miss", map[string]interface{}{
616-
"backend": "milvus",
617-
"best_similarity": bestSimilarity,
618-
"threshold": c.similarityThreshold,
619-
"model": model,
620-
"collection": c.collectionName,
621-
"entries_checked": queryColumn.Len(),
580+
atomic.AddInt64(&c.hitCount, 1)
581+
observability.Debugf("MilvusCache.FindSimilar: CACHE HIT - similarity=%.4f >= threshold=%.4f, response_size=%d bytes",
582+
bestScore, c.similarityThreshold, len(responseBody))
583+
observability.LogEvent("cache_hit", map[string]interface{}{
584+
"backend": "milvus",
585+
"similarity": bestScore,
586+
"threshold": c.similarityThreshold,
587+
"model": model,
588+
"collection": c.collectionName,
622589
})
623-
metrics.RecordCacheOperation("milvus", "find_similar", "miss", time.Since(start).Seconds())
624-
metrics.RecordCacheMiss()
625-
return nil, false, nil
590+
metrics.RecordCacheOperation("milvus", "find_similar", "hit", time.Since(start).Seconds())
591+
metrics.RecordCacheHit()
592+
return responseBody, true, nil
626593
}
627594

628595
// Close releases all resources held by the cache

0 commit comments

Comments
 (0)