Skip to content

Commit f9c8aa8

Browse files
authored
fix:hnsw heap polarity (#550)
* fix:hnsw heap polarity Signed-off-by: cryo <[email protected]> * fix: go-lint failures Signed-off-by: cryo <[email protected]> --------- Signed-off-by: cryo <[email protected]>
1 parent 860c4c8 commit f9c8aa8

File tree

2 files changed

+138
-22
lines changed

2 files changed

+138
-22
lines changed

src/semantic-router/pkg/cache/inmemory_cache.go

Lines changed: 20 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -771,8 +771,8 @@ func (h *HNSWIndex) searchKNN(queryEmbedding []float32, k, ef int, entries []Cac
771771
// searchLayer searches for nearest neighbors at a specific layer
772772
func (h *HNSWIndex) searchLayer(queryEmbedding []float32, entryPoint, ef, layer int, entries []CacheEntry) []int {
773773
visited := make(map[int]bool)
774-
candidates := newMaxHeap()
775-
results := newMinHeap()
774+
candidates := newMinHeap() // set of candidates, explore closest candidate first
775+
results := newMaxHeap() // dynamic list of found nearest neighbors, track current frontier, worst distance on top
776776

777777
// Calculate distance to entry point
778778
if entryPoint >= 0 && entryPoint < len(entries) {
@@ -785,11 +785,9 @@ func (h *HNSWIndex) searchLayer(queryEmbedding []float32, entryPoint, ef, layer
785785
for candidates.len() > 0 {
786786
currentIdx, currentDist := candidates.pop()
787787

788-
if results.len() > 0 {
789-
worstDist := results.peekDist()
790-
if currentDist > worstDist {
791-
break
792-
}
788+
// If we have enough results and the current distance is worse than the worst in results, we can stop
789+
if results.len() > 0 && currentDist > results.peekDist() {
790+
break
793791
}
794792

795793
// Fast O(1) lookup using nodeIndex map
@@ -881,25 +879,10 @@ func (h *minHeap) pop() (int, float32) {
881879
return result.index, result.dist
882880
}
883881

884-
func (h *minHeap) peekDist() float32 {
885-
if len(h.data) == 0 {
886-
return math.MaxFloat32
887-
}
888-
return h.data[0].dist
889-
}
890-
891882
func (h *minHeap) len() int {
892883
return len(h.data)
893884
}
894885

895-
func (h *minHeap) items() []int {
896-
result := make([]int, len(h.data))
897-
for i, item := range h.data {
898-
result[i] = item.index
899-
}
900-
return result
901-
}
902-
903886
func (h *minHeap) bubbleUp(i int) {
904887
for i > 0 {
905888
parent := (i - 1) / 2
@@ -961,6 +944,21 @@ func (h *maxHeap) len() int {
961944
return len(h.data)
962945
}
963946

947+
func (h *maxHeap) peekDist() float32 {
948+
if len(h.data) == 0 {
949+
return math.MaxFloat32
950+
}
951+
return h.data[0].dist
952+
}
953+
954+
func (h *maxHeap) items() []int {
955+
result := make([]int, len(h.data))
956+
for i, item := range h.data {
957+
result[i] = item.index
958+
}
959+
return result
960+
}
961+
964962
func (h *maxHeap) bubbleUp(i int) {
965963
for i > 0 {
966964
parent := (i - 1) / 2
Lines changed: 118 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,118 @@
1+
package cache
2+
3+
import (
4+
"slices"
5+
"testing"
6+
)
7+
8+
func TestSearchLayerHeapManagement(t *testing.T) {
9+
t.Run("retains the closest neighbor when ef is saturated", func(t *testing.T) {
10+
// Regression fixture: with the previous max-heap candidates/min-heap results
11+
// mix, trimming to ef would evict the best element instead of the worst.
12+
queryEmbedding := []float32{1.0}
13+
14+
entries := []CacheEntry{
15+
{Embedding: []float32{0.1}}, // entry point has low similarity
16+
{Embedding: []float32{1.0}}, // neighbor is the true nearest
17+
}
18+
19+
entryNode := &HNSWNode{
20+
entryIndex: 0,
21+
neighbors: map[int][]int{
22+
0: {1},
23+
},
24+
maxLayer: 0,
25+
}
26+
27+
neighborNode := &HNSWNode{
28+
entryIndex: 1,
29+
neighbors: map[int][]int{
30+
0: {0},
31+
},
32+
maxLayer: 0,
33+
}
34+
35+
index := &HNSWIndex{
36+
nodes: []*HNSWNode{entryNode, neighborNode},
37+
nodeIndex: map[int]*HNSWNode{
38+
0: entryNode,
39+
1: neighborNode,
40+
},
41+
entryPoint: 0,
42+
maxLayer: 0,
43+
efConstruction: 2,
44+
M: 1,
45+
Mmax: 1,
46+
Mmax0: 2,
47+
ml: 1,
48+
}
49+
50+
results := index.searchLayer(queryEmbedding, index.entryPoint, 1, 0, entries)
51+
52+
if !slices.Contains(results, 1) {
53+
t.Fatalf("expected results to contain best neighbor 1, got %v", results)
54+
}
55+
if slices.Contains(results, 0) {
56+
t.Fatalf("expected results to drop entry point 0 once ef trimmed, got %v", results)
57+
}
58+
})
59+
60+
t.Run("continues exploring even when next candidate looks worse", func(t *testing.T) {
61+
// Regression fixture: the break condition used the wrong polarity so the
62+
// search stopped before expanding the intermediate (worse) vertex, making
63+
// the actual best neighbor unreachable.
64+
queryEmbedding := []float32{1.0}
65+
66+
entries := []CacheEntry{
67+
{Embedding: []float32{0.2}}, // entry point
68+
{Embedding: []float32{0.05}}, // intermediate node with poor similarity
69+
{Embedding: []float32{1.0}}, // hidden best match
70+
}
71+
72+
entryNode := &HNSWNode{
73+
entryIndex: 0,
74+
neighbors: map[int][]int{
75+
0: {1},
76+
},
77+
maxLayer: 0,
78+
}
79+
80+
intermediateNode := &HNSWNode{
81+
entryIndex: 1,
82+
neighbors: map[int][]int{
83+
0: {0, 2},
84+
},
85+
maxLayer: 0,
86+
}
87+
88+
bestNode := &HNSWNode{
89+
entryIndex: 2,
90+
neighbors: map[int][]int{
91+
0: {1},
92+
},
93+
maxLayer: 0,
94+
}
95+
96+
index := &HNSWIndex{
97+
nodes: []*HNSWNode{entryNode, intermediateNode, bestNode},
98+
nodeIndex: map[int]*HNSWNode{
99+
0: entryNode,
100+
1: intermediateNode,
101+
2: bestNode,
102+
},
103+
entryPoint: 0,
104+
maxLayer: 0,
105+
efConstruction: 2,
106+
M: 1,
107+
Mmax: 1,
108+
Mmax0: 2,
109+
ml: 1,
110+
}
111+
112+
results := index.searchLayer(queryEmbedding, index.entryPoint, 2, 0, entries)
113+
114+
if !slices.Contains(results, 2) {
115+
t.Fatalf("expected results to reach best neighbor 2 via intermediate node, got %v", results)
116+
}
117+
})
118+
}

0 commit comments

Comments
 (0)