Skip to content

Commit c83d109

Browse files
committed
update embedding setting in config (#489)
Signed-off-by: Huamin Chen <[email protected]>
1 parent 2bee957 commit c83d109

File tree

16 files changed

+226
-156
lines changed

16 files changed

+226
-156
lines changed

candle-binding/Cargo.lock

Lines changed: 108 additions & 131 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

candle-binding/Cargo.toml

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -19,9 +19,9 @@ flash-attn = ["candle-flash-attn"]
1919

2020
[dependencies]
2121
anyhow = { version = "1", features = ["backtrace"] }
22-
candle-core = "0.8.4"
23-
candle-nn = "0.8.4"
24-
candle-transformers = "0.8.4"
22+
candle-core = { version = "0.8.4", features = ["cuda"] }
23+
candle-nn = { version = "0.8.4", features = ["cuda"] }
24+
candle-transformers = { version = "0.8.4", features = ["cuda"] }
2525
# Flash Attention 2 (optional, requires CUDA)
2626
# Reference: https://github.com/huggingface/candle/tree/main/candle-flash-attn
2727
candle-flash-attn = { version = "0.8.4", optional = true }

candle-binding/src/model_architectures/embedding/qwen3_embedding.rs

Lines changed: 5 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1326,7 +1326,7 @@ impl Qwen3Attention {
13261326
q: &Tensor,
13271327
k: &Tensor,
13281328
v: &Tensor,
1329-
attention_mask: Option<&Tensor>,
1329+
_attention_mask: Option<&Tensor>,
13301330
) -> UnifiedResult<Tensor> {
13311331
// Flash Attention 2 implementation using candle-flash-attn
13321332
//
@@ -1363,8 +1363,8 @@ impl Qwen3Attention {
13631363
&q_flash,
13641364
&k_flash,
13651365
&v_flash,
1366-
self.scale as f32, // softmax scaling factor
1367-
false, // causal: false (Qwen3-Embedding is non-causal)
1366+
self.scaling as f32, // softmax scaling factor
1367+
false, // causal: false (Qwen3-Embedding is non-causal)
13681368
)
13691369
.map_err(|e| UnifiedError::Processing {
13701370
operation: "Flash Attention 2: flash_attn".to_string(),
@@ -1975,15 +1975,11 @@ impl Qwen3EmbeddingModel {
19751975
#[cfg(not(feature = "flash-attn"))]
19761976
{
19771977
if config.max_position_embeddings > 8192 {
1978-
eprintln!("⚠️ WARNING: Flash Attention 2 not enabled!");
1978+
eprintln!("ℹ️ Note: Using standard attention");
19791979
eprintln!(
1980-
" For {}K sequence length, performance may degrade:",
1980+
" Sequence length: {}K tokens",
19811981
config.max_position_embeddings / 1024
19821982
);
1983-
eprintln!(" - Memory usage: +40% (estimated)");
1984-
eprintln!(" - Inference speed: -50% (estimated)");
1985-
eprintln!(" Official recommendation: Compile with --features flash-attn");
1986-
eprintln!(" Reference: https://github.com/qwenlm/qwen3-embedding#usage");
19871983
}
19881984
}
19891985

config/config.yaml

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,11 @@ semantic_cache:
1919
# Combines in-memory HNSW for fast search with Milvus for scalable storage
2020
# max_memory_entries: 100000 # Max entries in HNSW index (default: 100,000)
2121
# backend_config_path: "config/milvus.yaml" # Path to Milvus config
22+
23+
# Embedding model for semantic similarity matching
24+
# Options: "bert" (fast, 384-dim), "qwen3" (high quality, 1024-dim, 32K context), "gemma" (balanced, 768-dim, 8K context)
25+
# Default: "bert" (fastest, lowest memory)
26+
embedding_model: "bert"
2227

2328
tools:
2429
enabled: true

deploy/kubernetes/config.yaml

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,10 @@ semantic_cache:
99
similarity_threshold: 0.8
1010
max_entries: 1000 # Only applies to memory backend
1111
ttl_seconds: 3600
12-
eviction_policy: "fifo"
12+
eviction_policy: "fifo"
13+
# Embedding model for semantic similarity matching
14+
# Options: "bert" (fast, 384-dim), "qwen3" (high quality, 1024-dim, 32K context), "gemma" (balanced, 768-dim, 8K context)
15+
embedding_model: "bert" # Default: BERT (fastest, lowest memory for Kubernetes)
1316

1417
tools:
1518
enabled: true

deploy/openshift/config-openshift.yaml

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,9 @@ semantic_cache:
1010
max_entries: 1000 # Only applies to memory backend
1111
ttl_seconds: 3600
1212
eviction_policy: "fifo"
13+
# Embedding model for semantic similarity matching
14+
# Options: "bert" (fast, 384-dim), "qwen3" (high quality, 1024-dim, 32K context), "gemma" (balanced, 768-dim, 8K context)
15+
embedding_model: "bert" # Default: BERT (fastest, lowest memory for OpenShift)
1316

1417
tools:
1518
enabled: true

src/semantic-router/pkg/apis/vllm.ai/v1alpha1/filter_types.go

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -117,6 +117,12 @@ type SemanticCacheConfig struct {
117117
// +kubebuilder:default=memory
118118
Backend *string `json:"backend,omitempty"`
119119

120+
// EmbeddingModel defines which embedding model to use for semantic similarity
121+
// +optional
122+
// +kubebuilder:validation:Enum=bert;qwen3;gemma
123+
// +kubebuilder:default=bert
124+
EmbeddingModel *string `json:"embeddingModel,omitempty"`
125+
120126
// BackendConfig defines backend-specific configuration
121127
// +optional
122128
BackendConfig map[string]string `json:"backendConfig,omitempty"`

src/semantic-router/pkg/cache/cache_factory.go

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -24,8 +24,9 @@ func NewCacheBackend(config CacheConfig) (CacheBackend, error) {
2424
switch config.BackendType {
2525
case InMemoryCacheType, "":
2626
// Use in-memory cache as the default backend
27-
observability.Debugf("Creating in-memory cache backend - MaxEntries: %d, TTL: %ds, Threshold: %.3f, UseHNSW: %t",
28-
config.MaxEntries, config.TTLSeconds, config.SimilarityThreshold, config.UseHNSW)
27+
observability.Debugf("Creating in-memory cache backend - MaxEntries: %d, TTL: %ds, Threshold: %.3f, UseHNSW: %t, EmbeddingModel: %s",
28+
config.MaxEntries, config.TTLSeconds, config.SimilarityThreshold, config.UseHNSW, config.EmbeddingModel)
29+
2930
options := InMemoryCacheOptions{
3031
Enabled: config.Enabled,
3132
SimilarityThreshold: config.SimilarityThreshold,
@@ -35,6 +36,7 @@ func NewCacheBackend(config CacheConfig) (CacheBackend, error) {
3536
UseHNSW: config.UseHNSW,
3637
HNSWM: config.HNSWM,
3738
HNSWEfConstruction: config.HNSWEfConstruction,
39+
EmbeddingModel: config.EmbeddingModel,
3840
}
3941
return NewInMemoryCache(options), nil
4042

src/semantic-router/pkg/cache/cache_interface.go

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -116,4 +116,8 @@ type CacheConfig struct {
116116

117117
// Hybrid cache specific settings
118118
MaxMemoryEntries int `yaml:"max_memory_entries,omitempty"` // Max entries in HNSW for hybrid cache
119+
120+
// EmbeddingModel specifies which embedding model to use
121+
// Options: "bert" (default), "qwen3", "gemma"
122+
EmbeddingModel string `yaml:"embedding_model,omitempty"`
119123
}

src/semantic-router/pkg/cache/cache_test.go

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,7 @@ var _ = Describe("Cache Package", func() {
5151
SimilarityThreshold: 0.8,
5252
MaxEntries: 1000,
5353
TTLSeconds: 3600,
54+
EmbeddingModel: "bert",
5455
}
5556

5657
backend, err := cache.NewCacheBackend(config)
@@ -66,6 +67,7 @@ var _ = Describe("Cache Package", func() {
6667
SimilarityThreshold: 0.8,
6768
MaxEntries: 1000,
6869
TTLSeconds: 3600,
70+
EmbeddingModel: "bert",
6971
}
7072

7173
backend, err := cache.NewCacheBackend(config)
@@ -81,6 +83,7 @@ var _ = Describe("Cache Package", func() {
8183
SimilarityThreshold: 0.8,
8284
MaxEntries: 500,
8385
TTLSeconds: 1800,
86+
EmbeddingModel: "bert",
8487
}
8588

8689
backend, err := cache.NewCacheBackend(config)
@@ -142,6 +145,7 @@ development:
142145
SimilarityThreshold: 0.85,
143146
TTLSeconds: 7200,
144147
BackendConfigPath: milvusConfigPath,
148+
EmbeddingModel: "bert",
145149
}
146150

147151
backend, err := cache.NewCacheBackend(config)
@@ -169,6 +173,7 @@ development:
169173
SimilarityThreshold: 0.8,
170174
TTLSeconds: 3600,
171175
BackendConfigPath: milvusConfigPath,
176+
EmbeddingModel: "bert",
172177
}
173178

174179
backend, err := cache.NewCacheBackend(config)
@@ -223,6 +228,7 @@ connection:
223228
Enabled: true,
224229
SimilarityThreshold: 0.8,
225230
TTLSeconds: 3600,
231+
EmbeddingModel: "bert",
226232
}
227233

228234
backend, err := cache.NewCacheBackend(config)
@@ -240,6 +246,7 @@ connection:
240246
SimilarityThreshold: -0.8, // invalid
241247
MaxEntries: 10,
242248
TTLSeconds: -1, // invalid
249+
EmbeddingModel: "bert",
243250
}
244251

245252
backend, err := cache.NewCacheBackend(config)
@@ -259,6 +266,7 @@ connection:
259266
SimilarityThreshold: 0.8,
260267
MaxEntries: 1000,
261268
TTLSeconds: 3600,
269+
EmbeddingModel: "bert",
262270
EvictionPolicy: "lru",
263271
}
264272

@@ -285,6 +293,7 @@ connection:
285293
SimilarityThreshold: 1.5, // Invalid: > 1.0
286294
MaxEntries: 1000,
287295
TTLSeconds: 3600,
296+
EmbeddingModel: "bert",
288297
}
289298

290299
err := cache.ValidateCacheConfig(config)
@@ -299,6 +308,7 @@ connection:
299308
SimilarityThreshold: -0.1, // Invalid: < 0.0
300309
MaxEntries: 1000,
301310
TTLSeconds: 3600,
311+
EmbeddingModel: "bert",
302312
}
303313

304314
err := cache.ValidateCacheConfig(config)
@@ -313,6 +323,7 @@ connection:
313323
SimilarityThreshold: 0.8,
314324
MaxEntries: 1000,
315325
TTLSeconds: -1, // Invalid: negative TTL
326+
EmbeddingModel: "bert",
316327
}
317328

318329
err := cache.ValidateCacheConfig(config)
@@ -327,6 +338,7 @@ connection:
327338
SimilarityThreshold: 0.8,
328339
MaxEntries: -1, // Invalid: negative max entries
329340
TTLSeconds: 3600,
341+
EmbeddingModel: "bert",
330342
}
331343

332344
err := cache.ValidateCacheConfig(config)
@@ -341,6 +353,7 @@ connection:
341353
SimilarityThreshold: 0.8,
342354
MaxEntries: 1000,
343355
TTLSeconds: 3600,
356+
EmbeddingModel: "bert",
344357
EvictionPolicy: "random", // unsupported
345358
}
346359

@@ -355,6 +368,7 @@ connection:
355368
Enabled: true,
356369
SimilarityThreshold: 0.8,
357370
TTLSeconds: 3600,
371+
EmbeddingModel: "bert",
358372
// BackendConfigPath is missing
359373
}
360374

@@ -369,6 +383,7 @@ connection:
369383
Enabled: true,
370384
SimilarityThreshold: 0.8,
371385
TTLSeconds: 3600,
386+
EmbeddingModel: "bert",
372387
BackendConfigPath: "/nonexistent/milvus.yaml",
373388
}
374389

