Skip to content

Commit 9db4775

Browse files
committed
rebuild index upon restart
Signed-off-by: Huamin Chen <[email protected]>
1 parent 60dac1f commit 9db4775

File tree

4 files changed

+271
-343
lines changed

4 files changed

+271
-343
lines changed

src/semantic-router/pkg/cache/hybrid_cache.go

Lines changed: 100 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -98,6 +98,9 @@ type HybridCacheOptions struct {
9898

9999
// Milvus settings
100100
MilvusConfigPath string
101+
102+
// Startup settings
103+
DisableRebuildOnStartup bool // Skip rebuilding HNSW index from Milvus on startup (default: false, meaning rebuild IS enabled)
101104
}
102105

103106
// NewHybridCache creates a new hybrid cache instance
@@ -153,6 +156,26 @@ func NewHybridCache(options HybridCacheOptions) (*HybridCache, error) {
153156
observability.Infof("Hybrid cache initialized: HNSW(M=%d, ef=%d), maxMemory=%d",
154157
options.HNSWM, options.HNSWEfConstruction, options.MaxMemoryEntries)
155158

159+
// Rebuild HNSW index from Milvus on startup (enabled by default)
160+
// This ensures the in-memory index is populated after a restart
161+
// Set DisableRebuildOnStartup=true to skip this step (not recommended for production)
162+
if !options.DisableRebuildOnStartup {
163+
observability.Infof("Hybrid cache: rebuilding HNSW index from Milvus...")
164+
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Minute)
165+
defer cancel()
166+
167+
if err := cache.RebuildFromMilvus(ctx); err != nil {
168+
observability.Warnf("Hybrid cache: failed to rebuild HNSW index from Milvus: %v", err)
169+
observability.Warnf("Hybrid cache: continuing with empty HNSW index")
170+
// Don't fail initialization, just log warning and continue with empty index
171+
} else {
172+
observability.Infof("Hybrid cache: HNSW index rebuild complete")
173+
}
174+
} else {
175+
observability.Warnf("Hybrid cache: skipping HNSW index rebuild (DisableRebuildOnStartup=true)")
176+
observability.Warnf("Hybrid cache: index will be empty until entries are added")
177+
}
178+
156179
return cache, nil
157180
}
158181

@@ -161,6 +184,83 @@ func (h *HybridCache) IsEnabled() bool {
161184
return h.enabled
162185
}
163186

