Skip to content

Commit 747b1fa

Browse files
committed
[Quickfix] update CachedRequestState as NewRequestData changed
Signed-off-by: MengqingCao <[email protected]>
1 parent 103654c commit 747b1fa

File tree

2 files changed

+19
-19
lines changed

2 files changed

+19
-19
lines changed

vllm_ascend/worker/model_runner_v1.py

Lines changed: 18 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -51,8 +51,8 @@
5151
from vllm.model_executor.models.interfaces import supports_transcription
5252
from vllm.model_executor.models.interfaces_base import (
5353
VllmModelForPooling, is_pooling_model, is_text_generation_model)
54-
from vllm.multimodal.inputs import MultiModalKwargs, PlaceholderRange
55-
from vllm.multimodal.utils import group_mm_inputs_by_modality
54+
from vllm.multimodal.inputs import MultiModalKwargsItem, PlaceholderRange
55+
from vllm.multimodal.utils import group_mm_kwargs_by_modality
5656
from vllm.pooling_params import PoolingParams
5757
from vllm.sampling_params import SamplingType
5858
from vllm.sequence import IntermediateTensors
@@ -479,7 +479,7 @@ def _update_states(self, scheduler_output: "SchedulerOutput") -> None:
479479
self.requests[req_id] = CachedRequestState(
480480
req_id=req_id,
481481
prompt_token_ids=new_req_data.prompt_token_ids,
482-
mm_inputs=new_req_data.mm_inputs,
482+
mm_kwargs=new_req_data.mm_kwargs,
483483
mm_positions=new_req_data.mm_positions,
484484
sampling_params=sampling_params,
485485
pooling_params=new_req_data.pooling_params,
@@ -497,18 +497,19 @@ def _update_states(self, scheduler_output: "SchedulerOutput") -> None:
497497
second_per_grid_ts = []
498498
audio_feature_lengths = []
499499
use_audio_in_video = False
500-
for mm_input in self.requests[req_id].mm_inputs:
500+
for item in self.requests[req_id].mm_kwargs:
501+
mm_input = item.require_data()
501502
if mm_input.get("image_grid_thw") is not None:
502-
image_grid_thw.extend(
503+
image_grid_thw.append(
503504
mm_input["image_grid_thw"].tolist())
504505
if mm_input.get("video_grid_thw") is not None:
505-
video_grid_thw.extend(
506+
video_grid_thw.append(
506507
mm_input["video_grid_thw"].tolist())
507508
if mm_input.get("second_per_grid_ts") is not None:
508-
second_per_grid_ts.extend(
509+
second_per_grid_ts.append(
509510
mm_input["second_per_grid_ts"])
510511
if mm_input.get("audio_feature_lengths") is not None:
511-
audio_feature_lengths.extend(
512+
audio_feature_lengths.append(
512513
mm_input["audio_feature_lengths"])
513514
if mm_input.get("use_audio_in_video") is True:
514515
use_audio_in_video = True
@@ -912,13 +913,13 @@ def _execute_mm_encoder(self, scheduler_output: "SchedulerOutput"):
912913
return
913914

914915
# Batch the multi-modal inputs.
915-
mm_inputs = list[MultiModalKwargs]()
916+
mm_kwargs = list[MultiModalKwargsItem]()
916917
req_ids_pos = list[tuple[str, int, PlaceholderRange]]()
917918
for req_id, encoder_input_ids in scheduled_encoder_inputs.items():
918919
req_state = self.requests[req_id]
919920

920921
for mm_input_id in encoder_input_ids:
921-
mm_inputs.append(req_state.mm_inputs[mm_input_id])
922+
mm_kwargs.append(req_state.mm_kwargs[mm_input_id])
922923
req_ids_pos.append(
923924
(req_id, mm_input_id, req_state.mm_positions[mm_input_id]))
924925

@@ -929,14 +930,13 @@ def _execute_mm_encoder(self, scheduler_output: "SchedulerOutput"):
929930
# in the same batch while still being able to benefit from batching
930931
# multimodal inputs. The proper solution should be reordering the
931932
# encoder outputs.
932-
grouped_mm_inputs_list = group_mm_inputs_by_modality(mm_inputs)
933933

934934
encoder_outputs = []
935-
for grouped_mm_inputs in grouped_mm_inputs_list:
936-
batched_mm_inputs = MultiModalKwargs.batch(grouped_mm_inputs)
937-
batched_mm_inputs = MultiModalKwargs.as_kwargs(batched_mm_inputs,
938-
device=self.device)
939-
935+
for _, num_items, mm_kwargs_group in group_mm_kwargs_by_modality(
936+
mm_kwargs,
937+
device=self.device,
938+
pin_memory=True,
939+
):
940940
# Run the encoder.
941941
# `curr_group_outputs` is either of the following:
942942
# 1. A tensor of shape (num_items, feature_size, hidden_size)
@@ -945,11 +945,11 @@ def _execute_mm_encoder(self, scheduler_output: "SchedulerOutput"):
945945
# (feature_size, hidden_size) in case the feature size is dynamic
946946
# depending on the input multimodal items.
947947
curr_group_outputs = self.model.get_multimodal_embeddings(
948-
**batched_mm_inputs)
948+
**mm_kwargs_group)
949949

950950
sanity_check_mm_encoder_outputs(
951951
curr_group_outputs,
952-
expected_num_items=len(grouped_mm_inputs),
952+
expected_num_items=num_items,
953953
)
954954

955955
for output in curr_group_outputs:

vllm_ascend/worker/npu_input_batch.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,7 @@ class CachedRequestState:
4343

4444
req_id: str
4545
prompt_token_ids: list[int]
46-
mm_inputs: list[MultiModalKwargs]
46+
mm_kwargs: list[MultiModalKwargs]
4747
mm_positions: list[PlaceholderRange]
4848
sampling_params: Optional[SamplingParams]
4949
pooling_params: Optional[PoolingParams]

0 commit comments

Comments
 (0)