diff --git a/tests/e2e/multicard/test_torchair_graph_mode.py b/tests/e2e/multicard/test_torchair_graph_mode.py index 71d33f0c82..5e95b9dcfc 100644 --- a/tests/e2e/multicard/test_torchair_graph_mode.py +++ b/tests/e2e/multicard/test_torchair_graph_mode.py @@ -21,8 +21,10 @@ """ import os from typing import Dict +from unittest.mock import patch from tests.e2e.conftest import VllmRunner +from vllm_ascend.ascend_forward_context import _get_fused_moe_state os.environ["PYTORCH_NPU_ALLOC_CONF"] = "max_split_size_mb:256" @@ -162,3 +164,67 @@ def test_e2e_pangu_with_torchair(): }, } _pangu_torchair_test_fixture(additional_config) + + +def _qwen_torchair_test_fixture( + model, + tp, + enable_expert_parallel, +): + # The current access control does not support 16 cards, + # so the MC2 operator in Qwen's graph mode cannot run. + # Once 16-card support is available, + # this e2e can be switched to graph mode. + example_prompts = [ + "Hello, my name is", + "The president of the United States is", + "The capital of France is", + "The future of AI is", + ] + + additional_config = { + "torchair_graph_config": { + "enabled": False, + }, + "ascend_scheduler_config": { + "enabled": True, + }, + "refresh": True, + } + + with VllmRunner( + model, + dtype="half", + tensor_parallel_size=tp, + distributed_executor_backend="mp", + enforce_eager=True, + additional_config=additional_config, + enable_expert_parallel=enable_expert_parallel, + ) as vllm_model: + # use greedy sampler to make sure the generated results are fix + vllm_output = vllm_model.generate_greedy(example_prompts, 5) + + # NOTE: vllm-ascend/pangu-pro-moe-pruing is only part of PanguProMoE + # with 2 hidden layers, thus the golden results seems inaccurate. + # This will only change if accuracy changes with the official weights + # of PanguProMoE. + golden_results = [ + 'Hello, my name is Remempondeprecatedmiot忱', + 'The president of the United States is Remem下的一个 rever ceremoni Segnali', + 'The capital of France is Rememvoud administrativ Remem投', + 'The future of AI isotope Segnali Zoeken精细化 supus', + ] + + assert len(golden_results) == len(vllm_output) + for i in range(len(vllm_output)): + print(f"Generated text: {vllm_output[i][1]!r}") + + +def test_e2e_qwen3_moe_with_torchair(): + + def stubbed_get_state(ep_size, with_prefill, is_deepseek_v3_r1): + return _get_fused_moe_state(16, with_prefill, is_deepseek_v3_r1) + + with patch('vllm_ascend.ascend_forward_context._get_fused_moe_state', + side_effect=stubbed_get_state): + _qwen_torchair_test_fixture("Qwen/Qwen3-30B-A3B", 2, True) diff --git a/tests/ut/models/test_qwen3_moe.py b/tests/ut/models/test_qwen3_moe.py index 71be045a64..cf521e4dc9 100644 --- a/tests/ut/models/test_qwen3_moe.py +++ b/tests/ut/models/test_qwen3_moe.py @@ -12,11 +12,15 @@ # limitations under the License. # This file is a part of the vllm-ascend project. # +import math +import unittest import pytest +import torch from vllm.model_executor.models.qwen3_moe import Qwen3MoeForCausalLM -from vllm_ascend.models.qwen3_moe import CustomQwen3MoeForCausalLM +from vllm_ascend.models.qwen3_moe import (CustomQwen3MoeAttention, + CustomQwen3MoeForCausalLM) class TestCustomQwen3MoeForCausalLM: @@ -44,3 +48,51 @@ def test_packed_modules_mapping_structure(self): ] } assert CustomQwen3MoeForCausalLM.packed_modules_mapping == expected_mapping + + +class DummyRMSNorm: + + def __init__(self, dim: int, eps: float = 1e-6): + self.dim = dim + self.eps = eps + + def __call__(self, x): + mean_sq = x.pow(2).mean(dim=-1, keepdim=True) + denom = (mean_sq + self.eps).sqrt() + return x / denom + + +class TestCustomQwen3MoeAttention(unittest.TestCase): + + def setUp(self): + self.batch = 2 + self.seq_len = 3 + self.q_size = 8 + self.kv_size = 8 + self.head_dim = 4 + self.rms_eps = 1e-6 + + total_dim = self.q_size + 2 * self.kv_size + + self.qkv = torch.arange(self.batch * self.seq_len * total_dim, + dtype=torch.float32).reshape( + self.batch, self.seq_len, total_dim) + + def test_constant_input_normalization(self): + ones_qkv = torch.ones((1, 1, self.q_size + 2 * self.kv_size), + dtype=torch.float32) + + q_norm = DummyRMSNorm(self.head_dim, self.rms_eps) + k_norm = DummyRMSNorm(self.head_dim, self.rms_eps) + q, k, v = CustomQwen3MoeAttention.normalize_qkv( + ones_qkv, self.q_size, self.kv_size, self.head_dim, q_norm, k_norm) + + norm_val = 1.0 / math.sqrt(1.0 + self.rms_eps) + + expected_q = torch.full((1, 1, self.q_size), norm_val) + expected_k = torch.full((1, 1, self.kv_size), norm_val) + expected_v = torch.ones((1, 1, self.kv_size), dtype=torch.float32) + + self.assertTrue(torch.allclose(q, expected_q, atol=1e-6)) + self.assertTrue(torch.allclose(k, expected_k, atol=1e-6)) + self.assertTrue(torch.equal(v, expected_v)) diff --git a/tests/ut/ops/test_rotary_embedding.py b/tests/ut/ops/test_rotary_embedding.py index 3b388e09f2..a97ca1b13d 100644 --- a/tests/ut/ops/test_rotary_embedding.py +++ b/tests/ut/ops/test_rotary_embedding.py @@ -1,12 +1,19 @@ import math +from unittest import mock from unittest.mock import MagicMock, patch +import pytest import torch +import torch_npu from tests.ut.base import TestBase +from vllm_ascend.ops.rotary_embedding import __set_cos_sin_cache # noqa E402 +from vllm_ascend.ops.rotary_embedding import \ + __set_cos_sin_cache as raw__set_cos_sin_cache from vllm_ascend.ops.rotary_embedding import (custom_rotary_embedding_enabled, native_rope_deepseek_forward, - rope_forward_oot, rotate_half, + rope_forward, rope_forward_oot, + rotate_half, yarn_find_correction_dim, yarn_get_mscale) @@ -312,3 +319,113 @@ def test_scale_greater_than_1(self): expected, places=6, msg=f"Failed for scale={scale}, mscale={mscale}") + + +class MockRotaryEmbedding: + + def __init__(self, base, rotary_dim, max_position_embeddings): + self.base = base + self.rotary_dim = rotary_dim + self.max_position_embeddings = max_position_embeddings + + +@pytest.fixture +def dummy_module(): + return MockRotaryEmbedding(base=10000.0, + rotary_dim=64, + max_position_embeddings=512) + + +class TestSetCosSinCache: + + def test_set_cos_sin_cache_generates_real_tensors(self, dummy_module): + calls = [] + + def fake_register_buffer(name, tensor, persistent=True): + setattr(dummy_module, name, tensor) + calls.append(name) + + dummy_module.register_buffer = fake_register_buffer + seq_len = 128 + device = torch.device("cpu") + dtype = torch.float32 + + raw__set_cos_sin_cache(dummy_module, seq_len, device, dtype) + + assert calls == ['inv_freq', 'cos', 'sin'] + + assert isinstance(dummy_module.inv_freq, torch.Tensor) + assert dummy_module.inv_freq.shape == (dummy_module.rotary_dim // 2, ) + assert dummy_module.inv_freq.device == device + assert dummy_module.inv_freq.dtype == torch.float32 + + expected_shape = (dummy_module.max_position_embeddings, + dummy_module.rotary_dim) + for name in ('cos', 'sin'): + buf = getattr(dummy_module, name) + assert isinstance(buf, torch.Tensor) + assert buf.shape == expected_shape + assert buf.device == device + assert buf.dtype == torch.float32 + + +class DummyConfig: + + class TorchairGraphConfig: + enabled = True + + torchair_graph_config = TorchairGraphConfig() + + +class DummyModel: + + def __init__(self, head_size, max_pos): + self.head_size = head_size + self.max_position_embeddings = max_pos + self.cos = torch.randn(max_pos, head_size) + self.sin = torch.randn(max_pos, head_size) + + def embed(self, positions, weight): + B, S = positions.shape + return torch.ones(B, S, self.head_size) * 0.5 + + +@mock.patch("vllm_ascend.ops.rotary_embedding.get_ascend_config", + return_value=DummyConfig()) +@mock.patch.object(torch_npu, "npu_apply_rotary_pos_emb") +@mock.patch("vllm_ascend.ops.rotary_embedding.__set_cos_sin_cache") +def test_rope_forward_output_shape(mock_set_cache, mock_npu_apply, + mock_get_ascend_config): + batch_size = 2 + seq_len = 4 + num_heads = 3 + head_size = 5 + + q = torch.randn(batch_size, seq_len, num_heads * head_size) + k = torch.randn_like(q) + + positions = torch.arange(seq_len).unsqueeze(0).repeat(batch_size, 1) + + model = DummyModel(head_size=head_size, max_pos=100) + + def fake_apply_rotary(q_in, k_in, cos, sin): + return q_in, k_in + + mock_npu_apply.side_effect = fake_apply_rotary + + q_out, k_out = rope_forward( + model, + positions=positions, + query=q, + key=k, + offsets=None, + is_neox_style_override=None, + max_seq_len=None, + is_prefill=False, # no rope_forward_oot + is_qwen_torchair=True, # go rotary + ) + + assert q_out.shape == (batch_size, 1, seq_len, num_heads * head_size) + assert k_out.shape == (batch_size, 1, seq_len, num_heads * head_size) + + mock_set_cache.assert_not_called() diff --git a/tests/ut/test_ascend_config.py b/tests/ut/test_ascend_config.py index ec00c0d965..49b9abeabc 100644 --- a/tests/ut/test_ascend_config.py +++ b/tests/ut/test_ascend_config.py @@ -232,7 +232,7 @@ def test_check_ascend_config_wrong_case(self): def test_check_torchair_supported(self): test_cases = [('deepseek_v3', True), ('PanguProMoE', True), - ('qwen', False), ('llama', False)] + ('qwen', True), ('llama', False)] for model_type, expected_output in test_cases: self.assertEqual(_check_torchair_supported(model_type), expected_output) diff --git a/vllm_ascend/ascend_config.py b/vllm_ascend/ascend_config.py index 659f4415f7..91b9cb9be9 100644 --- a/vllm_ascend/ascend_config.py +++ b/vllm_ascend/ascend_config.py @@ -17,7 +17,7 @@ from vllm.logger import logger -TORCHAIR_MODEL_LIST = ["deepseek", "pangu", "kimi_k2"] +TORCHAIR_MODEL_LIST = ["deepseek", "pangu", "kimi_k2", "qwen"] def _check_torchair_supported(model_type: str): @@ -159,7 +159,7 @@ def check_ascend_config(vllm_config, enforce_eager): else: # torchair_graph case if ascend_config.torchair_graph_config.enabled: - # torchair_graph is supported for deepseek/pangu model only. + # torchair_graph is supported for deepseek/pangu/qwen model only. if vllm_config.model_config: model_type = vllm_config.model_config.hf_config.model_type if not _check_torchair_supported(model_type): diff --git a/vllm_ascend/attention/attention_v1_torchair.py b/vllm_ascend/attention/attention_v1_torchair.py index 4d84bac976..f95fe1ea55 100644 --- a/vllm_ascend/attention/attention_v1_torchair.py +++ b/vllm_ascend/attention/attention_v1_torchair.py @@ -378,8 +378,9 @@ def forward( shape = [batch_size * seq_len, num_heads, head_size] """ num_tokens = query.shape[0] - use_kv_cache_quant = kv_cache is not None and kv_cache[0].numel( - ) > 0 and kv_cache[0].dtype == torch.int8 + use_kv_cache_quant = (kv_cache is not None and len(kv_cache) > 0 + and kv_cache[0].numel() > 0 + and kv_cache[0].dtype == torch.int8) if output is None: output = torch.empty(num_tokens, self.num_heads, diff --git a/vllm_ascend/models/__init__.py b/vllm_ascend/models/__init__.py index f47e821b34..8d12fa741a 100644 --- a/vllm_ascend/models/__init__.py +++ b/vllm_ascend/models/__init__.py @@ -59,3 +59,6 @@ def register_model(): ModelRegistry.register_model( "PanguProMoEForCausalLM", "vllm_ascend.models.pangu_moe:PanguProMoEForCausalLM") + + ModelRegistry.register_model( + "Qwen2ForCausalLM", "vllm_ascend.models.qwen2:CustomQwen2ForCausalLM") diff --git a/vllm_ascend/models/qwen2.py b/vllm_ascend/models/qwen2.py new file mode 100644 index 0000000000..fa9cd7be1d --- /dev/null +++ b/vllm_ascend/models/qwen2.py @@ -0,0 +1,364 @@ +# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved. +# Copyright 2023 The vLLM team. +# +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# This file is a part of the vllm-ascend project. + +from collections.abc import Iterable +from typing import Any, List, Optional, Union + +import torch +import torch.nn.functional as F +import vllm +import vllm.envs as envs +from torch import nn +from transformers import Qwen2Config +from vllm.attention import AttentionMetadata, AttentionType +from vllm.compilation.decorators import support_torch_compile +from vllm.config import CacheConfig, VllmConfig +from vllm.distributed import (get_pp_group, tensor_model_parallel_all_gather, + tensor_model_parallel_reduce_scatter) +from vllm.model_executor.layers.layernorm import RMSNorm +from vllm.model_executor.layers.logits_processor import LogitsProcessor +from vllm.model_executor.layers.quantization import QuantizationConfig +from vllm.model_executor.layers.rotary_embedding import RotaryEmbedding +from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead +from vllm.model_executor.models.interfaces import SupportsLoRA, SupportsPP +from vllm.model_executor.models.qwen2 import Qwen2Attention # noqa: F401 +from vllm.model_executor.models.qwen2 import Qwen2ForCausalLM # noqa: F401 +from vllm.model_executor.models.qwen2 import Qwen2MLP, Qwen2Model +from vllm.model_executor.models.utils import (AutoWeightsLoader, + PPMissingLayer, maybe_prefix) +from vllm.model_executor.sampling_metadata import SamplingMetadata +from vllm.sequence import IntermediateTensors + +from vllm_ascend.ascend_config import get_ascend_config +from vllm_ascend.attention.attention_v1 import AscendAttentionState + + +def all_gather_and_maybe_unpad( + hidden_states: torch.Tensor, + pad_size: int, +) -> torch.Tensor: + hidden_states = tensor_model_parallel_all_gather(hidden_states, 0) + if pad_size > 0: + return hidden_states[:-pad_size, :] + return hidden_states + + +def maybe_pad_and_reduce_scatter( + hidden_states: torch.Tensor, + pad_size: int, +) -> torch.Tensor: + if pad_size > 0: + hidden_states = F.pad(hidden_states, (0, 0, 0, pad_size)) + hidden_states = tensor_model_parallel_reduce_scatter(hidden_states, 0) + return hidden_states + + +class CustomQwen2Attention(Qwen2Attention): + + def __init__( + self, + hidden_size: int, + num_heads: int, + num_kv_heads: int, + max_position: int = 4096 * 32, + rope_theta: float = 10000, + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None, + rope_scaling: Optional[tuple] = None, + prefix: str = "", + attn_type: str = AttentionType.DECODER, + dual_chunk_attention_config: Optional[dict[str, Any]] = None, + ) -> None: + super().__init__( + hidden_size=hidden_size, + num_heads=num_heads, + num_kv_heads=num_kv_heads, + max_position=max_position, + rope_theta=rope_theta, + cache_config=cache_config, + quant_config=quant_config, + rope_scaling=rope_scaling, + prefix=prefix, + attn_type=attn_type, + dual_chunk_attention_config=dual_chunk_attention_config) + ascend_config = get_ascend_config() + self.torchair_graph_enabled = ascend_config.torchair_graph_config.enabled + + def forward( + self, + positions: torch.Tensor, + hidden_states: torch.Tensor, + kv_cache: Optional[torch.Tensor] = None, + attn_metadata: Optional[AttentionMetadata] = None) -> torch.Tensor: + qkv, _ = self.qkv_proj(hidden_states) + q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1) + if self.torchair_graph_enabled and attn_metadata is not None and attn_metadata.attn_state == AscendAttentionState.DecodeOnly: + q, k = self.rotary_emb(positions, + q, + k, + is_prefill=False, + is_qwen_torchair=True) + forward_kwargs = {} + if envs.VLLM_USE_V1: + output_shape = q.shape + output = torch.empty(output_shape, + dtype=q.dtype, + device=q.device) + forward_kwargs['output'] = output + + attn_output = self.attn.impl.forward(self.attn, + q, + k, + v, + kv_cache=kv_cache, + attn_metadata=attn_metadata, + trace_flag=False, + **forward_kwargs) + output, _ = self.o_proj(attn_output) + return output + else: + if type(self.rotary_emb) is RotaryEmbedding: + q, k = self.rotary_emb(positions, q, k, is_qwen_torchair=True) + else: + q, k = self.rotary_emb(positions, q, k) + attn_output = self.attn(q, k, v) + output, _ = self.o_proj(attn_output) + return output + + +class CustomQwen2DecoderLayer(nn.Module): + + def __init__( + self, + config: Qwen2Config, + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + ) -> None: + super().__init__() + self.hidden_size = config.hidden_size + # Requires transformers > 4.32.0 + rope_theta = getattr(config, "rope_theta", 1000000) + rope_scaling = getattr(config, "rope_scaling", None) + dual_chunk_attention_config = getattr(config, + "dual_chunk_attention_config", + None) + + # By default, Qwen2 uses causal attention as it is a decoder-only model. + # You can override the HF config with `is_causal=False` to enable + # bidirectional attention, which is used in some embedding models + # (e.g. Alibaba-NLP/gte-Qwen2-7B-instruct) + if getattr(config, "is_causal", True): + attn_type = AttentionType.DECODER + else: + attn_type = AttentionType.ENCODER_ONLY + + self.self_attn = CustomQwen2Attention( + hidden_size=self.hidden_size, + num_heads=config.num_attention_heads, + max_position=config.max_position_embeddings, + num_kv_heads=config.num_key_value_heads, + rope_theta=rope_theta, + cache_config=cache_config, + quant_config=quant_config, + rope_scaling=rope_scaling, + prefix=f"{prefix}.self_attn", + attn_type=attn_type, + dual_chunk_attention_config=dual_chunk_attention_config, + ) + self.mlp = Qwen2MLP( + hidden_size=self.hidden_size, + intermediate_size=config.intermediate_size, + hidden_act=config.hidden_act, + quant_config=quant_config, + prefix=f"{prefix}.mlp", + ) + self.input_layernorm = RMSNorm(config.hidden_size, + eps=config.rms_norm_eps) + self.post_attention_layernorm = RMSNorm(config.hidden_size, + eps=config.rms_norm_eps) + + def forward( + self, + positions: torch.Tensor, + hidden_states: torch.Tensor, + residual: Optional[torch.Tensor], + kv_cache: Optional[torch.Tensor] = None, + attn_metadata: Optional[AttentionMetadata] = None, + ) -> tuple[torch.Tensor, torch.Tensor]: + # Self Attention + if residual is None: + residual = hidden_states + hidden_states = self.input_layernorm(hidden_states) + else: + hidden_states, residual = self.input_layernorm( + hidden_states, residual) + + hidden_states = self.self_attn(positions=positions, + hidden_states=hidden_states, + kv_cache=kv_cache, + attn_metadata=attn_metadata) + + # Fully Connected + hidden_states, residual = self.post_attention_layernorm( + hidden_states, residual) + hidden_states = self.mlp(hidden_states) + return hidden_states, residual + + +@support_torch_compile( + dynamic_arg_dims={ + "input_ids": 0, + # positions is of shape (3, seq_len) if mrope is enabled for qwen2-vl, + # otherwise (seq_len, ). + "positions": -1, + "intermediate_tensors": 0, + "inputs_embeds": 0, + }) +class CustomQwen2Model(Qwen2Model): + + def __init__( + self, + *, + vllm_config: VllmConfig, + prefix: str = "", + decoder_layer_type: type[nn.Module] = CustomQwen2DecoderLayer): + super().__init__(vllm_config=vllm_config, + prefix=prefix, + decoder_layer_type=decoder_layer_type) + + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + kv_caches: Optional[List[torch.Tensor]] = None, + attn_metadata: Optional[AttentionMetadata] = None, + intermediate_tensors: Optional[IntermediateTensors] = None, + inputs_embeds: Optional[torch.Tensor] = None, + ) -> Union[torch.Tensor, IntermediateTensors]: + if get_pp_group().is_first_rank: + if inputs_embeds is not None: + hidden_states = inputs_embeds + else: + hidden_states = self.get_input_embeddings(input_ids) + residual = None + else: + assert intermediate_tensors is not None + hidden_states = intermediate_tensors["hidden_states"] + residual = intermediate_tensors["residual"] + + for i in range(self.start_layer, self.end_layer): + layer = self.layers[i] + kv_cache = kv_caches[i - self.start_layer] \ + if kv_caches is not None else None + hidden_states, residual = layer(positions, + hidden_states, + residual, + kv_cache=kv_cache, + attn_metadata=attn_metadata) + + if not get_pp_group().is_last_rank: + return IntermediateTensors({ + "hidden_states": hidden_states, + "residual": residual + }) + + hidden_states, _ = self.norm(hidden_states, residual) + return hidden_states + + +class CustomQwen2ForCausalLM(nn.Module, SupportsLoRA, SupportsPP): + # add `CustomQwen2Model` to init self.model + packed_modules_mapping = { + "qkv_proj": [ + "q_proj", + "k_proj", + "v_proj", + ], + "gate_up_proj": [ + "gate_proj", + "up_proj", + ], + } + + def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): + super().__init__() + config = vllm_config.model_config.hf_config + quant_config = vllm_config.quant_config + lora_config = vllm_config.lora_config + + self.config = config + self.lora_config = lora_config + + self.quant_config = quant_config + self.model = CustomQwen2Model(vllm_config=vllm_config, + prefix=maybe_prefix(prefix, "model")) + + if get_pp_group().is_last_rank: + if config.tie_word_embeddings: + self.lm_head = self.model.embed_tokens + else: + self.lm_head = ParallelLMHead(config.vocab_size, + config.hidden_size, + quant_config=quant_config, + prefix=maybe_prefix( + prefix, "lm_head")) + else: + self.lm_head = PPMissingLayer() + + self.logits_processor = LogitsProcessor(config.vocab_size) + + self.make_empty_intermediate_tensors = ( + self.model.make_empty_intermediate_tensors) + + def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: + return self.model.get_input_embeddings(input_ids) + + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + kv_caches: Optional[List[torch.Tensor]] = None, + attn_metadata: Optional[AttentionMetadata] = None, + intermediate_tensors: Optional[IntermediateTensors] = None, + inputs_embeds: Optional[torch.Tensor] = None, + ) -> Union[torch.Tensor, IntermediateTensors]: + hidden_states = self.model(input_ids, positions, kv_caches, + attn_metadata, intermediate_tensors, + inputs_embeds) + return hidden_states + + def compute_logits( + self, + hidden_states: torch.Tensor, + sampling_metadata: SamplingMetadata, + ) -> Optional[torch.Tensor]: + logits = self.logits_processor(self.lm_head, hidden_states, + sampling_metadata) + return logits + + def load_weights(self, weights: Iterable[tuple[str, + torch.Tensor]]) -> set[str]: + loader = AutoWeightsLoader( + self, + skip_prefixes=(["lm_head."] + if self.config.tie_word_embeddings else None), + ) + return loader.load_weights(weights) + + +vllm.model_executor.models.qwen2.Qwen2ForCausalLM = CustomQwen2ForCausalLM diff --git a/vllm_ascend/models/qwen3_moe.py b/vllm_ascend/models/qwen3_moe.py index 29ab675525..dd4a592d65 100644 --- a/vllm_ascend/models/qwen3_moe.py +++ b/vllm_ascend/models/qwen3_moe.py @@ -16,12 +16,13 @@ # limitations under the License. # Adapted from vllm/model_executor/models/qwen3_moe.py # This file is a part of the vllm-ascend project. - -from typing import Optional, Union +from typing import Any, List, Optional, Union import torch +import vllm.envs as envs from torch import nn from transformers import PretrainedConfig +from vllm.attention import Attention, AttentionMetadata from vllm.compilation.decorators import support_torch_compile from vllm.config import CacheConfig, CompilationLevel, VllmConfig from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size @@ -30,9 +31,12 @@ from vllm.forward_context import get_forward_context from vllm.model_executor.layers.fused_moe.layer import FusedMoE from vllm.model_executor.layers.layernorm import RMSNorm -from vllm.model_executor.layers.linear import ReplicatedLinear +from vllm.model_executor.layers.linear import (QKVParallelLinear, + ReplicatedLinear, + RowParallelLinear) from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.quantization import QuantizationConfig +from vllm.model_executor.layers.rotary_embedding import get_rope from vllm.model_executor.layers.vocab_parallel_embedding import ( ParallelLMHead, VocabParallelEmbedding) from vllm.model_executor.models.interfaces import (MixtureOfExperts, @@ -47,6 +51,8 @@ make_empty_intermediate_tensors_factory, make_layers, maybe_prefix) from vllm.sequence import IntermediateTensors +from vllm_ascend.ascend_config import get_ascend_config +from vllm_ascend.attention.attention_v1 import AscendAttentionState from vllm_ascend.ops.fused_moe import AscendFusedMoE from vllm_ascend.ops.sequence_parallel import (MetadataForPadding, init_metadata_for_sp) @@ -125,6 +131,137 @@ def forward( return hidden_states +class CustomQwen3MoeAttention(Qwen3MoeAttention): + + def __init__( + self, + hidden_size: int, + num_heads: int, + num_kv_heads: int, + rope_theta: float = 10000, + rope_scaling: Optional[dict[str, Any]] = None, + max_position_embeddings: int = 8192, + head_dim: Optional[int] = None, + rms_norm_eps: float = 1e-06, + qkv_bias: bool = False, + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + ) -> None: + nn.Module.__init__(self) + self.hidden_size = hidden_size + tp_size = get_tensor_model_parallel_world_size() + self.total_num_heads = num_heads + assert self.total_num_heads % tp_size == 0 + self.num_heads = self.total_num_heads // tp_size + self.total_num_kv_heads = num_kv_heads + if self.total_num_kv_heads >= tp_size: + # Number of KV heads is greater than TP size, so we partition + # the KV heads across multiple tensor parallel GPUs. + assert self.total_num_kv_heads % tp_size == 0 + else: + # Number of KV heads is less than TP size, so we replicate + # the KV heads across multiple tensor parallel GPUs. + assert tp_size % self.total_num_kv_heads == 0 + self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size) + self.head_dim = head_dim or (hidden_size // self.total_num_heads) + self.q_size = self.num_heads * self.head_dim + self.kv_size = self.num_kv_heads * self.head_dim + self.scaling = self.head_dim**-0.5 + self.rope_theta = rope_theta + self.max_position_embeddings = max_position_embeddings + + self.qkv_proj = QKVParallelLinear(hidden_size, + self.head_dim, + self.total_num_heads, + self.total_num_kv_heads, + bias=qkv_bias, + quant_config=quant_config, + prefix=f"{prefix}.qkv_proj") + + self.o_proj = RowParallelLinear(self.total_num_heads * self.head_dim, + hidden_size, + bias=False, + quant_config=quant_config, + prefix=f"{prefix}.o_proj") + + self.rotary_emb = get_rope( + self.head_dim, + rotary_dim=self.head_dim, + max_position=max_position_embeddings, + base=rope_theta, + rope_scaling=rope_scaling, + ) + self.attn = Attention(self.num_heads, + self.head_dim, + self.scaling, + num_kv_heads=self.num_kv_heads, + cache_config=cache_config, + quant_config=quant_config, + prefix=f"{prefix}.attn") + + self.q_norm = RMSNorm(self.head_dim, eps=rms_norm_eps) + self.k_norm = RMSNorm(self.head_dim, eps=rms_norm_eps) + ascend_config = get_ascend_config() + self.torchair_graph_enabled = ascend_config.torchair_graph_config.enabled + + @staticmethod + def normalize_qkv(qkv: torch.Tensor, q_size: int, kv_size: int, + head_dim: int, q_norm, k_norm): + q, k, v = qkv.split([q_size, kv_size, kv_size], dim=-1) + + q_by_head = q.view(*q.shape[:-1], q.shape[-1] // head_dim, head_dim) + q_by_head = q_norm(q_by_head) + q = q_by_head.view(q.shape) + + k_by_head = k.view(*k.shape[:-1], k.shape[-1] // head_dim, head_dim) + k_by_head = k_norm(k_by_head) + k = k_by_head.view(k.shape) + + return q, k, v + + def forward( + self, + positions: torch.Tensor, + hidden_states: torch.Tensor, + kv_cache: Optional[torch.Tensor] = None, + attn_metadata: Optional[AttentionMetadata] = None) -> torch.Tensor: + qkv, _ = self.qkv_proj(hidden_states) + q, k, v = self.normalize_qkv(qkv, self.q_size, self.kv_size, + self.head_dim, self.q_norm, self.k_norm) + + if (self.torchair_graph_enabled and attn_metadata is not None and + attn_metadata.attn_state == AscendAttentionState.DecodeOnly): + q, k = self.rotary_emb(positions, + q, + k, + is_prefill=False, + is_qwen_torchair=True) + forward_kwargs = {} + if envs.VLLM_USE_V1: + output_shape = q.shape + output = torch.empty(output_shape, + dtype=q.dtype, + device=q.device) + forward_kwargs['output'] = output + + attn_output = self.attn.impl.forward(self.attn, + q, + k, + v, + kv_cache=kv_cache, + attn_metadata=attn_metadata, + trace_flag=False, + **forward_kwargs) + output, _ = self.o_proj(attn_output) + return output + else: + q, k = self.rotary_emb(positions, q, k, is_qwen_torchair=True) + attn_output = self.attn(q, k, v) + output, _ = self.o_proj(attn_output) + return output + + class CustomQwen3MoeDecoderLayer(Qwen3MoeDecoderLayer): def __init__( @@ -142,7 +279,7 @@ def __init__( rope_scaling = getattr(config, "rope_scaling", None) max_position_embeddings = getattr(config, "max_position_embeddings", 8192) - self.self_attn = Qwen3MoeAttention( + self.self_attn = CustomQwen3MoeAttention( hidden_size=self.hidden_size, num_heads=config.num_attention_heads, num_kv_heads=config.num_key_value_heads, @@ -197,6 +334,8 @@ def forward( positions: torch.Tensor, hidden_states: torch.Tensor, residual: Optional[torch.Tensor], + kv_cache: Optional[torch.Tensor] = None, + attn_metadata: Optional[AttentionMetadata] = None, _metadata_for_padding: Optional[MetadataForPadding] = None, ) -> torch.Tensor: @@ -224,6 +363,8 @@ def forward( hidden_states = self.self_attn( positions=positions, hidden_states=hidden_states, + kv_cache=kv_cache, + attn_metadata=attn_metadata, ) if _metadata_for_padding and _metadata_for_padding.not_dummy_and_is_prefill: @@ -280,6 +421,8 @@ def forward( self, input_ids: torch.Tensor, positions: torch.Tensor, + kv_caches: Optional[List[torch.Tensor]] = None, + attn_metadata: Optional[AttentionMetadata] = None, intermediate_tensors: Optional[IntermediateTensors] = None, inputs_embeds: Optional[torch.Tensor] = None, _metadata_for_padding: Optional[MetadataForPadding] = None, @@ -300,6 +443,9 @@ def forward( positions, hidden_states, residual, + kv_caches[i - + self.start_layer] if kv_caches is not None else None, + attn_metadata, _metadata_for_padding=_metadata_for_padding) if not get_pp_group().is_last_rank: return IntermediateTensors({ @@ -378,11 +524,14 @@ def forward( self, input_ids: torch.Tensor, positions: torch.Tensor, + kv_caches: Optional[List[torch.Tensor]] = None, + attn_metadata: Optional[AttentionMetadata] = None, intermediate_tensors: Optional[IntermediateTensors] = None, inputs_embeds: Optional[torch.Tensor] = None, ) -> Union[torch.Tensor, IntermediateTensors]: _metadata_for_padding = init_metadata_for_sp( input_ids, self.enable_sequence_parallelism) - hidden_states = self.model(input_ids, positions, intermediate_tensors, + hidden_states = self.model(input_ids, positions, kv_caches, + attn_metadata, intermediate_tensors, inputs_embeds, _metadata_for_padding) return hidden_states diff --git a/vllm_ascend/ops/rotary_embedding.py b/vllm_ascend/ops/rotary_embedding.py index 3dd91ea63f..dd9b9c928d 100644 --- a/vllm_ascend/ops/rotary_embedding.py +++ b/vllm_ascend/ops/rotary_embedding.py @@ -19,6 +19,8 @@ from typing import Optional, Tuple import torch +import torch.nn.functional as F +import torch_npu from vllm.model_executor.layers.rotary_embedding import ( DeepseekScalingRotaryEmbedding, RotaryEmbedding) @@ -37,9 +39,11 @@ def rope_forward_oot( query: torch.Tensor, key: torch.Tensor, offsets: Optional[torch.Tensor] = None, - is_neox_style_override: Optional[bool] = None + is_neox_style_override: Optional[bool] = None, + is_qwen_torchair: Optional[bool] = False, ) -> Tuple[torch.Tensor, torch.Tensor]: - if get_ascend_config().torchair_graph_config.enabled: + if get_ascend_config( + ).torchair_graph_config.enabled and not is_qwen_torchair: return self.forward_native( positions, query, @@ -47,7 +51,6 @@ def rope_forward_oot( offsets, ) - import torch_npu query_shape, key_shape = query.shape, key.shape if self.cos_sin_cache.device != query.device: self.cos_sin_cache = self.cos_sin_cache.to(query.device) @@ -246,6 +249,98 @@ def _set_cos_sin_cache(self, seq_len, device, dtype): self.register_buffer("sin_cached", sin_cached, persistent=False) +def __set_cos_sin_cache(self, seq_len, device, dtype): + inv_freq = 1.0 / (self.base**(torch.arange( + 0, self.rotary_dim, 2, device=device, dtype=torch.float32) * + (1 / self.rotary_dim))) + self.register_buffer("inv_freq", inv_freq) + + t = torch.arange(self.max_position_embeddings, + device=self.inv_freq.device, + dtype=torch.float32) + freqs = torch.einsum("i,j->ij", t, self.inv_freq) + + emb = torch.cat((freqs, freqs), dim=-1) + self.register_buffer("cos", emb.cos().to(dtype=dtype), persistent=False) + self.register_buffer("sin", emb.sin().to(dtype=dtype), persistent=False) + self.embed = F.embedding + + +def qwen_rope_init_func( + self, + head_size: int, + rotary_dim: int, + max_position_embeddings: int, + base: float, + is_neox_style: bool, + dtype: torch.dtype, +) -> None: + super(RotaryEmbedding, self).__init__() + self.head_size = head_size + self.rotary_dim = rotary_dim + self.max_position_embeddings = max_position_embeddings + self.base = base + self.is_neox_style = is_neox_style + self.dtype = dtype + + cache = self._compute_cos_sin_cache() + cache = cache.to(dtype) + self.cos_sin_cache: torch.Tensor # type: ignore[misc] + self.register_buffer("cos_sin_cache", cache, persistent=False) + if get_ascend_config().torchair_graph_config.enabled: + __set_cos_sin_cache(self, + seq_len=max_position_embeddings, + device="npu", + dtype=dtype) + + +def rope_forward( + self, + positions: torch.Tensor, + query: torch.Tensor, + key: torch.Tensor, + offsets: Optional[torch.Tensor] = None, + is_neox_style_override: Optional[bool] = None, + max_seq_len: Optional[int] = None, + is_prefill: Optional[bool] = True, + is_qwen_torchair: Optional[bool] = False, +): + if (not get_ascend_config().torchair_graph_config.enabled + or not is_qwen_torchair or is_prefill): + return rope_forward_oot(self, positions, query, key, offsets, + is_neox_style_override, + is_qwen_torchair) # type: ignore + + if max_seq_len is not None and torch.gt(max_seq_len, + self.max_position_embeddings): + __set_cos_sin_cache(self, + seq_len=max_seq_len, + device=query.device, + dtype=torch.float32) + + # bsnd/bnsd + if positions is not None: + cos = self.embed(positions, self.cos) + sin = self.embed(positions, self.sin) + self.cos_embed = cos + self.sin_embed = sin + else: + cos = self.cos_embed + sin = self.sin_embed + + query = query.view(*query.shape[:-1], -1, self.head_size).contiguous() + key = key.view(*key.shape[:-1], -1, self.head_size).contiguous() + + cos = cos.unsqueeze(-2).unsqueeze(-2) + sin = sin.unsqueeze(-2).unsqueeze(-2) + + query = query.unsqueeze(1) + key = key.unsqueeze(1) + + q_embed, k_embed = torch_npu.npu_apply_rotary_pos_emb(query, key, cos, sin) + return q_embed.flatten(-2), k_embed.flatten(-2) + + def deepseek_rope_init_func( self, head_size: int, @@ -283,7 +378,8 @@ def deepseek_rope_init_func( device="npu") -RotaryEmbedding.forward_oot = rope_forward_oot +RotaryEmbedding.__init__ = qwen_rope_init_func +RotaryEmbedding.forward_oot = rope_forward # Note: we adopt the native huggingface deepseek rope initialization code from # https://huggingface.co/deepseek-ai/DeepSeek-V3-0324/blob/main/modeling_deepseek.py for