Skip to content

Commit 17eaaef

Browse files
[Bugfix] Fix RuntimeError: Index put requires the source and destination dtypes match (#22065)
Signed-off-by: chaunceyjiang <[email protected]>
1 parent 3303f13 commit 17eaaef

File tree

2 files changed

+106
-2
lines changed

2 files changed

+106
-2
lines changed
Lines changed: 103 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,103 @@
1+
# SPDX-License-Identifier: Apache-2.0
2+
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3+
4+
import base64
5+
import io
6+
import json
7+
8+
import openai # use the official client for correctness check
9+
import pytest
10+
import pytest_asyncio
11+
import torch
12+
from transformers import AutoConfig
13+
14+
from tests.conftest import ImageTestAssets
15+
from tests.utils import RemoteOpenAIServer
16+
17+
# any model with a chat template should work here
18+
MODEL_NAME = "llava-hf/llava-1.5-7b-hf"
19+
CONFIG = AutoConfig.from_pretrained(MODEL_NAME)
20+
MAXIMUM_IMAGES = 2
21+
22+
23+
@pytest.fixture(scope="module")
24+
def default_image_embeds_server_args() -> list[str]:
25+
return [
26+
"--dtype",
27+
"bfloat16",
28+
"--max-model-len",
29+
"2048",
30+
"--max-num-seqs",
31+
"4",
32+
"--enforce-eager",
33+
"--limit-mm-per-prompt",
34+
json.dumps({"image": MAXIMUM_IMAGES}),
35+
]
36+
37+
38+
@pytest.fixture(scope="module")
39+
def server_with_image_embeds(default_image_embeds_server_args):
40+
with RemoteOpenAIServer(MODEL_NAME,
41+
default_image_embeds_server_args) as remote_server:
42+
yield remote_server
43+
44+
45+
@pytest_asyncio.fixture
46+
async def client_with_image_embeds(server_with_image_embeds):
47+
async with server_with_image_embeds.get_async_client() as async_client:
48+
yield async_client
49+
50+
51+
def encode_image_embedding_to_base64(image_embedding) -> str:
52+
"""
53+
Encode image embedding to base64 string
54+
"""
55+
buffer = io.BytesIO()
56+
torch.save(image_embedding, buffer)
57+
buffer.seek(0)
58+
binary_data = buffer.read()
59+
base64_image_embedding = base64.b64encode(binary_data).decode('utf-8')
60+
return base64_image_embedding
61+
62+
63+
@pytest.mark.asyncio
64+
@pytest.mark.parametrize("model_name", [MODEL_NAME])
65+
@pytest.mark.parametrize("dtype", [torch.half, torch.float16, torch.float32])
66+
async def test_completions_with_image_embeds(
67+
client_with_image_embeds: openai.AsyncOpenAI,
68+
model_name: str,
69+
image_assets: ImageTestAssets,
70+
dtype: torch.dtype,
71+
):
72+
# Test case: Single image embeds input
73+
image_embeds = image_assets[0].image_embeds.to(dtype=dtype)
74+
base64_image_embedding = encode_image_embedding_to_base64(image_embeds)
75+
chat_completion = await client_with_image_embeds.chat.completions.create(
76+
messages=[
77+
{
78+
"role": "system",
79+
"content": "You are a helpful assistant."
80+
},
81+
{
82+
"role":
83+
"user",
84+
"content": [
85+
{
86+
"type":
87+
"text",
88+
"text":
89+
"Describe these images separately. For each image,"
90+
"reply with a short sentence (no more than 10 words).",
91+
},
92+
{
93+
"type": "image_embeds",
94+
"image_embeds": base64_image_embedding,
95+
},
96+
],
97+
},
98+
],
99+
model=model_name,
100+
)
101+
assert chat_completion.choices[0].message.content is not None
102+
assert isinstance(chat_completion.choices[0].message.content, str)
103+
assert len(chat_completion.choices[0].message.content) > 0

vllm/model_executor/models/utils.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -401,7 +401,7 @@ def merge_multimodal_embeddings_from_map(
401401
"""
402402
flattened_embeddings = _flatten_embeddings(multimodal_embeddings)
403403
inputs_embeds[placeholder_map.dest] = flattened_embeddings[
404-
placeholder_map.src]
404+
placeholder_map.src].to(dtype=inputs_embeds.dtype)
405405
return inputs_embeds
406406

407407

@@ -421,7 +421,8 @@ def _merge_multimodal_embeddings(
421421
flattened = _flatten_embeddings(multimodal_embeddings)
422422
try:
423423
# This is equivalent to: inputs_embeds[is_multimodal] = flattened.
424-
inputs_embeds.masked_scatter_(is_multimodal.unsqueeze(-1), flattened)
424+
inputs_embeds.masked_scatter_(is_multimodal.unsqueeze(-1),
425+
flattened.to(dtype=inputs_embeds.dtype))
425426
except RuntimeError as e:
426427
num_expected_tokens = is_multimodal.sum().item()
427428
assert isinstance(num_expected_tokens, int)

0 commit comments

Comments
 (0)