Skip to content

Commit 4904e53

Browse files
fake0fanherotai214
andauthored
[Bugfix] SharedStorage Connector for V1 PD multimodal (#21611)
Signed-off-by: fake0fan <[email protected]> Signed-off-by: herotai214 <[email protected]> Co-authored-by: herotai214 <[email protected]>
1 parent 004203e commit 4904e53

File tree

2 files changed

+244
-12
lines changed

2 files changed

+244
-12
lines changed
Lines changed: 215 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,215 @@
1+
# SPDX-License-Identifier: Apache-2.0
2+
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3+
from dataclasses import asdict
4+
from typing import NamedTuple
5+
6+
from PIL import Image
7+
8+
from vllm import LLM, EngineArgs, SamplingParams
9+
from vllm.assets.image import ImageAsset
10+
from vllm.config import KVTransferConfig
11+
from vllm.multimodal.utils import encode_image_base64
12+
13+
MODEL_NAME = "Qwen/Qwen2.5-VL-3B-Instruct"
14+
15+
SAMPLING_PARAMS = SamplingParams(temperature=0.0, top_k=1, max_tokens=128)
16+
17+
TEXT_PROMPTS = [
18+
"What's in the image(s)? Around 30 words. What's special in 2nd image?",
19+
"The future of AI is",
20+
]
21+
22+
23+
class InputCase(NamedTuple):
24+
text: str
25+
img: list[Image]
26+
expected_len: int
27+
info: str
28+
29+
30+
def _check_path_len(path):
31+
"""Return the latest length in path"""
32+
return len(list(path.iterdir()))
33+
34+
35+
def _list_path(path):
36+
"""Return the list of foldername (hashes generatd) under the path"""
37+
return list(path.iterdir())
38+
39+
40+
def run_test(tmp_path, processor, llm: LLM, question: str,
41+
image_urls: list[Image], expected_len: int, info: str):
42+
"""
43+
One individual test to process the prompt and output base on 1 set of input
44+
Then check if the length in the strorage path matches the expected length
45+
`info` introduces details or purpose of the individual test
46+
"""
47+
print(f"***info: {info}***")
48+
print(
49+
f"**Expected storage path length after llm generate: {expected_len}**")
50+
process_prompt(processor, llm, question, image_urls)
51+
52+
print(f"Path matched expected length: {_check_path_len(tmp_path)}")
53+
print(f"Hashes under the storage path: {_list_path(tmp_path)}")
54+
55+
assert _check_path_len(tmp_path) == expected_len, (
56+
f"Expect storage path length {expected_len} ;",
57+
f"but end up {_check_path_len(tmp_path)} instead. ", f"Info: {info}")
58+
59+
60+
def process_prompt(processor, llm: LLM, question: str,
61+
image_urls: list[Image]):
62+
"""
63+
Form the prompt based on the text and image input, then llm generate output
64+
"""
65+
placeholders = [{
66+
"type": "image_url",
67+
"image_url": {
68+
"url": f"data:image;base64,{encode_image_base64(image_pil)}"
69+
}
70+
} for image_pil in image_urls]
71+
72+
messages = [
73+
{
74+
"role": "system",
75+
"content": "You are a helpful assistant."
76+
},
77+
{
78+
"role": "user",
79+
"content": [
80+
*placeholders,
81+
{
82+
"type": "text",
83+
"text": question
84+
},
85+
],
86+
},
87+
]
88+
89+
prompt = processor.apply_chat_template(messages,
90+
tokenize=False,
91+
add_generation_prompt=True)
92+
93+
outputs = llm.generate(
94+
{
95+
"prompt":
96+
prompt,
97+
**({
98+
"multi_modal_data": {
99+
"image": [*image_urls]
100+
}
101+
} if image_urls else {})
102+
},
103+
sampling_params=SAMPLING_PARAMS,
104+
)
105+
106+
print("-" * 50)
107+
print("Output:")
108+
for o in outputs:
109+
generated_text = o.outputs[0].text
110+
print(generated_text)
111+
print("-" * 50)
112+
113+
114+
def test_shared_storage_connector_hashes(tmp_path):
115+
"""
116+
Tests that SharedStorageConnector saves KV to the storage locations
117+
with proper hashes; that are unique for inputs with identical text but
118+
differnt images (same size), or same multiple images but different orders.
119+
"""
120+
# Using tmp_path as the storage path to store KV
121+
print(f"KV storage path at: {str(tmp_path)}")
122+
123+
# Configure the SharedStorageConnector
124+
kv_transfer_config = KVTransferConfig(
125+
kv_connector="SharedStorageConnector",
126+
kv_role="kv_both",
127+
kv_connector_extra_config={"shared_storage_path": str(tmp_path)})
128+
129+
engine_args = EngineArgs(
130+
model=MODEL_NAME,
131+
max_model_len=8192,
132+
max_num_seqs=1,
133+
kv_transfer_config=kv_transfer_config,
134+
limit_mm_per_prompt={"image": 2},
135+
)
136+
137+
# don't put this import at the top level
138+
# it will call torch.cuda.device_count()
139+
from transformers import AutoProcessor # noqa: F401
140+
141+
# Create processor to handle the chat prompt
142+
processor = AutoProcessor.from_pretrained(MODEL_NAME)
143+
144+
# Prepare images for the tests
145+
# Resize to the same size to check hashes correctness
146+
image_1 = ImageAsset("stop_sign").pil_image.resize((1280, 720))
147+
image_2 = ImageAsset("cherry_blossom").pil_image.resize((1280, 720))
148+
149+
# Make sure that they are not the same picture
150+
assert image_1 != image_2, "The images should not be identical"
151+
152+
# Create the LLM instance
153+
engine_args = asdict(engine_args)
154+
llm = LLM(**engine_args)
155+
156+
# Prepare the input cases
157+
input_cases = [
158+
InputCase(text=TEXT_PROMPTS[0],
159+
img=[image_1],
160+
expected_len=1,
161+
info="image_1 single input the first time."),
162+
InputCase(text=TEXT_PROMPTS[0],
163+
img=[image_2],
164+
expected_len=2,
165+
info=("image_2 single input the first time. "
166+
"It is in same pixel size with image_1, yet it "
167+
"should be able to form a new unique hash.")),
168+
InputCase(text=TEXT_PROMPTS[0],
169+
img=[image_1],
170+
expected_len=2,
171+
info=("image_1 single input the 2nd time. "
172+
"It should not form aother new hash.")),
173+
InputCase(text=TEXT_PROMPTS[0],
174+
img=[image_2],
175+
expected_len=2,
176+
info=("image_2 single input the 2nd time. "
177+
"It should not form aother new hash.")),
178+
InputCase(text=TEXT_PROMPTS[0],
179+
img=[image_1, image_2],
180+
expected_len=3,
181+
info="image_1 with image_2 input the first time."),
182+
InputCase(text=TEXT_PROMPTS[0],
183+
img=[image_2, image_1],
184+
expected_len=4,
185+
info="The image order is swapped. Should form new hash."),
186+
InputCase(text=TEXT_PROMPTS[0],
187+
img=[image_1, image_2],
188+
expected_len=4,
189+
info=("[image_1, image_2] input the 2nd time. "
190+
"It should not form aother new hash.")),
191+
InputCase(text=TEXT_PROMPTS[0],
192+
img=[image_2, image_1],
193+
expected_len=4,
194+
info=("[image_2, image_1] input the 2nd time. "
195+
"It should not form aother new hash.")),
196+
InputCase(text=TEXT_PROMPTS[0],
197+
img=[],
198+
expected_len=5,
199+
info="Pure text input test as a case-control"),
200+
InputCase(text=TEXT_PROMPTS[0],
201+
img=[],
202+
expected_len=5,
203+
info="Identical pure text input as a case-control"),
204+
InputCase(text=TEXT_PROMPTS[1],
205+
img=[],
206+
expected_len=6,
207+
info="Another pure text input as a case-control"),
208+
]
209+
210+
# Run tests
211+
for case_id, (text, img, expected_len, info) in enumerate(input_cases):
212+
print("\n", "=" * 25, f"Below running input case: {case_id}", "=" * 25)
213+
run_test(tmp_path, processor, llm, text, img, expected_len, info)
214+
215+
print("All tests passed successfully!")

vllm/distributed/kv_transfer/kv_connector/v1/shared_storage_connector.py

Lines changed: 29 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -32,10 +32,11 @@ class ReqMeta:
3232
slot_mapping: torch.Tensor
3333
# Is store or load
3434
is_store: bool
35+
mm_hashes: list[str]
3536

3637
@staticmethod
3738
def make_meta(token_ids: list[int], block_ids: list[int], block_size: int,
38-
is_store: bool) -> "ReqMeta":
39+
is_store: bool, mm_hashes: list[str]) -> "ReqMeta":
3940
valid_num_tokens = align_to_block_size(len(token_ids), block_size)
4041
token_ids_tensor = torch.tensor(token_ids)[:valid_num_tokens]
4142
block_ids_tensor = torch.tensor(block_ids)
@@ -48,6 +49,7 @@ def make_meta(token_ids: list[int], block_ids: list[int], block_size: int,
4849
token_ids=token_ids_tensor,
4950
slot_mapping=slot_mapping,
5051
is_store=is_store,
52+
mm_hashes=mm_hashes,
5153
)
5254

