Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 6 additions & 1 deletion vllm_ascend/models/qwen3_next.py
Original file line number Diff line number Diff line change
Expand Up @@ -618,14 +618,19 @@ def forward(
positions: torch.Tensor = None,
**kwargs: object,
):
self_attention_output = torch.empty_like(hidden_states)
if residual is None:
residual = hidden_states
hidden_states = self.input_layernorm(hidden_states)
forward_context = get_forward_context()
if forward_context.sp_enabled:
tp_size = get_tensor_model_parallel_world_size()
chunk_size = (hidden_states.shape[0] + forward_context.pad_size) // tp_size
self_attention_output = self_attention_output[:chunk_size, :]
else:
hidden_states, residual = self.input_layernorm(
hidden_states, residual)

self_attention_output = torch.empty_like(hidden_states)
if self.layer_type == "linear_attention":
self.linear_attn(
hidden_states=hidden_states,
Expand Down
49 changes: 48 additions & 1 deletion vllm_ascend/ops/layernorm.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@

import torch
from vllm.forward_context import get_forward_context
from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.layernorm import RMSNorm, GemmaRMSNorm


def _addrmsnorm_forward_oot(
Expand Down Expand Up @@ -130,3 +130,50 @@ def forward_oot(
x, residual = super().forward_oot(x, residual)
return x.add_(self.bias), residual
return cast(torch.Tensor, super().forward_oot(x)).add_(self.bias)


class AscendGemmaRMSNorm(GemmaRMSNorm):
"""RMS normalization for Gemma.

Two differences from the above RMSNorm:
1. x * (1 + w) instead of x * w.
2. (x * w).to(orig_dtype) instead of x.to(orig_dtype) * w.
"""

@staticmethod
def forward_static(
weight: torch.Tensor,
variance_epsilon: float,
x: torch.Tensor,
residual: Optional[torch.Tensor],
) -> Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]:
"""PyTorch-native implementation equivalent to forward()."""
orig_dtype = x.dtype
if residual is not None:
residual = torch.ops.vllm.maybe_chunk_residual(x, residual)
if orig_dtype == torch.float16:
x = x + residual.float()
else:
x = x + residual
residual = x

x = x.float()
variance = x.pow(2).mean(dim=-1, keepdim=True)
x = x * torch.rsqrt(variance + variance_epsilon)
# Llama does x.to(float16) * w whilst Gemma is (x * w).to(float16)
# See https://github.com/huggingface/transformers/pull/29402
x = x * (1.0 + weight.float())
x = x.to(orig_dtype)
return x if residual is None else (x, residual)

@staticmethod
def forward_oot(
self,
variance_epsilon: float,
x: torch.Tensor,
residual: Optional[torch.Tensor] = None,
) -> Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]:
"""PyTorch-native implementation equivalent to forward()."""
return self.forward_static(self.weight.data, self.variance_epsilon, x,
residual)
Comment on lines +169 to +178
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

The forward_oot method in AscendGemmaRMSNorm is incorrectly defined as a static method with a mismatched signature. It should be an instance method, and its signature should match the forward method it's overriding, which is (self, x, residual=None). The current implementation will lead to incorrect argument passing at runtime, as the variance_epsilon parameter in the signature will receive the x tensor, and the x parameter will receive the residual tensor. The variance_epsilon should be accessed from self.variance_epsilon instead of being a parameter.

Suggested change
@staticmethod
def forward_oot(
self,
variance_epsilon: float,
x: torch.Tensor,
residual: Optional[torch.Tensor] = None,
) -> Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]:
"""PyTorch-native implementation equivalent to forward()."""
return self.forward_static(self.weight.data, self.variance_epsilon, x,
residual)
def forward_oot(
self,
x: torch.Tensor,
residual: Optional[torch.Tensor] = None,
) -> Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]:
"""PyTorch-native implementation equivalent to forward()."""
return self.forward_static(self.weight.data, self.variance_epsilon, x,
residual)


19 changes: 9 additions & 10 deletions vllm_ascend/ops/linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,8 +34,7 @@
QuantizationConfig
from vllm.model_executor.utils import set_weight_attrs

from vllm_ascend.ops.linear_op import (get_column_parallel_op,
get_row_parallel_op)
from vllm_ascend.ops.linear_op import get_parallel_op


