Skip to content

Commit fbf722c

Browse files
authored
[Frontend] support matryoshka representation / support embedding API dimensions (#16331)
1 parent e92d708 commit fbf722c

File tree

11 files changed

+253
-22
lines changed

11 files changed

+253
-22
lines changed
Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,48 @@
1+
# SPDX-License-Identifier: Apache-2.0
2+
3+
from argparse import Namespace
4+
5+
from vllm import LLM, EngineArgs, PoolingParams
6+
from vllm.utils import FlexibleArgumentParser
7+
8+
9+
def main(args: Namespace):
10+
# Sample prompts.
11+
prompts = [
12+
"Follow the white rabbit.", # English
13+
"Sigue al conejo blanco.", # Spanish
14+
"Suis le lapin blanc.", # French
15+
"跟着白兔走。", # Chinese
16+
"اتبع الأرنب الأبيض.", # Arabic
17+
"Folge dem weißen Kaninchen.", # German
18+
]
19+
20+
# Create an LLM.
21+
# You should pass task="embed" for embedding models
22+
model = LLM(**vars(args))
23+
24+
# Generate embedding. The output is a list of EmbeddingRequestOutputs.
25+
outputs = model.embed(prompts, pooling_params=PoolingParams(dimensions=32))
26+
27+
# Print the outputs.
28+
print("\nGenerated Outputs:")
29+
print("-" * 60)
30+
for prompt, output in zip(prompts, outputs):
31+
embeds = output.outputs.embedding
32+
embeds_trimmed = ((str(embeds[:16])[:-1] +
33+
", ...]") if len(embeds) > 16 else embeds)
34+
print(f"Prompt: {prompt!r} \n"
35+
f"Embeddings: {embeds_trimmed} "
36+
f"(size={len(embeds)})")
37+
print("-" * 60)
38+
39+
40+
if __name__ == "__main__":
41+
parser = FlexibleArgumentParser()
42+
parser = EngineArgs.add_cli_args(parser)
43+
# Set example specific arguments
44+
parser.set_defaults(model="jinaai/jina-embeddings-v3",
45+
task="embed",
46+
trust_remote_code=True)
47+
args = parser.parse_args()
48+
main(args)

tests/conftest.py

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -960,19 +960,19 @@ def classify(self, prompts: list[str]) -> list[list[float]]:
960960
req_outputs = self.model.classify(prompts)
961961
return [req_output.outputs.probs for req_output in req_outputs]
962962

963-
def encode(
964-
self,
965-
prompts: list[str],
966-
images: Optional[PromptImageInput] = None,
967-
videos: Optional[PromptVideoInput] = None,
968-
audios: Optional[PromptAudioInput] = None,
969-
) -> list[list[float]]:
963+
def encode(self,
964+
prompts: list[str],
965+
images: Optional[PromptImageInput] = None,
966+
videos: Optional[PromptVideoInput] = None,
967+
audios: Optional[PromptAudioInput] = None,
968+
*args,
969+
**kwargs) -> list[list[float]]:
970970
inputs = self.get_inputs(prompts,
971971
images=images,
972972
videos=videos,
973973
audios=audios)
974974

975-
req_outputs = self.model.embed(inputs)
975+
req_outputs = self.model.embed(inputs, *args, **kwargs)
976976
return [req_output.outputs.embedding for req_output in req_outputs]
977977

978978
def score(
Lines changed: 82 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,82 @@
1+
# SPDX-License-Identifier: Apache-2.0
2+
"""
3+
Run `pytest tests/entrypoints/openai/test_embedding_dimensions.py`.
4+
"""
5+
6+
from typing import NamedTuple
7+
8+
import openai
9+
import pytest
10+
11+
from vllm.entrypoints.openai.protocol import EmbeddingResponse
12+
13+
from ...utils import RemoteOpenAIServer
14+
15+
16+
class ModelInfo(NamedTuple):
17+
name: str
18+
is_matryoshka: bool
19+
20+
21+
MODELS = [
22+
ModelInfo(name="BAAI/bge-m3", is_matryoshka=False),
23+
ModelInfo(name="jinaai/jina-embeddings-v3", is_matryoshka=True),
24+
]
25+
26+
input_texts = [
27+
"The chef prepared a delicious meal.",
28+
] * 3
29+
30+
31+
@pytest.mark.asyncio
32+
@pytest.mark.parametrize("model", MODELS)
33+
async def test_validating_dimensions(model: ModelInfo):
34+
args = [
35+
"--task",
36+
"embed",
37+
# use half precision for speed and memory savings in CI environment
38+
"--dtype",
39+
"bfloat16",
40+
"--enforce-eager",
41+
"--max-model-len",
42+
"512",
43+
"--trust_remote_code"
44+
]
45+
with RemoteOpenAIServer(model.name, args) as remote_server:
46+
client = remote_server.get_async_client()
47+
48+
async def make_request(dimensions):
49+
embedding_response = await client.embeddings.create(
50+
model=model.name,
51+
input=input_texts,
52+
dimensions=dimensions,
53+
encoding_format="float",
54+
)
55+
embeddings = EmbeddingResponse.model_validate(
56+
embedding_response.model_dump(mode="json"))
57+
58+
assert embeddings.id is not None
59+
assert len(embeddings.data) == 3
60+
assert len(embeddings.data[0].embedding) > 0
61+
assert embeddings.usage.completion_tokens == 0
62+
assert embeddings.usage.prompt_tokens > 0
63+
assert embeddings.usage.total_tokens > 0
64+
65+
if dimensions is not None:
66+
assert len(embeddings.data[0].embedding) == dimensions
67+
68+
if model.is_matryoshka:
69+
for dimensions in [None, 16]:
70+
await make_request(dimensions)
71+
72+
with pytest.raises(openai.BadRequestError):
73+
for dimensions in [-1]:
74+
await make_request(dimensions)
75+
76+
else:
77+
for dimensions in [None]:
78+
await make_request(dimensions)
79+
80+
with pytest.raises(openai.BadRequestError):
81+
for dimensions in [-1, 16]:
82+
await make_request(dimensions)

tests/models/embedding/language/test_jina.py

Lines changed: 39 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,8 @@
88

99
import pytest
1010

11-
from tests.models.embedding.utils import check_embeddings_close
11+
from tests.models.embedding.utils import check_embeddings_close, matryoshka_fy
12+
from vllm import PoolingParams
1213

1314
SCORING_MODELS = [
1415
"jinaai/jina-reranker-v2-base-multilingual", # Roberta
@@ -126,3 +127,40 @@ def test_embeddings(
126127
name_1="vllm",
127128
tol=1e-2,
128129
)
130+
131+
132+
@pytest.mark.parametrize("model", EMBEDDING_MODELS)
133+
@pytest.mark.parametrize("dtype", ["half"])
134+
@pytest.mark.parametrize("dimensions", [16, 32])
135+
def test_matryoshka(
136+
hf_runner,
137+
vllm_runner,
138+
model,
139+
dtype: str,
140+
dimensions: int,
141+
monkeypatch,
142+
) -> None:
143+
144+
example_prompts = EMBEDDING_PROMPTS
145+
146+
with hf_runner(
147+
model,
148+
dtype=dtype,
149+
is_sentence_transformer=True,
150+
) as hf_model:
151+
hf_outputs = hf_model.encode(example_prompts, task="text-matching")
152+
hf_outputs = matryoshka_fy(hf_outputs, dimensions)
153+
154+
with vllm_runner(model, task="embed", dtype=dtype,
155+
max_model_len=None) as vllm_model:
156+
vllm_outputs = vllm_model.encode(
157+
example_prompts,
158+
pooling_params=PoolingParams(dimensions=dimensions))
159+
160+
check_embeddings_close(
161+
embeddings_0_lst=hf_outputs,
162+
embeddings_1_lst=vllm_outputs,
163+
name_0="hf",
164+
name_1="vllm",
165+
tol=1e-2,
166+
)

tests/models/embedding/utils.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,3 +30,10 @@ def check_embeddings_close(
3030
f"\n{name_1}:\t{embeddings_1[:16]!r}")
3131

3232
assert sim >= 1 - tol, fail_msg
33+
34+
35+
def matryoshka_fy(tensor, dimensions):
36+
tensor = torch.tensor(tensor)
37+
tensor = tensor[..., :dimensions]
38+
tensor = F.normalize(tensor, p=2, dim=1)
39+
return tensor

vllm/config.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -583,6 +583,15 @@ def _init_pooler_config(
583583
if getattr(user_config, k) is None:
584584
setattr(user_config, k, v)
585585

586+
if self.is_matryoshka:
587+
if user_config.normalize is None:
588+
user_config.normalize = True
589+
elif not user_config.normalize:
590+
raise ValueError(
591+
"`normalize` must be enabled (set to True) "
592+
"for models that are compatible with "
593+
"Matryoshka Representation.")
594+
586595
return user_config
587596

588597
return None

vllm/entrypoints/llm.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -921,6 +921,11 @@ def encode(
921921
if pooling_params is None:
922922
# Use default pooling params.
923923
pooling_params = PoolingParams()
924+
elif isinstance(pooling_params, PoolingParams):
925+
pooling_params.verify(self.llm_engine.model_config)
926+
else:
927+
for pooling_param in pooling_params:
928+
pooling_param.verify(self.llm_engine.model_config)
924929

925930
self._validate_and_add_requests(
926931
prompts=parsed_prompts,
@@ -939,6 +944,8 @@ def embed(
939944
/,
940945
*,
941946
use_tqdm: bool = True,
947+
pooling_params: Optional[Union[PoolingParams,
948+
Sequence[PoolingParams]]] = None,
942949
lora_request: Optional[Union[list[LoRARequest], LoRARequest]] = None,
943950
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
944951
) -> list[EmbeddingRequestOutput]:
@@ -953,6 +960,8 @@ def embed(
953960
prompts: The prompts to the LLM. You may pass a sequence of prompts
954961
for batch inference. See :class:`~vllm.inputs.PromptType`
955962
for more details about the format of each prompts.
963+
pooling_params: The pooling parameters for pooling. If None, we
964+
use the default pooling parameters.
956965
use_tqdm: Whether to use tqdm to display the progress bar.
957966
lora_request: LoRA request to use for generation, if any.
958967
prompt_adapter_request: Prompt Adapter request to use for
@@ -968,6 +977,7 @@ def embed(
968977

969978
items = self.encode(prompts,
970979
use_tqdm=use_tqdm,
980+
pooling_params=pooling_params,
971981
lora_request=lora_request,
972982
prompt_adapter_request=prompt_adapter_request)
973983

vllm/entrypoints/openai/protocol.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1006,7 +1006,8 @@ class EmbeddingCompletionRequest(OpenAIBaseModel):
10061006
# doc: end-embedding-extra-params
10071007

10081008
def to_pooling_params(self):
1009-
return PoolingParams(additional_data=self.additional_data)
1009+
return PoolingParams(dimensions=self.dimensions,
1010+
additional_data=self.additional_data)
10101011

10111012

10121013
class EmbeddingChatRequest(OpenAIBaseModel):
@@ -1068,7 +1069,8 @@ def check_generation_prompt(cls, data):
10681069
return data
10691070

10701071
def to_pooling_params(self):
1071-
return PoolingParams(additional_data=self.additional_data)
1072+
return PoolingParams(dimensions=self.dimensions,
1073+
additional_data=self.additional_data)
10721074

10731075

10741076
EmbeddingRequest = Union[EmbeddingCompletionRequest, EmbeddingChatRequest]

vllm/entrypoints/openai/serving_embedding.py

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -80,9 +80,6 @@ async def create_embedding(
8080
return error_check_ret
8181

8282
encoding_format = request.encoding_format
83-
if request.dimensions is not None:
84-
return self.create_error_response(
85-
"dimensions is currently not supported")
8683

8784
model_name = self._get_model_name(request.model)
8885
request_id = f"embd-{self._base_request_id(raw_request)}"
@@ -99,6 +96,13 @@ async def create_embedding(
9996
"greater than max_model_len."
10097
" Please, select a smaller truncation size.")
10198

99+
pooling_params = request.to_pooling_params()
100+
101+
try:
102+
pooling_params.verify(self.model_config)
103+
except ValueError as e:
104+
return self.create_error_response(str(e))
105+
102106
try:
103107
(
104108
lora_request,
@@ -146,8 +150,6 @@ async def create_embedding(
146150
# Schedule the request and get the result generator.
147151
generators: list[AsyncGenerator[PoolingRequestOutput, None]] = []
148152
try:
149-
pooling_params = request.to_pooling_params()
150-
151153
for i, engine_prompt in enumerate(engine_prompts):
152154
request_id_item = f"{request_id}-{i}"
153155

vllm/model_executor/layers/pooler.py

Lines changed: 18 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -97,7 +97,7 @@ def forward(
9797
pooling_metadata: PoolingMetadata,
9898
) -> PoolerOutput:
9999
pooled_data = self.extract_states(hidden_states, pooling_metadata)
100-
pooled_data = self.head(pooled_data)
100+
pooled_data = self.head(pooled_data, pooling_metadata)
101101
pooled_outputs = [self.build_output(data) for data in pooled_data]
102102
return PoolerOutput(outputs=pooled_outputs)
103103

@@ -217,14 +217,28 @@ def __init__(self, *, normalize: bool, softmax: bool) -> None:
217217
self.normalize = normalize
218218
self.softmax = softmax
219219

220-
def forward(self, pooled_data: Union[list[torch.Tensor], torch.Tensor]):
220+
def forward(self, pooled_data: Union[list[torch.Tensor], torch.Tensor],
221+
pooling_metadata: PoolingMetadata):
222+
223+
dimensions_list = [
224+
pooling_param.dimensions
225+
for _, pooling_param in pooling_metadata.seq_groups
226+
]
227+
if any(d is not None for d in dimensions_list):
228+
# change the output dimension
229+
assert len(pooled_data) == len(dimensions_list)
230+
pooled_data = [
231+
vecs if d is None else vecs[..., :d]
232+
for vecs, d in zip(pooled_data, dimensions_list)
233+
]
234+
221235
if self.normalize:
222236
if isinstance(pooled_data, list):
223237
pooled_data = [
224-
F.normalize(data, p=2, dim=1) for data in pooled_data
238+
F.normalize(data, p=2, dim=-1) for data in pooled_data
225239
]
226240
else:
227-
pooled_data = F.normalize(pooled_data, p=2, dim=1)
241+
pooled_data = F.normalize(pooled_data, p=2, dim=-1)
228242

229243
if self.softmax:
230244
if isinstance(pooled_data, list):

0 commit comments

Comments
 (0)