Skip to content

Commit 41f1cf3

Browse files
authored
[Feature][OCP MX] Support mxfp6 and mixed mxfp6-mxfp4 (#21166)
1 parent 08d26a1 commit 41f1cf3

File tree

18 files changed

+658
-182
lines changed

18 files changed

+658
-182
lines changed

docs/features/quantization/quark.md

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -231,27 +231,31 @@ python3 quantize_quark.py --model_dir meta-llama/Llama-2-70b-chat-hf \
231231
--tasks gsm8k
232232
```
233233

234-
## Using MXFP4 models
234+
## Using OCP MX (MXFP4, MXFP6) models
235235

236-
vLLM supports loading MXFP4 models quantized offline through AMD Quark, compliant with [Open Compute Project (OCP) specification](https://www.opencompute.org/documents/ocp-microscaling-formats-mx-v1-0-spec-final-pdf).
236+
vLLM supports loading MXFP4 and MXFP6 models quantized offline through AMD Quark, compliant with [Open Compute Project (OCP) specification](https://www.opencompute.org/documents/ocp-microscaling-formats-mx-v1-0-spec-final-pdf).
237237

238238
The scheme currently only supports dynamic quantization for activations.
239239

240240
Example usage, after installing the latest AMD Quark release:
241241

242242
```bash
243243
vllm serve fxmarty/qwen_1.5-moe-a2.7b-mxfp4 --tensor-parallel-size 1
244+
# or, for a model using fp6 activations and fp4 weights:
245+
vllm serve fxmarty/qwen1.5_moe_a2.7b_chat_w_fp4_a_fp6_e2m3 --tensor-parallel-size 1
244246
```
245247

246-
A simulation of the matrix multiplication execution in MXFP4 can be run on devices that do not support MXFP4 operations natively (e.g. AMD Instinct MI325, MI300 and MI250), dequantizing weights from MXFP4 to half precision on the fly, using a fused kernel. This is useful e.g. to evaluate MXFP4 models using vLLM, or alternatively to benefit from the ~4x memory savings (compared to float16 and bfloat16).
248+
A simulation of the matrix multiplication execution in MXFP4/MXFP6 can be run on devices that do not support OCP MX operations natively (e.g. AMD Instinct MI325, MI300 and MI250), dequantizing weights from FP4/FP6 to half precision on the fly, using a fused kernel. This is useful e.g. to evaluate FP4/FP6 models using vLLM, or alternatively to benefit from the ~2.5-4x memory savings (compared to float16 and bfloat16).
247249

248250
To generate offline models quantized using MXFP4 data type, the easiest approach is to use AMD Quark's [quantization script](https://quark.docs.amd.com/latest/pytorch/example_quark_torch_llm_ptq.html), as an example:
249251

250252
```bash
251253
python quantize_quark.py --model_dir Qwen/Qwen1.5-MoE-A2.7B-Chat \
252-
--quant_scheme w_mxfp4_a_mxfp4_sym \
254+
--quant_scheme w_mxfp4_a_mxfp4 \
253255
--output_dir qwen_1.5-moe-a2.7b-mxfp4 \
254256
--skip_evaluation \
255257
--model_export hf_format \
256258
--group_size 32
257259
```
260+
261+
The current integration supports [all combination of FP4, FP6_E3M2, FP6_E2M3](https://github.com/vllm-project/vllm/blob/main/vllm/model_executor/layers/quantization/utils/ocp_mx_utils.py) used for either weights or activations. Eventually, some target hardware support mixed precision GEMM, as AMD Instinct MI350/MI355, for example using FP6 for activations and FP4 for weights.

tests/kernels/moe/test_mxfp4_moe.py renamed to tests/kernels/moe/test_ocp_mx_moe.py

Lines changed: 23 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -10,13 +10,6 @@
1010
import torch
1111
from packaging import version
1212

13-
from vllm.model_executor.layers.quantization.quark.quark import (
14-
QuarkLinearMethod,
15-
QuarkW4A4MXFP4,
16-
)
17-
from vllm.model_executor.layers.quantization.quark.quark_moe import (
18-
QuarkW4A4MXFp4MoEMethod,
19-
)
2013
from vllm.platforms import current_platform
2114
from vllm.utils.flashinfer import has_flashinfer
2215

@@ -63,9 +56,11 @@ def enable_pickle(monkeypatch):
6356
@pytest.mark.parametrize(
6457
"model_case",
6558
[
66-
ModelCase("fxmarty/qwen_1.5-moe-a2.7b-mxfp4", tp=1),
59+
ModelCase("fxmarty/qwen_1.5-moe-a2.7b-mxfp4", tp=2),
6760
ModelCase("fxmarty/deepseek_r1_3_layers_mxfp4", tp=8),
6861
ModelCase("fxmarty/Llama-4-Scout-17B-16E-Instruct-2-layers-mxfp4", tp=1),
62+
ModelCase("fxmarty/Llama-3.1-70B-Instruct-2-layers-mxfp6", tp=1),
63+
ModelCase("fxmarty/Llama-3.1-70B-Instruct-2-layers-mxfp6", tp=4),
6964
],
7065
)
7166
@pytest.mark.skipif(not QUARK_MXFP4_AVAILABLE, reason="amd-quark>=0.9 is not available")
@@ -76,22 +71,33 @@ def test_mxfp4_loading_and_execution_moe(vllm_runner, model_case: ModelCase):
7671
f"{torch.cuda.device_count()}"
7772
)
7873

74+
# `cuda_graph_sizes=[16]` to reduce load time.
7975
with vllm_runner(
80-
model_case.model_id, tensor_parallel_size=model_case.tp, load_format="dummy"
76+
model_case.model_id,
77+
tensor_parallel_size=model_case.tp,
78+
load_format="dummy",
79+
cuda_graph_sizes=[16],
8180
) as llm:
81+
# Disabled as check_model is broken: https://github.com/vllm-project/vllm/pull/18465#issuecomment-3329880562
82+
# def check_model(model):
83+
# from vllm.model_executor.layers.quantization.quark.quark import ( # noqa: E501
84+
# QuarkLinearMethod)
85+
# from vllm.model_executor.layers.quantization.quark.schemes.quark_ocp_mx import QuarkOCP_MX # noqa: E501
86+
# from vllm.model_executor.layers.quantization.quark.quark_moe import ( # noqa: E501
87+
# QuarkOCP_MX_MoEMethod)
8288

83-
def check_model(model):
84-
layer = model.model.layers[0]
89+
# layer = model.model.layers[0]
8590

86-
qkv_proj = layer.self_attn.qkv_proj
91+
# qkv_proj = layer.self_attn.qkv_proj
8792

88-
assert isinstance(qkv_proj.quant_method, QuarkLinearMethod)
89-
assert isinstance(qkv_proj.scheme, QuarkW4A4MXFP4)
93+
# assert isinstance(qkv_proj.quant_method, QuarkLinearMethod)
94+
# assert isinstance(qkv_proj.scheme, QuarkOCP_MX)
9095

91-
assert isinstance(layer.mlp.experts.quant_method, QuarkW4A4MXFp4MoEMethod)
96+
# assert isinstance(layer.mlp.experts.quant_method,
97+
# QuarkOCP_MX_MoEMethod)
9298

93-
if model_case.model_id == "fxmarty/qwen_1.5-moe-a2.7b-mxfp4":
94-
llm.apply_model(check_model)
99+
# if model_case.model_id == "fxmarty/qwen_1.5-moe-a2.7b-mxfp4":
100+
# llm.apply_model(check_model)
95101

96102
output = llm.generate_greedy("Today I am in the French Alps and", max_tokens=20)
97103
assert output

tests/quantization/test_quark.py

Lines changed: 74 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
import os
1212
from dataclasses import dataclass
1313
from importlib.util import find_spec
14+
from typing import Optional
1415

1516
import huggingface_hub
1617
import lm_eval
@@ -148,39 +149,93 @@ def get_state_dict(model):
148149

149150

150151
@dataclass
151-
class ModelCase:
152-
model_id: str
153-
tp: int
154-
155-
156-
@dataclass
157-
class GSM8KAccuracyTestConfig:
152+
class AccuracyTestConfig:
158153
model_name: str
159154
excepted_value: float
160155

161-
def get_model_args(self) -> str:
162-
return (
163-
f"pretrained={self.model_name},"
164-
"dtype=auto,add_bos_token=True,tensor_parallel_size=8,gpu_memory_utilization=0.7,max_model_len=38768"
165-
)
166-
167-
168-
ACCURACY_CONFIGS = [
156+
def get_model_args(
157+
self,
158+
tp_size: int,
159+
model_max_len: Optional[int] = None,
160+
kwargs: Optional[dict] = None,
161+
) -> dict:
162+
if kwargs is None:
163+
kwargs = {}
164+
165+
model_args = {
166+
"pretrained": self.model_name,
167+
"dtype": "auto",
168+
"add_bos_token": True,
169+
"tensor_parallel_size": tp_size,
170+
"gpu_memory_utilization": 0.7,
171+
**kwargs,
172+
}
173+
if model_max_len is not None:
174+
model_args["max_model_len"] = model_max_len
175+
176+
return model_args
177+
178+
179+
GSM8K_ACCURACY_CONFIGS = [
169180
# Private model.
170-
GSM8KAccuracyTestConfig(
181+
AccuracyTestConfig(
171182
model_name="amd/DeepSeek-R1-WMXFP4-AMXFP4-Scale-UINT8-MoE-Quant",
172183
excepted_value=0.96,
173184
),
174185
]
175186

187+
WIKITEXT_ACCURACY_CONFIGS = [
188+
AccuracyTestConfig(
189+
model_name="fxmarty/qwen1.5_moe_a2.7b_chat_w_fp4_a_fp6_e2m3",
190+
excepted_value=11.3,
191+
),
192+
AccuracyTestConfig(
193+
model_name="fxmarty/qwen1.5_moe_a2.7b_chat_w_fp6_e3m2_a_fp6_e3m2",
194+
excepted_value=10.6,
195+
),
196+
AccuracyTestConfig(
197+
model_name="fxmarty/qwen_1.5-moe-a2.7b-mxfp4", excepted_value=12.4
198+
),
199+
]
200+
201+
202+
@pytest.mark.skipif(not QUARK_MXFP4_AVAILABLE, reason="amd-quark>=0.9 is not available")
203+
@pytest.mark.parametrize("config", WIKITEXT_ACCURACY_CONFIGS)
204+
@pytest.mark.parametrize("tp_size", [1, 2])
205+
def test_ocp_mx_wikitext_correctness(config: AccuracyTestConfig, tp_size: int):
206+
if torch.cuda.device_count() < tp_size:
207+
pytest.skip(
208+
f"This test requires >={tp_size} gpus, got only {torch.cuda.device_count()}"
209+
)
210+
211+
task = "wikitext"
212+
rtol = 0.1
213+
214+
# Smaller cuda_graph_sizes to speed up the test.
215+
results = lm_eval.simple_evaluate(
216+
model="vllm",
217+
model_args=config.get_model_args(
218+
tp_size=tp_size, kwargs={"cuda_graph_sizes": [16]}
219+
),
220+
tasks=task,
221+
batch_size=64,
222+
)
223+
224+
EXPECTED_VALUE = config.excepted_value
225+
measured_value = results["results"][task]["word_perplexity,none"]
226+
assert (
227+
measured_value < EXPECTED_VALUE + rtol
228+
and measured_value > EXPECTED_VALUE - rtol
229+
), f"Expected: {EXPECTED_VALUE} | Measured: {measured_value}"
230+
176231

177-
@pytest.mark.parametrize("config", ACCURACY_CONFIGS)
232+
@pytest.mark.parametrize("config", GSM8K_ACCURACY_CONFIGS)
178233
@pytest.mark.skipif(not QUARK_MXFP4_AVAILABLE, reason="amd-quark>=0.9 is not available")
179234
@pytest.mark.skipif(
180235
not HF_HUB_AMD_ORG_ACCESS,
181236
reason="Read access to huggingface.co/amd is required for this test.",
182237
)
183-
def test_mxfp4_gsm8k_correctness(config: GSM8KAccuracyTestConfig):
238+
def test_mxfp4_gsm8k_correctness(config: AccuracyTestConfig):
184239
if torch.cuda.device_count() < 8:
185240
pytest.skip(
186241
f"This test requires >=8 gpus, got only {torch.cuda.device_count()}"
@@ -193,7 +248,7 @@ def test_mxfp4_gsm8k_correctness(config: GSM8KAccuracyTestConfig):
193248

194249
results = lm_eval.simple_evaluate(
195250
model="vllm",
196-
model_args=config.get_model_args(),
251+
model_args=config.get_model_args(tp_size=8, model_max_len=38768),
197252
tasks=task,
198253
batch_size=64,
199254
num_fewshot=8,

vllm/model_executor/layers/fused_moe/config.py

Lines changed: 56 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,10 @@
99
from vllm.config import ParallelConfig
1010
from vllm.distributed import get_dp_group, get_tensor_model_parallel_rank
1111
from vllm.logger import init_logger
12+
from vllm.model_executor.layers.quantization.utils.ocp_mx_utils import (
13+
OCP_MX_DTYPES,
14+
OCP_MX_Scheme,
15+
)
1216
from vllm.model_executor.layers.quantization.utils.quant_utils import GroupShape
1317
from vllm.utils import cdiv, has_triton_kernels
1418
from vllm.utils.flashinfer import has_flashinfer_cutlass_fused_moe
@@ -30,7 +34,7 @@ def _get_config_dtype_str(
3034
use_fp8_w8a8: bool = False,
3135
use_int8_w8a16: bool = False,
3236
use_int4_w4a16: bool = False,
33-
use_mxfp4_w4a4: bool = False,
37+
ocp_mx_scheme: Optional[str] = None,
3438
) -> Optional[str]:
3539
"""
3640
Return a string used to construct the filename that contains the
@@ -43,8 +47,11 @@ def _get_config_dtype_str(
4347
return "int8_w8a16"
4448
elif use_int4_w4a16:
4549
return "int4_w4a16"
46-
elif use_mxfp4_w4a4:
47-
return "mxfp4_w4a4"
50+
elif ocp_mx_scheme is not None:
51+
# The output of this function is passed to `try_get_optimal_moe_config`,
52+
# and as we only simulate OCP MX execution in fused_moe for now,
53+
# we will NOT look for `*,dtype=w_mxfp4_a_mxfp4.json` for now.
54+
return None
4855
elif dtype == torch.float:
4956
# avoiding cases where kernel fails when float32 MoE
5057
# use fp16/bfloat16 configs
@@ -289,8 +296,23 @@ def use_int4_w4a16(self) -> bool:
289296
return self._a1.dtype is None and self._w1.dtype == "int4"
290297

291298
@property
292-
def use_mxfp4_w4a4(self) -> bool:
293-
return self._a1.dtype == "mxfp4" and self._w1.dtype == "mxfp4"
299+
def ocp_mx_scheme(self) -> Union[str, None]:
300+
if not hasattr(self, "_ocp_mx_scheme"):
301+
if (self._a1.dtype is not None and not isinstance(self._a1.dtype, str)) or (
302+
self._w1.dtype is not None and not isinstance(self._w1.dtype, str)
303+
):
304+
self._ocp_mx_scheme = None
305+
else:
306+
ocp_mx_scheme = OCP_MX_Scheme.from_quant_dtype(
307+
self._a1.dtype, self._w1.dtype
308+
)
309+
310+
if ocp_mx_scheme is not None:
311+
ocp_mx_scheme = ocp_mx_scheme.value
312+
313+
self._ocp_mx_scheme = ocp_mx_scheme
314+
315+
return self._ocp_mx_scheme
294316

295317
@property
296318
def use_mxfp4_w4a16(self) -> bool:
@@ -310,7 +332,7 @@ def config_name(self, dtype: torch.dtype) -> Optional[str]:
310332
use_fp8_w8a8=self.use_fp8_w8a8,
311333
use_int8_w8a16=self.use_int8_w8a16,
312334
use_int4_w4a16=self.use_int4_w4a16,
313-
use_mxfp4_w4a4=self.use_mxfp4_w4a4,
335+
ocp_mx_scheme=self.ocp_mx_scheme,
314336
dtype=dtype,
315337
)
316338

@@ -371,12 +393,14 @@ def make(
371393
w2_bias: Optional[torch.Tensor] = None,
372394
w1_zp: Optional[torch.Tensor] = None,
373395
w2_zp: Optional[torch.Tensor] = None,
396+
weight_dtype: Union[torch.dtype, str, None] = None,
374397
) -> "FusedMoEQuantConfig":
375398
"""
376399
General builder function for a FusedMoEQuantConfig.
377400
- quant_dtype: Optional quantization type. None if activations are
378-
unquantized or quantized prior to calling. Note: "nvfp4" and
379-
"mxfp4" are the only valid string values for quant_dtype.
401+
unquantized or quantized prior to calling. Note: "nvfp4", "mxfp4",
402+
"mxfp6_e3m2", "mxfp6_e2m3" are the only valid string values
403+
for quant_dtype.
380404
- per_act_token_quant: Activations have per token quantization.
381405
- per_out_ch_quant: Outputs have per channel quantization. (only
382406
for cutlass).
@@ -395,22 +419,33 @@ def make(
395419
- w1_zp: Optional w1 zero points for int4/int8 quantization.
396420
- w2_zp: Optional w2 zero points for int4/int8 quantization.
397421
"""
398-
assert (
399-
not isinstance(quant_dtype, str)
400-
or quant_dtype == "nvfp4"
401-
or quant_dtype == "mxfp4"
402-
)
422+
assert not isinstance(quant_dtype, str) or quant_dtype in {
423+
"nvfp4",
424+
"mxfp4",
425+
"mxfp6_e3m2",
426+
"mxfp6_e2m3",
427+
}
428+
assert not isinstance(weight_dtype, str) or weight_dtype in {
429+
"nvfp4",
430+
"mxfp4",
431+
"mxfp6_e3m2",
432+
"mxfp6_e2m3",
433+
}
434+
435+
if weight_dtype is None:
436+
weight_dtype = quant_dtype
437+
403438
a_shape, w_shape = _quant_flags_to_group_shape(
404439
quant_dtype, per_act_token_quant, per_out_ch_quant, block_shape
405440
)
406441
quant_config = FusedMoEQuantConfig(
407442
_a1=FusedMoEQuantDesc(quant_dtype, a_shape, a1_scale, a1_gscale),
408443
_a2=FusedMoEQuantDesc(quant_dtype, a_shape, a2_scale, a2_gscale),
409444
_w1=FusedMoEQuantDesc(
410-
quant_dtype, w_shape, w1_scale, g1_alphas, w1_zp, w1_bias
445+
weight_dtype, w_shape, w1_scale, g1_alphas, w1_zp, w1_bias
411446
),
412447
_w2=FusedMoEQuantDesc(
413-
quant_dtype, w_shape, w2_scale, g2_alphas, w2_zp, w2_bias
448+
weight_dtype, w_shape, w2_scale, g2_alphas, w2_zp, w2_bias
414449
),
415450
)
416451
assert quant_config.per_act_token_quant == per_act_token_quant
@@ -482,9 +517,11 @@ def mxfp4_w4a16_moe_quant_config(
482517
)
483518

484519

485-
def mxfp4_w4a4_moe_quant_config(
520+
def ocp_mx_moe_quant_config(
521+
quant_dtype: str,
486522
w1_scale: Union[torch.Tensor, "PrecisionConfig"],
487523
w2_scale: Union[torch.Tensor, "PrecisionConfig"],
524+
weight_dtype: Optional[str] = None,
488525
a1_scale: Optional[torch.Tensor] = None,
489526
a2_scale: Optional[torch.Tensor] = None,
490527
w1_bias: Optional[torch.Tensor] = None,
@@ -494,8 +531,10 @@ def mxfp4_w4a4_moe_quant_config(
494531
"""
495532
Construct a quant config for mxfp4 activations and mxfp4 weights.
496533
"""
534+
assert quant_dtype in OCP_MX_DTYPES
497535
return FusedMoEQuantConfig.make(
498-
"mxfp4",
536+
quant_dtype=quant_dtype,
537+
weight_dtype=weight_dtype,
499538
w1_scale=w1_scale,
500539
w2_scale=w2_scale,
501540
a1_scale=a1_scale,

0 commit comments

Comments
 (0)