3
3
4
4
Run `pytest tests/kernels/test_moe.py`.
5
5
"""
6
+
6
7
import pytest
7
8
import torch
9
+ from torch .nn import Parameter
10
+ from torch .nn import functional as F
8
11
from transformers import MixtralConfig
9
12
from transformers .models .mixtral .modeling_mixtral import MixtralSparseMoeBlock
10
13
37
40
@pytest .mark .parametrize ("topk" , TOP_KS )
38
41
@pytest .mark .parametrize ("ep_size" , EP_SIZE )
39
42
@pytest .mark .parametrize ("dtype" , [torch .float16 , torch .bfloat16 ])
43
+ @pytest .mark .parametrize ("padding" , [True , False ])
40
44
def test_fused_moe (
41
45
m : int ,
42
46
n : int ,
@@ -45,6 +49,7 @@ def test_fused_moe(
45
49
topk : int ,
46
50
ep_size : int ,
47
51
dtype : torch .dtype ,
52
+ padding : bool ,
48
53
):
49
54
a = torch .randn ((m , k ), device = "cuda" , dtype = dtype ) / 10
50
55
w1 = torch .randn ((e , 2 * n , k ), device = "cuda" , dtype = dtype ) / 10
@@ -65,16 +70,7 @@ def test_fused_moe(
65
70
else :
66
71
e_map = None
67
72
68
- triton_output = fused_moe (a ,
69
- w1 ,
70
- w2 ,
71
- score ,
72
- topk ,
73
- global_num_experts = e ,
74
- expert_map = e_map ,
75
- renormalize = False )
76
73
torch_output = torch_moe (a , w1 , w2 , score , topk , e_map )
77
- torch .testing .assert_close (triton_output , torch_output , atol = 2e-2 , rtol = 0 )
78
74
iterative_output = iterative_moe (a ,
79
75
w1 ,
80
76
w2 ,
@@ -83,6 +79,23 @@ def test_fused_moe(
83
79
global_num_experts = e ,
84
80
expert_map = e_map ,
85
81
renormalize = False )
82
+
83
+ # Pad the weight if moe padding is enabled
84
+ if padding :
85
+ w1 = F .pad (w1 , (0 , 128 ), "constant" , 0 )[..., 0 :- 128 ]
86
+ torch .cuda .empty_cache ()
87
+ w2 = F .pad (w2 , (0 , 128 ), "constant" , 0 )[..., 0 :- 128 ]
88
+ torch .cuda .empty_cache ()
89
+
90
+ triton_output = fused_moe (a ,
91
+ w1 ,
92
+ w2 ,
93
+ score ,
94
+ topk ,
95
+ global_num_experts = e ,
96
+ expert_map = e_map ,
97
+ renormalize = False )
98
+ torch .testing .assert_close (triton_output , torch_output , atol = 2e-2 , rtol = 0 )
86
99
torch .testing .assert_close (iterative_output ,
87
100
torch_output ,
88
101
atol = 2e-2 ,
@@ -202,8 +215,9 @@ def test_fused_moe_wn16(m: int, n: int, k: int, e: int, topk: int,
202
215
203
216
@pytest .mark .parametrize ("dtype" ,
204
217
[torch .float32 , torch .float16 , torch .bfloat16 ])
218
+ @pytest .mark .parametrize ("padding" , [True , False ])
205
219
@torch .inference_mode ()
206
- def test_mixtral_moe (dtype : torch .dtype ):
220
+ def test_mixtral_moe (dtype : torch .dtype , padding : bool ):
207
221
"""Make sure our Mixtral MoE implementation agrees with the one from
208
222
huggingface."""
209
223
@@ -233,6 +247,17 @@ def test_mixtral_moe(dtype: torch.dtype):
233
247
# vLLM uses 1D query [num_tokens, hidden_dim]
234
248
vllm_inputs = hf_inputs .flatten (0 , 1 )
235
249
250
+ # Pad the weight if moe padding is enabled
251
+ if padding :
252
+ vllm_moe .experts .w13_weight = Parameter (F .pad (
253
+ vllm_moe .experts .w13_weight , (0 , 128 ), "constant" , 0 )[..., 0 :- 128 ],
254
+ requires_grad = False )
255
+ torch .cuda .empty_cache ()
256
+ vllm_moe .experts .w2_weight = Parameter (F .pad (
257
+ vllm_moe .experts .w2_weight , (0 , 128 ), "constant" , 0 )[..., 0 :- 128 ],
258
+ requires_grad = False )
259
+ torch .cuda .empty_cache ()
260
+
236
261
# Run forward passes for both MoE blocks
237
262
hf_states , _ = hf_moe .forward (hf_inputs )
238
263
vllm_states = vllm_moe .forward (vllm_inputs )
0 commit comments