# TODO(realliujiaxu): Remove this class after linear of vllm supports custom comm group
Expand Down Expand Up @@ -100,8 +99,8 @@ def __init__(
return_bias: bool = True,
disable_tp: bool = False,
):
self.custom_op, _, tp_size = get_column_parallel_op(
disable_tp, prefix, self)
self.custom_op, _, tp_size = get_parallel_op(
disable_tp, prefix, self, "column")
# TODO(realliujiaxu): Replace the initialization code below with super().__init__ after linear of vllm supports custom comm group
self.hidden_size = hidden_size
self.head_size = head_size
Expand Down Expand Up @@ -173,8 +172,8 @@ def __init__(
return_bias: bool = True,
disable_tp: bool = False,
):
self.custom_op, self.tp_rank, self.tp_size = get_column_parallel_op(
disable_tp, prefix, self)
self.custom_op, self.tp_rank, self.tp_size = get_parallel_op(
disable_tp, prefix, self, "column")
# TODO(realliujiaxu): Replace the initialization code below with super().__init__ after linear of vllm supports custom comm group
self.output_sizes = output_sizes
assert all(output_size % self.tp_size == 0
Expand Down Expand Up @@ -222,8 +221,8 @@ def __init__(
return_bias: bool = True,
disable_tp: bool = False,
):
self.custom_op, self.tp_rank, self.tp_size = get_row_parallel_op(
disable_tp, prefix, self)
self.custom_op, self.tp_rank, self.tp_size = get_parallel_op(
disable_tp, prefix, self, "row")
# TODO(realliujiaxu): Replace the initialization code below with super().__init__ after linear of vllm supports custom comm group
# Divide the weight matrix along the first dimension.
self.input_size_per_partition = divide(input_size, self.tp_size)
Expand Down Expand Up @@ -304,8 +303,8 @@ def __init__(
return_bias: bool = True,
disable_tp: bool = False,
):
self.custom_op, self.tp_rank, self.tp_size = get_column_parallel_op(
disable_tp, prefix, self)
self.custom_op, self.tp_rank, self.tp_size = get_parallel_op(
disable_tp, prefix, self, "column")
# TODO(realliujiaxu): Replace the initialization code below with super().__init__ after linear of vllm supports custom comm group
self.input_size_per_partition = input_size
self.output_size_per_partition = divide(output_size, self.tp_size)
Expand Down
131 changes: 53 additions & 78 deletions vllm_ascend/ops/linear_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,9 +98,10 @@ def apply(self, input_):

class CustomColumnParallelOp(CustomTensorParallelOp):

def __init__(self, layer):
def __init__(self, layer, skip_first_layer=False):
super().__init__(layer)
self.gather_output = None
self.skip_first_layer = skip_first_layer

def update_attrs(self):
super().update_attrs()
Expand Down Expand Up @@ -153,7 +154,7 @@ def apply_impl(
return output, output_bias


class SequenceMergedColumnParallelOp(CustomColumnParallelOp):
class SequenceColumnParallelOp(CustomColumnParallelOp):

def apply_impl(
self, input_: torch.Tensor
Expand All @@ -169,42 +170,13 @@ def apply_impl(
# Matrix multiply.
assert self.quant_method is not None

input_ = torch.ops.vllm.maybe_all_gather_and_maybe_unpad(input_, True)
output_parallel = self.quant_method.apply(self.layer, input_, bias)

if self.gather_output:
# All-gather across the partitions.
output = self.comm_group.all_gather(output_parallel)
input_ = None
if self.skip_first_layer:
layer_num = self.prefix.split('.')[2]
input_ = torch.ops.vllm.maybe_all_gather_and_maybe_unpad(
input_, layer_num != '0')
else:
output = output_parallel
output_bias = self.bias if self.skip_bias_add else None
return output, output_bias


class SequenceQKVParallelOp(CustomColumnParallelOp):

def __init__(self, layer, prefix):
super().__init__(layer)
self.prefix = prefix

def apply_impl(
self, input_: torch.Tensor
) -> Union[torch.Tensor, tuple[torch.Tensor, Optional[Parameter]]]:
"""Linear layer with column parallelism.

Implemented multiple optimization projects for dense models, such as FlashComm and
communication-computation fusion.
"""

bias = self.bias if not self.skip_bias_add else None

# Matrix multiply.
assert self.quant_method is not None

layer_num = self.prefix.split('.')[2]

input_ = torch.ops.vllm.maybe_all_gather_and_maybe_unpad(
input_, layer_num != '0')
input_ = torch.ops.vllm.maybe_all_gather_and_maybe_unpad(input_, True)
output_parallel = self.quant_method.apply(self.layer, input_, bias)

if self.gather_output:
Expand All @@ -216,6 +188,7 @@ def apply_impl(
return output, output_bias



class MLPRowParallelOp(CustomRowParallelOp):

def __init__(self, layer):
Expand Down Expand Up @@ -365,10 +338,6 @@ def update_attrs(self):

class SequenceRowParallelOp(CustomRowParallelOp):

def __init__(self, layer, prefix):
super().__init__(layer)
self.prefix = prefix

def apply_impl(
self, input_: torch.Tensor
) -> Union[torch.Tensor, tuple[torch.Tensor, Optional[Parameter]]]:
Expand Down Expand Up @@ -405,50 +374,56 @@ def update_attrs(self):
self.reduce_results = self.layer.reduce_results


def get_column_parallel_op(
disable_tp, prefix, layer
) -> Tuple[Optional[Union[MLPColumnParallelOp, SequenceMergedColumnParallelOp,
SequenceQKVParallelOp]], int, int]:
if disable_tp:
return None, 0, 1
def _get_column_parallel_op(
prefix, layer
) -> Optional[Union[MLPColumnParallelOp, SequenceColumnParallelOp]]:
if mlp_tp_enable() and "gate_up_proj" in prefix:
return MLPColumnParallelOp(layer)
if enable_sp():
if "gate_up_proj" in prefix:
return SequenceColumnParallelOp(layer)
if "in_proj" in prefix:
return SequenceColumnParallelOp(layer, True)
if "qkv_proj" in prefix or "conv1d" in prefix:
return SequenceColumnParallelOp(layer, True)

custom_op: Optional[Union[
MLPColumnParallelOp,
SequenceMergedColumnParallelOp,
SequenceQKVParallelOp,
]] = None
if "gate_up_proj" in prefix and mlp_tp_enable():
custom_op = MLPColumnParallelOp(layer)
elif "gate_up_proj" in prefix and enable_sp():
custom_op = SequenceMergedColumnParallelOp(layer)
elif enable_sp():
custom_op = SequenceQKVParallelOp(layer, prefix)
return None

if custom_op is not None:
return custom_op, custom_op.tp_rank, custom_op.tp_size

return None, get_tp_group().rank_in_group, get_tp_group().world_size
def _get_row_parallel_op(
prefix, layer
) -> Optional[Union[MLPRowParallelOp, OProjRowParallelOp,
MatmulAllreduceRowParallelOp,
SequenceRowParallelOp]]:
if "down_proj" in prefix and mlp_tp_enable():
return MLPRowParallelOp(layer)
if "o_proj" in prefix and oproj_tp_enable():
return OProjRowParallelOp(layer)
if matmul_allreduce_enable():
return MatmulAllreduceRowParallelOp(layer)
if enable_sp():
if "o_proj" in prefix or "out_proj" in prefix:
return SequenceRowParallelOp(layer)

return None

def get_row_parallel_op(
disable_tp, prefix, layer
) -> Tuple[Optional[Union[MLPRowParallelOp, OProjRowParallelOp,
MatmulAllreduceRowParallelOp,
SequenceRowParallelOp]], int, int]:

def get_parallel_op(disable_tp, prefix, layer, direct):
if disable_tp:
return None, 0, 1

custom_op: Optional[Union[MLPRowParallelOp, OProjRowParallelOp,
MatmulAllreduceRowParallelOp,
SequenceRowParallelOp]] = None
if "down_proj" in prefix and mlp_tp_enable():
custom_op = MLPRowParallelOp(layer)
elif "o_proj" in prefix and oproj_tp_enable():
custom_op = OProjRowParallelOp(layer)
elif matmul_allreduce_enable():
custom_op = MatmulAllreduceRowParallelOp(layer)
elif enable_sp():
custom_op = SequenceRowParallelOp(layer, prefix)
custom_op: Optional[Union[
MLPColumnParallelOp,
SequenceColumnParallelOp,
MLPRowParallelOp,
OProjRowParallelOp,
MatmulAllreduceRowParallelOp,
SequenceRowParallelOp
]] = None
if direct == "row":
custom_op = _get_row_parallel_op(prefix, layer)

if direct == "column":
custom_op = _get_column_parallel_op(prefix, layer)

if custom_op is not None:
return custom_op, custom_op.tp_rank, custom_op.tp_size
Expand Down
5 changes: 4 additions & 1 deletion vllm_ascend/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -500,7 +500,9 @@ def register_ascend_customop(vllm_config: Optional[VllmConfig] = None):
from vllm_ascend.ops.activation import AscendQuickGELU, AscendSiluAndMul
from vllm_ascend.ops.common_fused_moe import (AscendFusedMoE,
AscendSharedFusedMoE)
from vllm_ascend.ops.layernorm import AscendQuantRMSNorm, AscendRMSNorm
from vllm_ascend.ops.layernorm import (AscendQuantRMSNorm,
AscendRMSNorm,
AscendGemmaRMSNorm)
from vllm_ascend.ops.linear import (AscendColumnParallelLinear,
AscendMergedColumnParallelLinear,
AscendQKVParallelLinear,
Expand All @@ -525,6 +527,7 @@ def register_ascend_customop(vllm_config: Optional[VllmConfig] = None):
"ParallelLMHead": AscendParallelLMHead,
"LogitsProcessor": AscendLogitsProcessor,
"RMSNorm": AscendRMSNorm,
"GemmaRMSNorm": AscendGemmaRMSNorm,
"FusedMoE": AscendFusedMoE,
"SharedFusedMoE": AscendSharedFusedMoE,
"MultiHeadLatentAttention": AscendMultiHeadLatentAttention,
Expand Down
Loading