Skip to content

Commit 0e3d5f7

Browse files
committed
[Bugfix] Fix Dense module loading for sentence-transformers embedding models v4
Signed-off-by: FFFfff1FFFfff <[email protected]>
1 parent ea1292a commit 0e3d5f7

File tree

6 files changed

+542
-134
lines changed

6 files changed

+542
-134
lines changed

requirements/test.txt

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -968,7 +968,6 @@ setuptools==77.0.3
968968
# via
969969
# lightning-utilities
970970
# pytablewriter
971-
# torch
972971
# triton
973972
shapely==2.1.1
974973
# via
Lines changed: 184 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,184 @@
1+
# SPDX-License-Identifier: Apache-2.0
2+
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3+
from typing import Any
4+
5+
import numpy as np
6+
import pytest
7+
from scipy.spatial.distance import cosine
8+
9+
from ...utils import EmbedModelInfo
10+
11+
12+
def _get_vllm_embeddings(vllm_runner, model_info: EmbedModelInfo,
13+
test_texts: list[str]):
14+
"""Get embeddings from vLLM."""
15+
vllm_extra_kwargs: dict[str, Any] = {}
16+
if model_info.architecture == "GteNewModel":
17+
vllm_extra_kwargs["hf_overrides"] = {"architectures": ["GteNewModel"]}
18+
19+
with vllm_runner(
20+
model_info.name,
21+
runner="pooling",
22+
max_model_len=None,
23+
trust_remote_code=True,
24+
**vllm_extra_kwargs,
25+
) as vllm_model:
26+
embeddings = vllm_model.encode(test_texts)
27+
28+
# Extract tensor/numpy data
29+
data = []
30+
for emb in embeddings:
31+
if hasattr(emb, "outputs"):
32+
data.append(emb.outputs.data.cpu().numpy())
33+
else:
34+
data.append(emb.cpu().numpy() if hasattr(emb, "cpu") else emb)
35+
return np.array(data)
36+
37+
38+
def _get_hf_embeddings(hf_runner, model_info: EmbedModelInfo,
39+
test_texts: list[str]):
40+
"""Get embeddings from HuggingFace ST interface."""
41+
with hf_runner(
42+
model_info.name,
43+
is_sentence_transformer=True,
44+
dtype="float32",
45+
) as hf_model:
46+
embeddings = hf_model.encode(test_texts)
47+
if hasattr(embeddings, "cpu"):
48+
return embeddings.cpu().numpy()
49+
return np.array(embeddings)
50+
51+
52+
# ST models with projector (Dense) layers
53+
ST_PROJECTOR_MODELS = [
54+
EmbedModelInfo(
55+
"TencentBAC/Conan-embedding-v1",
56+
architecture="BertModel",
57+
enable_test=True,
58+
),
59+
]
60+
61+
62+
@pytest.mark.parametrize("model_info", ST_PROJECTOR_MODELS)
63+
def test_st_projector_loading(vllm_runner, model_info: EmbedModelInfo) -> None:
64+
"""Ensure projector models load and output expected dim."""
65+
if not model_info.enable_test:
66+
pytest.skip("Skipping test.")
67+
68+
test_texts = ["This is a test sentence."]
69+
embeddings_data = _get_vllm_embeddings(vllm_runner, model_info, test_texts)
70+
71+
actual_dim = embeddings_data.shape[-1]
72+
expected_dim = 1792
73+
assert actual_dim == expected_dim, (
74+
f"Expected {expected_dim}, got {actual_dim}")
75+
76+
77+
@pytest.mark.parametrize("model_info", ST_PROJECTOR_MODELS)
78+
def test_compare_with_hf_dimensions(hf_runner, vllm_runner,
79+
model_info: EmbedModelInfo) -> None:
80+
"""Compare embedding dimensions between vLLM and HuggingFace."""
81+
if not model_info.enable_test:
82+
pytest.skip("Skipping test.")
83+
84+
test_texts = ["This is a test sentence for dimension comparison."]
85+
86+
vllm_data = _get_vllm_embeddings(vllm_runner, model_info, test_texts)
87+
hf_data = _get_hf_embeddings(hf_runner, model_info, test_texts)
88+
89+
vllm_dim = vllm_data.shape[-1]
90+
hf_dim = hf_data.shape[-1]
91+
92+
assert vllm_dim == hf_dim, ("Embedding dim mismatch: "
93+
f"vLLM {vllm_dim} vs HF {hf_dim}")
94+
print(f"✓ Embedding dimensions match: {vllm_dim}")
95+
96+
97+
@pytest.mark.parametrize("model_info", ST_PROJECTOR_MODELS)
98+
def test_embedding_numerical_similarity(hf_runner, vllm_runner,
99+
model_info: EmbedModelInfo) -> None:
100+
"""Numerical similarity between vLLM and HF embeddings."""
101+
if not model_info.enable_test:
102+
pytest.skip("Skipping test.")
103+
104+
test_texts = [
105+
"This is a test sentence for numerical comparison.",
106+
"Another sentence to verify embedding quality.",
107+
"机器学习是人工智能的一个重要分支。", # Chinese test
108+
]
109+
110+
vllm_data = _get_vllm_embeddings(vllm_runner, model_info, test_texts)
111+
hf_data = _get_hf_embeddings(hf_runner, model_info, test_texts)
112+
113+
assert vllm_data.shape == hf_data.shape, (
114+
"Shape mismatch: "
115+
f"vLLM {vllm_data.shape} vs HF {hf_data.shape}")
116+
117+
print(f"Embedding shape: {vllm_data.shape}")
118+
print(f"Embedding dimension: {vllm_data.shape[-1]}")
119+
120+
similarities = []
121+
for i, text in enumerate(test_texts):
122+
vllm_emb = vllm_data[i]
123+
hf_emb = hf_data[i]
124+
125+
similarity = 1 - cosine(vllm_emb, hf_emb)
126+
similarities.append(similarity)
127+
128+
preview = text[:50] + ("..." if len(text) > 50 else "")
129+
print(f"Text {i + 1}: '{preview}'")
130+
print(f" Cosine similarity: {similarity:.6f}")
131+
132+
min_similarity = 0.95
133+
assert similarity > min_similarity, (
134+
f"Text {i + 1} similarity too low: "
135+
f"{similarity:.6f} < {min_similarity}\n"
136+
f"vLLM norm: {np.linalg.norm(vllm_emb):.6f}, "
137+
f"HF norm: {np.linalg.norm(hf_emb):.6f}")
138+
139+
avg_similarity = np.mean(similarities)
140+
print(f"\nAverage cosine similarity: {avg_similarity:.6f}")
141+
142+
assert avg_similarity > 0.98, (
143+
f"Average similarity too low: {avg_similarity:.6f} < 0.98")
144+
print("✓ All numerical similarity tests passed!")
145+
146+
147+
@pytest.mark.parametrize("model_info", ST_PROJECTOR_MODELS)
148+
def test_embedding_quality_checks(vllm_runner,
149+
model_info: EmbedModelInfo) -> None:
150+
"""Basic quality checks: non-zero, non-constant, distinct."""
151+
if not model_info.enable_test:
152+
pytest.skip("Skipping test.")
153+
154+
test_texts = [
155+
"First test sentence.",
156+
"Second different sentence.",
157+
"Completely different content here.",
158+
]
159+
160+
embeddings_data = _get_vllm_embeddings(vllm_runner, model_info, test_texts)
161+
162+
print(f"Embeddings shape: {embeddings_data.shape}")
163+
164+
# Non-zero and non-constant
165+
for i, emb in enumerate(embeddings_data):
166+
norm = np.linalg.norm(emb)
167+
print(f"Embedding {i + 1} L2 norm: {norm:.6f}")
168+
assert norm > 1e-6, (
169+
f"Embedding {i + 1} too close to zero: norm={norm}")
170+
171+
std = np.std(emb)
172+
print(f"Embedding {i + 1} std: {std:.6f}")
173+
assert std > 1e-6, (
174+
f"Embedding {i + 1} too close to constant: std={std}")
175+
176+
# Different texts should differ
177+
for i in range(len(embeddings_data)):
178+
for j in range(i + 1, len(embeddings_data)):
179+
sim = 1 - cosine(embeddings_data[i], embeddings_data[j])
180+
print(f"Similarity between text {i + 1} and {j + 1}: {sim:.6f}")
181+
assert sim < 0.99, ("Embeddings too similar: "
182+
f"{i + 1} vs {j + 1} -> {sim:.6f}")
183+
184+
print("✓ All embedding quality checks passed!")

0 commit comments

Comments
 (0)