Skip to content

Commit b0344a6

Browse files
committed
update embedding setting in config (vllm-project#489)
Signed-off-by: Huamin Chen <[email protected]>
1 parent 3d3fe4a commit b0344a6

File tree

16 files changed

+225
-156
lines changed

16 files changed

+225
-156
lines changed

candle-binding/Cargo.lock

Lines changed: 108 additions & 131 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

candle-binding/Cargo.toml

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -18,9 +18,9 @@ flash-attn = ["candle-flash-attn"]
1818

1919
[dependencies]
2020
anyhow = { version = "1", features = ["backtrace"] }
21-
candle-core = "0.8.4"
22-
candle-nn = "0.8.4"
23-
candle-transformers = "0.8.4"
21+
candle-core = { version = "0.8.4", features = ["cuda"] }
22+
candle-nn = { version = "0.8.4", features = ["cuda"] }
23+
candle-transformers = { version = "0.8.4", features = ["cuda"] }
2424
# Flash Attention 2 (optional, requires CUDA)
2525
# Reference: https://github.com/huggingface/candle/tree/main/candle-flash-attn
2626
candle-flash-attn = { version = "0.8.4", optional = true }

candle-binding/src/model_architectures/embedding/qwen3_embedding.rs

Lines changed: 5 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1326,7 +1326,7 @@ impl Qwen3Attention {
13261326
q: &Tensor,
13271327
k: &Tensor,
13281328
v: &Tensor,
1329-
attention_mask: Option<&Tensor>,
1329+
_attention_mask: Option<&Tensor>,
13301330
) -> UnifiedResult<Tensor> {
13311331
// Flash Attention 2 implementation using candle-flash-attn
13321332
//
@@ -1363,8 +1363,8 @@ impl Qwen3Attention {
13631363
&q_flash,
13641364
&k_flash,
13651365
&v_flash,
1366-
self.scale as f32, // softmax scaling factor
1367-
false, // causal: false (Qwen3-Embedding is non-causal)
1366+
self.scaling as f32, // softmax scaling factor
1367+
false, // causal: false (Qwen3-Embedding is non-causal)
13681368
)
13691369
.map_err(|e| UnifiedError::Processing {
13701370
operation: "Flash Attention 2: flash_attn".to_string(),
@@ -1975,15 +1975,11 @@ impl Qwen3EmbeddingModel {
19751975
#[cfg(not(feature = "flash-attn"))]
19761976
{
19771977
if config.max_position_embeddings > 8192 {
1978-
eprintln!("⚠️ WARNING: Flash Attention 2 not enabled!");
1978+
eprintln!("ℹ️ Note: Using standard attention");
19791979
eprintln!(
1980-
" For {}K sequence length, performance may degrade:",
1980+
" Sequence length: {}K tokens",
19811981
config.max_position_embeddings / 1024
19821982
);
1983-
eprintln!(" - Memory usage: +40% (estimated)");
1984-
eprintln!(" - Inference speed: -50% (estimated)");
1985-
eprintln!(" Official recommendation: Compile with --features flash-attn");
1986-
eprintln!(" Reference: https://github.com/qwenlm/qwen3-embedding#usage");
19871983
}
19881984
}
19891985

config/config.yaml

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,10 @@ semantic_cache:
1010
max_entries: 1000 # Only applies to memory backend
1111
ttl_seconds: 3600
1212
eviction_policy: "fifo"
13+
# Embedding model for semantic similarity matching
14+
# Options: "bert" (fast, 384-dim), "qwen3" (high quality, 1024-dim, 32K context), "gemma" (balanced, 768-dim, 8K context)
15+
# Default: "bert" (fastest, lowest memory)
16+
embedding_model: "bert"
1317

1418
tools:
1519
enabled: true

deploy/kubernetes/config.yaml

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,10 @@ semantic_cache:
99
similarity_threshold: 0.8
1010
max_entries: 1000 # Only applies to memory backend
1111
ttl_seconds: 3600
12-
eviction_policy: "fifo"
12+
eviction_policy: "fifo"
13+
# Embedding model for semantic similarity matching
14+
# Options: "bert" (fast, 384-dim), "qwen3" (high quality, 1024-dim, 32K context), "gemma" (balanced, 768-dim, 8K context)
15+
embedding_model: "bert" # Default: BERT (fastest, lowest memory for Kubernetes)
1316

1417
tools:
1518
enabled: true

deploy/openshift/config-openshift.yaml

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,9 @@ semantic_cache:
1010
max_entries: 1000 # Only applies to memory backend
1111
ttl_seconds: 3600
1212
eviction_policy: "fifo"
13+
# Embedding model for semantic similarity matching
14+
# Options: "bert" (fast, 384-dim), "qwen3" (high quality, 1024-dim, 32K context), "gemma" (balanced, 768-dim, 8K context)
15+
embedding_model: "bert" # Default: BERT (fastest, lowest memory for OpenShift)
1316

1417
tools:
1518
enabled: true

src/semantic-router/pkg/apis/vllm.ai/v1alpha1/filter_types.go

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -117,6 +117,12 @@ type SemanticCacheConfig struct {
117117
// +kubebuilder:default=memory
118118
Backend *string `json:"backend,omitempty"`
119119

120+
// EmbeddingModel defines which embedding model to use for semantic similarity
121+
// +optional
122+
// +kubebuilder:validation:Enum=bert;qwen3;gemma
123+
// +kubebuilder:default=bert
124+
EmbeddingModel *string `json:"embeddingModel,omitempty"`
125+
120126
// BackendConfig defines backend-specific configuration
121127
// +optional
122128
BackendConfig map[string]string `json:"backendConfig,omitempty"`

src/semantic-router/pkg/cache/cache_factory.go

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -24,14 +24,16 @@ func NewCacheBackend(config CacheConfig) (CacheBackend, error) {
2424
switch config.BackendType {
2525
case InMemoryCacheType, "":
2626
// Use in-memory cache as the default backend
27-
observability.Debugf("Creating in-memory cache backend - MaxEntries: %d, TTL: %ds, Threshold: %.3f",
28-
config.MaxEntries, config.TTLSeconds, config.SimilarityThreshold)
27+
observability.Debugf("Creating in-memory cache backend - MaxEntries: %d, TTL: %ds, Threshold: %.3f, EmbeddingModel: %s",
28+
config.MaxEntries, config.TTLSeconds, config.SimilarityThreshold, config.EmbeddingModel)
29+
2930
options := InMemoryCacheOptions{
3031
Enabled: config.Enabled,
3132
SimilarityThreshold: config.SimilarityThreshold,
3233
MaxEntries: config.MaxEntries,
3334
TTLSeconds: config.TTLSeconds,
3435
EvictionPolicy: config.EvictionPolicy,
36+
EmbeddingModel: config.EmbeddingModel,
3537
}
3638
return NewInMemoryCache(options), nil
3739

src/semantic-router/pkg/cache/cache_interface.go

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -96,4 +96,8 @@ type CacheConfig struct {
9696

9797
// BackendConfigPath points to backend-specific configuration files
9898
BackendConfigPath string `yaml:"backend_config_path,omitempty"`
99+
100+
// EmbeddingModel specifies which embedding model to use
101+
// Options: "bert" (default), "qwen3", "gemma"
102+
EmbeddingModel string `yaml:"embedding_model,omitempty"`
99103
}

src/semantic-router/pkg/cache/cache_test.go

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,7 @@ var _ = Describe("Cache Package", func() {
5050
SimilarityThreshold: 0.8,
5151
MaxEntries: 1000,
5252
TTLSeconds: 3600,
53+
EmbeddingModel: "bert",
5354
}
5455

5556
backend, err := cache.NewCacheBackend(config)
@@ -65,6 +66,7 @@ var _ = Describe("Cache Package", func() {
6566
SimilarityThreshold: 0.8,
6667
MaxEntries: 1000,
6768
TTLSeconds: 3600,
69+
EmbeddingModel: "bert",
6870
}
6971

7072
backend, err := cache.NewCacheBackend(config)
@@ -80,6 +82,7 @@ var _ = Describe("Cache Package", func() {
8082
SimilarityThreshold: 0.8,
8183
MaxEntries: 500,
8284
TTLSeconds: 1800,
85+
EmbeddingModel: "bert",
8386
}
8487

8588
backend, err := cache.NewCacheBackend(config)
@@ -141,6 +144,7 @@ development:
141144
SimilarityThreshold: 0.85,
142145
TTLSeconds: 7200,
143146
BackendConfigPath: milvusConfigPath,
147+
EmbeddingModel: "bert",
144148
}
145149

146150
backend, err := cache.NewCacheBackend(config)
@@ -168,6 +172,7 @@ development:
168172
SimilarityThreshold: 0.8,
169173
TTLSeconds: 3600,
170174
BackendConfigPath: milvusConfigPath,
175+
EmbeddingModel: "bert",
171176
}
172177

173178
backend, err := cache.NewCacheBackend(config)
@@ -184,6 +189,7 @@ development:
184189
Enabled: true,
185190
SimilarityThreshold: 0.8,
186191
TTLSeconds: 3600,
192+
EmbeddingModel: "bert",
187193
}
188194

189195
backend, err := cache.NewCacheBackend(config)
@@ -201,6 +207,7 @@ development:
201207
SimilarityThreshold: -0.8, // invalid
202208
MaxEntries: 10,
203209
TTLSeconds: -1, // invalid
210+
EmbeddingModel: "bert",
204211
}
205212

206213
backend, err := cache.NewCacheBackend(config)
@@ -220,6 +227,7 @@ development:
220227
SimilarityThreshold: 0.8,
221228
MaxEntries: 1000,
222229
TTLSeconds: 3600,
230+
EmbeddingModel: "bert",
223231
EvictionPolicy: "lru",
224232
}
225233

@@ -246,6 +254,7 @@ development:
246254
SimilarityThreshold: 1.5, // Invalid: > 1.0
247255
MaxEntries: 1000,
248256
TTLSeconds: 3600,
257+
EmbeddingModel: "bert",
249258
}
250259

251260
err := cache.ValidateCacheConfig(config)
@@ -260,6 +269,7 @@ development:
260269
SimilarityThreshold: -0.1, // Invalid: < 0.0
261270
MaxEntries: 1000,
262271
TTLSeconds: 3600,
272+
EmbeddingModel: "bert",
263273
}
264274

265275
err := cache.ValidateCacheConfig(config)
@@ -274,6 +284,7 @@ development:
274284
SimilarityThreshold: 0.8,
275285
MaxEntries: 1000,
276286
TTLSeconds: -1, // Invalid: negative TTL
287+
EmbeddingModel: "bert",
277288
}
278289

279290
err := cache.ValidateCacheConfig(config)
@@ -288,6 +299,7 @@ development:
288299
SimilarityThreshold: 0.8,
289300
MaxEntries: -1, // Invalid: negative max entries
290301
TTLSeconds: 3600,
302+
EmbeddingModel: "bert",
291303
}
292304

293305
err := cache.ValidateCacheConfig(config)
@@ -302,6 +314,7 @@ development:
302314
SimilarityThreshold: 0.8,
303315
MaxEntries: 1000,
304316
TTLSeconds: 3600,
317+
EmbeddingModel: "bert",
305318
EvictionPolicy: "random", // unsupported
306319
}
307320

@@ -316,6 +329,7 @@ development:
316329
Enabled: true,
317330
SimilarityThreshold: 0.8,
318331
TTLSeconds: 3600,
332+
EmbeddingModel: "bert",
319333
// BackendConfigPath is missing
320334
}
321335

