|
19 | 19 |
|
20 | 20 | import torch
|
21 | 21 | from vllm.forward_context import get_forward_context
|
22 |
| -from vllm.model_executor.layers.layernorm import RMSNorm |
| 22 | +from vllm.model_executor.layers.layernorm import GemmaRMSNorm, RMSNorm |
23 | 23 |
|
24 | 24 |
|
25 | 25 | def _addrmsnorm_forward_oot(
|
@@ -130,3 +130,30 @@ def forward_oot(
|
130 | 130 | x, residual = super().forward_oot(x, residual)
|
131 | 131 | return x.add_(self.bias), residual
|
132 | 132 | return cast(torch.Tensor, super().forward_oot(x)).add_(self.bias)
|
| 133 | + |
| 134 | + |
| 135 | +class AscendGemmaRMSNorm(GemmaRMSNorm): |
| 136 | + |
| 137 | + def forward_oot( |
| 138 | + self, |
| 139 | + x: torch.Tensor, |
| 140 | + residual: Optional[torch.Tensor] = None, |
| 141 | + ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]: |
| 142 | + import torch_npu |
| 143 | + |
| 144 | + from vllm_ascend.utils import is_310p |
| 145 | + if residual is not None: |
| 146 | + if is_310p(): |
| 147 | + orig_dtype = residual.dtype |
| 148 | + x = x + residual.to(x.dtype) |
| 149 | + residual = x.to(orig_dtype) |
| 150 | + x, _ = torch_npu.npu_rms_norm(x, 1.0 + self.weight, |
| 151 | + self.variance_epsilon) |
| 152 | + else: |
| 153 | + x, _, residual = torch_npu.npu_add_rms_norm( |
| 154 | + x, residual, 1.0 + self.weight, self.variance_epsilon) |
| 155 | + return x, residual |
| 156 | + |
| 157 | + x, _ = torch_npu.npu_rms_norm(x, 1.0 + self.weight, |
| 158 | + self.variance_epsilon) |
| 159 | + return x |
0 commit comments