Skip to content

Commit 108374b

Browse files
authored
feat: enforce milvus dial timeout if set (#503)
Signed-off-by: cryo <[email protected]>
1 parent 0c2a6c1 commit 108374b

File tree

2 files changed

+49
-1
lines changed

2 files changed

+49
-1
lines changed

src/semantic-router/pkg/cache/cache_test.go

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
package cache_test
22

33
import (
4+
"fmt"
45
"os"
56
"path/filepath"
67
"strings"
@@ -177,6 +178,44 @@ development:
177178
})
178179
})
179180

181+
Context("Milvus connection timeouts", func() {
182+
It("should respect connection timeout when endpoint is unreachable", func() {
183+
unreachableConfigPath := filepath.Join(tempDir, "milvus-unreachable.yaml")
184+
unreachableHost := "10.255.255.1" // unroutable address to simulate a hanging dial
185+
unreachableConfig := fmt.Sprintf(`
186+
connection:
187+
host: "%s"
188+
port: 19530
189+
database: "test_cache"
190+
timeout: 1
191+
`, unreachableHost)
192+
193+
err := os.WriteFile(unreachableConfigPath, []byte(unreachableConfig), 0o644)
194+
Expect(err).NotTo(HaveOccurred())
195+
196+
done := make(chan struct{})
197+
var cacheErr error
198+
199+
go func() {
200+
defer GinkgoRecover()
201+
_, cacheErr = cache.NewMilvusCache(cache.MilvusCacheOptions{
202+
Enabled: true,
203+
SimilarityThreshold: 0.85,
204+
TTLSeconds: 60,
205+
ConfigPath: unreachableConfigPath,
206+
})
207+
close(done)
208+
}()
209+
210+
Eventually(done, 2*time.Second, 100*time.Millisecond).Should(BeClosed())
211+
Expect(cacheErr).To(HaveOccurred())
212+
Expect(cacheErr.Error()).To(Or(
213+
ContainSubstring("context deadline exceeded"),
214+
ContainSubstring("timeout"),
215+
))
216+
})
217+
})
218+
180219
Context("with unsupported backend type", func() {
181220
It("should return error for unsupported backend type", func() {
182221
config := cache.CacheConfig{

src/semantic-router/pkg/cache/milvus_cache.go

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -138,7 +138,16 @@ func NewMilvusCache(options MilvusCacheOptions) (*MilvusCache, error) {
138138
// Establish connection to Milvus server
139139
connectionString := fmt.Sprintf("%s:%d", config.Connection.Host, config.Connection.Port)
140140
observability.Debugf("MilvusCache: connecting to Milvus at %s", connectionString)
141-
milvusClient, err := client.NewGrpcClient(context.Background(), connectionString)
141+
dialCtx := context.Background()
142+
var cancel context.CancelFunc
143+
if config.Connection.Timeout > 0 {
144+
// If a timeout is specified, apply it to the connection context
145+
timeout := time.Duration(config.Connection.Timeout) * time.Second
146+
dialCtx, cancel = context.WithTimeout(dialCtx, timeout)
147+
defer cancel()
148+
observability.Debugf("MilvusCache: connection timeout set to %s", timeout)
149+
}
150+
milvusClient, err := client.NewGrpcClient(dialCtx, connectionString)
142151
if err != nil {
143152
observability.Debugf("MilvusCache: failed to connect: %v", err)
144153
return nil, fmt.Errorf("failed to create Milvus client: %w", err)

0 commit comments

Comments
 (0)