@@ -330,6 +344,7 @@ development:
330344
Enabled: true,
331345
SimilarityThreshold: 0.8,
332346
TTLSeconds: 3600,
347+
EmbeddingModel: "bert",
333348
BackendConfigPath: "/nonexistent/milvus.yaml",
334349
}
335350

@@ -358,6 +373,7 @@ development:
358373
SimilarityThreshold: 1.0, // Valid: maximum threshold
359374
MaxEntries: 10000,
360375
TTLSeconds: 86400,
376+
EmbeddingModel: "bert",
361377
}
362378

363379
err := cache.ValidateCacheConfig(config)
@@ -416,6 +432,7 @@ development:
416432
SimilarityThreshold: 0.8,
417433
MaxEntries: 100,
418434
TTLSeconds: 300,
435+
EmbeddingModel: "bert",
419436
}
420437
inMemoryCache = cache.NewInMemoryCache(options)
421438
})
@@ -442,6 +459,7 @@ development:
442459
SimilarityThreshold: 0.8,
443460
MaxEntries: 100,
444461
TTLSeconds: 300,
462+
EmbeddingModel: "bert",
445463
}
446464
disabledCache := cache.NewInMemoryCache(disabledOptions)
447465
defer disabledCache.Close()
@@ -509,6 +527,7 @@ development:
509527
SimilarityThreshold: 0.8,
510528
MaxEntries: 100,
511529
TTLSeconds: 1,
530+
EmbeddingModel: "bert",
512531
})
513532

