Skip to content
Open
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 6 additions & 1 deletion vllm_ascend/models/qwen3_next.py
Original file line number Diff line number Diff line change
Expand Up @@ -618,14 +618,19 @@ def forward(
positions: torch.Tensor = None,
**kwargs: object,
):
self_attention_output = torch.empty_like(hidden_states)
if residual is None:
residual = hidden_states
hidden_states = self.input_layernorm(hidden_states)
forward_context = get_forward_context()
if forward_context.sp_enabled:
tp_size = get_tensor_model_parallel_world_size()
chunk_size = (hidden_states.shape[0] + forward_context.pad_size) // tp_size
self_attention_output = self_attention_output[:chunk_size, :]
else:
hidden_states, residual = self.input_layernorm(
hidden_states, residual)

self_attention_output = torch.empty_like(hidden_states)
if self.layer_type == "linear_attention":
self.linear_attn(
hidden_states=hidden_states,
Expand Down
49 changes: 48 additions & 1 deletion vllm_ascend/ops/layernorm.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@

import torch
from vllm.forward_context import get_forward_context
from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.layernorm import RMSNorm, GemmaRMSNorm


def _addrmsnorm_forward_oot(
Expand Down Expand Up @@ -130,3 +130,50 @@ def forward_oot(
x, residual = super().forward_oot(x, residual)
return x.add_(self.bias), residual
return cast(torch.Tensor, super().forward_oot(x)).add_(self.bias)


class AscendGemmaRMSNorm(GemmaRMSNorm):
"""RMS normalization for Gemma.

Two differences from the above RMSNorm:
1. x * (1 + w) instead of x * w.
2. (x * w).to(orig_dtype) instead of x.to(orig_dtype) * w.
"""

@staticmethod
def forward_static(
weight: torch.Tensor,
variance_epsilon: float,
x: torch.Tensor,
residual: Optional[torch.Tensor],
) -> Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]:
"""PyTorch-native implementation equivalent to forward()."""
orig_dtype = x.dtype
if residual is not None:
residual = torch.ops.vllm.maybe_chunk_residual(x, residual)
if orig_dtype == torch.float16:
x = x + residual.float()
else:
x = x + residual
residual = x

x = x.float()
variance = x.pow(2).mean(dim=-1, keepdim=True)
x = x * torch.rsqrt(variance + variance_epsilon)
# Llama does x.to(float16) * w whilst Gemma is (x * w).to(float16)
# See https://github.com/huggingface/transformers/pull/29402
x = x * (1.0 + weight.float())
x = x.to(orig_dtype)
return x if residual is None else (x, residual)

@staticmethod
def forward_oot(
self,
variance_epsilon: float,
x: torch.Tensor,
residual: Optional[torch.Tensor] = None,
) -> Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]:
"""PyTorch-native implementation equivalent to forward()."""
return self.forward_static(self.weight.data, self.variance_epsilon, x,
residual)
Comment on lines +169 to +178
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

The forward_oot method in AscendGemmaRMSNorm is incorrectly defined as a static method with a mismatched signature. It should be an instance method, and its signature should match the forward method it's overriding, which is (self, x, residual=None). The current implementation will lead to incorrect argument passing at runtime, as the variance_epsilon parameter in the signature will receive the x tensor, and the x parameter will receive the residual tensor. The variance_epsilon should be accessed from self.variance_epsilon instead of being a parameter.

Suggested change
@staticmethod
def forward_oot(
self,
variance_epsilon: float,
x: torch.Tensor,
residual: Optional[torch.Tensor] = None,
) -> Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]:
"""PyTorch-native implementation equivalent to forward()."""
return self.forward_static(self.weight.data, self.variance_epsilon, x,
residual)
def forward_oot(
self,
x: torch.Tensor,
residual: Optional[torch.Tensor] = None,
) -> Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]:
"""PyTorch-native implementation equivalent to forward()."""
return self.forward_static(self.weight.data, self.variance_epsilon, x,
residual)


19 changes: 9 additions & 10 deletions vllm_ascend/ops/linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,8 +34,7 @@
QuantizationConfig
from vllm.model_executor.utils import set_weight_attrs

from vllm_ascend.ops.linear_op import (get_column_parallel_op,
get_row_parallel_op)
from vllm_ascend.ops.linear_op import get_parallel_op


