Skip to content

Commit e9899fb

Browse files
authored
[Model] Enable FP8 QKV in MoE and refine kernel tuning script (#5039)
1 parent a377f0b commit e9899fb

8 files changed

+711
-114
lines changed

benchmarks/kernels/benchmark_mixtral_moe.py

Lines changed: 36 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -11,25 +11,36 @@
1111
from vllm.model_executor.layers.fused_moe import (fused_moe,
1212
get_config_file_name)
1313

14-
os.environ['CUDA_VISIBLE_DEVICES'] = '0'
1514

16-
17-
def main(dtype: str):
15+
def main(model, tp_size, gpu, dtype: str):
16+
os.environ['CUDA_VISIBLE_DEVICES'] = str(gpu)
1817
method = fused_moe
1918
for bs in [
2019
1, 2, 4, 8, 16, 24, 32, 48, 64, 96, 128, 256, 512, 1024, 1536,
2120
2048, 3072, 4096
2221
]:
23-
run_grid(bs, method=method, dtype=dtype)
24-
25-
26-
def run_grid(bs, method, dtype: str):
27-
d_model = 4096
22+
run_grid(bs,
23+
model=model,
24+
method=method,
25+
gpu=gpu,
26+
tp_size=tp_size,
27+
dtype=dtype)
28+
29+
30+
def run_grid(bs, model, method, gpu, tp_size, dtype: str):
31+
if model == '8x7B':
32+
d_model = 4096
33+
model_intermediate_size = 14336
34+
num_layers = 32
35+
elif model == '8x22B':
36+
d_model = 6144
37+
model_intermediate_size = 16384
38+
num_layers = 56
39+
else:
40+
raise ValueError(f'Unsupported Mixtral model {model}')
2841
num_total_experts = 8
2942
top_k = 2
30-
tp_size = 2
31-
model_intermediate_size = 14336
32-
num_layers = 32
43+
# tp_size = 2
3344
num_calls = 100
3445

3546
num_warmup_trials = 1
@@ -211,5 +222,18 @@ def run_timing(num_calls: int, bs: int, d_model: int, num_total_experts: int,
211222
choices=['float8', 'float16'],
212223
help='Data type used for fused_moe kernel computations',
213224
)
225+
parser.add_argument('--model',
226+
type=str,
227+
default='8x7B',
228+
choices=['8x7B', '8x22B'],
229+
help='The Mixtral model to benchmark')
230+
parser.add_argument('--tp-size',
231+
type=int,
232+
default=2,
233+
help='Tensor paralleli size')
234+
parser.add_argument('--gpu',
235+
type=int,
236+
default=0,
237+
help="GPU ID for benchmarking")
214238
args = parser.parse_args()
215-
sys.exit(main(args.dtype))
239+
sys.exit(main(args.model, args.tp_size, args.gpu, args.dtype))
Lines changed: 138 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,138 @@
1+
{
2+
"1": {
3+
"BLOCK_SIZE_M": 64,
4+
"BLOCK_SIZE_N": 64,
5+
"BLOCK_SIZE_K": 128,
6+
"GROUP_SIZE_M": 64,
7+
"num_warps": 8,
8+
"num_stages": 5
9+
},
10+
"2": {
11+
"BLOCK_SIZE_M": 64,
12+
"BLOCK_SIZE_N": 128,
13+
"BLOCK_SIZE_K": 256,
14+
"GROUP_SIZE_M": 64,
15+
"num_warps": 4,
16+
"num_stages": 3
17+
},
18+
"4": {
19+
"BLOCK_SIZE_M": 64,
20+
"BLOCK_SIZE_N": 128,
21+
"BLOCK_SIZE_K": 256,
22+
"GROUP_SIZE_M": 1,
23+
"num_warps": 4,
24+
"num_stages": 4
25+
},
26+
"8": {
27+
"BLOCK_SIZE_M": 64,
28+
"BLOCK_SIZE_N": 128,
29+
"BLOCK_SIZE_K": 256,
30+
"GROUP_SIZE_M": 32,
31+
"num_warps": 4,
32+
"num_stages": 4
33+
},
34+
"16": {
35+
"BLOCK_SIZE_M": 64,
36+
"BLOCK_SIZE_N": 64,
37+
"BLOCK_SIZE_K": 128,
38+
"GROUP_SIZE_M": 1,
39+
"num_warps": 4,
40+
"num_stages": 3
41+
},
42+
"24": {
43+
"BLOCK_SIZE_M": 64,
44+
"BLOCK_SIZE_N": 128,
45+
"BLOCK_SIZE_K": 256,
46+
"GROUP_SIZE_M": 1,
47+
"num_warps": 4,
48+
"num_stages": 4
49+
},
50+
"32": {
51+
"BLOCK_SIZE_M": 64,
52+
"BLOCK_SIZE_N": 128,
53+
"BLOCK_SIZE_K": 256,
54+
"GROUP_SIZE_M": 1,
55+
"num_warps": 4,
56+
"num_stages": 4
57+
},
58+
"48": {
59+
"BLOCK_SIZE_M": 64,
60+
"BLOCK_SIZE_N": 128,
61+
"BLOCK_SIZE_K": 256,
62+
"GROUP_SIZE_M": 1,
63+
"num_warps": 4,
64+
"num_stages": 4
65+
},
66+
"64": {
67+
"BLOCK_SIZE_M": 64,
68+
"BLOCK_SIZE_N": 128,
69+
"BLOCK_SIZE_K": 256,
70+
"GROUP_SIZE_M": 1,
71+
"num_warps": 4,
72+
"num_stages": 4
73+
},
74+
"96": {
75+
"BLOCK_SIZE_M": 64,
76+
"BLOCK_SIZE_N": 128,
77+
"BLOCK_SIZE_K": 256,
78+
"GROUP_SIZE_M": 1,
79+
"num_warps": 4,
80+
"num_stages": 2
81+
},
82+
"128": {
83+
"BLOCK_SIZE_M": 64,
84+
"BLOCK_SIZE_N": 128,
85+
"BLOCK_SIZE_K": 256,
86+
"GROUP_SIZE_M": 1,
87+
"num_warps": 4,
88+
"num_stages": 2
89+
},
90+
"256": {
91+
"BLOCK_SIZE_M": 128,
92+
"BLOCK_SIZE_N": 128,
93+
"BLOCK_SIZE_K": 128,
94+
"GROUP_SIZE_M": 1,
95+
"num_warps": 8,
96+
"num_stages": 3
97+
},
98+
"512": {
99+
"BLOCK_SIZE_M": 128,
100+
"BLOCK_SIZE_N": 256,
101+
"BLOCK_SIZE_K": 128,
102+
"GROUP_SIZE_M": 64,
103+
"num_warps": 8,
104+
"num_stages": 4
105+
},
106+
"1024": {
107+
"BLOCK_SIZE_M": 128,
108+
"BLOCK_SIZE_N": 256,
109+
"BLOCK_SIZE_K": 128,
110+
"GROUP_SIZE_M": 64,
111+
"num_warps": 8,
112+
"num_stages": 4
113+
},
114+
"1536": {
115+
"BLOCK_SIZE_M": 128,
116+
"BLOCK_SIZE_N": 256,
117+
"BLOCK_SIZE_K": 128,
118+
"GROUP_SIZE_M": 64,
119+
"num_warps": 8,
120+
"num_stages": 3
121+
},
122+
"2048": {
123+
"BLOCK_SIZE_M": 128,
124+
"BLOCK_SIZE_N": 256,
125+
"BLOCK_SIZE_K": 128,
126+
"GROUP_SIZE_M": 64,
127+
"num_warps": 8,
128+
"num_stages": 3
129+
},
130+
"3072": {
131+
"BLOCK_SIZE_M": 128,
132+
"BLOCK_SIZE_N": 256,
133+
"BLOCK_SIZE_K": 128,
134+
"GROUP_SIZE_M": 32,
135+
"num_warps": 8,
136+
"num_stages": 3
137+
}
138+
}
Lines changed: 146 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,146 @@
1+
{
2+
"1": {
3+
"BLOCK_SIZE_M": 16,
4+
"BLOCK_SIZE_N": 32,
5+
"BLOCK_SIZE_K": 64,
6+
"GROUP_SIZE_M": 1,
7+
"num_warps": 8,
8+
"num_stages": 3
9+
},
10+
"2": {
11+
"BLOCK_SIZE_M": 16,
12+
"BLOCK_SIZE_N": 64,
13+
"BLOCK_SIZE_K": 64,
14+
"GROUP_SIZE_M": 16,
15+
"num_warps": 4,
16+
"num_stages": 5
17+
},
18+
"4": {
19+
"BLOCK_SIZE_M": 64,
20+
"BLOCK_SIZE_N": 32,
21+
"BLOCK_SIZE_K": 256,
22+
"GROUP_SIZE_M": 64,
23+
"num_warps": 8,
24+
"num_stages": 5
25+
},
26+
"8": {
27+
"BLOCK_SIZE_M": 16,
28+
"BLOCK_SIZE_N": 64,
29+
"BLOCK_SIZE_K": 256,
30+
"GROUP_SIZE_M": 1,
31+
"num_warps": 4,
32+
"num_stages": 3
33+
},
34+
"16": {
35+
"BLOCK_SIZE_M": 128,
36+
"BLOCK_SIZE_N": 64,
37+
"BLOCK_SIZE_K": 64,
38+
"GROUP_SIZE_M": 1,
39+
"num_warps": 8,
40+
"num_stages": 3
41+
},
42+
"24": {
43+
"BLOCK_SIZE_M": 64,
44+
"BLOCK_SIZE_N": 256,
45+
"BLOCK_SIZE_K": 128,
46+
"GROUP_SIZE_M": 64,
47+
"num_warps": 8,
48+
"num_stages": 2
49+
},
50+
"32": {
51+
"BLOCK_SIZE_M": 64,
52+
"BLOCK_SIZE_N": 256,
53+
"BLOCK_SIZE_K": 64,
54+
"GROUP_SIZE_M": 1,
55+
"num_warps": 8,
56+
"num_stages": 3
57+
},
58+
"48": {
59+
"BLOCK_SIZE_M": 64,
60+
"BLOCK_SIZE_N": 256,
61+
"BLOCK_SIZE_K": 128,
62+
"GROUP_SIZE_M": 64,
63+
"num_warps": 8,
64+
"num_stages": 4
65+
},
66+
"64": {
67+
"BLOCK_SIZE_M": 128,
68+
"BLOCK_SIZE_N": 256,
69+
"BLOCK_SIZE_K": 128,
70+
"GROUP_SIZE_M": 64,
71+
"num_warps": 8,
72+
"num_stages": 2
73+
},
74+
"96": {
75+
"BLOCK_SIZE_M": 64,
76+
"BLOCK_SIZE_N": 256,
77+
"BLOCK_SIZE_K": 128,
78+
"GROUP_SIZE_M": 16,
79+
"num_warps": 4,
80+
"num_stages": 4
81+
},
82+
"128": {
83+
"BLOCK_SIZE_M": 64,
84+
"BLOCK_SIZE_N": 256,
85+
"BLOCK_SIZE_K": 64,
86+
"GROUP_SIZE_M": 32,
87+
"num_warps": 8,
88+
"num_stages": 3
89+
},
90+
"256": {
91+
"BLOCK_SIZE_M": 128,
92+
"BLOCK_SIZE_N": 64,
93+
"BLOCK_SIZE_K": 128,
94+
"GROUP_SIZE_M": 64,
95+
"num_warps": 8,
96+
"num_stages": 4
97+
},
98+
"512": {
99+
"BLOCK_SIZE_M": 64,
100+
"BLOCK_SIZE_N": 256,
101+
"BLOCK_SIZE_K": 128,
102+
"GROUP_SIZE_M": 64,
103+
"num_warps": 4,
104+
"num_stages": 3
105+
},
106+
"1024": {
107+
"BLOCK_SIZE_M": 128,
108+
"BLOCK_SIZE_N": 256,
109+
"BLOCK_SIZE_K": 128,
110+
"GROUP_SIZE_M": 32,
111+
"num_warps": 8,
112+
"num_stages": 4
113+
},
114+
"1536": {
115+
"BLOCK_SIZE_M": 128,
116+
"BLOCK_SIZE_N": 256,
117+
"BLOCK_SIZE_K": 128,
118+
"GROUP_SIZE_M": 64,
119+
"num_warps": 8,
120+
"num_stages": 4
121+
},
122+
"2048": {
123+
"BLOCK_SIZE_M": 128,
124+
"BLOCK_SIZE_N": 256,
125+
"BLOCK_SIZE_K": 128,
126+
"GROUP_SIZE_M": 64,
127+
"num_warps": 8,
128+
"num_stages": 4
129+
},
130+
"3072": {
131+
"BLOCK_SIZE_M": 128,
132+
"BLOCK_SIZE_N": 256,
133+
"BLOCK_SIZE_K": 128,
134+
"GROUP_SIZE_M": 32,
135+
"num_warps": 8,
136+
"num_stages": 4
137+
},
138+
"4096": {
139+
"BLOCK_SIZE_M": 128,
140+
"BLOCK_SIZE_N": 256,
141+
"BLOCK_SIZE_K": 128,
142+
"GROUP_SIZE_M": 32,
143+
"num_warps": 8,
144+
"num_stages": 4
145+
}
146+
}

0 commit comments

Comments
 (0)