514533
err := inMemoryCache.AddPendingRequest("expired-request-id", "test-model", "stale query", []byte("request"))
@@ -532,6 +551,7 @@ development:
532551
SimilarityThreshold: 0.99, // Very high threshold
533552
MaxEntries: 100,
534553
TTLSeconds: 300,
554+
EmbeddingModel: "bert",
535555
}
536556
highThresholdCache := cache.NewInMemoryCache(highThresholdOptions)
537557
defer highThresholdCache.Close()
@@ -582,6 +602,7 @@ development:
582602
SimilarityThreshold: 0.1,
583603
MaxEntries: 10,
584604
TTLSeconds: 1,
605+
EmbeddingModel: "bert",
585606
})
586607
defer ttlCache.Close()
587608

@@ -621,6 +642,7 @@ development:
621642
SimilarityThreshold: 0.8,
622643
MaxEntries: 100,
623644
TTLSeconds: 300,
645+
EmbeddingModel: "bert",
624646
}
625647
disabledCache := cache.NewInMemoryCache(disabledOptions)
626648
defer disabledCache.Close()
@@ -664,6 +686,7 @@ development:
664686
SimilarityThreshold: 0.9,
665687
MaxEntries: 2000,
666688
TTLSeconds: 7200,
689+
EmbeddingModel: "bert",
667690
BackendConfigPath: "config/cache/milvus.yaml",
668691
}
669692

0 commit comments

Comments
 (0)