5355

@@ -64,9 +66,11 @@ def add_request(
6466
block_ids: list[int],
6567
block_size: int,
6668
is_store: bool,
69+
mm_hashes: list[str],
6770
) -> None:
6871
self.requests.append(
69-
ReqMeta.make_meta(token_ids, block_ids, block_size, is_store))
72+
ReqMeta.make_meta(token_ids, block_ids, block_size, is_store,
73+
mm_hashes))
7074

7175

7276
class SharedStorageConnector(KVConnectorBase_V1):
@@ -169,7 +173,7 @@ def inject_kv_into_layer(
169173
forward_context.virtual_engine]
170174

171175
filename = self._generate_filename_debug(
172-
layer_name, request.token_ids)
176+
layer_name, request.token_ids, request.mm_hashes)
173177
kv_cache = safetensors.torch.load_file(
174178
filename)["kv_cache"].cuda()
175179
inject_kv_into_layer(kv_cache_layer, kv_cache,
@@ -221,7 +225,7 @@ def extract_kv_from_layer(
221225
for request in connector_metadata.requests:
222226
if request.is_store:
223227
filename = self._generate_filename_debug(
224-
layer_name, request.token_ids)
228+
layer_name, request.token_ids, request.mm_hashes)
225229
kv_cache = extract_kv_from_layer(kv_layer,
226230
request.slot_mapping)
227231
tensors = {"kv_cache": kv_cache.detach().cpu()}
@@ -299,7 +303,8 @@ def build_connector_meta(
299303
meta.add_request(token_ids=new_req.prompt_token_ids,
300304
block_ids=new_req.block_ids[0],
301305
block_size=self._block_size,
302-
is_store=False)
306+
is_store=False,
307+
mm_hashes=new_req.mm_hashes)
303308
total_need_load += 1
304309
else:
305310
# NOTE: here, we set the store and load being exclusive,
@@ -310,7 +315,8 @@ def build_connector_meta(
310315
meta.add_request(token_ids=new_req.prompt_token_ids,
311316
block_ids=new_req.block_ids[0],
312317
block_size=self._block_size,
313-
is_store=True)
318+
is_store=True,
319+
mm_hashes=new_req.mm_hashes)
314320

315321
cached_reqs = scheduler_output.scheduled_cached_reqs
316322
for i, req_id in enumerate(cached_reqs.req_ids):
@@ -338,7 +344,8 @@ def build_connector_meta(
338344
meta.add_request(token_ids=token_ids,
339345
block_ids=block_ids,
340346
block_size=self._block_size,
341-
is_store=False)
347+
is_store=False,
348+
mm_hashes=request.mm_hashes)
342349
total_need_load += 1
343350

344351
assert total_need_load == len(self._requests_need_load)
@@ -359,20 +366,28 @@ def _found_match_for_request(
359366
len(request.prompt_token_ids) - 1, self._block_size)
360367
foldername = self._generate_foldername_debug(torch.tensor(
361368
request.prompt_token_ids)[:num_tokens_to_check],
369+
request.mm_hashes,
362370
create_folder=False)
363371
return os.path.exists(foldername)
364372

365373
def _generate_foldername_debug(
366374
self,
367-
input_ids: torch.Tensor,
375+
token_ids: torch.Tensor,
376+
mm_hashes: list[str],
368377
create_folder=False,
369378
) -> str:
370379
"""Generate a folder name based on the hash of the bytes of the input
371380
ids.
372381
"""
373-
input_ids_bytes = input_ids.numpy().tobytes()
374-
input_ids_hash = hashlib.md5(input_ids_bytes,
382+
token_bytes = token_ids.numpy().tobytes()
383+
# Add mm_hashes to the bytes being hashed to avoid path traversal and
384+
# to create a canonical key.
385+
if mm_hashes:
386+
mm_str = "-".join(mm_hashes)
387+
token_bytes += mm_str.encode('utf-8')
388+
input_ids_hash = hashlib.md5(token_bytes,
375389
usedforsecurity=False).hexdigest()
390+
376391
foldername = os.path.join(self._storage_path, input_ids_hash)
377392
if create_folder:
378393
os.makedirs(foldername, exist_ok=True)
@@ -381,12 +396,14 @@ def _generate_foldername_debug(
381396
def _generate_filename_debug(
382397
self,
383398
layer_name: str,
384-
input_ids: torch.Tensor,
399+
token_ids: torch.Tensor,
400+
mm_hashes: list[str],
385401
) -> str:
386402
"""Generate a file name based on the layer name and the hash
387403
of the bytes of the input ids.
388404
"""
389-
foldername = self._generate_foldername_debug(input_ids,
405+
foldername = self._generate_foldername_debug(token_ids,
406+
mm_hashes=mm_hashes,
390407
create_folder=True)
391408
return os.path.join(foldername, f"{layer_name}.safetensors")
392409

0 commit comments

Comments
 (0)