Skip to content

Commit f1ecc20

Browse files
committed
rebuild index upon restart
Signed-off-by: Huamin Chen <[email protected]>
1 parent 2d06b40 commit f1ecc20

File tree

4 files changed

+271
-270
lines changed

4 files changed

+271
-270
lines changed

src/semantic-router/pkg/cache/hybrid_cache.go

Lines changed: 100 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -98,6 +98,9 @@ type HybridCacheOptions struct {
9898

9999
// Milvus settings
100100
MilvusConfigPath string
101+
102+
// Startup settings
103+
DisableRebuildOnStartup bool // Skip rebuilding HNSW index from Milvus on startup (default: false, meaning rebuild IS enabled)
101104
}
102105

103106
// NewHybridCache creates a new hybrid cache instance
@@ -153,6 +156,26 @@ func NewHybridCache(options HybridCacheOptions) (*HybridCache, error) {
153156
observability.Infof("Hybrid cache initialized: HNSW(M=%d, ef=%d), maxMemory=%d",
154157
options.HNSWM, options.HNSWEfConstruction, options.MaxMemoryEntries)
155158

159+
// Rebuild HNSW index from Milvus on startup (enabled by default)
160+
// This ensures the in-memory index is populated after a restart
161+
// Set DisableRebuildOnStartup=true to skip this step (not recommended for production)
162+
if !options.DisableRebuildOnStartup {
163+
observability.Infof("Hybrid cache: rebuilding HNSW index from Milvus...")
164+
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Minute)
165+
defer cancel()
166+
167+
if err := cache.RebuildFromMilvus(ctx); err != nil {
168+
observability.Warnf("Hybrid cache: failed to rebuild HNSW index from Milvus: %v", err)
169+
observability.Warnf("Hybrid cache: continuing with empty HNSW index")
170+
// Don't fail initialization, just log warning and continue with empty index
171+
} else {
172+
observability.Infof("Hybrid cache: HNSW index rebuild complete")
173+
}
174+
} else {
175+
observability.Warnf("Hybrid cache: skipping HNSW index rebuild (DisableRebuildOnStartup=true)")
176+
observability.Warnf("Hybrid cache: index will be empty until entries are added")
177+
}
178+
156179
return cache, nil
157180
}
158181

@@ -161,6 +184,83 @@ func (h *HybridCache) IsEnabled() bool {
161184
return h.enabled
162185
}
163186

