diff --git a/.github/workflows/vllm_ascend_test.yaml b/.github/workflows/vllm_ascend_test.yaml index 0c0deed9a0..b0220624e3 100644 --- a/.github/workflows/vllm_ascend_test.yaml +++ b/.github/workflows/vllm_ascend_test.yaml @@ -81,7 +81,7 @@ jobs: VLLM_USE_MODELSCOPE: True strategy: matrix: - vllm_version: [main, v0.10.0] + vllm_version: [main] steps: - name: Install packages run: | @@ -137,7 +137,7 @@ jobs: max-parallel: 2 matrix: os: [linux-aarch64-a2-1] - vllm_version: [main, v0.10.0] + vllm_version: [main] name: singlecard e2e test runs-on: ${{ matrix.os }} container: @@ -185,9 +185,6 @@ jobs: run: | pip install -r requirements-dev.txt pip install -v -e . - if [[ "${{ matrix.vllm_version }}" == "v0.10.0" ]]; then - pip install "transformers<4.54.0" - fi - name: Run e2e test env: @@ -222,7 +219,7 @@ jobs: max-parallel: 2 matrix: os: [linux-aarch64-a2-2] - vllm_version: [main, v0.10.0] + vllm_version: [main] name: multicard e2e test runs-on: ${{ matrix.os }} container: @@ -270,9 +267,6 @@ jobs: run: | pip install -r requirements-dev.txt pip install -v -e . - if [[ "${{ matrix.vllm_version }}" == "v0.10.0" ]]; then - pip install "transformers<4.54.0" - fi - name: Run vllm-project/vllm-ascend test env: diff --git a/.github/workflows/vllm_ascend_test_310p.yaml b/.github/workflows/vllm_ascend_test_310p.yaml index 2bd1d2db87..a3d3cae94d 100644 --- a/.github/workflows/vllm_ascend_test_310p.yaml +++ b/.github/workflows/vllm_ascend_test_310p.yaml @@ -53,7 +53,7 @@ jobs: max-parallel: 2 matrix: os: [linux-aarch64-310p-1, linux-aarch64-310p-4] - vllm_version: [main, v0.10.0] + vllm_version: [main] name: 310p e2e test runs-on: ${{ matrix.os }} container: diff --git a/tests/ut/core/test_scheduler.py b/tests/ut/core/test_scheduler.py index 74aa8b84da..b572629f66 100644 --- a/tests/ut/core/test_scheduler.py +++ b/tests/ut/core/test_scheduler.py @@ -50,7 +50,7 @@ def create_requests( request_id=f"{i}", prompt_token_ids=[i] * num_tokens, sampling_params=sampling_params, - multi_modal_inputs=mm_inputs, + multi_modal_kwargs=mm_inputs, multi_modal_placeholders=mm_position, multi_modal_hashes=None, eos_token_id=EOS_TOKEN_ID, diff --git a/tests/ut/kv_connector/test_remote_decode_lifecycle.py b/tests/ut/kv_connector/test_remote_decode_lifecycle.py index 0a337437d0..bf44c0fdc8 100644 --- a/tests/ut/kv_connector/test_remote_decode_lifecycle.py +++ b/tests/ut/kv_connector/test_remote_decode_lifecycle.py @@ -25,7 +25,6 @@ create_model_runner_output, create_request, create_scheduler, create_vllm_config) -from vllm_ascend.utils import vllm_version_is def test_basic_lifecycle(): @@ -103,13 +102,10 @@ def test_basic_lifecycle(): # (3b): execute_model() model_runner_output = copy.deepcopy(EMPTY_MODEL_RUNNER_OUTPUT) - if vllm_version_is("0.10.0"): - model_runner_output.finished_sending = [request_id] - else: - from vllm.v1.worker.kv_connector_model_runner_mixin import \ - KVConnectorOutput # type: ignore # noqa - model_runner_output.kv_connector_output = KVConnectorOutput( - finished_sending=[request_id]) + from vllm.v1.worker.kv_connector_model_runner_mixin import \ + KVConnectorOutput # type: ignore # noqa + model_runner_output.kv_connector_output = KVConnectorOutput( + finished_sending=[request_id]) # (3c): update_from_output() scheduler.update_from_output(scheduler_output, model_runner_output) @@ -164,13 +160,10 @@ def test_prefix_cache_lifecycle(): scheduler_output = scheduler.schedule() scheduler.schedule() model_runner_output = copy.deepcopy(EMPTY_MODEL_RUNNER_OUTPUT) - if vllm_version_is("0.10.0"): - model_runner_output.finished_sending = [request_remote.request_id] - else: - from vllm.v1.worker.kv_connector_model_runner_mixin import \ - KVConnectorOutput # noqa - model_runner_output.kv_connector_output = KVConnectorOutput( - finished_sending=[request_remote.request_id]) + from vllm.v1.worker.kv_connector_model_runner_mixin import \ + KVConnectorOutput # noqa + model_runner_output.kv_connector_output = KVConnectorOutput( + finished_sending=[request_remote.request_id]) scheduler.update_from_output(scheduler_output, model_runner_output) _ = scheduler.schedule() assert_scheduler_empty(scheduler) diff --git a/tests/ut/kv_connector/test_remote_prefill_lifecycle.py b/tests/ut/kv_connector/test_remote_prefill_lifecycle.py index cb070ad74d..867dafb294 100644 --- a/tests/ut/kv_connector/test_remote_prefill_lifecycle.py +++ b/tests/ut/kv_connector/test_remote_prefill_lifecycle.py @@ -25,7 +25,6 @@ create_model_runner_output, create_request, create_scheduler, create_vllm_config) -from vllm_ascend.utils import vllm_version_is def test_basic_lifecycle(): @@ -91,13 +90,10 @@ def test_basic_lifecycle(): # (2b): forward(): request finishes recv. model_runner_output = copy.deepcopy(EMPTY_MODEL_RUNNER_OUTPUT) - if vllm_version_is("0.10.0"): - model_runner_output.finished_recving = [request_id] - else: - from vllm.v1.worker.kv_connector_model_runner_mixin import \ - KVConnectorOutput # type: ignore # noqa - model_runner_output.kv_connector_output = KVConnectorOutput( - finished_recving=[request_id]) + from vllm.v1.worker.kv_connector_model_runner_mixin import \ + KVConnectorOutput # type: ignore # noqa + model_runner_output.kv_connector_output = KVConnectorOutput( + finished_recving=[request_id]) # (2c): update_from_output(): engine_core_outputs = scheduler.update_from_output(scheduler_output, @@ -211,13 +207,10 @@ def test_full_block_prompt(): # # STEP (2): Recv. scheduler_output = scheduler.schedule() model_runner_output = copy.deepcopy(EMPTY_MODEL_RUNNER_OUTPUT) - if vllm_version_is("0.10.0"): - model_runner_output.finished_recving = [request_id] - else: - from vllm.v1.worker.kv_connector_model_runner_mixin import \ - KVConnectorOutput # type: ignore # noqa - model_runner_output.kv_connector_output = KVConnectorOutput( - finished_recving=[request_id]) + from vllm.v1.worker.kv_connector_model_runner_mixin import \ + KVConnectorOutput # type: ignore # noqa + model_runner_output.kv_connector_output = KVConnectorOutput( + finished_recving=[request_id]) scheduler.update_from_output(scheduler_output, model_runner_output) assert len(scheduler.waiting) == 1 assert (request_id in scheduler.finished_recving_kv_req_ids) diff --git a/tests/ut/kv_connector/utils.py b/tests/ut/kv_connector/utils.py index 2c540b30f0..e696a7692f 100644 --- a/tests/ut/kv_connector/utils.py +++ b/tests/ut/kv_connector/utils.py @@ -157,7 +157,7 @@ def create_request( request_id=f"id-{request_id}", prompt_token_ids=prompt_token_ids, sampling_params=sampling_params, - multi_modal_inputs=None, + multi_modal_kwargs=None, multi_modal_placeholders=None, multi_modal_hashes=None, **({ @@ -187,19 +187,11 @@ def create_model_runner_output( # Make output data structure. extra_args = {} - if not vllm_version_is("0.10.0"): - from vllm.v1.worker.kv_connector_model_runner_mixin import \ - KVConnectorOutput # type: ignore # noqa - kv_connector_output = KVConnectorOutput( - finished_sending=finished_sending, - finished_recving=finished_recving) - extra_args = {"kv_connector_output": kv_connector_output} - else: - extra_args = { - "finished_sending": finished_sending, - "finished_recving": finished_recving, - } - + from vllm.v1.worker.kv_connector_model_runner_mixin import \ + KVConnectorOutput # type: ignore # noqa + kv_connector_output = KVConnectorOutput(finished_sending=finished_sending, + finished_recving=finished_recving) + extra_args = {"kv_connector_output": kv_connector_output} return ModelRunnerOutput( req_ids=req_ids, req_id_to_index=req_id_to_index, diff --git a/tests/ut/worker/test_input_batch.py b/tests/ut/worker/test_input_batch.py index 7baee71226..685cf174a5 100644 --- a/tests/ut/worker/test_input_batch.py +++ b/tests/ut/worker/test_input_batch.py @@ -12,7 +12,7 @@ def mock_cached_request_state(req_id="1", prompt=[1, 2, 3], output=[4, 5, 6]): return CachedRequestState( req_id=req_id, prompt_token_ids=prompt, - mm_inputs=[], + mm_kwargs=[], mm_positions=[], sampling_params=SamplingParams(), pooling_params=None, diff --git a/vllm_ascend/models/qwen2_5_vl.py b/vllm_ascend/models/qwen2_5_vl.py index 4629f760eb..31ad2603a9 100644 --- a/vllm_ascend/models/qwen2_5_vl.py +++ b/vllm_ascend/models/qwen2_5_vl.py @@ -30,8 +30,7 @@ from vllm.config import VllmConfig from vllm.distributed import parallel_state from vllm.distributed import utils as dist_utils -from vllm.model_executor.layers.activation import (_ACTIVATION_REGISTRY, - get_act_and_mul_fn) +from vllm.model_executor.layers.activation import get_act_and_mul_fn from vllm.model_executor.layers.layernorm import RMSNorm from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.model_loader.weight_utils import default_weight_loader @@ -43,8 +42,6 @@ from vllm.model_executor.models.utils import maybe_prefix from vllm.multimodal import MULTIMODAL_REGISTRY -from vllm_ascend.utils import vllm_version_is - MIN_PAD_SIZE = 64 # min_size to pad weight MAX_PAD_SIZE = 128 # max_size to pad weight @@ -202,8 +199,6 @@ def __init__( ) act_fn = get_act_and_mul_fn(vision_config.hidden_act) - if vllm_version_is("0.10.0"): - act_fn = _ACTIVATION_REGISTRY[vision_config.hidden_act] self.blocks = nn.ModuleList([ AscendQwen2_5_VisionBlock( dim=self.hidden_size, @@ -303,12 +298,9 @@ def load_weights(self, weights: Iterable[Tuple[str, ("qkv_proj", "q_proj", "q"), ("qkv_proj", "k_proj", "k"), ("qkv_proj", "v_proj", "v"), + ("mlp.gate_up_proj.", "mlp.gate_proj.", 0), + ("mlp.gate_up_proj.", "mlp.up_proj.", 1), ] - if not vllm_version_is("0.10.0"): - stacked_params_mapping.extend([ - ("mlp.gate_up_proj.", "mlp.gate_proj.", 0), - ("mlp.gate_up_proj.", "mlp.up_proj.", 1), - ]) params_dict = dict(self.named_parameters(remove_duplicate=False)) loaded_params: Set[str] = set() for name, loaded_weight in weights: diff --git a/vllm_ascend/models/qwen2_5_vl_without_padding.py b/vllm_ascend/models/qwen2_5_vl_without_padding.py index 8877456a6d..8a1d92e146 100644 --- a/vllm_ascend/models/qwen2_5_vl_without_padding.py +++ b/vllm_ascend/models/qwen2_5_vl_without_padding.py @@ -30,8 +30,7 @@ from vllm.config import VllmConfig from vllm.distributed import parallel_state from vllm.distributed import utils as dist_utils -from vllm.model_executor.layers.activation import (_ACTIVATION_REGISTRY, - get_act_and_mul_fn) +from vllm.model_executor.layers.activation import get_act_and_mul_fn from vllm.model_executor.layers.layernorm import RMSNorm from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.models.qwen2_5_vl import ( @@ -43,7 +42,6 @@ from vllm.multimodal import MULTIMODAL_REGISTRY from vllm_ascend.models.qwen2_5_vl import AscendQwen2_5_VisionRotaryEmbedding -from vllm_ascend.utils import vllm_version_is class AscendQwen2_5_VisionAttention_Without_Padding(Qwen2_5_VisionAttention): @@ -175,8 +173,6 @@ def __init__( ) act_fn = get_act_and_mul_fn(vision_config.hidden_act) - if vllm_version_is("0.10.0"): - act_fn = _ACTIVATION_REGISTRY[vision_config.hidden_act] self.blocks = nn.ModuleList([ AscendQwen2_5_VisionBlock_Without_Padding( dim=self.hidden_size, diff --git a/vllm_ascend/multistream/ms_split.py b/vllm_ascend/multistream/ms_split.py index fd32a18abb..b7b356bed9 100644 --- a/vllm_ascend/multistream/ms_split.py +++ b/vllm_ascend/multistream/ms_split.py @@ -105,7 +105,7 @@ def model_input_split_v1_mla_attn( [block_table_pre, block_table_post] = split_attn_tensor_type(attn_metadata.block_tables, seq_index) - + assert attn_metadata.attn_mask is not None if attn_metadata.attn_state == AscendAttentionState.PrefillNoCache or attn_metadata.attn_state == AscendAttentionState.PrefillCacheHit: # the attn_mla kernel in torch npu only accept 128*128 attn mask attn_mask_pre = attn_mask_post = attn_metadata.attn_mask diff --git a/vllm_ascend/patch/platform/__init__.py b/vllm_ascend/patch/platform/__init__.py index c6512f5fb3..c0772a8722 100644 --- a/vllm_ascend/patch/platform/__init__.py +++ b/vllm_ascend/patch/platform/__init__.py @@ -14,12 +14,5 @@ # See the License for the specific language governing permissions and # limitations under the License. -from vllm_ascend.utils import vllm_version_is - -# Import specific patches for different versions -if vllm_version_is("0.10.0"): - from vllm_ascend.patch.platform import patch_0_10_0 # noqa: F401 - from vllm_ascend.patch.platform import patch_common # noqa: F401 -else: - from vllm_ascend.patch.platform import patch_common # noqa: F401 - from vllm_ascend.patch.platform import patch_main # noqa: F401 +from vllm_ascend.patch.platform import patch_common # noqa: F401 +from vllm_ascend.patch.platform import patch_main # noqa: F401 diff --git a/vllm_ascend/patch/platform/patch_0_10_0/__init__.py b/vllm_ascend/patch/platform/patch_0_10_0/__init__.py deleted file mode 100644 index 116c73c06c..0000000000 --- a/vllm_ascend/patch/platform/patch_0_10_0/__init__.py +++ /dev/null @@ -1,16 +0,0 @@ -# -# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved. -# This file is a part of the vllm-ascend project. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# diff --git a/vllm_ascend/patch/worker/__init__.py b/vllm_ascend/patch/worker/__init__.py index a3e572b0e6..d294f14eb3 100644 --- a/vllm_ascend/patch/worker/__init__.py +++ b/vllm_ascend/patch/worker/__init__.py @@ -15,12 +15,5 @@ # limitations under the License. # -from vllm_ascend.utils import vllm_version_is - -# Import specific patches for different versions -if vllm_version_is("0.10.0"): - from vllm_ascend.patch.worker import patch_0_10_0 # noqa: F401 - from vllm_ascend.patch.worker import patch_common # noqa: F401 -else: - from vllm_ascend.patch.worker import patch_common # noqa: F401 - from vllm_ascend.patch.worker import patch_main # noqa: F401 +from vllm_ascend.patch.worker import patch_common # noqa: F401 +from vllm_ascend.patch.worker import patch_main # noqa: F401 diff --git a/vllm_ascend/patch/worker/patch_0_10_0/__init__.py b/vllm_ascend/patch/worker/patch_0_10_0/__init__.py deleted file mode 100644 index d95e2e302d..0000000000 --- a/vllm_ascend/patch/worker/patch_0_10_0/__init__.py +++ /dev/null @@ -1,18 +0,0 @@ -# -# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved. -# This file is a part of the vllm-ascend project. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# - -import vllm_ascend.patch.worker.patch_0_10_0.patch_sampler_gather_logprobs # noqa diff --git a/vllm_ascend/patch/worker/patch_0_10_0/patch_sampler_gather_logprobs.py b/vllm_ascend/patch/worker/patch_0_10_0/patch_sampler_gather_logprobs.py deleted file mode 100644 index 1e6b44ea8b..0000000000 --- a/vllm_ascend/patch/worker/patch_0_10_0/patch_sampler_gather_logprobs.py +++ /dev/null @@ -1,87 +0,0 @@ -# -# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved. -# SPDX-License-Identifier: Apache-2.0 -# This file is a part of the vllm-ascend project. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# - -import torch -from vllm.platforms import current_platform -from vllm.v1.outputs import LogprobsTensors -from vllm.v1.sample.sampler import Sampler - - -@torch.compile(dynamic=True, backend=current_platform.simple_compile_backend) -def batched_count_greater_than(x: torch.Tensor, - values: torch.Tensor) -> torch.Tensor: - """ - Counts elements in each row of x that are greater than the corresponding - value in values. Use torch.compile to generate an optimized kernel for - this function. otherwise, it will create additional copies of the input - tensors and cause memory issues. - Args: - x (torch.Tensor): A 2D tensor of shape (batch_size, n_elements). - values (torch.Tensor): A 2D tensor of shape (batch_size, 1). - Returns: - torch.Tensor: A 1D tensor of shape (batch_size,) with the counts. - """ - return (x >= values).sum(-1) - - -def gather_logprobs( - self, - logprobs: torch.Tensor, - num_logprobs: int, - token_ids: torch.Tensor, -) -> LogprobsTensors: - """ - Gather logprobs for topk and sampled/prompt token. - - Args: - logprobs: (num tokens) x (vocab) tensor - num_logprobs: minimum number of logprobs to - retain per token - token_ids: prompt tokens (if prompt logprobs) - or sampled tokens (if sampled - logprobs); 1D token ID tensor - with (num tokens) elements - Must be int64. - - Returns: - Top-k int indices tensor, (num tokens) x (num_logprobs + 1) - Top-k float logprobs tensor, (num tokens) x (num_logprobs + 1) - Sampled token rank tensor, (num tokens) - """ - assert token_ids.dtype == torch.int64 - # Find the topK values. - topk_logprobs, topk_indices = torch.topk(logprobs, num_logprobs, dim=-1) - - # Get with the logprob of the prompt or sampled token. - token_ids = token_ids.unsqueeze(-1) - token_logprobs = logprobs.gather(-1, token_ids) - - # Compute the ranks of the actual token. - token_ranks = batched_count_greater_than(logprobs, token_logprobs) - - # Concatenate together with the topk. - indices = torch.cat((token_ids, topk_indices), dim=1) - logprobs = torch.cat((token_logprobs, topk_logprobs), dim=1) - - # Use int32 to reduce the tensor size. - indices = indices.to(torch.int32) - - return LogprobsTensors(indices, logprobs, token_ranks) - - -Sampler.gather_logprobs = gather_logprobs diff --git a/vllm_ascend/worker/model_runner_v1.py b/vllm_ascend/worker/model_runner_v1.py index a8355b3657..14f89bb5d6 100644 --- a/vllm_ascend/worker/model_runner_v1.py +++ b/vllm_ascend/worker/model_runner_v1.py @@ -51,11 +51,12 @@ from vllm.model_executor.models.interfaces import supports_transcription from vllm.model_executor.models.interfaces_base import ( VllmModelForPooling, is_pooling_model, is_text_generation_model) -from vllm.multimodal.inputs import MultiModalKwargs, PlaceholderRange -from vllm.multimodal.utils import group_mm_inputs_by_modality +from vllm.multimodal.inputs import MultiModalKwargsItem, PlaceholderRange +from vllm.multimodal.utils import group_mm_kwargs_by_modality from vllm.pooling_params import PoolingParams from vllm.sampling_params import SamplingType from vllm.sequence import IntermediateTensors +from vllm.tasks import GenerationTask, SupportedTask from vllm.utils import (STR_DTYPE_TO_TORCH_DTYPE, DeviceMemoryProfiler, LazyLoader, cdiv) from vllm.v1.kv_cache_interface import (FullAttentionSpec, KVCacheConfig, @@ -66,6 +67,7 @@ from vllm.v1.sample.metadata import SamplingMetadata from vllm.v1.spec_decode.metadata import SpecDecodeMetadata from vllm.v1.spec_decode.ngram_proposer import NgramProposer +from vllm.v1.worker.kv_connector_model_runner_mixin import KVConnectorOutput from vllm.v1.worker.lora_model_runner_mixin import LoRAModelRunnerMixin from vllm.v1.worker.utils import (bind_kv_cache, gather_mm_placeholders, sanity_check_mm_encoder_outputs, @@ -86,17 +88,11 @@ from vllm_ascend.sample.rejection_sampler import AscendRejectionSampler from vllm_ascend.utils import (ACL_FORMAT_FRACTAL_ND, ACL_FORMAT_FRACTAL_NZ, ProfileExecuteDuration, is_310p, - maybe_converting_weight_acl_format, - vllm_version_is) + maybe_converting_weight_acl_format) from vllm_ascend.worker.eagle_proposer_v1 import EagleProposer from vllm_ascend.worker.mtp_proposer_v1 import MtpProposer from vllm_ascend.worker.npu_input_batch import CachedRequestState, InputBatch -if not vllm_version_is("0.10.0"): - from vllm.tasks import GenerationTask, SupportedTask - from vllm.v1.worker.kv_connector_model_runner_mixin import \ - KVConnectorOutput - if TYPE_CHECKING: import xgrammar as xgr # type: ignore[import-untyped] from vllm.v1.core.sched.output import SchedulerOutput @@ -479,7 +475,7 @@ def _update_states(self, scheduler_output: "SchedulerOutput") -> None: self.requests[req_id] = CachedRequestState( req_id=req_id, prompt_token_ids=new_req_data.prompt_token_ids, - mm_inputs=new_req_data.mm_inputs, + mm_kwargs=new_req_data.mm_kwargs, mm_positions=new_req_data.mm_positions, sampling_params=sampling_params, pooling_params=new_req_data.pooling_params, @@ -497,18 +493,20 @@ def _update_states(self, scheduler_output: "SchedulerOutput") -> None: second_per_grid_ts = [] audio_feature_lengths = [] use_audio_in_video = False - for mm_input in self.requests[req_id].mm_inputs: + + for item in self.requests[req_id].mm_kwargs: + mm_input = item.require_data() if mm_input.get("image_grid_thw") is not None: - image_grid_thw.extend( + image_grid_thw.append( mm_input["image_grid_thw"].tolist()) if mm_input.get("video_grid_thw") is not None: - video_grid_thw.extend( + video_grid_thw.append( mm_input["video_grid_thw"].tolist()) if mm_input.get("second_per_grid_ts") is not None: - second_per_grid_ts.extend( + second_per_grid_ts.append( mm_input["second_per_grid_ts"]) if mm_input.get("audio_feature_lengths") is not None: - audio_feature_lengths.extend( + audio_feature_lengths.append( mm_input["audio_feature_lengths"]) if mm_input.get("use_audio_in_video") is True: use_audio_in_video = True @@ -912,13 +910,13 @@ def _execute_mm_encoder(self, scheduler_output: "SchedulerOutput"): return # Batch the multi-modal inputs. - mm_inputs = list[MultiModalKwargs]() + mm_kwargs = list[MultiModalKwargsItem]() req_ids_pos = list[tuple[str, int, PlaceholderRange]]() for req_id, encoder_input_ids in scheduled_encoder_inputs.items(): req_state = self.requests[req_id] for mm_input_id in encoder_input_ids: - mm_inputs.append(req_state.mm_inputs[mm_input_id]) + mm_kwargs.append(req_state.mm_kwargs[mm_input_id]) req_ids_pos.append( (req_id, mm_input_id, req_state.mm_positions[mm_input_id])) @@ -929,14 +927,12 @@ def _execute_mm_encoder(self, scheduler_output: "SchedulerOutput"): # in the same batch while still being able to benefit from batching # multimodal inputs. The proper solution should be reordering the # encoder outputs. - grouped_mm_inputs_list = group_mm_inputs_by_modality(mm_inputs) - encoder_outputs = [] - for grouped_mm_inputs in grouped_mm_inputs_list: - batched_mm_inputs = MultiModalKwargs.batch(grouped_mm_inputs) - batched_mm_inputs = MultiModalKwargs.as_kwargs(batched_mm_inputs, - device=self.device) - + for _, num_items, mm_kwargs_group in group_mm_kwargs_by_modality( + mm_kwargs, + device=self.device, + pin_memory=True, + ): # Run the encoder. # `curr_group_outputs` is either of the following: # 1. A tensor of shape (num_items, feature_size, hidden_size) @@ -945,11 +941,11 @@ def _execute_mm_encoder(self, scheduler_output: "SchedulerOutput"): # (feature_size, hidden_size) in case the feature size is dynamic # depending on the input multimodal items. curr_group_outputs = self.model.get_multimodal_embeddings( - **batched_mm_inputs) + **mm_kwargs_group) sanity_check_mm_encoder_outputs( curr_group_outputs, - expected_num_items=len(grouped_mm_inputs), + expected_num_items=num_items, ) for output in curr_group_outputs: @@ -1604,12 +1600,7 @@ def _pool( pooler_output.append(raw_output.data.cpu()) else: pooler_output.append(None) - extra_args = ({ - "finished_sending": finished_sending, - "finished_recving": finished_recving - } if vllm_version_is("0.10.0") else { - "kv_connector_output": kv_connector_output - }) + extra_args = ({"kv_connector_output": kv_connector_output}) return ModelRunnerOutput( req_ids=self.input_batch.req_ids, @@ -1645,15 +1636,14 @@ def execute_model( finished_recving) = (self._process_reqs(scheduler_output, intermediate_tensors)) kv_connector_output = None - if not vllm_version_is("0.10.0"): - if finished_sending is not None and finished_recving is not None: - kv_connector_output = KVConnectorOutput( - finished_sending=finished_sending, - finished_recving=finished_recving) - else: - kv_connector_output = None - finished_sending = None - finished_recving = None + if finished_sending is not None and finished_recving is not None: + kv_connector_output = KVConnectorOutput( + finished_sending=finished_sending, + finished_recving=finished_recving) + else: + kv_connector_output = None + finished_sending = None + finished_recving = None with ProfileExecuteDuration().capture_async("post process"): # Broadcast PP output for external_launcher (torchrun) # to make sure we are synced across pp ranks @@ -1665,12 +1655,7 @@ def execute_model( if not get_pp_group().is_last_rank: # For mid-pipeline stages, return the hidden states. if not broadcast_pp_output: - if kv_connector_output is not None: - hidden_states.kv_connector_output = kv_connector_output - else: - #TODO: Remove this after we drop vllm v0.10.0 - hidden_states.finished_sending = finished_sending - hidden_states.finished_recving = finished_recving + hidden_states.kv_connector_output = kv_connector_output return hidden_states assert isinstance(hidden_states, IntermediateTensors) get_pp_group().send_tensor_dict( @@ -1815,12 +1800,7 @@ def execute_model( if has_kv_transfer_group(): get_kv_transfer_group().clear_connector_metadata() - extra_args = ({ - "finished_sending": finished_sending, - "finished_recving": finished_recving - } if vllm_version_is("0.10.0") else { - "kv_connector_output": kv_connector_output - }) + extra_args = ({"kv_connector_output": kv_connector_output}) model_runner_output = ModelRunnerOutput( req_ids=self.input_batch.req_ids, diff --git a/vllm_ascend/worker/npu_input_batch.py b/vllm_ascend/worker/npu_input_batch.py index d0acd04cd0..9b8132cc87 100644 --- a/vllm_ascend/worker/npu_input_batch.py +++ b/vllm_ascend/worker/npu_input_batch.py @@ -43,7 +43,7 @@ class CachedRequestState: req_id: str prompt_token_ids: list[int] - mm_inputs: list[MultiModalKwargs] + mm_kwargs: list[MultiModalKwargs] mm_positions: list[PlaceholderRange] sampling_params: Optional[SamplingParams] pooling_params: Optional[PoolingParams] diff --git a/vllm_ascend/worker/worker_v1.py b/vllm_ascend/worker/worker_v1.py index dfde5499e9..19ef2ef70d 100644 --- a/vllm_ascend/worker/worker_v1.py +++ b/vllm_ascend/worker/worker_v1.py @@ -34,6 +34,7 @@ from vllm.logger import logger from vllm.lora.request import LoRARequest from vllm.sequence import IntermediateTensors +from vllm.tasks import SupportedTask from vllm.utils import STR_DTYPE_TO_TORCH_DTYPE, GiB_bytes from vllm.v1.core.sched.output import SchedulerOutput from vllm.v1.kv_cache_interface import KVCacheConfig, KVCacheSpec @@ -45,12 +46,9 @@ from vllm_ascend.distributed.parallel_state import init_ascend_model_parallel from vllm_ascend.platform import NPUPlatform from vllm_ascend.utils import (init_ascend_soc_version, sleep_mode_enabled, - try_register_lib, vllm_version_is) + try_register_lib) from vllm_ascend.worker.model_runner_v1 import NPUModelRunner -if not vllm_version_is("0.10.0"): - from vllm.tasks import SupportedTask - class NPUWorker(WorkerBase): @@ -209,26 +207,15 @@ def execute_model( if not has_kv_transfer_group(): return None - is_legacy = vllm_version_is("0.10.0") - - if is_legacy: - finished_sending = output.finished_sending - finished_recving = output.finished_recving - else: - kv_connector_output = output.kv_connector_output - finished_sending = kv_connector_output.finished_sending - finished_recving = kv_connector_output.finished_recving + kv_connector_output = output.kv_connector_output + finished_sending = kv_connector_output.finished_sending + finished_recving = kv_connector_output.finished_recving if not finished_sending and not finished_recving: return EMPTY_MODEL_RUNNER_OUTPUT new_output = copy.copy(EMPTY_MODEL_RUNNER_OUTPUT) - - if is_legacy: - new_output.finished_sending = finished_sending - new_output.finished_recving = finished_recving - else: - new_output.kv_connector_output = kv_connector_output + new_output.kv_connector_output = kv_connector_output return new_output assert isinstance(output, ModelRunnerOutput)