Skip to content

Commit f533b58

Browse files
gshtrascharlifu
andauthored
[ROCm][Kernel] MoE weights padding (#14454)
Signed-off-by: Gregory Shtrasberg <[email protected]> Signed-off-by: charlifu <[email protected]> Co-authored-by: charlifu <[email protected]>
1 parent 8279201 commit f533b58

File tree

5 files changed

+65
-16
lines changed

5 files changed

+65
-16
lines changed

tests/kernels/test_moe.py

Lines changed: 35 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -3,8 +3,11 @@
33
44
Run `pytest tests/kernels/test_moe.py`.
55
"""
6+
67
import pytest
78
import torch
9+
from torch.nn import Parameter
10+
from torch.nn import functional as F
811
from transformers import MixtralConfig
912
from transformers.models.mixtral.modeling_mixtral import MixtralSparseMoeBlock
1013

@@ -37,6 +40,7 @@
3740
@pytest.mark.parametrize("topk", TOP_KS)
3841
@pytest.mark.parametrize("ep_size", EP_SIZE)
3942
@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16])
43+
@pytest.mark.parametrize("padding", [True, False])
4044
def test_fused_moe(
4145
m: int,
4246
n: int,
@@ -45,6 +49,7 @@ def test_fused_moe(
4549
topk: int,
4650
ep_size: int,
4751
dtype: torch.dtype,
52+
padding: bool,
4853
):
4954
a = torch.randn((m, k), device="cuda", dtype=dtype) / 10
5055
w1 = torch.randn((e, 2 * n, k), device="cuda", dtype=dtype) / 10
@@ -65,16 +70,7 @@ def test_fused_moe(
6570
else:
6671
e_map = None
6772

68-
triton_output = fused_moe(a,
69-
w1,
70-
w2,
71-
score,
72-
topk,
73-
global_num_experts=e,
74-
expert_map=e_map,
75-
renormalize=False)
7673
torch_output = torch_moe(a, w1, w2, score, topk, e_map)
77-
torch.testing.assert_close(triton_output, torch_output, atol=2e-2, rtol=0)
7874
iterative_output = iterative_moe(a,
7975
w1,
8076
w2,
@@ -83,6 +79,23 @@ def test_fused_moe(
8379
global_num_experts=e,
8480
expert_map=e_map,
8581
renormalize=False)
82+
83+
# Pad the weight if moe padding is enabled
84+
if padding:
85+
w1 = F.pad(w1, (0, 128), "constant", 0)[..., 0:-128]
86+
torch.cuda.empty_cache()
87+
w2 = F.pad(w2, (0, 128), "constant", 0)[..., 0:-128]
88+
torch.cuda.empty_cache()
89+
90+
triton_output = fused_moe(a,
91+
w1,
92+
w2,
93+
score,
94+
topk,
95+
global_num_experts=e,
96+
expert_map=e_map,
97+
renormalize=False)
98+
torch.testing.assert_close(triton_output, torch_output, atol=2e-2, rtol=0)
8699
torch.testing.assert_close(iterative_output,
87100
torch_output,
88101
atol=2e-2,
@@ -202,8 +215,9 @@ def test_fused_moe_wn16(m: int, n: int, k: int, e: int, topk: int,
202215

203216
@pytest.mark.parametrize("dtype",
204217
[torch.float32, torch.float16, torch.bfloat16])
218+
@pytest.mark.parametrize("padding", [True, False])
205219
@torch.inference_mode()
206-
def test_mixtral_moe(dtype: torch.dtype):
220+
def test_mixtral_moe(dtype: torch.dtype, padding: bool):
207221
"""Make sure our Mixtral MoE implementation agrees with the one from
208222
huggingface."""
209223

@@ -233,6 +247,17 @@ def test_mixtral_moe(dtype: torch.dtype):
233247
# vLLM uses 1D query [num_tokens, hidden_dim]
234248
vllm_inputs = hf_inputs.flatten(0, 1)
235249

250+
# Pad the weight if moe padding is enabled
251+
if padding:
252+
vllm_moe.experts.w13_weight = Parameter(F.pad(
253+
vllm_moe.experts.w13_weight, (0, 128), "constant", 0)[..., 0:-128],
254+
requires_grad=False)
255+
torch.cuda.empty_cache()
256+
vllm_moe.experts.w2_weight = Parameter(F.pad(
257+
vllm_moe.experts.w2_weight, (0, 128), "constant", 0)[..., 0:-128],
258+
requires_grad=False)
259+
torch.cuda.empty_cache()
260+
236261
# Run forward passes for both MoE blocks
237262
hf_states, _ = hf_moe.forward(hf_inputs)
238263
vllm_states = vllm_moe.forward(vllm_inputs)

vllm/envs.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -75,6 +75,7 @@
7575
VLLM_ROCM_USE_AITER: bool = False
7676
VLLM_ROCM_USE_AITER_RMSNORM: bool = True
7777
VLLM_ROCM_FP8_PADDING: bool = True
78+
VLLM_ROCM_MOE_PADDING: bool = True
7879
VLLM_ENABLE_V1_MULTIPROCESSING: bool = True
7980
VLLM_LOG_BATCHSIZE_INTERVAL: float = -1
8081
VLLM_DISABLE_COMPILE_CACHE: bool = False
@@ -520,6 +521,10 @@ def maybe_convert_int(value: Optional[str]) -> Optional[int]:
520521
"VLLM_ROCM_FP8_PADDING":
521522
lambda: bool(int(os.getenv("VLLM_ROCM_FP8_PADDING", "1"))),
522523

524+
# Pad the weights for the moe kernel
525+
"VLLM_ROCM_MOE_PADDING":
526+
lambda: bool(int(os.getenv("VLLM_ROCM_MOE_PADDING", "1"))),
527+
523528
# Divisor for dynamic query scale factor calculation for FP8 KV Cache
524529
"Q_SCALE_CONSTANT":
525530
lambda: int(os.getenv("Q_SCALE_CONSTANT", "200")),

vllm/model_executor/layers/fused_moe/fused_moe.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -800,7 +800,7 @@ def invoke_fused_moe_kernel(A: torch.Tensor,
800800
expert_ids,
801801
num_tokens_post_padded,
802802
B.shape[1],
803-
A.shape[1],
803+
B.shape[2],
804804
EM,
805805
topk_ids.numel(),
806806
A.stride(0),
@@ -1322,8 +1322,8 @@ def fused_experts_impl(hidden_states: torch.Tensor,
13221322

13231323
assert topk_weights.shape == topk_ids.shape, "topk shape mismatch"
13241324
assert hidden_states.is_contiguous(), "Hidden_states must be contiguous"
1325-
assert w1.is_contiguous(), "Expert weights1 must be contiguous"
1326-
assert w2.is_contiguous(), "Expert weights2 must be contiguous"
1325+
assert w1.stride(-1) == 1, "Stride of last dimension must be 1"
1326+
assert w2.stride(-1) == 1, "Stride of last dimension must be 1"
13271327
assert hidden_states.dtype in [
13281328
torch.float32, torch.float16, torch.bfloat16
13291329
]

vllm/model_executor/layers/fused_moe/layer.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
from typing import Callable, List, Optional, Tuple
66

77
import torch
8+
import torch.nn.functional as F
89
from torch.nn.parameter import UninitializedParameter
910

1011
from vllm import envs
@@ -96,9 +97,27 @@ def create_weights(self, layer: torch.nn.Module, num_experts: int,
9697
layer.register_parameter("w2_weight", w2_weight)
9798
set_weight_attrs(w2_weight, extra_weight_attrs)
9899

100+
def _maybe_pad_weight(self, weight: torch.Tensor) -> torch.Tensor:
101+
# Pad the weight tensor. This is an optimization on ROCm platform, which
102+
# can benefit from tensors located far enough from one another in memory
103+
if (envs.VLLM_ROCM_MOE_PADDING and current_platform.is_rocm()
104+
and weight.stride(-1) == 1
105+
and (weight.stride(-2) * weight.element_size()) % 512 == 0):
106+
num_pad = 256 // weight.element_size()
107+
weight = F.pad(weight, (0, num_pad), "constant", 0)[..., :-num_pad]
108+
torch.cuda.empty_cache()
109+
return weight
110+
99111
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
100112
super().process_weights_after_loading(layer)
101113

114+
layer.w13_weight = torch.nn.Parameter(self._maybe_pad_weight(
115+
layer.w13_weight.data),
116+
requires_grad=False)
117+
layer.w2_weight = torch.nn.Parameter(self._maybe_pad_weight(
118+
layer.w2_weight.data),
119+
requires_grad=False)
120+
102121
if current_platform.is_cpu():
103122
if current_platform.get_cpu_architecture() == CpuArchEnum.X86:
104123
import intel_extension_for_pytorch as ipex

vllm/model_executor/layers/quantization/fp8.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -255,7 +255,7 @@ def create_weights(
255255
else:
256256
layer.register_parameter("input_scale", None)
257257

258-
def add_padding_to_weight(self, weight: torch.Tensor) -> torch.Tensor:
258+
def _maybe_pad_weight(self, weight: torch.Tensor) -> torch.Tensor:
259259
# Pad the weight tensor. This is an optimization on ROCm platform, which
260260
# can benefit from tensors located far enough from one another in memory
261261
if (envs.VLLM_ROCM_FP8_PADDING and current_platform.is_rocm()
@@ -279,7 +279,7 @@ def process_weights_after_loading(self, layer: Module) -> None:
279279
weight = layer.weight.data
280280
weight_scale_inv = layer.weight_scale_inv.data
281281

282-
weight = self.add_padding_to_weight(weight)
282+
weight = self._maybe_pad_weight(weight)
283283

284284
# Torch.compile cannot use Parameter subclasses.
285285
layer.weight = Parameter(weight, requires_grad=False)
@@ -343,7 +343,7 @@ def process_weights_after_loading(self, layer: Module) -> None:
343343
logical_widths=layer.logical_widths,
344344
)
345345

346-
weight = self.add_padding_to_weight(weight)
346+
weight = self._maybe_pad_weight(weight)
347347
# Update layer with new values.
348348
layer.weight = Parameter(weight.t(), requires_grad=False)
349349
layer.weight_scale = Parameter(weight_scale, requires_grad=False)

0 commit comments

Comments
 (0)