Skip to content

Commit 78863f8

Browse files
[BugFix] Add support for loading prompt embeds tensors serialized on unavailable devices and sparse tensors (#22962)
Signed-off-by: Andrew Sansom <[email protected]>
1 parent 5157827 commit 78863f8

File tree

2 files changed

+53
-2
lines changed

2 files changed

+53
-2
lines changed

tests/entrypoints/openai/test_prompt_validation.py

Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,16 @@
11
# SPDX-License-Identifier: Apache-2.0
22
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
33

4+
import io
5+
46
# imports for guided decoding tests
57
import openai
8+
import pybase64
69
import pytest
710
import regex as re
11+
import torch
12+
13+
from vllm.entrypoints.openai.serving_engine import OpenAIServing
814

915
from ...utils import RemoteOpenAIServer
1016

@@ -42,3 +48,46 @@ async def test_out_of_vocab_token_ids():
4248
prompt=[999999],
4349
max_tokens=5,
4450
temperature=0.0)
51+
52+
53+
@pytest.mark.parametrize("dtype",
54+
[torch.float32, torch.bfloat16, torch.float16])
55+
@pytest.mark.parametrize(
56+
"layout",
57+
[torch.strided, torch.sparse_coo, torch.sparse_csc, torch.sparse_csr])
58+
@pytest.mark.parametrize("seq_len", [2, 10])
59+
@pytest.mark.parametrize("hidden_size", [2, 10])
60+
def test_load_prompt_embeds(dtype: torch.dtype, layout: torch.layout,
61+
seq_len: int, hidden_size: int):
62+
# construct arbitrary tensors of various dtypes, layouts, and sizes.
63+
# We need to check against different layouts to make sure that if a user
64+
# uses sparse tensors to reduce the transmission size of prompt embeddings,
65+
# we must cast them to dense/strided before passing them into the engine.
66+
# We don't use non-CPU tensors in this test to avoid preemptively
67+
# initializing cuda and break other tests in the suite that fork processes.
68+
# We also need to make sure that we only use devices that are actually
69+
# available in the environment the test is running on. For simplicity,
70+
# we just test against CPU.
71+
tensor = torch.randn((seq_len, hidden_size), dtype=dtype)
72+
if layout == torch.strided:
73+
tensor = tensor.contiguous()
74+
elif layout == torch.sparse_coo:
75+
tensor = tensor.to_sparse_coo()
76+
elif layout == torch.sparse_csc:
77+
tensor = tensor.to_sparse_csc()
78+
elif layout == torch.sparse_csr:
79+
tensor = tensor.to_sparse_csr()
80+
81+
buffer = io.BytesIO()
82+
torch.save(tensor, buffer)
83+
buffer.seek(0)
84+
encoded_tensor = pybase64.b64encode(buffer.getvalue())
85+
86+
loaded_prompt_embeds = OpenAIServing._load_prompt_embeds(encoded_tensor)
87+
assert len(loaded_prompt_embeds) == 1
88+
loaded_tensor = loaded_prompt_embeds[0]["prompt_embeds"]
89+
assert loaded_tensor.device.type == "cpu"
90+
assert loaded_tensor.layout == torch.strided
91+
torch.testing.assert_close(loaded_tensor,
92+
tensor.to("cpu").to_dense(),
93+
equal_nan=True)

vllm/entrypoints/openai/serving_engine.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1006,21 +1006,23 @@ async def _generate_with_builtin_tools(
10061006
# OPTIMIZATION
10071007
priority = orig_priority - 1
10081008

1009+
@staticmethod
10091010
def _load_prompt_embeds(
1010-
self,
10111011
prompt_embeds: Optional[Union[bytes, list[bytes]]],
10121012
truncate_prompt_tokens: Optional[Annotated[int, Field(ge=1)]] = None
10131013
) -> list[EmbedsPrompt]:
10141014

10151015
def _load_and_validate_embed(embed: bytes) -> EmbedsPrompt:
10161016
tensor = torch.load(io.BytesIO(
10171017
pybase64.b64decode(embed, validate=True)),
1018-
weights_only=True)
1018+
weights_only=True,
1019+
map_location=torch.device("cpu"))
10191020
assert isinstance(tensor, torch.Tensor) and tensor.dtype in (
10201021
torch.float32,
10211022
torch.bfloat16,
10221023
torch.float16,
10231024
)
1025+
tensor = tensor.to_dense()
10241026
if tensor.dim() > 2:
10251027
tensor = tensor.squeeze(0)
10261028
assert tensor.dim() == 2

0 commit comments

Comments
 (0)