diff --git a/src/semantic-router/pkg/cache/cache_interface.go b/src/semantic-router/pkg/cache/cache_interface.go index b00e6edd..16a95dd1 100644 --- a/src/semantic-router/pkg/cache/cache_interface.go +++ b/src/semantic-router/pkg/cache/cache_interface.go @@ -20,6 +20,11 @@ type CacheBackend interface { // IsEnabled returns whether caching is currently active IsEnabled() bool + // CheckConnection verifies the cache backend connection is healthy + // Returns nil if the connection is healthy, error otherwise + // For local caches (in-memory), this may be a no-op + CheckConnection() error + // AddPendingRequest stores a request awaiting its response AddPendingRequest(requestID string, model string, query string, requestBody []byte) error diff --git a/src/semantic-router/pkg/cache/hybrid_cache.go b/src/semantic-router/pkg/cache/hybrid_cache.go index 3a7a452e..69c660a3 100644 --- a/src/semantic-router/pkg/cache/hybrid_cache.go +++ b/src/semantic-router/pkg/cache/hybrid_cache.go @@ -191,6 +191,21 @@ func (h *HybridCache) IsEnabled() bool { return h.enabled } +// CheckConnection verifies the cache backend connection is healthy +// For hybrid cache, this checks the Milvus connection +func (h *HybridCache) CheckConnection() error { + if !h.enabled { + return nil + } + + if h.milvusCache == nil { + return fmt.Errorf("milvus cache is not initialized") + } + + // Delegate to Milvus cache connection check + return h.milvusCache.CheckConnection() +} + // RebuildFromMilvus rebuilds the in-memory HNSW index from persistent Milvus storage // This is called on startup to recover the index after a restart func (h *HybridCache) RebuildFromMilvus(ctx context.Context) error { diff --git a/src/semantic-router/pkg/cache/inmemory_cache.go b/src/semantic-router/pkg/cache/inmemory_cache.go index 8f1dc555..a0e4d82e 100644 --- a/src/semantic-router/pkg/cache/inmemory_cache.go +++ b/src/semantic-router/pkg/cache/inmemory_cache.go @@ -131,6 +131,13 @@ func (c *InMemoryCache) IsEnabled() bool { return c.enabled } +// CheckConnection verifies the cache connection is healthy +// For in-memory cache, this is always healthy (no external connection) +func (c *InMemoryCache) CheckConnection() error { + // In-memory cache has no external connection to check + return nil +} + // generateEmbedding generates an embedding using the configured model func (c *InMemoryCache) generateEmbedding(text string) ([]float32, error) { // Normalize to lowercase for case-insensitive comparison @@ -1099,10 +1106,3 @@ func (h *maxHeap) bubbleDown(i int) { i = largest } } - -func min(a, b int) int { - if a < b { - return a - } - return b -} diff --git a/src/semantic-router/pkg/cache/milvus_cache.go b/src/semantic-router/pkg/cache/milvus_cache.go index 35eaa83f..3c67e136 100644 --- a/src/semantic-router/pkg/cache/milvus_cache.go +++ b/src/semantic-router/pkg/cache/milvus_cache.go @@ -154,7 +154,6 @@ func NewMilvusCache(options MilvusCacheOptions) (*MilvusCache, error) { logging.Debugf("MilvusCache: failed to connect: %v", err) return nil, fmt.Errorf("failed to create Milvus client: %w", err) } - logging.Debugf("MilvusCache: successfully connected to Milvus") cache := &MilvusCache{ client: milvusClient, @@ -165,6 +164,14 @@ func NewMilvusCache(options MilvusCacheOptions) (*MilvusCache, error) { enabled: options.Enabled, } + // Test connection using the new CheckConnection method + if err := cache.CheckConnection(); err != nil { + logging.Debugf("MilvusCache: connection check failed: %v", err) + milvusClient.Close() + return nil, err + } + logging.Debugf("MilvusCache: successfully connected to Milvus") + // Set up the collection for caching logging.Debugf("MilvusCache: initializing collection '%s'", config.Collection.Name) if err := cache.initializeCollection(); err != nil { @@ -392,6 +399,34 @@ func (c *MilvusCache) IsEnabled() bool { return c.enabled } +// CheckConnection verifies the Milvus connection is healthy +func (c *MilvusCache) CheckConnection() error { + if !c.enabled { + return nil + } + + if c.client == nil { + return fmt.Errorf("milvus client is not initialized") + } + + ctx := context.Background() + if c.config != nil && c.config.Connection.Timeout > 0 { + timeout := time.Duration(c.config.Connection.Timeout) * time.Second + var cancel context.CancelFunc + ctx, cancel = context.WithTimeout(ctx, timeout) + defer cancel() + } + + // Simple connection check - list collections to verify connectivity + // We don't check if specific collection exists here as it may not be created yet + _, err := c.client.ListCollections(ctx) + if err != nil { + return fmt.Errorf("milvus connection check failed: %w", err) + } + + return nil +} + // AddPendingRequest stores a request that is awaiting its response func (c *MilvusCache) AddPendingRequest(requestID string, model string, query string, requestBody []byte) error { start := time.Now() diff --git a/src/semantic-router/pkg/cache/redis_cache.go b/src/semantic-router/pkg/cache/redis_cache.go index aac25a45..f4c91b03 100644 --- a/src/semantic-router/pkg/cache/redis_cache.go +++ b/src/semantic-router/pkg/cache/redis_cache.go @@ -115,22 +115,6 @@ func NewRedisCache(options RedisCacheOptions) (*RedisCache, error) { Protocol: 2, // Use RESP2 protocol for compatibility }) - // Test connection - ctx := context.Background() - if config.Connection.Timeout > 0 { - timeout := time.Duration(config.Connection.Timeout) * time.Second - var cancel context.CancelFunc - ctx, cancel = context.WithTimeout(ctx, timeout) - defer cancel() - logging.Debugf("RedisCache: connection timeout set to %s", timeout) - } - - if err := redisClient.Ping(ctx).Err(); err != nil { - logging.Debugf("RedisCache: failed to connect: %v", err) - return nil, fmt.Errorf("failed to connect to Redis: %w", err) - } - logging.Debugf("RedisCache: successfully connected to Redis") - cache := &RedisCache{ client: redisClient, config: config, @@ -140,6 +124,13 @@ func NewRedisCache(options RedisCacheOptions) (*RedisCache, error) { enabled: options.Enabled, } + // Test connection using the new CheckConnection method + if err := cache.CheckConnection(); err != nil { + logging.Debugf("RedisCache: failed to connect: %v", err) + return nil, err + } + logging.Debugf("RedisCache: successfully connected to Redis") + // Set up the index for vector search logging.Debugf("RedisCache: initializing index '%s'", config.Index.Name) if err := cache.initializeIndex(); err != nil { @@ -350,6 +341,31 @@ func (c *RedisCache) IsEnabled() bool { return c.enabled } +// CheckConnection verifies the Redis connection is healthy +func (c *RedisCache) CheckConnection() error { + if !c.enabled { + return nil + } + + if c.client == nil { + return fmt.Errorf("redis client is not initialized") + } + + ctx := context.Background() + if c.config != nil && c.config.Connection.Timeout > 0 { + timeout := time.Duration(c.config.Connection.Timeout) * time.Second + var cancel context.CancelFunc + ctx, cancel = context.WithTimeout(ctx, timeout) + defer cancel() + } + + if err := c.client.Ping(ctx).Err(); err != nil { + return fmt.Errorf("redis connection check failed: %w", err) + } + + return nil +} + // AddPendingRequest stores a request that is awaiting its response func (c *RedisCache) AddPendingRequest(requestID string, model string, query string, requestBody []byte) error { start := time.Now()