# TODO(realliujiaxu): Remove this class after linear of vllm supports custom comm group
Expand Down Expand Up @@ -100,8 +99,8 @@ def __init__(
return_bias: bool = True,
disable_tp: bool = False,
):
self.custom_op, _, tp_size = get_column_parallel_op(
disable_tp, prefix, self)
self.custom_op, _, tp_size = get_parallel_op(
disable_tp, prefix, self, "column")
# TODO(realliujiaxu): Replace the initialization code below with super().__init__ after linear of vllm supports custom comm group
self.hidden_size = hidden_size
self.head_size = head_size
Expand Down Expand Up @@ -173,8 +172,8 @@ def __init__(
return_bias: bool = True,
disable_tp: bool = False,
):
self.custom_op, self.tp_rank, self.tp_size = get_column_parallel_op(
disable_tp, prefix, self)
self.custom_op, self.tp_rank, self.tp_size = get_parallel_op(
disable_tp, prefix, self, "column")
# TODO(realliujiaxu): Replace the initialization code below with super().__init__ after linear of vllm supports custom comm group
self.output_sizes = output_sizes
assert all(output_size % self.tp_size == 0
Expand Down Expand Up @@ -222,8 +221,8 @@ def __init__(
return_bias: bool = True,
disable_tp: bool = False,
):
self.custom_op, self.tp_rank, self.tp_size = get_row_parallel_op(
disable_tp, prefix, self)
self.custom_op, self.tp_rank, self.tp_size = get_parallel_op(
disable_tp, prefix, self, "row")
# TODO(realliujiaxu): Replace the initialization code below with super().__init__ after linear of vllm supports custom comm group
# Divide the weight matrix along the first dimension.
self.input_size_per_partition = divide(input_size, self.tp_size)
Expand Down Expand Up @@ -304,8 +303,8 @@ def __init__(
return_bias: bool = True,
disable_tp: bool = False,
):
self.custom_op, self.tp_rank, self.tp_size = get_column_parallel_op(
disable_tp, prefix, self)
self.custom_op, self.tp_rank, self.tp_size = get_parallel_op(
disable_tp, prefix, self, "column")
# TODO(realliujiaxu): Replace the initialization code below with super().__init__ after linear of vllm supports custom comm group
self.input_size_per_partition = input_size
self.output_size_per_partition = divide(output_size, self.tp_size)
Expand Down
90 changes: 51 additions & 39 deletions vllm_ascend/ops/linear_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,9 +98,11 @@

class CustomColumnParallelOp(CustomTensorParallelOp):

def __init__(self, layer):
def __init__(self, layer, prefix, skip_first_layer=False):
super().__init__(layer)
self.gather_output = None
self.prefix = prefix
self.skip_first_layer = skip_first_layer

def update_attrs(self):
super().update_attrs()
Expand Down Expand Up @@ -133,7 +135,7 @@
class MLPColumnParallelOp(CustomColumnParallelOp):

def __init__(self, layer):
super().__init__(layer)

Check failure on line 138 in vllm_ascend/ops/linear_op.py

View workflow job for this annotation

GitHub Actions / lint / pre-commit

Missing positional argument "prefix" in call to "__init__" [call-arg]

@property
def comm_group(self):
Expand Down Expand Up @@ -169,7 +171,13 @@
# Matrix multiply.
assert self.quant_method is not None

input_ = torch.ops.vllm.maybe_all_gather_and_maybe_unpad(input_, True)
input_ = None
if self.skip_first_layer:
layer_num = self.prefix.split('.')[2]
input_ = torch.ops.vllm.maybe_all_gather_and_maybe_unpad(
input_, layer_num != '0')
else:
input_ = torch.ops.vllm.maybe_all_gather_and_maybe_unpad(input_, True)
output_parallel = self.quant_method.apply(self.layer, input_, bias)

if self.gather_output:
Expand All @@ -183,10 +191,6 @@

class SequenceQKVParallelOp(CustomColumnParallelOp):

def __init__(self, layer, prefix):
super().__init__(layer)
self.prefix = prefix

def apply_impl(
self, input_: torch.Tensor
) -> Union[torch.Tensor, tuple[torch.Tensor, Optional[Parameter]]]:
Expand Down Expand Up @@ -405,50 +409,58 @@
self.reduce_results = self.layer.reduce_results


def get_column_parallel_op(
def _get_column_parallel_op(
disable_tp, prefix, layer
) -> Tuple[Optional[Union[MLPColumnParallelOp, SequenceMergedColumnParallelOp,
SequenceQKVParallelOp]], int, int]:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

The signature of _get_column_parallel_op is (disable_tp, prefix, layer), but it is called as _get_column_parallel_op(prefix, layer) in get_parallel_op at line 463. This mismatch will cause a TypeError at runtime. The disable_tp parameter is not used within the function and should be removed from the signature to match the call site.

Suggested change
def _get_column_parallel_op(
disable_tp, prefix, layer
) -> Tuple[Optional[Union[MLPColumnParallelOp, SequenceMergedColumnParallelOp,
SequenceQKVParallelOp]], int, int]:
def _get_column_parallel_op(
prefix, layer
) -> Tuple[Optional[Union[MLPColumnParallelOp, SequenceMergedColumnParallelOp,
SequenceQKVParallelOp]], int, int]:

if disable_tp:
return None, 0, 1
if mlp_tp_enable() and "gate_up_proj" in prefix:
return MLPColumnParallelOp(layer)
if enable_sp():
if "gate_up_proj" in prefix:
return SequenceMergedColumnParallelOp(layer, prefix)
if "in_proj" in prefix:

Check failure on line 421 in vllm_ascend/ops/linear_op.py

View workflow job for this annotation

GitHub Actions / lint / pre-commit

Incompatible return value type (got "MLPColumnParallelOp", expected "tuple[Union[MLPColumnParallelOp, SequenceMergedColumnParallelOp, SequenceQKVParallelOp, None], int, int]") [return-value]
return SequenceMergedColumnParallelOp(layer, prefix, True)
if "qkv_proj" in prefix or "conv1d" in prefix:
return SequenceQKVParallelOp(layer, prefix, True)

Check failure on line 424 in vllm_ascend/ops/linear_op.py

View workflow job for this annotation

GitHub Actions / lint / pre-commit

Incompatible return value type (got "SequenceMergedColumnParallelOp", expected "tuple[Union[MLPColumnParallelOp, SequenceMergedColumnParallelOp, SequenceQKVParallelOp, None], int, int]") [return-value]

custom_op: Optional[Union[
MLPColumnParallelOp,
SequenceMergedColumnParallelOp,
SequenceQKVParallelOp,
]] = None
if "gate_up_proj" in prefix and mlp_tp_enable():
custom_op = MLPColumnParallelOp(layer)
elif "gate_up_proj" in prefix and enable_sp():
custom_op = SequenceMergedColumnParallelOp(layer)
elif enable_sp():
custom_op = SequenceQKVParallelOp(layer, prefix)
return None

Check failure on line 426 in vllm_ascend/ops/linear_op.py

View workflow job for this annotation

GitHub Actions / lint / pre-commit

Incompatible return value type (got "SequenceMergedColumnParallelOp", expected "tuple[Union[MLPColumnParallelOp, SequenceMergedColumnParallelOp, SequenceQKVParallelOp, None], int, int]") [return-value]

if custom_op is not None:
return custom_op, custom_op.tp_rank, custom_op.tp_size

Check failure on line 428 in vllm_ascend/ops/linear_op.py

View workflow job for this annotation

GitHub Actions / lint / pre-commit

Incompatible return value type (got "SequenceQKVParallelOp", expected "tuple[Union[MLPColumnParallelOp, SequenceMergedColumnParallelOp, SequenceQKVParallelOp, None], int, int]") [return-value]
return None, get_tp_group().rank_in_group, get_tp_group().world_size


def get_row_parallel_op(
disable_tp, prefix, layer
def _get_row_parallel_op(
prefix, layer

Check failure on line 430 in vllm_ascend/ops/linear_op.py

View workflow job for this annotation

GitHub Actions / lint / pre-commit

Incompatible return value type (got "None", expected "tuple[Union[MLPColumnParallelOp, SequenceMergedColumnParallelOp, SequenceQKVParallelOp, None], int, int]") [return-value]
) -> Tuple[Optional[Union[MLPRowParallelOp, OProjRowParallelOp,
MatmulAllreduceRowParallelOp,
SequenceRowParallelOp]], int, int]:
if "down_proj" in prefix and mlp_tp_enable():
return MLPRowParallelOp(layer)
if "o_proj" in prefix and oproj_tp_enable():
return OProjRowParallelOp(layer)
if matmul_allreduce_enable():
return MatmulAllreduceRowParallelOp(layer)

Check failure on line 439 in vllm_ascend/ops/linear_op.py

View workflow job for this annotation

GitHub Actions / lint / pre-commit

Incompatible return value type (got "MLPRowParallelOp", expected "tuple[Union[MLPRowParallelOp, OProjRowParallelOp, MatmulAllreduceRowParallelOp, SequenceRowParallelOp, None], int, int]") [return-value]
if enable_sp():
if "o_proj" in prefix or "out_proj" in prefix:

Check failure on line 441 in vllm_ascend/ops/linear_op.py

View workflow job for this annotation

GitHub Actions / lint / pre-commit

Incompatible return value type (got "OProjRowParallelOp", expected "tuple[Union[MLPRowParallelOp, OProjRowParallelOp, MatmulAllreduceRowParallelOp, SequenceRowParallelOp, None], int, int]") [return-value]
return SequenceRowParallelOp(layer, prefix)

Check failure on line 443 in vllm_ascend/ops/linear_op.py

View workflow job for this annotation

GitHub Actions / lint / pre-commit

Incompatible return value type (got "MatmulAllreduceRowParallelOp", expected "tuple[Union[MLPRowParallelOp, OProjRowParallelOp, MatmulAllreduceRowParallelOp, SequenceRowParallelOp, None], int, int]") [return-value]
return None


Check failure on line 446 in vllm_ascend/ops/linear_op.py

View workflow job for this annotation

GitHub Actions / lint / pre-commit

Incompatible return value type (got "SequenceRowParallelOp", expected "tuple[Union[MLPRowParallelOp, OProjRowParallelOp, MatmulAllreduceRowParallelOp, SequenceRowParallelOp, None], int, int]") [return-value]
def get_parallel_op(disable_tp, prefix, layer, direct):
if disable_tp:
return None, 0, 1

custom_op: Optional[Union[MLPRowParallelOp, OProjRowParallelOp,
MatmulAllreduceRowParallelOp,
SequenceRowParallelOp]] = None
if "down_proj" in prefix and mlp_tp_enable():
custom_op = MLPRowParallelOp(layer)
elif "o_proj" in prefix and oproj_tp_enable():
custom_op = OProjRowParallelOp(layer)
elif matmul_allreduce_enable():
custom_op = MatmulAllreduceRowParallelOp(layer)
elif enable_sp():
custom_op = SequenceRowParallelOp(layer, prefix)
custom_op: Optional[Union[
MLPColumnParallelOp,
SequenceMergedColumnParallelOp,
SequenceQKVParallelOp,
MLPRowParallelOp,
OProjRowParallelOp,
MatmulAllreduceRowParallelOp,
SequenceRowParallelOp
]] = None
if direct == "row":
custom_op = _get_row_parallel_op(prefix, layer)

if direct == "column":
custom_op = _get_column_parallel_op(prefix, layer)

if custom_op is not None:
return custom_op, custom_op.tp_rank, custom_op.tp_size
Expand Down
5 changes: 4 additions & 1 deletion vllm_ascend/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -500,7 +500,9 @@ def register_ascend_customop(vllm_config: Optional[VllmConfig] = None):
from vllm_ascend.ops.activation import AscendQuickGELU, AscendSiluAndMul
from vllm_ascend.ops.common_fused_moe import (AscendFusedMoE,
AscendSharedFusedMoE)
from vllm_ascend.ops.layernorm import AscendQuantRMSNorm, AscendRMSNorm
from vllm_ascend.ops.layernorm import (AscendQuantRMSNorm,
AscendRMSNorm,
AscendGemmaRMSNorm)
from vllm_ascend.ops.linear import (AscendColumnParallelLinear,
AscendMergedColumnParallelLinear,
AscendQKVParallelLinear,
Expand All @@ -525,6 +527,7 @@ def register_ascend_customop(vllm_config: Optional[VllmConfig] = None):
"ParallelLMHead": AscendParallelLMHead,
"LogitsProcessor": AscendLogitsProcessor,
"RMSNorm": AscendRMSNorm,
"GemmaRMSNorm": AscendGemmaRMSNorm,
"FusedMoE": AscendFusedMoE,
"SharedFusedMoE": AscendSharedFusedMoE,
"MultiHeadLatentAttention": AscendMultiHeadLatentAttention,
Expand Down
Loading