diff --git a/vllm_ascend/models/qwen3_next.py b/vllm_ascend/models/qwen3_next.py index 175a529abf..4a7a7a74df 100644 --- a/vllm_ascend/models/qwen3_next.py +++ b/vllm_ascend/models/qwen3_next.py @@ -618,14 +618,19 @@ def forward( positions: torch.Tensor = None, **kwargs: object, ): + self_attention_output = torch.empty_like(hidden_states) if residual is None: residual = hidden_states hidden_states = self.input_layernorm(hidden_states) + forward_context = get_forward_context() + if forward_context.sp_enabled: + tp_size = get_tensor_model_parallel_world_size() + chunk_size = (hidden_states.shape[0] + forward_context.pad_size) // tp_size + self_attention_output = self_attention_output[:chunk_size, :] else: hidden_states, residual = self.input_layernorm( hidden_states, residual) - self_attention_output = torch.empty_like(hidden_states) if self.layer_type == "linear_attention": self.linear_attn( hidden_states=hidden_states, diff --git a/vllm_ascend/ops/layernorm.py b/vllm_ascend/ops/layernorm.py index da48362f46..7c335ff558 100644 --- a/vllm_ascend/ops/layernorm.py +++ b/vllm_ascend/ops/layernorm.py @@ -19,7 +19,7 @@ import torch from vllm.forward_context import get_forward_context -from vllm.model_executor.layers.layernorm import RMSNorm +from vllm.model_executor.layers.layernorm import RMSNorm, GemmaRMSNorm def _addrmsnorm_forward_oot( @@ -130,3 +130,50 @@ def forward_oot( x, residual = super().forward_oot(x, residual) return x.add_(self.bias), residual return cast(torch.Tensor, super().forward_oot(x)).add_(self.bias) + + +class AscendGemmaRMSNorm(GemmaRMSNorm): + """RMS normalization for Gemma. + + Two differences from the above RMSNorm: + 1. x * (1 + w) instead of x * w. + 2. (x * w).to(orig_dtype) instead of x.to(orig_dtype) * w. + """ + + @staticmethod + def forward_static( + weight: torch.Tensor, + variance_epsilon: float, + x: torch.Tensor, + residual: Optional[torch.Tensor], + ) -> Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]: + """PyTorch-native implementation equivalent to forward().""" + orig_dtype = x.dtype + if residual is not None: + residual = torch.ops.vllm.maybe_chunk_residual(x, residual) + if orig_dtype == torch.float16: + x = x + residual.float() + else: + x = x + residual + residual = x + + x = x.float() + variance = x.pow(2).mean(dim=-1, keepdim=True) + x = x * torch.rsqrt(variance + variance_epsilon) + # Llama does x.to(float16) * w whilst Gemma is (x * w).to(float16) + # See https://github.com/huggingface/transformers/pull/29402 + x = x * (1.0 + weight.float()) + x = x.to(orig_dtype) + return x if residual is None else (x, residual) + + @staticmethod + def forward_oot( + self, + variance_epsilon: float, + x: torch.Tensor, + residual: Optional[torch.Tensor] = None, + ) -> Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]: + """PyTorch-native implementation equivalent to forward().""" + return self.forward_static(self.weight.data, self.variance_epsilon, x, + residual) + diff --git a/vllm_ascend/ops/linear.py b/vllm_ascend/ops/linear.py index 51399cc7fa..81ef10a60d 100644 --- a/vllm_ascend/ops/linear.py +++ b/vllm_ascend/ops/linear.py @@ -34,8 +34,7 @@ QuantizationConfig from vllm.model_executor.utils import set_weight_attrs -from vllm_ascend.ops.linear_op import (get_column_parallel_op, - get_row_parallel_op) +from vllm_ascend.ops.linear_op import get_parallel_op # TODO(realliujiaxu): Remove this class after linear of vllm supports custom comm group @@ -100,8 +99,8 @@ def __init__( return_bias: bool = True, disable_tp: bool = False, ): - self.custom_op, _, tp_size = get_column_parallel_op( - disable_tp, prefix, self) + self.custom_op, _, tp_size = get_parallel_op( + disable_tp, prefix, self, "column") # TODO(realliujiaxu): Replace the initialization code below with super().__init__ after linear of vllm supports custom comm group self.hidden_size = hidden_size self.head_size = head_size @@ -173,8 +172,8 @@ def __init__( return_bias: bool = True, disable_tp: bool = False, ): - self.custom_op, self.tp_rank, self.tp_size = get_column_parallel_op( - disable_tp, prefix, self) + self.custom_op, self.tp_rank, self.tp_size = get_parallel_op( + disable_tp, prefix, self, "column") # TODO(realliujiaxu): Replace the initialization code below with super().__init__ after linear of vllm supports custom comm group self.output_sizes = output_sizes assert all(output_size % self.tp_size == 0 @@ -222,8 +221,8 @@ def __init__( return_bias: bool = True, disable_tp: bool = False, ): - self.custom_op, self.tp_rank, self.tp_size = get_row_parallel_op( - disable_tp, prefix, self) + self.custom_op, self.tp_rank, self.tp_size = get_parallel_op( + disable_tp, prefix, self, "row") # TODO(realliujiaxu): Replace the initialization code below with super().__init__ after linear of vllm supports custom comm group # Divide the weight matrix along the first dimension. self.input_size_per_partition = divide(input_size, self.tp_size) @@ -304,8 +303,8 @@ def __init__( return_bias: bool = True, disable_tp: bool = False, ): - self.custom_op, self.tp_rank, self.tp_size = get_column_parallel_op( - disable_tp, prefix, self) + self.custom_op, self.tp_rank, self.tp_size = get_parallel_op( + disable_tp, prefix, self, "column") # TODO(realliujiaxu): Replace the initialization code below with super().__init__ after linear of vllm supports custom comm group self.input_size_per_partition = input_size self.output_size_per_partition = divide(output_size, self.tp_size) diff --git a/vllm_ascend/ops/linear_op.py b/vllm_ascend/ops/linear_op.py index 57044f58ae..6c75398d9c 100644 --- a/vllm_ascend/ops/linear_op.py +++ b/vllm_ascend/ops/linear_op.py @@ -98,9 +98,10 @@ def apply(self, input_): class CustomColumnParallelOp(CustomTensorParallelOp): - def __init__(self, layer): + def __init__(self, layer, skip_first_layer=False): super().__init__(layer) self.gather_output = None + self.skip_first_layer = skip_first_layer def update_attrs(self): super().update_attrs() @@ -153,7 +154,7 @@ def apply_impl( return output, output_bias -class SequenceMergedColumnParallelOp(CustomColumnParallelOp): +class SequenceColumnParallelOp(CustomColumnParallelOp): def apply_impl( self, input_: torch.Tensor @@ -169,42 +170,13 @@ def apply_impl( # Matrix multiply. assert self.quant_method is not None - input_ = torch.ops.vllm.maybe_all_gather_and_maybe_unpad(input_, True) - output_parallel = self.quant_method.apply(self.layer, input_, bias) - - if self.gather_output: - # All-gather across the partitions. - output = self.comm_group.all_gather(output_parallel) + input_ = None + if self.skip_first_layer: + layer_num = self.prefix.split('.')[2] + input_ = torch.ops.vllm.maybe_all_gather_and_maybe_unpad( + input_, layer_num != '0') else: - output = output_parallel - output_bias = self.bias if self.skip_bias_add else None - return output, output_bias - - -class SequenceQKVParallelOp(CustomColumnParallelOp): - - def __init__(self, layer, prefix): - super().__init__(layer) - self.prefix = prefix - - def apply_impl( - self, input_: torch.Tensor - ) -> Union[torch.Tensor, tuple[torch.Tensor, Optional[Parameter]]]: - """Linear layer with column parallelism. - - Implemented multiple optimization projects for dense models, such as FlashComm and - communication-computation fusion. - """ - - bias = self.bias if not self.skip_bias_add else None - - # Matrix multiply. - assert self.quant_method is not None - - layer_num = self.prefix.split('.')[2] - - input_ = torch.ops.vllm.maybe_all_gather_and_maybe_unpad( - input_, layer_num != '0') + input_ = torch.ops.vllm.maybe_all_gather_and_maybe_unpad(input_, True) output_parallel = self.quant_method.apply(self.layer, input_, bias) if self.gather_output: @@ -216,6 +188,7 @@ def apply_impl( return output, output_bias + class MLPRowParallelOp(CustomRowParallelOp): def __init__(self, layer): @@ -365,10 +338,6 @@ def update_attrs(self): class SequenceRowParallelOp(CustomRowParallelOp): - def __init__(self, layer, prefix): - super().__init__(layer) - self.prefix = prefix - def apply_impl( self, input_: torch.Tensor ) -> Union[torch.Tensor, tuple[torch.Tensor, Optional[Parameter]]]: @@ -405,50 +374,56 @@ def update_attrs(self): self.reduce_results = self.layer.reduce_results -def get_column_parallel_op( - disable_tp, prefix, layer -) -> Tuple[Optional[Union[MLPColumnParallelOp, SequenceMergedColumnParallelOp, - SequenceQKVParallelOp]], int, int]: - if disable_tp: - return None, 0, 1 +def _get_column_parallel_op( + prefix, layer +) -> Optional[Union[MLPColumnParallelOp, SequenceColumnParallelOp]]: + if mlp_tp_enable() and "gate_up_proj" in prefix: + return MLPColumnParallelOp(layer) + if enable_sp(): + if "gate_up_proj" in prefix: + return SequenceColumnParallelOp(layer) + if "in_proj" in prefix: + return SequenceColumnParallelOp(layer, True) + if "qkv_proj" in prefix or "conv1d" in prefix: + return SequenceColumnParallelOp(layer, True) - custom_op: Optional[Union[ - MLPColumnParallelOp, - SequenceMergedColumnParallelOp, - SequenceQKVParallelOp, - ]] = None - if "gate_up_proj" in prefix and mlp_tp_enable(): - custom_op = MLPColumnParallelOp(layer) - elif "gate_up_proj" in prefix and enable_sp(): - custom_op = SequenceMergedColumnParallelOp(layer) - elif enable_sp(): - custom_op = SequenceQKVParallelOp(layer, prefix) + return None - if custom_op is not None: - return custom_op, custom_op.tp_rank, custom_op.tp_size - return None, get_tp_group().rank_in_group, get_tp_group().world_size +def _get_row_parallel_op( + prefix, layer +) -> Optional[Union[MLPRowParallelOp, OProjRowParallelOp, + MatmulAllreduceRowParallelOp, + SequenceRowParallelOp]]: + if "down_proj" in prefix and mlp_tp_enable(): + return MLPRowParallelOp(layer) + if "o_proj" in prefix and oproj_tp_enable(): + return OProjRowParallelOp(layer) + if matmul_allreduce_enable(): + return MatmulAllreduceRowParallelOp(layer) + if enable_sp(): + if "o_proj" in prefix or "out_proj" in prefix: + return SequenceRowParallelOp(layer) + return None -def get_row_parallel_op( - disable_tp, prefix, layer -) -> Tuple[Optional[Union[MLPRowParallelOp, OProjRowParallelOp, - MatmulAllreduceRowParallelOp, - SequenceRowParallelOp]], int, int]: + +def get_parallel_op(disable_tp, prefix, layer, direct): if disable_tp: return None, 0, 1 - - custom_op: Optional[Union[MLPRowParallelOp, OProjRowParallelOp, - MatmulAllreduceRowParallelOp, - SequenceRowParallelOp]] = None - if "down_proj" in prefix and mlp_tp_enable(): - custom_op = MLPRowParallelOp(layer) - elif "o_proj" in prefix and oproj_tp_enable(): - custom_op = OProjRowParallelOp(layer) - elif matmul_allreduce_enable(): - custom_op = MatmulAllreduceRowParallelOp(layer) - elif enable_sp(): - custom_op = SequenceRowParallelOp(layer, prefix) + custom_op: Optional[Union[ + MLPColumnParallelOp, + SequenceColumnParallelOp, + MLPRowParallelOp, + OProjRowParallelOp, + MatmulAllreduceRowParallelOp, + SequenceRowParallelOp + ]] = None + if direct == "row": + custom_op = _get_row_parallel_op(prefix, layer) + + if direct == "column": + custom_op = _get_column_parallel_op(prefix, layer) if custom_op is not None: return custom_op, custom_op.tp_rank, custom_op.tp_size diff --git a/vllm_ascend/utils.py b/vllm_ascend/utils.py index 570756fd0c..24936d4a9e 100644 --- a/vllm_ascend/utils.py +++ b/vllm_ascend/utils.py @@ -500,7 +500,9 @@ def register_ascend_customop(vllm_config: Optional[VllmConfig] = None): from vllm_ascend.ops.activation import AscendQuickGELU, AscendSiluAndMul from vllm_ascend.ops.common_fused_moe import (AscendFusedMoE, AscendSharedFusedMoE) - from vllm_ascend.ops.layernorm import AscendQuantRMSNorm, AscendRMSNorm + from vllm_ascend.ops.layernorm import (AscendQuantRMSNorm, + AscendRMSNorm, + AscendGemmaRMSNorm) from vllm_ascend.ops.linear import (AscendColumnParallelLinear, AscendMergedColumnParallelLinear, AscendQKVParallelLinear, @@ -525,6 +527,7 @@ def register_ascend_customop(vllm_config: Optional[VllmConfig] = None): "ParallelLMHead": AscendParallelLMHead, "LogitsProcessor": AscendLogitsProcessor, "RMSNorm": AscendRMSNorm, + "GemmaRMSNorm": AscendGemmaRMSNorm, "FusedMoE": AscendFusedMoE, "SharedFusedMoE": AscendSharedFusedMoE, "MultiHeadLatentAttention": AscendMultiHeadLatentAttention,