Skip to content

Commit 2d0338d

Browse files
authored
fix: correct HNSW frontier comparisons in hybrid cache (#587)
Signed-off-by: cryo <[email protected]>
1 parent 1d4986a commit 2d0338d

File tree

2 files changed

+95
-6
lines changed

2 files changed

+95
-6
lines changed

src/semantic-router/pkg/cache/cache_test.go

Lines changed: 89 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1514,6 +1514,95 @@ milvus:
15141514
}
15151515
}
15161516

1517+
// Ensures hybrid layer search skips candidates that are already worse than the frontier.
1518+
func TestHybridCacheSearchLayerPrunesWeakerBranch(t *testing.T) {
1519+
// Regression fixture: the buggy comparison let the frontier accept a much
1520+
// worse neighbor (node 3) even after ef was saturated. That re-opened the
1521+
// branch to node 4, so the search would walk every reachable node—hurting
1522+
// latency and risking a worse match. We wire an artificial edge (3→4) to
1523+
// isolate the pruning logic; production HNSW builders try to avoid such links.
1524+
embeddings := [][]float32{
1525+
{0.80}, // node 0: entry point
1526+
{0.79}, // node 1: near-tie neighbor
1527+
{0.78}, // node 2: another strong neighbor
1528+
{0.10}, // node 3: weak branch that should be pruned
1529+
{0.995}, // node 4: hidden best reachable only via node 3
1530+
}
1531+
1532+
nodes := []*HNSWNode{
1533+
{
1534+
entryIndex: 0,
1535+
neighbors: map[int][]int{
1536+
0: {1, 2, 3},
1537+
},
1538+
maxLayer: 0,
1539+
},
1540+
{
1541+
entryIndex: 1,
1542+
neighbors: map[int][]int{
1543+
0: {0},
1544+
},
1545+
maxLayer: 0,
1546+
},
1547+
{
1548+
entryIndex: 2,
1549+
neighbors: map[int][]int{
1550+
0: {0},
1551+
},
1552+
maxLayer: 0,
1553+
},
1554+
{
1555+
entryIndex: 3,
1556+
neighbors: map[int][]int{
1557+
0: {0, 4},
1558+
},
1559+
maxLayer: 0,
1560+
},
1561+
{
1562+
entryIndex: 4,
1563+
neighbors: map[int][]int{
1564+
0: {3},
1565+
},
1566+
maxLayer: 0,
1567+
},
1568+
}
1569+
1570+
nodeIndex := map[int]*HNSWNode{
1571+
0: nodes[0],
1572+
1: nodes[1],
1573+
2: nodes[2],
1574+
3: nodes[3],
1575+
4: nodes[4],
1576+
}
1577+
1578+
cache := &HybridCache{
1579+
hnswIndex: &HNSWIndex{
1580+
nodes: nodes,
1581+
nodeIndex: nodeIndex,
1582+
entryPoint: 0,
1583+
maxLayer: 0,
1584+
efConstruction: 4,
1585+
M: 4,
1586+
Mmax: 4,
1587+
Mmax0: 4,
1588+
ml: 1,
1589+
},
1590+
embeddings: embeddings,
1591+
idMap: map[int]string{},
1592+
}
1593+
1594+
results := cache.searchLayerHybrid([]float32{1}, 3, 0, []int{0})
1595+
if len(results) != 3 {
1596+
t.Fatalf("expected frontier to keep three best neighbors, got %v", results)
1597+
}
1598+
if slices.Contains(results, 4) {
1599+
t.Fatalf("expected weaker branch to stay pruned, got %v", results)
1600+
}
1601+
if !slices.Contains(results, 1) {
1602+
t.Fatalf("expected best neighbor 1 to remain in results, got %v", results)
1603+
}
1604+
}
1605+
15171606
// BenchmarkHybridCacheAddEntry benchmarks adding entries to hybrid cache
15181607
func BenchmarkHybridCacheAddEntry(b *testing.B) {
15191608
if os.Getenv("MILVUS_URI") == "" {

src/semantic-router/pkg/cache/hybrid_cache.go

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -938,15 +938,15 @@ func (h *HybridCache) searchLayerHybrid(query []float32, ef int, layer int, entr
938938
if ep < 0 || ep >= len(h.embeddings) {
939939
continue
940940
}
941-
dist := -dotProduct(query, h.embeddings[ep])
941+
dist := -dotProduct(query, h.embeddings[ep]) // Negative product so that higher similarity = lower distance
942942
candidates.push(ep, dist)
943943
results.push(ep, dist)
944944
visited[ep] = true
945945
}
946946

947947
for len(candidates.data) > 0 {
948948
currentIdx, currentDist := candidates.pop()
949-
if len(results.data) > 0 && currentDist > -results.data[0].dist {
949+
if len(results.data) > 0 && currentDist > results.data[0].dist {
950950
break
951951
}
952952

@@ -964,7 +964,7 @@ func (h *HybridCache) searchLayerHybrid(query []float32, ef int, layer int, entr
964964

965965
dist := -dotProduct(query, h.embeddings[neighborID])
966966

967-
if len(results.data) < ef || dist < -results.data[0].dist {
967+
if len(results.data) < ef || dist < results.data[0].dist {
968968
candidates.push(neighborID, dist)
969969
results.push(neighborID, dist)
970970

@@ -1062,7 +1062,7 @@ func (h *HybridCache) searchLayerHybridWithEarlyStop(query []float32, ef int, la
10621062
if ep < 0 || ep >= len(h.embeddings) {
10631063
continue
10641064
}
1065-
dist := -dotProductSIMD(query, h.embeddings[ep])
1065+
dist := -dotProductSIMD(query, h.embeddings[ep]) // Negative product so that higher similarity = lower distance
10661066
candidates.push(ep, dist)
10671067
results.push(ep, dist)
10681068
visited[ep] = true
@@ -1075,7 +1075,7 @@ func (h *HybridCache) searchLayerHybridWithEarlyStop(query []float32, ef int, la
10751075

10761076
for len(candidates.data) > 0 {
10771077
currentIdx, currentDist := candidates.pop()
1078-
if len(results.data) > 0 && currentDist > -results.data[0].dist {
1078+
if len(results.data) > 0 && currentDist > results.data[0].dist {
10791079
break
10801080
}
10811081

@@ -1098,7 +1098,7 @@ func (h *HybridCache) searchLayerHybridWithEarlyStop(query []float32, ef int, la
10981098
return []int{neighborID}
10991099
}
11001100

1101-
if len(results.data) < ef || dist < -results.data[0].dist {
1101+
if len(results.data) < ef || dist < results.data[0].dist {
11021102
candidates.push(neighborID, dist)
11031103
results.push(neighborID, dist)
11041104

0 commit comments

Comments
 (0)