Skip to content

Commit 3af47c3

Browse files
[Feature] Add Hopper DeepGEMM E8M0 for DeepSeekV3.1 scale_fmt (#23666)
Signed-off-by: yewentao256 <[email protected]> Signed-off-by: youkaichao <[email protected]> Co-authored-by: youkaichao <[email protected]>
1 parent 513c1fe commit 3af47c3

File tree

10 files changed

+68
-53
lines changed

10 files changed

+68
-53
lines changed

tests/kernels/moe/test_block_fp8.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616
fused_topk, modular_triton_fused_moe)
1717
from vllm.platforms import current_platform
1818
from vllm.utils import has_deep_gemm
19-
from vllm.utils.deep_gemm import is_blackwell_deep_gemm_e8m0_used
19+
from vllm.utils.deep_gemm import is_deep_gemm_e8m0_used
2020

2121
dg_available = has_deep_gemm()
2222

@@ -226,8 +226,7 @@ def test_w8a8_block_fp8_fused_moe(M, N, K, E, topk, block_size, dtype, seed,
226226
@pytest.mark.parametrize("topk", TOP_KS)
227227
@pytest.mark.parametrize("seed", SEEDS)
228228
@pytest.mark.skipif(not dg_available, reason="DeepGemm kernels not available.")
229-
@pytest.mark.skipif(is_blackwell_deep_gemm_e8m0_used(),
230-
reason="Not E8M0 scale MOE")
229+
@pytest.mark.skipif(is_deep_gemm_e8m0_used(), reason="Not E8M0 scale MOE")
231230
@torch.inference_mode()
232231
def test_w8a8_block_fp8_deep_gemm_fused_moe(M, N, K, E, topk, seed,
233232
monkeypatch):

tests/kernels/moe/test_deepep_deepgemm_moe.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -20,8 +20,7 @@
2020
FusedMoEModularKernel)
2121
from vllm.platforms import current_platform
2222
from vllm.utils import has_deep_ep, has_deep_gemm
23-
from vllm.utils.deep_gemm import (is_blackwell_deep_gemm_e8m0_used,
24-
is_deep_gemm_supported)
23+
from vllm.utils.deep_gemm import is_deep_gemm_e8m0_used, is_deep_gemm_supported
2524

