Skip to content

Commit 1b32d3c

Browse files
committed
[Quickfix] update CachedRequestState as NewRequestData changed
Signed-off-by: MengqingCao <[email protected]>
1 parent 103654c commit 1b32d3c

File tree

2 files changed

+118
-57
lines changed

2 files changed

+118
-57
lines changed

vllm_ascend/worker/model_runner_v1.py

Lines changed: 117 additions & 56 deletions
Original file line numberDiff line numberDiff line change
@@ -51,8 +51,6 @@
5151
from vllm.model_executor.models.interfaces import supports_transcription
5252
from vllm.model_executor.models.interfaces_base import (
5353
VllmModelForPooling, is_pooling_model, is_text_generation_model)
54-
from vllm.multimodal.inputs import MultiModalKwargs, PlaceholderRange
55-
from vllm.multimodal.utils import group_mm_inputs_by_modality
5654
from vllm.pooling_params import PoolingParams
5755
from vllm.sampling_params import SamplingType
5856
from vllm.sequence import IntermediateTensors
@@ -93,9 +91,14 @@
9391
from vllm_ascend.worker.npu_input_batch import CachedRequestState, InputBatch
9492

9593
if not vllm_version_is("0.10.0"):
94+
from vllm.multimodal.inputs import MultiModalKwargsItem, PlaceholderRange
95+
from vllm.multimodal.utils import group_mm_kwargs_by_modality
9696
from vllm.tasks import GenerationTask, SupportedTask
9797
from vllm.v1.worker.kv_connector_model_runner_mixin import \
9898
KVConnectorOutput
99+
else:
100+
from vllm.multimodal.inputs import MultiModalKwargs, PlaceholderRange
101+
from vllm.multimodal.utils import group_mm_inputs_by_modality
99102

100103
if TYPE_CHECKING:
101104
import xgrammar as xgr # type: ignore[import-untyped]
@@ -475,20 +478,34 @@ def _update_states(self, scheduler_output: "SchedulerOutput") -> None:
475478
model = cast(VllmModelForPooling, self.model)
476479
to_update = model.pooler.get_pooling_updates(task)
477480
to_update.apply(pooling_params)
478-
479-
self.requests[req_id] = CachedRequestState(
480-
req_id=req_id,
481-
prompt_token_ids=new_req_data.prompt_token_ids,
482-
mm_inputs=new_req_data.mm_inputs,
483-
mm_positions=new_req_data.mm_positions,
484-
sampling_params=sampling_params,
485-
pooling_params=new_req_data.pooling_params,
486-
generator=generator,
487-
block_ids=new_req_data.block_ids,
488-
num_computed_tokens=new_req_data.num_computed_tokens,
489-
output_token_ids=[],
490-
lora_request=new_req_data.lora_request,
491-
)
481+
if vllm_version_is("0.10.0"):
482+
self.requests[req_id] = CachedRequestState(
483+
req_id=req_id,
484+
prompt_token_ids=new_req_data.prompt_token_ids,
485+
mm_kwargs=new_req_data.mm_inputs,
486+
mm_positions=new_req_data.mm_positions,
487+
sampling_params=sampling_params,
488+
pooling_params=new_req_data.pooling_params,
489+
generator=generator,
490+
block_ids=new_req_data.block_ids,
491+
num_computed_tokens=new_req_data.num_computed_tokens,
492+
output_token_ids=[],
493+
lora_request=new_req_data.lora_request,
494+
)
495+
else:
496+
self.requests[req_id] = CachedRequestState(
497+
req_id=req_id,
498+
prompt_token_ids=new_req_data.prompt_token_ids,
499+
mm_kwargs=new_req_data.mm_kwargs,
500+
mm_positions=new_req_data.mm_positions,
501+
sampling_params=sampling_params,
502+
pooling_params=new_req_data.pooling_params,
503+
generator=generator,
504+
block_ids=new_req_data.block_ids,
505+
num_computed_tokens=new_req_data.num_computed_tokens,
506+
output_token_ids=[],
507+
lora_request=new_req_data.lora_request,
508+
)
492509

