Skip to content

Commit 84cf78a

Browse files
authored
[Model] Pooling models default to using chunked prefill & prefix caching if supported. (#20930)
Signed-off-by: wang.yuqi <[email protected]>
1 parent 16fb668 commit 84cf78a

31 files changed

+452
-261
lines changed

tests/entrypoints/llm/test_classify.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -65,3 +65,9 @@ def get_outputs(activation):
6565
assert torch.allclose(
6666
softmax(wo_activation), w_activation, atol=1e-2
6767
), "w_activation should be close to activation(wo_activation)."
68+
69+
70+
def test_encode_api(llm: LLM):
71+
err_msg = "pooling_task must be one of.+"
72+
with pytest.raises(ValueError, match=err_msg):
73+
llm.encode(prompts, use_tqdm=False)

tests/entrypoints/openai/test_classification.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -211,3 +211,18 @@ async def get_outputs(activation):
211211
assert torch.allclose(
212212
F.softmax(wo_activation, dim=-1), w_activation, atol=1e-2
213213
), "w_activation should be close to activation(wo_activation)."
214+
215+
216+
@pytest.mark.asyncio
217+
@pytest.mark.parametrize("model_name", [MODEL_NAME])
218+
def test_pooling(server: RemoteOpenAIServer, model_name: str):
219+
# pooling api uses ALL pooling, which does not support chunked prefill.
220+
response = requests.post(
221+
server.url_for("pooling"),
222+
json={
223+
"model": model_name,
224+
"input": "test",
225+
"encoding_format": "float"
226+
},
227+
)
228+
assert response.json()["error"]["type"] == "BadRequestError"

tests/models/language/pooling/mteb_utils.py

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -177,9 +177,12 @@ def mteb_test_embed_models(hf_runner,
177177
max_model_len=None,
178178
**vllm_extra_kwargs) as vllm_model:
179179

180+
model_config = vllm_model.llm.llm_engine.model_config
181+
180182
if model_info.architecture:
181-
assert (model_info.architecture
182-
in vllm_model.llm.llm_engine.model_config.architectures)
183+
assert model_info.architecture in model_config.architectures
184+
assert (model_config._model_info.default_pooling_type ==
185+
model_info.default_pooling_type)
183186

184187
vllm_main_score = run_mteb_embed_task(VllmMtebEncoder(vllm_model),
185188
MTEB_EMBED_TASKS)
@@ -286,7 +289,12 @@ def mteb_test_rerank_models(hf_runner,
286289
**vllm_extra_kwargs) as vllm_model:
287290

288291
model_config = vllm_model.llm.llm_engine.model_config
292+
293+
if model_info.architecture:
294+
assert (model_info.architecture in model_config.architectures)
289295
assert model_config.hf_config.num_labels == 1
296+
assert (model_config._model_info.default_pooling_type ==
297+
model_info.default_pooling_type)
290298

291299
vllm_main_score = run_mteb_rerank(vllm_mteb_encoder(vllm_model),
292300
tasks=MTEB_RERANK_TASKS,
Lines changed: 93 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,93 @@
1+
# SPDX-License-Identifier: Apache-2.0
2+
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3+
import pytest
4+
import torch
5+
from transformers import AutoModelForSequenceClassification
6+
7+
from tests.models.language.pooling.embed_utils import (
8+
run_embedding_correctness_test)
9+
10+
11+
@pytest.mark.parametrize(
12+
"model",
13+
["jason9693/Qwen2.5-1.5B-apeach"],
14+
)
15+
@pytest.mark.parametrize("dtype", ["half"])
16+
def test_classify_models(
17+
hf_runner,
18+
vllm_runner,
19+
example_prompts,
20+
model: str,
21+
dtype: str,
22+
) -> None:
23+
24+
example_prompts = example_prompts * 2
25+
26+
with vllm_runner(model,
27+
max_model_len=512,
28+
dtype=dtype,
29+
enable_prefix_caching=True) as vllm_model:
30+
cache_config = vllm_model.llm.llm_engine.cache_config
31+
assert cache_config.enable_prefix_caching
32+
vllm_outputs = vllm_model.classify(example_prompts)
33+
34+
with hf_runner(model,
35+
dtype=dtype,
36+
auto_cls=AutoModelForSequenceClassification) as hf_model:
37+
hf_outputs = hf_model.classify(example_prompts)
38+
39+
for hf_output, vllm_output in zip(hf_outputs, vllm_outputs):
40+
hf_output = torch.tensor(hf_output)
41+
vllm_output = torch.tensor(vllm_output)
42+
43+
assert torch.allclose(hf_output, vllm_output,
44+
1e-3 if dtype == "float" else 1e-2)
45+
46+
47+
@pytest.mark.parametrize(
48+
"model",
49+
["Qwen/Qwen3-Embedding-0.6B"],
50+
)
51+
@pytest.mark.parametrize("dtype", ["half"])
52+
def test_embed_models(
53+
hf_runner,
54+
vllm_runner,
55+
example_prompts,
56+
model: str,
57+
dtype: str,
58+
):
59+
example_prompts = [str(s).strip() for s in example_prompts] * 2
60+
61+
with vllm_runner(
62+
model,
63+
runner="pooling",
64+
max_model_len=None,
65+
enable_prefix_caching=True,
66+
) as vllm_model:
67+
cache_config = vllm_model.llm.llm_engine.cache_config
68+
assert cache_config.enable_prefix_caching
69+
vllm_outputs = vllm_model.embed(example_prompts)
70+
71+
with hf_runner(
72+
model,
73+
is_sentence_transformer=True,
74+
) as hf_model:
75+
run_embedding_correctness_test(hf_model, example_prompts, vllm_outputs)
76+
77+
78+
@pytest.mark.parametrize(
79+
"model",
80+
[
81+
"intfloat/e5-small",
82+
"Alibaba-NLP/gte-Qwen2-1.5B-instruct", # is_causal == False
83+
"papluca/xlm-roberta-base-language-detection",
84+
])
85+
@pytest.mark.parametrize("dtype", ["half"])
86+
def test_non_causal_models(hf_runner, vllm_runner, example_prompts, model: str,
87+
dtype: str) -> None:
88+
with vllm_runner(model,
89+
max_model_len=512,
90+
dtype=dtype,
91+
enable_prefix_caching=True) as vllm_model:
92+
cache_config = vllm_model.llm.llm_engine.cache_config
93+
assert not cache_config.enable_prefix_caching

tests/models/language/pooling/test_baai.py

Lines changed: 61 additions & 56 deletions
Original file line numberDiff line numberDiff line change
@@ -2,73 +2,78 @@
22
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
33
import pytest
44

5-
from ...utils import EmbedModelInfo, RerankModelInfo
5+
from ...utils import (CLSPoolingEmbedModelInfo, CLSPoolingRerankModelInfo,
6+
EmbedModelInfo, LASTPoolingEmbedModelInfo,
7+
RerankModelInfo)
68
from .embed_utils import correctness_test_embed_models
79
from .mteb_utils import mteb_test_embed_models, mteb_test_rerank_models
810

911
MODELS = [
1012
########## BertModel
11-
EmbedModelInfo("BAAI/bge-base-en",
12-
architecture="BertModel",
13-
enable_test=True),
14-
EmbedModelInfo("BAAI/bge-base-zh",
15-
architecture="BertModel",
16-
enable_test=False),
17-
EmbedModelInfo("BAAI/bge-small-en",
18-
architecture="BertModel",
19-
enable_test=False),
20-
EmbedModelInfo("BAAI/bge-small-zh",
21-
architecture="BertModel",
22-
enable_test=False),
23-
EmbedModelInfo("BAAI/bge-large-en",
24-
architecture="BertModel",
25-
enable_test=False),
26-
EmbedModelInfo("BAAI/bge-large-zh",
27-
architecture="BertModel",
28-
enable_test=False),
29-
EmbedModelInfo("BAAI/bge-large-zh-noinstruct",
30-
architecture="BertModel",
31-
enable_test=False),
32-
EmbedModelInfo("BAAI/bge-base-en-v1.5",
33-
architecture="BertModel",
34-
enable_test=False),
35-
EmbedModelInfo("BAAI/bge-base-zh-v1.5",
36-
architecture="BertModel",
37-
enable_test=False),
38-
EmbedModelInfo("BAAI/bge-small-en-v1.5",
39-
architecture="BertModel",
40-
enable_test=False),
41-
EmbedModelInfo("BAAI/bge-small-zh-v1.5",
42-
architecture="BertModel",
43-
enable_test=False),
44-
EmbedModelInfo("BAAI/bge-large-en-v1.5",
45-
architecture="BertModel",
46-
enable_test=False),
47-
EmbedModelInfo("BAAI/bge-large-zh-v1.5",
48-
architecture="BertModel",
49-
enable_test=False),
13+
CLSPoolingEmbedModelInfo("BAAI/bge-base-en",
14+
architecture="BertModel",
15+
enable_test=True),
16+
CLSPoolingEmbedModelInfo("BAAI/bge-base-zh",
17+
architecture="BertModel",
18+
enable_test=False),
19+
CLSPoolingEmbedModelInfo("BAAI/bge-small-en",
20+
architecture="BertModel",
21+
enable_test=False),
22+
CLSPoolingEmbedModelInfo("BAAI/bge-small-zh",
23+
architecture="BertModel",
24+
enable_test=False),
25+
CLSPoolingEmbedModelInfo("BAAI/bge-large-en",
26+
architecture="BertModel",
27+
enable_test=False),
28+
CLSPoolingEmbedModelInfo("BAAI/bge-large-zh",
29+
architecture="BertModel",
30+
enable_test=False),
31+
CLSPoolingEmbedModelInfo("BAAI/bge-large-zh-noinstruct",
32+
architecture="BertModel",
33+
enable_test=False),
34+
CLSPoolingEmbedModelInfo("BAAI/bge-base-en-v1.5",
35+
architecture="BertModel",
36+
enable_test=False),
37+
CLSPoolingEmbedModelInfo("BAAI/bge-base-zh-v1.5",
38+
architecture="BertModel",
39+
enable_test=False),
40+
CLSPoolingEmbedModelInfo("BAAI/bge-small-en-v1.5",
41+
architecture="BertModel",
42+
enable_test=False),
43+
CLSPoolingEmbedModelInfo("BAAI/bge-small-zh-v1.5",
44+
architecture="BertModel",
45+
enable_test=False),
46+
CLSPoolingEmbedModelInfo("BAAI/bge-large-en-v1.5",
47+
architecture="BertModel",
48+
enable_test=False),
49+
CLSPoolingEmbedModelInfo("BAAI/bge-large-zh-v1.5",
50+
architecture="BertModel",
51+
enable_test=False),
5052
########## XLMRobertaModel
51-
EmbedModelInfo("BAAI/bge-m3",
52-
architecture="XLMRobertaModel",
53-
enable_test=True),
53+
CLSPoolingEmbedModelInfo("BAAI/bge-m3",
54+
architecture="XLMRobertaModel",
55+
enable_test=True),
5456
########## Qwen2Model
55-
EmbedModelInfo("BAAI/bge-code-v1",
56-
architecture="Qwen2Model",
57-
dtype="float32",
58-
enable_test=True),
57+
LASTPoolingEmbedModelInfo("BAAI/bge-code-v1",
58+
architecture="Qwen2Model",
59+
dtype="float32",
60+
enable_test=True),
5961
]
6062

6163
RERANK_MODELS = [
6264
########## XLMRobertaForSequenceClassification
63-
RerankModelInfo("BAAI/bge-reranker-base",
64-
architecture="XLMRobertaForSequenceClassification",
65-
enable_test=True),
66-
RerankModelInfo("BAAI/bge-reranker-large",
67-
architecture="XLMRobertaForSequenceClassification",
68-
enable_test=False),
69-
RerankModelInfo("BAAI/bge-reranker-v2-m3",
70-
architecture="XLMRobertaForSequenceClassification",
71-
enable_test=False)
65+
CLSPoolingRerankModelInfo(
66+
"BAAI/bge-reranker-base",
67+
architecture="XLMRobertaForSequenceClassification",
68+
enable_test=True),
69+
CLSPoolingRerankModelInfo(
70+
"BAAI/bge-reranker-large",
71+
architecture="XLMRobertaForSequenceClassification",
72+
enable_test=False),
73+
CLSPoolingRerankModelInfo(
74+
"BAAI/bge-reranker-v2-m3",
75+
architecture="XLMRobertaForSequenceClassification",
76+
enable_test=False)
7277
]
7378

7479

tests/models/language/pooling/test_bge_reranker_v2_gemma.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -8,12 +8,12 @@
88

99
from tests.conftest import HfRunner
1010

11-
from .mteb_utils import (RerankModelInfo, VllmMtebEncoder,
12-
mteb_test_rerank_models)
11+
from ...utils import LASTPoolingRerankModelInfo, RerankModelInfo
12+
from .mteb_utils import VllmMtebEncoder, mteb_test_rerank_models
1313

1414
RERANK_MODELS = [
15-
RerankModelInfo("BAAI/bge-reranker-v2-gemma",
16-
architecture="GemmaForSequenceClassification"),
15+
LASTPoolingRerankModelInfo("BAAI/bge-reranker-v2-gemma",
16+
architecture="GemmaForSequenceClassification"),
1717
]
1818

1919
PROMPT = "Given a query A and a passage B, determine whether the passage contains an answer to the query by providing a prediction of either 'Yes' or 'No'." # noqa: E501

tests/models/language/pooling/test_cross_encoder.py

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -2,13 +2,15 @@
22
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
33
import pytest
44

5-
from .mteb_utils import RerankModelInfo, mteb_test_rerank_models
5+
from ...utils import (CLSPoolingRerankModelInfo, LASTPoolingRerankModelInfo,
6+
RerankModelInfo)
7+
from .mteb_utils import mteb_test_rerank_models
68

79
RERANK_MODELS = [
8-
RerankModelInfo("cross-encoder/ms-marco-TinyBERT-L-2-v2",
9-
architecture="BertForSequenceClassification"),
10-
RerankModelInfo("tomaarsen/Qwen3-Reranker-0.6B-seq-cls",
11-
architecture="Qwen3ForSequenceClassification")
10+
CLSPoolingRerankModelInfo("cross-encoder/ms-marco-TinyBERT-L-2-v2",
11+
architecture="BertForSequenceClassification"),
12+
LASTPoolingRerankModelInfo("tomaarsen/Qwen3-Reranker-0.6B-seq-cls",
13+
architecture="Qwen3ForSequenceClassification")
1214
]
1315

1416

0 commit comments

Comments
 (0)