187+
// RebuildFromMilvus rebuilds the in-memory HNSW index from persistent Milvus storage
188+
// This is called on startup to recover the index after a restart
189+
func (h *HybridCache) RebuildFromMilvus(ctx context.Context) error {
190+
if !h.enabled {
191+
return nil
192+
}
193+
194+
start := time.Now()
195+
observability.Infof("HybridCache.RebuildFromMilvus: starting HNSW index rebuild from Milvus")
196+
197+
// Query all entries from Milvus
198+
requestIDs, embeddings, err := h.milvusCache.GetAllEntries(ctx)
199+
if err != nil {
200+
return fmt.Errorf("failed to get entries from Milvus: %w", err)
201+
}
202+
203+
if len(requestIDs) == 0 {
204+
observability.Infof("HybridCache.RebuildFromMilvus: no entries to rebuild, starting with empty index")
205+
return nil
206+
}
207+
208+
observability.Infof("HybridCache.RebuildFromMilvus: rebuilding HNSW index with %d entries", len(requestIDs))
209+
210+
// Lock for the entire rebuild process
211+
h.mu.Lock()
212+
defer h.mu.Unlock()
213+
214+
// Clear existing index
215+
h.embeddings = make([][]float32, 0, len(embeddings))
216+
h.idMap = make(map[int]string)
217+
h.hnswIndex = newHNSWIndex(h.hnswIndex.M, h.hnswIndex.efConstruction)
218+
219+
// Rebuild HNSW index with progress logging
220+
batchSize := 1000
221+
for i, embedding := range embeddings {
222+
// Check memory limits
223+
if len(h.embeddings) >= h.maxMemoryEntries {
224+
observability.Warnf("HybridCache.RebuildFromMilvus: reached max memory entries (%d), stopping rebuild at %d/%d",
225+
h.maxMemoryEntries, i, len(embeddings))
226+
break
227+
}
228+
229+
// Add to HNSW
230+
entryIndex := len(h.embeddings)
231+
h.embeddings = append(h.embeddings, embedding)
232+
h.idMap[entryIndex] = requestIDs[i]
233+
h.addNodeHybrid(entryIndex, embedding)
234+
235+
// Progress logging for large datasets
236+
if (i+1)%batchSize == 0 {
237+
elapsed := time.Since(start)
238+
rate := float64(i+1) / elapsed.Seconds()
239+
remaining := len(embeddings) - (i + 1)
240+
eta := time.Duration(float64(remaining)/rate) * time.Second
241+
observability.Infof("HybridCache.RebuildFromMilvus: progress %d/%d (%.1f%%, %.0f entries/sec, ETA: %v)",
242+
i+1, len(embeddings), float64(i+1)/float64(len(embeddings))*100, rate, eta)
243+
}
244+
}
245+
246+
elapsed := time.Since(start)
247+
rate := float64(len(h.embeddings)) / elapsed.Seconds()
248+
observability.Infof("HybridCache.RebuildFromMilvus: rebuild complete - %d entries in %v (%.0f entries/sec)",
249+
len(h.embeddings), elapsed, rate)
250+
251+
observability.LogEvent("hybrid_cache_rebuilt", map[string]interface{}{
252+
"backend": "hybrid",
253+
"entries_loaded": len(h.embeddings),
254+
"entries_in_milvus": len(embeddings),
255+
"duration_seconds": elapsed.Seconds(),
256+
"entries_per_sec": rate,
257+
})
258+
259+
metrics.UpdateCacheEntries("hybrid", len(h.embeddings))
260+
261+
return nil
262+
}
263+
164264
// AddPendingRequest stores a request awaiting its response
165265
func (h *HybridCache) AddPendingRequest(requestID string, model string, query string, requestBody []byte) error {
166266
start := time.Now()

src/semantic-router/pkg/cache/milvus_cache.go

Lines changed: 74 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -753,6 +753,80 @@ func (c *MilvusCache) FindSimilarWithThreshold(model string, query string, thres
753753
return responseBody, true, nil
754754
}
755755

756+
// GetAllEntries retrieves all entries from Milvus for HNSW index rebuilding
757+
// Returns slices of request_ids and embeddings for efficient bulk loading
758+
func (c *MilvusCache) GetAllEntries(ctx context.Context) ([]string, [][]float32, error) {
759+
start := time.Now()
760+
761+
if !c.enabled {
762+
return nil, nil, fmt.Errorf("milvus cache is not enabled")
763+
}
764+
765+
observability.Infof("MilvusCache.GetAllEntries: querying all entries for HNSW rebuild")
766+
767+
// Query all entries with embeddings and request_ids
768+
// Filter to only get entries with complete responses (not pending)
769+
queryResult, err := c.client.Query(
770+
ctx,
771+
c.collectionName,
772+
[]string{}, // Empty partitions means search all
773+
"response_body != \"\"", // Only get complete entries
774+
[]string{"request_id", c.config.Collection.VectorField.Name}, // Get IDs and embeddings
775+
)
776+
777+
if err != nil {
778+
observability.Warnf("MilvusCache.GetAllEntries: query failed: %v", err)
779+
return nil, nil, fmt.Errorf("milvus query all failed: %w", err)
780+
}
781+
782+
if len(queryResult) < 2 {
783+
observability.Infof("MilvusCache.GetAllEntries: no entries found or incomplete result")
784+
return []string{}, [][]float32{}, nil
785+
}
786+
787+
// Extract request IDs (first column)
788+
requestIDColumn, ok := queryResult[0].(*entity.ColumnVarChar)
789+
if !ok {
790+
return nil, nil, fmt.Errorf("unexpected request_id column type: %T", queryResult[0])
791+
}
792+
793+
// Extract embeddings (second column)
794+
embeddingColumn, ok := queryResult[1].(*entity.ColumnFloatVector)
795+
if !ok {
796+
return nil, nil, fmt.Errorf("unexpected embedding column type: %T", queryResult[1])
797+
}
798+
799+
if requestIDColumn.Len() != embeddingColumn.Len() {
800+
return nil, nil, fmt.Errorf("column length mismatch: request_ids=%d, embeddings=%d",
801+
requestIDColumn.Len(), embeddingColumn.Len())
802+
}
803+
804+
entryCount := requestIDColumn.Len()
805+
requestIDs := make([]string, entryCount)
806+
807+
// Extract request IDs from column
808+
for i := 0; i < entryCount; i++ {
809+
requestID, err := requestIDColumn.ValueByIdx(i)
810+
if err != nil {
811+
return nil, nil, fmt.Errorf("failed to get request_id at index %d: %w", i, err)
812+
}
813+
requestIDs[i] = requestID
814+
}
815+
816+
// Extract embeddings directly from column data
817+
embeddings := embeddingColumn.Data()
818+
if len(embeddings) != entryCount {
819+
return nil, nil, fmt.Errorf("embedding data length mismatch: got %d, expected %d",
820+
len(embeddings), entryCount)
821+
}
822+
823+
elapsed := time.Since(start)
824+
observability.Infof("MilvusCache.GetAllEntries: loaded %d entries in %v (%.0f entries/sec)",
825+
entryCount, elapsed, float64(entryCount)/elapsed.Seconds())
826+
827+
return requestIDs, embeddings, nil
828+
}
829+
756830
// GetByID retrieves a document from Milvus by its request ID
757831
// This is much more efficient than FindSimilar when you already know the ID
758832
// Used by hybrid cache to fetch documents after local HNSW search

0 commit comments

Comments
 (0)