493510
# Only relevant for models using M-RoPE (e.g, Qwen2-VL)
494511
if self.uses_mrope:
@@ -497,21 +514,39 @@ def _update_states(self, scheduler_output: "SchedulerOutput") -> None:
497514
second_per_grid_ts = []
498515
audio_feature_lengths = []
499516
use_audio_in_video = False
500-
for mm_input in self.requests[req_id].mm_inputs:
501-
if mm_input.get("image_grid_thw") is not None:
502-
image_grid_thw.extend(
503-
mm_input["image_grid_thw"].tolist())
504-
if mm_input.get("video_grid_thw") is not None:
505-
video_grid_thw.extend(
506-
mm_input["video_grid_thw"].tolist())
507-
if mm_input.get("second_per_grid_ts") is not None:
508-
second_per_grid_ts.extend(
509-
mm_input["second_per_grid_ts"])
510-
if mm_input.get("audio_feature_lengths") is not None:
511-
audio_feature_lengths.extend(
512-
mm_input["audio_feature_lengths"])
513-
if mm_input.get("use_audio_in_video") is True:
514-
use_audio_in_video = True
517+
if vllm_version_is("0.10.0"):
518+
for mm_input in self.requests[req_id].mm_kwargs:
519+
if mm_input.get("image_grid_thw") is not None:
520+
image_grid_thw.extend(
521+
mm_input["image_grid_thw"].tolist())
522+
if mm_input.get("video_grid_thw") is not None:
523+
video_grid_thw.extend(
524+
mm_input["video_grid_thw"].tolist())
525+
if mm_input.get("second_per_grid_ts") is not None:
526+
second_per_grid_ts.extend(
527+
mm_input["second_per_grid_ts"])
528+
if mm_input.get("audio_feature_lengths") is not None:
529+
audio_feature_lengths.extend(
530+
mm_input["audio_feature_lengths"])
531+
if mm_input.get("use_audio_in_video") is True:
532+
use_audio_in_video = True
533+
else:
534+
for item in self.requests[req_id].mm_kwargs:
535+
mm_input = item.require_data()
536+
if mm_input.get("image_grid_thw") is not None:
537+
image_grid_thw.append(
538+
mm_input["image_grid_thw"].tolist())
539+
if mm_input.get("video_grid_thw") is not None:
540+
video_grid_thw.append(
541+
mm_input["video_grid_thw"].tolist())
542+
if mm_input.get("second_per_grid_ts") is not None:
543+
second_per_grid_ts.append(
544+
mm_input["second_per_grid_ts"])
545+
if mm_input.get("audio_feature_lengths") is not None:
546+
audio_feature_lengths.append(
547+
mm_input["audio_feature_lengths"])
548+
if mm_input.get("use_audio_in_video") is True:
549+
use_audio_in_video = True
515550

516551
hf_config = self.model_config.hf_config
517552

@@ -912,13 +947,16 @@ def _execute_mm_encoder(self, scheduler_output: "SchedulerOutput"):
912947
return
913948

914949
# Batch the multi-modal inputs.
915-
mm_inputs = list[MultiModalKwargs]()
950+
if vllm_version_is("0.10.0"):
951+
mm_kwargs = list[MultiModalKwargs]()
952+
else:
953+
mm_kwargs = list[MultiModalKwargsItem]()
916954
req_ids_pos = list[tuple[str, int, PlaceholderRange]]()
917955
for req_id, encoder_input_ids in scheduled_encoder_inputs.items():
918956
req_state = self.requests[req_id]
919957

920958
for mm_input_id in encoder_input_ids:
921-
mm_inputs.append(req_state.mm_inputs[mm_input_id])
959+
mm_kwargs.append(req_state.mm_kwargs[mm_input_id])
922960
req_ids_pos.append(
923961
(req_id, mm_input_id, req_state.mm_positions[mm_input_id]))
924962