187+
// RebuildFromMilvus rebuilds the in-memory HNSW index from persistent Milvus storage
188+
// This is called on startup to recover the index after a restart
189+
func (h *HybridCache) RebuildFromMilvus(ctx context.Context) error {
190+
if !h.enabled {
191+
return nil
192+
}
193+
194+
start := time.Now()
195+
observability.Infof("HybridCache.RebuildFromMilvus: starting HNSW index rebuild from Milvus")
196+
197+
// Query all entries from Milvus
198+
requestIDs, embeddings, err := h.milvusCache.GetAllEntries(ctx)
199+
if err != nil {
200+
return fmt.Errorf("failed to get entries from Milvus: %w", err)
201+
}
202+
203+
if len(requestIDs) == 0 {
204+
observability.Infof("HybridCache.RebuildFromMilvus: no entries to rebuild, starting with empty index")
205+
return nil
206+
}
207+
208+
observability.Infof("HybridCache.RebuildFromMilvus: rebuilding HNSW index with %d entries", len(requestIDs))
209+
210+
// Lock for the entire rebuild process
211+
h.mu.Lock()
212+
defer h.mu.Unlock()
213+
214+
// Clear existing index
215+
h.embeddings = make([][]float32, 0, len(embeddings))
216+
h.idMap = make(map[int]string)
217+
h.hnswIndex = newHNSWIndex(h.hnswIndex.M, h.hnswIndex.efConstruction)
218+
219+
// Rebuild HNSW index with progress logging
220+
batchSize := 1000
221+
for i, embedding := range embeddings {
222+
// Check memory limits
223+
if len(h.embeddings) >= h.maxMemoryEntries {
224+
observability.Warnf("HybridCache.RebuildFromMilvus: reached max memory entries (%d), stopping rebuild at %d/%d",
225+
h.maxMemoryEntries, i, len(embeddings))
226+
break
227+
}
228+
229+
// Add to HNSW
230+
entryIndex := len(h.embeddings)
231+
h.embeddings = append(h.embeddings, embedding)
232+
h.idMap[entryIndex] = requestIDs[i]
233+
h.addNodeHybrid(entryIndex, embedding)
234+
235+
// Progress logging for large datasets
236+
if (i+1)%batchSize == 0 {
237+
elapsed := time.Since(start)
238+
rate := float64(i+1) / elapsed.Seconds()
239+
remaining := len(embeddings) - (i + 1)
240+
eta := time.Duration(float64(remaining)/rate) * time.Second
241+
observability.Infof("HybridCache.RebuildFromMilvus: progress %d/%d (%.1f%%, %.0f entries/sec, ETA: %v)",
242+
i+1, len(embeddings), float64(i+1)/float64(len(embeddings))*100, rate, eta)
243+
}
244+
}
245+
246+
elapsed := time.Since(start)
247+
rate := float64(len(h.embeddings)) / elapsed.Seconds()
248+
observability.Infof("HybridCache.RebuildFromMilvus: rebuild complete - %d entries in %v (%.0f entries/sec)",
249+
len(h.embeddings), elapsed, rate)
250+
251+
observability.LogEvent("hybrid_cache_rebuilt", map[string]interface{}{
252+
"backend": "hybrid",
253+
"entries_loaded": len(h.embeddings),
254+
"entries_in_milvus": len(embeddings),
255+
"duration_seconds": elapsed.Seconds(),
256+
"entries_per_sec": rate,
257+
})
258+
259+
metrics.UpdateCacheEntries("hybrid", len(h.embeddings))
260+
261+
return nil
262+
}
263+
164264
// AddPendingRequest stores a request awaiting its response
165265
func (h *HybridCache) AddPendingRequest(requestID string, model string, query string, requestBody []byte) error {
166266
start := time.Now()

src/semantic-router/pkg/cache/milvus_cache.go

Lines changed: 74 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -748,6 +748,80 @@ func (c *MilvusCache) FindSimilar(model string, query string) ([]byte, bool, err
748748
return responseBody, true, nil
749749
}
750750

751+
// GetAllEntries retrieves all entries from Milvus for HNSW index rebuilding
752+
// Returns slices of request_ids and embeddings for efficient bulk loading
753+
func (c *MilvusCache) GetAllEntries(ctx context.Context) ([]string, [][]float32, error) {
754+
start := time.Now()
755+
756+
if !c.enabled {
757+
return nil, nil, fmt.Errorf("milvus cache is not enabled")
758+
}
759+
760+
observability.Infof("MilvusCache.GetAllEntries: querying all entries for HNSW rebuild")
761+
762+
// Query all entries with embeddings and request_ids
763+
// Filter to only get entries with complete responses (not pending)
764+
queryResult, err := c.client.Query(
765+
ctx,
766+
c.collectionName,
767+
[]string{}, // Empty partitions means search all
768+
"response_body != \"\"", // Only get complete entries
769+
[]string{"request_id", c.config.Collection.VectorField.Name}, // Get IDs and embeddings
770+
)
771+
772+
if err != nil {
773+
observability.Warnf("MilvusCache.GetAllEntries: query failed: %v", err)
774+
return nil, nil, fmt.Errorf("milvus query all failed: %w", err)
775+
}
776+
777+
if len(queryResult) < 2 {
778+
observability.Infof("MilvusCache.GetAllEntries: no entries found or incomplete result")
779+
return []string{}, [][]float32{}, nil
780+
}
781+
782+
// Extract request IDs (first column)
783+
requestIDColumn, ok := queryResult[0].(*entity.ColumnVarChar)
784+
if !ok {
785+
return nil, nil, fmt.Errorf("unexpected request_id column type: %T", queryResult[0])
786+
}
787+
788+
// Extract embeddings (second column)
789+
embeddingColumn, ok := queryResult[1].(*entity.ColumnFloatVector)
790+
if !ok {
791+
return nil, nil, fmt.Errorf("unexpected embedding column type: %T", queryResult[1])
792+
}
793+
794+
if requestIDColumn.Len() != embeddingColumn.Len() {
795+
return nil, nil, fmt.Errorf("column length mismatch: request_ids=%d, embeddings=%d",
796+
requestIDColumn.Len(), embeddingColumn.Len())
797+
}
798+
799+
entryCount := requestIDColumn.Len()
800+
requestIDs := make([]string, entryCount)
801+
802+
// Extract request IDs from column
803+
for i := 0; i < entryCount; i++ {
804+
requestID, err := requestIDColumn.ValueByIdx(i)
805+
if err != nil {
806+
return nil, nil, fmt.Errorf("failed to get request_id at index %d: %w", i, err)
807+
}
808+
requestIDs[i] = requestID
809+
}
810+
811+
// Extract embeddings directly from column data
812+
embeddings := embeddingColumn.Data()
813+
if len(embeddings) != entryCount {
814+
return nil, nil, fmt.Errorf("embedding data length mismatch: got %d, expected %d",
815+
len(embeddings), entryCount)
816+
}
817+
818+
elapsed := time.Since(start)
819+
observability.Infof("MilvusCache.GetAllEntries: loaded %d entries in %v (%.0f entries/sec)",
820+
entryCount, elapsed, float64(entryCount)/elapsed.Seconds())
821+
822+
return requestIDs, embeddings, nil
823+
}
824+
751825
// GetByID retrieves a document from Milvus by its request ID
752826
// This is much more efficient than FindSimilar when you already know the ID
753827
// Used by hybrid cache to fetch documents after local HNSW search

0 commit comments

Comments
 (0)