@@ -397,6 +412,7 @@ connection:
397412
SimilarityThreshold: 1.0, // Valid: maximum threshold
398413
MaxEntries: 10000,
399414
TTLSeconds: 86400,
415+
EmbeddingModel: "bert",
400416
}
401417

402418
err := cache.ValidateCacheConfig(config)
@@ -455,6 +471,7 @@ connection:
455471
SimilarityThreshold: 0.8,
456472
MaxEntries: 100,
457473
TTLSeconds: 300,
474+
EmbeddingModel: "bert",
458475
}
459476
inMemoryCache = cache.NewInMemoryCache(options)
460477
})
@@ -481,6 +498,7 @@ connection:
481498
SimilarityThreshold: 0.8,
482499
MaxEntries: 100,
483500
TTLSeconds: 300,
501+
EmbeddingModel: "bert",
484502
}
485503
disabledCache := cache.NewInMemoryCache(disabledOptions)
486504
defer disabledCache.Close()
@@ -548,6 +566,7 @@ connection:
548566
SimilarityThreshold: 0.8,
549567
MaxEntries: 100,
550568
TTLSeconds: 1,
569+
EmbeddingModel: "bert",
551570
})
552571

553572
err := inMemoryCache.AddPendingRequest("expired-request-id", "test-model", "stale query", []byte("request"))
@@ -571,6 +590,7 @@ connection:
571590
SimilarityThreshold: 0.99, // Very high threshold
572591
MaxEntries: 100,
573592
TTLSeconds: 300,
593+
EmbeddingModel: "bert",
574594
}
575595
highThresholdCache := cache.NewInMemoryCache(highThresholdOptions)
576596
defer highThresholdCache.Close()
@@ -621,6 +641,7 @@ connection:
621641
SimilarityThreshold: 0.1,
622642
MaxEntries: 10,
623643
TTLSeconds: 1,
644+
EmbeddingModel: "bert",
624645
})
625646
defer ttlCache.Close()
626647

@@ -660,6 +681,7 @@ connection:
660681
SimilarityThreshold: 0.8,
661682
MaxEntries: 100,
662683
TTLSeconds: 300,
684+
EmbeddingModel: "bert",
663685
}
664686
disabledCache := cache.NewInMemoryCache(disabledOptions)
665687
defer disabledCache.Close()
@@ -703,6 +725,7 @@ connection:
703725
SimilarityThreshold: 0.9,
704726
MaxEntries: 2000,
705727
TTLSeconds: 7200,
728+
EmbeddingModel: "bert",
706729
BackendConfigPath: "config/cache/milvus.yaml",
707730
}
708731

0 commit comments

Comments
 (0)