@@ -929,31 +967,54 @@ def _execute_mm_encoder(self, scheduler_output: "SchedulerOutput"):
929967
# in the same batch while still being able to benefit from batching
930968
# multimodal inputs. The proper solution should be reordering the
931969
# encoder outputs.
932-
grouped_mm_inputs_list = group_mm_inputs_by_modality(mm_inputs)
933-
934970
encoder_outputs = []
935-
for grouped_mm_inputs in grouped_mm_inputs_list:
936-
batched_mm_inputs = MultiModalKwargs.batch(grouped_mm_inputs)
937-
batched_mm_inputs = MultiModalKwargs.as_kwargs(batched_mm_inputs,
938-
device=self.device)
939-
940-
# Run the encoder.
941-
# `curr_group_outputs` is either of the following:
942-
# 1. A tensor of shape (num_items, feature_size, hidden_size)
943-
# in case feature_size is fixed across all multimodal items.
944-
# 2. A list or tuple (length: num_items) of tensors, each of shape
945-
# (feature_size, hidden_size) in case the feature size is dynamic
946-
# depending on the input multimodal items.
947-
curr_group_outputs = self.model.get_multimodal_embeddings(
948-
**batched_mm_inputs)
949-
950-
sanity_check_mm_encoder_outputs(
951-
curr_group_outputs,
952-
expected_num_items=len(grouped_mm_inputs),
953-
)
971+
if vllm_version_is("0.10.0"):
972+
grouped_mm_inputs_list = group_mm_inputs_by_modality(mm_kwargs)
973+
974+
for grouped_mm_inputs in grouped_mm_inputs_list:
975+
batched_mm_inputs = MultiModalKwargs.batch(grouped_mm_inputs)
976+
batched_mm_inputs = MultiModalKwargs.as_kwargs(
977+
batched_mm_inputs, device=self.device)
978+
# Run the encoder.
979+
# `curr_group_outputs` is either of the following:
980+
# 1. A tensor of shape (num_items, feature_size, hidden_size)
981+
# in case feature_size is fixed across all multimodal items.
982+
# 2. A list or tuple (length: num_items) of tensors, each of shape
983+
# (feature_size, hidden_size) in case the feature size is dynamic
984+
# depending on the input multimodal items.
985+
curr_group_outputs = self.model.get_multimodal_embeddings(
986+
**batched_mm_inputs)
987+
988+
sanity_check_mm_encoder_outputs(
989+
curr_group_outputs,
990+
expected_num_items=len(grouped_mm_inputs),
991+
)
992+
993+
for output in curr_group_outputs:
994+
encoder_outputs.append(output)
995+
else:
996+
for _, num_items, mm_kwargs_group in group_mm_kwargs_by_modality(
997+
mm_kwargs,
998+
device=self.device,
999+
pin_memory=True,
1000+
):
1001+
# Run the encoder.
1002+
# `curr_group_outputs` is either of the following:
1003+
# 1. A tensor of shape (num_items, feature_size, hidden_size)
1004+
# in case feature_size is fixed across all multimodal items.
1005+
# 2. A list or tuple (length: num_items) of tensors, each of shape
1006+
# (feature_size, hidden_size) in case the feature size is dynamic
1007+
# depending on the input multimodal items.
1008+
curr_group_outputs = self.model.get_multimodal_embeddings(
1009+
**mm_kwargs_group)
1010+
1011+
sanity_check_mm_encoder_outputs(
1012+
curr_group_outputs,
1013+
expected_num_items=num_items,
1014+
)
9541015

955-
for output in curr_group_outputs:
956-
encoder_outputs.append(output)
1016+
for output in curr_group_outputs:
1017+
encoder_outputs.append(output)
9571018

9581019
# Cache the encoder outputs.
9591020
for (req_id, input_id, pos_info), output in zip(

vllm_ascend/worker/npu_input_batch.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,7 @@ class CachedRequestState:
4343

4444
req_id: str
4545
prompt_token_ids: list[int]
46-
mm_inputs: list[MultiModalKwargs]
46+
mm_kwargs: list[MultiModalKwargs]
4747
mm_positions: list[PlaceholderRange]
4848
sampling_params: Optional[SamplingParams]
4949
pooling_params: Optional[PoolingParams]

0 commit comments

Comments
 (0)