2625
from ...utils import multi_gpu_test
2726
from .parallel_utils import ProcessGroupInfo, parallel_launch
@@ -374,7 +373,7 @@ def _test_deepep_deepgemm_moe(
374373
@multi_gpu_test(num_gpus=2)
375374
@requires_deep_ep
376375
@requires_deep_gemm
377-
@pytest.mark.skipif(is_blackwell_deep_gemm_e8m0_used(),
376+
@pytest.mark.skipif(is_deep_gemm_e8m0_used(),
378377
reason="Skipping test for Blackwell DeepGEMM")
379378
def test_ht_deepep_deepgemm_moe(mnk: tuple[int, int, int], num_experts: int,
380379
topk: int, world_dp_size: tuple[int, int]):
@@ -432,7 +431,7 @@ def test_ht_deepep_deepgemm_moe(mnk: tuple[int, int, int], num_experts: int,
432431
@multi_gpu_test(num_gpus=2)
433432
@requires_deep_ep
434433
@requires_deep_gemm
435-
@pytest.mark.skipif(is_blackwell_deep_gemm_e8m0_used(),
434+
@pytest.mark.skipif(is_deep_gemm_e8m0_used(),
436435
reason="Skipping test for Blackwell DeepGEMM")
437436
def test_ll_deepep_deepgemm_moe(
438437
mnk: tuple[int, int, int],

vllm/envs.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -131,6 +131,7 @@
131131
VLLM_TPU_USING_PATHWAYS: bool = False
132132
VLLM_USE_DEEP_GEMM: bool = False
133133
VLLM_USE_DEEP_GEMM_E8M0: bool = True
134+
VLLM_USE_DEEP_GEMM_E8M0_HOPPER: bool = False
134135
VLLM_SKIP_DEEP_GEMM_WARMUP: bool = False
135136
VLLM_USE_FUSED_MOE_GROUPED_TOPK: bool = True
136137
VLLM_USE_FLASHINFER_MOE_FP8: bool = False
@@ -954,9 +955,12 @@ def get_vllm_port() -> Optional[int]:
954955
lambda: bool(int(os.getenv("VLLM_USE_DEEP_GEMM", "0"))),
955956

956957
# Whether to use E8M0 scaling when DeepGEMM is used on Blackwell GPUs.
957-
# E8M0 is faster on B200 but may reduce accuracy.
958958
"VLLM_USE_DEEP_GEMM_E8M0":
959959
lambda: bool(int(os.getenv("VLLM_USE_DEEP_GEMM_E8M0", "1"))),
960+
# TODO(wentao): unify the two E8M0 flags after verifying the correctness.
961+
# Whether to use E8M0 scaling when DeepGEMM is used on Hopper GPUs.
962+
"VLLM_USE_DEEP_GEMM_E8M0_HOPPER":
963+
lambda: bool(int(os.getenv("VLLM_USE_DEEP_GEMM_E8M0_HOPPER", "0"))),
960964
# DeepGemm JITs the kernels on-demand. The warmup attempts to make DeepGemm
961965
# JIT all the required kernels before model execution so there is no
962966
# JIT'ing in the hot-path. However, this warmup increases the engine
@@ -1244,6 +1248,8 @@ def compute_hash() -> str:
12441248
"VLLM_USE_FLASHINFER_SAMPLER",
12451249
"VLLM_DISABLED_KERNELS",
12461250
"VLLM_USE_DEEP_GEMM",
1251+
"VLLM_USE_DEEP_GEMM_E8M0",
1252+
"VLLM_USE_DEEP_GEMM_E8M0_HOPPER",
12471253
"VLLM_USE_TRTLLM_FP4_GEMM",
12481254
"VLLM_USE_FUSED_MOE_GROUPED_TOPK",
12491255
"VLLM_USE_FLASHINFER_MOE_FP8",

vllm/model_executor/layers/fused_moe/batched_deep_gemm_moe.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
from vllm.model_executor.layers.fused_moe.utils import _resize_cache
1313
from vllm.triton_utils import tl, triton
1414
from vllm.utils.deep_gemm import (fp8_m_grouped_gemm_nt_masked,
15-
is_blackwell_deep_gemm_e8m0_used)
15+
is_deep_gemm_e8m0_used)
1616

1717
logger = init_logger(__name__)
1818

@@ -174,7 +174,7 @@ def silu_mul_fp8_quant_deep_gemm(
174174
eps,
175175
fp8_min,
176176
fp8_max,
177-
is_blackwell_deep_gemm_e8m0_used(),
177+
is_deep_gemm_e8m0_used(),
178178
BLOCK=group_size,
179179
NUM_STAGES=4,
180180
num_warps=1,

vllm/model_executor/layers/fused_moe/fused_moe.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,7 @@
4040
from vllm.platforms import current_platform
4141
from vllm.triton_utils import tl, triton
4242
from vllm.utils import direct_register_custom_op, is_torch_equal_or_newer
43-
from vllm.utils.deep_gemm import is_blackwell_deep_gemm_e8m0_used
43+
from vllm.utils.deep_gemm import is_deep_gemm_e8m0_used
4444

4545
from .rocm_aiter_fused_moe import is_rocm_aiter_moe_enabled
4646

@@ -1431,9 +1431,8 @@ def fused_experts(hidden_states: torch.Tensor,
14311431
# E8M0 scale, which means we requantize the weight and input to the specific
14321432
# scale. Fallen back to cutlass or triton for some cases would cause
14331433
# accuracy issue.
1434-
if (allow_deep_gemm and use_fp8_w8a8
1435-
and (is_blackwell_deep_gemm_e8m0_used()
1436-
or _valid_deep_gemm(hidden_states, w1, w2))):
1434+
if (allow_deep_gemm and use_fp8_w8a8 and
1435+
(is_deep_gemm_e8m0_used() or _valid_deep_gemm(hidden_states, w1, w2))):
14371436
assert apply_router_weight_on_input is False
14381437
assert is_act_and_mul, (
14391438
"DeepGemm only supports is_act_and_mul=True for now.")

vllm/model_executor/layers/fused_moe/triton_deep_gemm_moe.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010
DeepGemmExperts, _valid_deep_gemm, _valid_deep_gemm_shape,
1111
deep_gemm_block_shape)
1212
from vllm.model_executor.layers.fused_moe.fused_moe import TritonExperts
13-
from vllm.utils.deep_gemm import is_blackwell_deep_gemm_e8m0_used
13+
from vllm.utils.deep_gemm import is_deep_gemm_e8m0_used
1414

1515

1616
class TritonOrDeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute):
@@ -107,7 +107,7 @@ def workspace_shapes(
107107
# Note: the deep gemm workspaces are strictly larger than the triton
108108
# workspaces so we can be pessimistic here and allocate for DeepGemm
109109
# even if we fall back to triton later, e.g. if expert maps are set.
110-
if self.allow_deep_gemm and (is_blackwell_deep_gemm_e8m0_used()
110+
if self.allow_deep_gemm and (is_deep_gemm_e8m0_used()
111111
or _valid_deep_gemm_shape(M, N, K)):
112112
assert self.deep_gemm_expert is not None
113113
return self.deep_gemm_expert.workspace_shapes(
@@ -143,7 +143,7 @@ def apply(
143143
):
144144
use_deep_gemm = (self.allow_deep_gemm
145145
and (_valid_deep_gemm(hidden_states, w1, w2)
146-
or is_blackwell_deep_gemm_e8m0_used()))
146+
or is_deep_gemm_e8m0_used()))
147147

148148
experts = self.deep_gemm_expert if use_deep_gemm else self.triton_expert
149149
assert experts is not None

vllm/model_executor/layers/quantization/fp8.py

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -48,8 +48,7 @@
4848
from vllm.platforms import current_platform
4949
from vllm.scalar_type import scalar_types
5050
from vllm.utils import has_deep_gemm
51-
from vllm.utils.deep_gemm import (is_blackwell_deep_gemm_e8m0_used,
52-
is_deep_gemm_supported)
51+
from vllm.utils.deep_gemm import is_deep_gemm_e8m0_used, is_deep_gemm_supported
5352
from vllm.utils.flashinfer import has_flashinfer_moe
5453

5554
if TYPE_CHECKING:
@@ -427,7 +426,7 @@ def process_weights_after_loading(self, layer: Module) -> None:
427426
# On B200, if E8M0 for DeepGemm is used, we need to
428427
# requantize the weight and input to the specific scale
429428
# at the same time.
430-
if is_blackwell_deep_gemm_e8m0_used():
429+
if is_deep_gemm_e8m0_used():
431430
assert layer.weight_block_size is not None
432431
block_sz = tuple(layer.weight_block_size)
433432
requant_weight_ue8m0_inplace(
@@ -734,7 +733,7 @@ def process_weights_after_loading(self, layer: Module) -> None:
734733

735734
# DeepGemm scales need to be transposed and aligned. We try to do
736735
# it ahead of time for performance reasons.
737-
if self.allow_deep_gemm and not is_blackwell_deep_gemm_e8m0_used():
736+
if self.allow_deep_gemm and not is_deep_gemm_e8m0_used():
738737
# Lazy import to avoid CUDA initialization problems.
739738
if _is_col_major(layer.w13_weight_scale_inv):
740739
layer.w13_weight_scale_inv = \
@@ -871,7 +870,7 @@ def process_weights_after_loading(self, layer: Module) -> None:
871870
del layer.w13_input_scale
872871
del layer.w2_input_scale
873872

874-
if is_blackwell_deep_gemm_e8m0_used():
873+
if is_deep_gemm_e8m0_used():
875874
assert layer.weight_block_size is not None
876875
# Re-quantise the expert weights so their scales are UE8M0.
877876
block_sz = tuple(layer.weight_block_size)

vllm/model_executor/layers/quantization/utils/fp8_utils.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@
2020
from vllm.platforms import current_platform
2121
from vllm.triton_utils import tl, triton
2222
from vllm.utils import cdiv, direct_register_custom_op
23-
from vllm.utils.deep_gemm import (is_blackwell_deep_gemm_e8m0_used,
23+
from vllm.utils.deep_gemm import (is_deep_gemm_e8m0_used,
2424
should_use_deepgemm_for_fp8_linear)
2525

2626
logger = init_logger(__name__)
@@ -385,7 +385,7 @@ def per_token_group_quant_fp8(
385385
scaling factor.
386386
"""
387387
if use_ue8m0 is None:
388-
use_ue8m0 = is_blackwell_deep_gemm_e8m0_used()
388+
use_ue8m0 = is_deep_gemm_e8m0_used()
389389
dtype = current_platform.fp8_dtype() if dtype is None else dtype
390390
assert (x.shape[-1] % group_size == 0), (
391391
f"the last dimension of `x` {x.shape[-1]} must be divisible "

vllm/transformers_utils/config.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -501,6 +501,24 @@ def get_config(
501501

502502
if quantization_config is not None:
503503
config.quantization_config = quantization_config
504+
# auto-enable DeepGEMM UE8M0 on Hopper if model config requests it
505+
scale_fmt = quantization_config.get("scale_fmt", None)
506+
if scale_fmt in ("ue8m0", ):
507+
if not envs.is_set("VLLM_USE_DEEP_GEMM_E8M0_HOPPER"):
508+
os.environ["VLLM_USE_DEEP_GEMM_E8M0_HOPPER"] = "1"
509+
logger.info_once(
510+
("Detected quantization_config.scale_fmt=%s; "
511+
"enabling Hopper UE8M0."),
512+
scale_fmt,
513+
)
514+
elif not envs.VLLM_USE_DEEP_GEMM_E8M0_HOPPER:
515+
logger.warning_once(
516+
("Model config requests UE8M0 "
517+
"(quantization_config.scale_fmt=%s), but "
518+
"VLLM_USE_DEEP_GEMM_E8M0_HOPPER=0 is set; "
519+
"Hopper UE8M0 disabled."),
520+
scale_fmt,
521+
)
504522

505523
if hf_overrides_kw:
506524
logger.debug("Overriding HF config with %s", hf_overrides_kw)

vllm/utils/deep_gemm.py

Lines changed: 24 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -31,34 +31,33 @@ def is_deep_gemm_supported() -> bool:
3131

3232

3333
@functools.cache
34-
def is_blackwell_deep_gemm_e8m0_used() -> bool:
34+
def is_deep_gemm_e8m0_used() -> bool:
3535
"""Return ``True`` if vLLM is configured to use DeepGEMM "
36-
"E8M0 scale on a Blackwell-class GPU.
36+
"E8M0 scale on a Hopper or Blackwell-class GPU.
3737
"""
3838
if not is_deep_gemm_supported():
39-
logger.debug_once(
39+
logger.info_once(
4040
"DeepGEMM E8M0 disabled: DeepGEMM not supported on this system.")
4141
return False
4242

43-
if not envs.VLLM_USE_DEEP_GEMM_E8M0:
44-
logger.debug_once("DeepGEMM E8M0 disabled: VLLM_USE_DEEP_GEMM_E8M0=0.")
45-
return False
46-
4743
_lazy_init()
4844

4945
if _fp8_gemm_nt_impl is None:
50-
logger.debug_once(
51-
"DeepGEMM E8M0 disabled: _fp8_gemm_nt_impl not found")
46+
logger.info_once("DeepGEMM E8M0 disabled: _fp8_gemm_nt_impl not found")
5247
return False
5348

54-
enabled = (current_platform.is_cuda()
55-
and current_platform.has_device_capability(100))
56-
if enabled:
57-
logger.debug_once("DeepGEMM E8M0 enabled on Blackwell GPU.")
58-
else:
59-
logger.debug_once(
60-
"DeepGEMM E8M0 disabled: not running on Blackwell GPU.")
61-
return enabled
49+
if current_platform.is_device_capability(100) and \
50+
envs.VLLM_USE_DEEP_GEMM_E8M0:
51+
logger.info_once("DeepGEMM E8M0 enabled on Blackwell GPU.")
52+
return True
53+
54+
if current_platform.is_device_capability(90) and \
55+
envs.VLLM_USE_DEEP_GEMM_E8M0_HOPPER:
56+
logger.info_once("DeepGEMM E8M0 enabled on Hopper GPU.")
57+
return True
58+
59+
logger.info_once("DeepGEMM E8M0 disabled on current configuration.")
60+
return False
6261

6362

6463
def _missing(*_: Any, **__: Any) -> NoReturn:
@@ -124,30 +123,26 @@ def fp8_gemm_nt(*args, **kwargs):
124123
_lazy_init()
125124
if _fp8_gemm_nt_impl is None:
126125
return _missing(*args, **kwargs)
127-
return _fp8_gemm_nt_impl(
128-
*args,
129-
disable_ue8m0_cast=not is_blackwell_deep_gemm_e8m0_used(),
130-
**kwargs)
126+
return _fp8_gemm_nt_impl(*args,
127+
disable_ue8m0_cast=not is_deep_gemm_e8m0_used(),
128+
**kwargs)
131129

132130

133131
def m_grouped_fp8_gemm_nt_contiguous(*args, **kwargs):
134132
_lazy_init()
135133
if _grouped_impl is None:
136134
return _missing(*args, **kwargs)
137-
return _grouped_impl(
138-
*args,
139-
disable_ue8m0_cast=not is_blackwell_deep_gemm_e8m0_used(),
140-
**kwargs)
135+
return _grouped_impl(*args,
136+
disable_ue8m0_cast=not is_deep_gemm_e8m0_used(),
137+
**kwargs)
141138

142139

143140
def fp8_m_grouped_gemm_nt_masked(*args, **kwargs):
144141
_lazy_init()
145142
if _grouped_masked_impl is None:
146143
return _missing(*args, **kwargs)
147144
return _grouped_masked_impl(
148-
*args,
149-
disable_ue8m0_cast=not is_blackwell_deep_gemm_e8m0_used(),
150-
**kwargs)
145+
*args, disable_ue8m0_cast=not is_deep_gemm_e8m0_used(), **kwargs)
151146

152147

153148
def _ceil_to_ue8m0(x: torch.Tensor):
@@ -211,7 +206,7 @@ def should_use_deepgemm_for_fp8_linear(output_dtype: torch.dtype,
211206
"m_grouped_fp8_gemm_nt_contiguous",
212207
"fp8_m_grouped_gemm_nt_masked",
213208
"per_block_cast_to_fp8",
214-
"is_blackwell_deep_gemm_e8m0_used",
209+
"is_deep_gemm_e8m0_used",
215210
"is_deep_gemm_supported",
216211
"should_use_deepgemm_for_fp8_linear",
217212
]

0 commit comments

Comments
 (0)