Skip to content

Commit f7dcce7

Browse files
authored
[Feature] Add VLLM_USE_DEEP_GEMM_E8M0 Env to Control E8M0 Scale (#21968)
Signed-off-by: yewentao256 <[email protected]>
1 parent 8e13d9f commit f7dcce7

File tree

9 files changed

+65
-39
lines changed

9 files changed

+65
-39
lines changed

tests/kernels/moe/test_block_fp8.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616
fused_topk, modular_triton_fused_moe)
1717
from vllm.platforms import current_platform
1818
from vllm.utils import has_deep_gemm
19-
from vllm.utils.deep_gemm import is_blackwell_deep_gemm_used
19+
from vllm.utils.deep_gemm import is_blackwell_deep_gemm_e8m0_used
2020

2121
dg_available = has_deep_gemm()
2222

@@ -224,7 +224,8 @@ def test_w8a8_block_fp8_fused_moe(M, N, K, E, topk, block_size, dtype, seed,
224224
@pytest.mark.parametrize("topk", TOP_KS)
225225
@pytest.mark.parametrize("seed", SEEDS)
226226
@pytest.mark.skipif(not dg_available, reason="DeepGemm kernels not available.")
227-
@pytest.mark.skipif(is_blackwell_deep_gemm_used(), reason="Not E8M0 scale MOE")
227+
@pytest.mark.skipif(is_blackwell_deep_gemm_e8m0_used(),
228+
reason="Not E8M0 scale MOE")
228229
@torch.inference_mode()
229230
def test_w8a8_block_fp8_deep_gemm_fused_moe(M, N, K, E, topk, seed,
230231
monkeypatch):

tests/kernels/moe/test_deepep_deepgemm_moe.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@
2020
FusedMoEModularKernel)
2121
from vllm.platforms import current_platform
2222
from vllm.utils import has_deep_ep, has_deep_gemm
23-
from vllm.utils.deep_gemm import (is_blackwell_deep_gemm_used,
23+
from vllm.utils.deep_gemm import (is_blackwell_deep_gemm_e8m0_used,
2424
is_deep_gemm_supported)
2525

2626
from .parallel_utils import ProcessGroupInfo, parallel_launch
@@ -370,7 +370,7 @@ def _test_deepep_deepgemm_moe(
370370
@pytest.mark.parametrize("world_dp_size", [(2, 1)])
371371
@requires_deep_ep
372372
@requires_deep_gemm
373-
@pytest.mark.skipif(is_blackwell_deep_gemm_used(),
373+
@pytest.mark.skipif(is_blackwell_deep_gemm_e8m0_used(),
374374
reason="Skipping test for Blackwell DeepGEMM")
375375
def test_ht_deepep_deepgemm_moe(mnk: tuple[int, int, int], num_experts: int,
376376
topk: int, world_dp_size: tuple[int, int]):
@@ -427,7 +427,7 @@ def test_ht_deepep_deepgemm_moe(mnk: tuple[int, int, int], num_experts: int,
427427
@pytest.mark.parametrize("world_dp_size", [(2, 1)])
428428
@requires_deep_ep
429429
@requires_deep_gemm
430-
@pytest.mark.skipif(is_blackwell_deep_gemm_used(),
430+
@pytest.mark.skipif(is_blackwell_deep_gemm_e8m0_used(),
431431
reason="Skipping test for Blackwell DeepGEMM")
432432
def test_ll_deepep_deepgemm_moe(
433433
mnk: tuple[int, int, int],

vllm/envs.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -127,6 +127,7 @@
127127
VLLM_TPU_MOST_MODEL_LEN: Optional[int] = None
128128
VLLM_TPU_USING_PATHWAYS: bool = False
129129
VLLM_USE_DEEP_GEMM: bool = False
130+
VLLM_USE_DEEP_GEMM_E8M0: bool = True
130131
VLLM_SKIP_DEEP_GEMM_WARMUP: bool = False
131132
VLLM_USE_FLASHINFER_MOE_FP8: bool = False
132133
VLLM_USE_FLASHINFER_MOE_FP4: bool = False
@@ -925,6 +926,10 @@ def get_vllm_port() -> Optional[int]:
925926
"VLLM_USE_DEEP_GEMM":
926927
lambda: bool(int(os.getenv("VLLM_USE_DEEP_GEMM", "0"))),
927928

929+
# Whether to use E8M0 scaling when DeepGEMM is used on Blackwell GPUs.
930+
# E8M0 is faster on B200 but may reduce accuracy.
931+
"VLLM_USE_DEEP_GEMM_E8M0":
932+
lambda: bool(int(os.getenv("VLLM_USE_DEEP_GEMM_E8M0", "1"))),
928933
# DeepGemm JITs the kernels on-demand. The warmup attempts to make DeepGemm
929934
# JIT all the required kernels before model execution so there is no
930935
# JIT'ing in the hot-path. However, this warmup increases the engine

vllm/model_executor/layers/fused_moe/batched_deep_gemm_moe.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
from vllm.model_executor.layers.fused_moe.utils import _resize_cache
1313
from vllm.triton_utils import tl, triton
1414
from vllm.utils.deep_gemm import (fp8_m_grouped_gemm_nt_masked,
15-
is_blackwell_deep_gemm_used)
15+
is_blackwell_deep_gemm_e8m0_used)
1616

1717
logger = init_logger(__name__)
1818

@@ -176,7 +176,7 @@ def silu_mul_fp8_quant_deep_gemm(
176176
eps,
177177
fp8_min,
178178
fp8_max,
179-
is_blackwell_deep_gemm_used(),
179+
is_blackwell_deep_gemm_e8m0_used(),
180180
BLOCK=group_size,
181181
NUM_STAGES=8,
182182
num_warps=1,

vllm/model_executor/layers/fused_moe/fused_moe.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,7 @@
4040
from vllm.platforms import current_platform
4141
from vllm.triton_utils import tl, triton
4242
from vllm.utils import direct_register_custom_op, is_torch_equal_or_newer
43-
from vllm.utils.deep_gemm import is_blackwell_deep_gemm_used
43+
from vllm.utils.deep_gemm import is_blackwell_deep_gemm_e8m0_used
4444

4545
from .rocm_aiter_fused_moe import is_rocm_aiter_moe_enabled
4646

@@ -1387,8 +1387,8 @@ def fused_experts(hidden_states: torch.Tensor,
13871387
# E8M0 scale, which means we requantize the weight and input to the specific
13881388
# scale. Fallen back to cutlass or triton for some cases would cause
13891389
# accuracy issue.
1390-
should_use_deep_gemm = is_blackwell_deep_gemm_used() or _valid_deep_gemm(
1391-
hidden_states, w1, w2)
1390+
should_use_deep_gemm = is_blackwell_deep_gemm_e8m0_used(
1391+
) or _valid_deep_gemm(hidden_states, w1, w2)
13921392
if (allow_deep_gemm and use_fp8_w8a8 and should_use_deep_gemm):
13931393
assert apply_router_weight_on_input is False
13941394
assert is_act_and_mul, (

vllm/model_executor/layers/fused_moe/triton_deep_gemm_moe.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010
DeepGemmExperts, _valid_deep_gemm, _valid_deep_gemm_shape,
1111
deep_gemm_block_shape)
1212
from vllm.model_executor.layers.fused_moe.fused_moe import TritonExperts
13-
from vllm.utils.deep_gemm import is_blackwell_deep_gemm_used
13+
from vllm.utils.deep_gemm import is_blackwell_deep_gemm_e8m0_used
1414

1515

1616
class TritonOrDeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute):
@@ -107,7 +107,7 @@ def workspace_shapes(
107107
# Note: the deep gemm workspaces are strictly larger than the triton
108108
# workspaces so we can be pessimistic here and allocate for DeepGemm
109109
# even if we fall back to triton later, e.g. if expert maps are set.
110-
if self.allow_deep_gemm and (is_blackwell_deep_gemm_used()
110+
if self.allow_deep_gemm and (is_blackwell_deep_gemm_e8m0_used()
111111
or _valid_deep_gemm_shape(M, N, K)):
112112
assert self.deep_gemm_expert is not None
113113
return self.deep_gemm_expert.workspace_shapes(
@@ -133,7 +133,7 @@ def apply(self, output: torch.Tensor, hidden_states: torch.Tensor,
133133
extra_expert_args: Optional[dict[str, Any]]):
134134
use_deep_gemm = (self.allow_deep_gemm
135135
and (_valid_deep_gemm(hidden_states, w1, w2)
136-
or is_blackwell_deep_gemm_used()))
136+
or is_blackwell_deep_gemm_e8m0_used()))
137137

138138
experts = self.deep_gemm_expert if use_deep_gemm else self.triton_expert
139139
assert experts is not None

vllm/model_executor/layers/quantization/fp8.py

Lines changed: 7 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -45,7 +45,8 @@
4545
from vllm.platforms import current_platform
4646
from vllm.scalar_type import scalar_types
4747
from vllm.utils import has_deep_gemm
48-
from vllm.utils.deep_gemm import is_blackwell_deep_gemm_used
48+
from vllm.utils.deep_gemm import (is_blackwell_deep_gemm_e8m0_used,
49+
is_deep_gemm_supported)
4950
from vllm.utils.flashinfer import has_flashinfer_moe
5051

5152
if TYPE_CHECKING:
@@ -415,10 +416,10 @@ def process_weights_after_loading(self, layer: Module) -> None:
415416
# Activations not quantized for marlin.
416417
del layer.input_scale
417418

418-
# On B200, DeepGemm only support E8M0 scale, which means we need to
419+
# On B200, if E8M0 for DeepGemm is used, we need to
419420
# requantize the weight and input to the specific scale
420421
# at the same time.
421-
if is_blackwell_deep_gemm_used():
422+
if is_blackwell_deep_gemm_e8m0_used():
422423
assert layer.weight_block_size is not None
423424
block_sz = tuple(layer.weight_block_size)
424425
requant_weight_ue8m0_inplace(
@@ -505,15 +506,9 @@ def __init__(self, quant_config: Fp8Config):
505506
elif not self.block_quant:
506507
logger.warning_once("Model is not block quantized. Not using "
507508
"DeepGemm kernels")
508-
elif (current_platform.is_cuda()
509-
and current_platform.is_device_capability(90)):
509+
elif (is_deep_gemm_supported()):
510510
logger.info_once("Using DeepGemm kernels for Fp8MoEMethod.")
511511
self.allow_deep_gemm = True
512-
elif (current_platform.is_cuda()
513-
and is_blackwell_deep_gemm_used()):
514-
logger.info_once("Using DeepGemm SM100 kernels for "
515-
"Fp8MoEMethod.")
516-
self.allow_deep_gemm = True
517512
else:
518513
logger.warning_once(
519514
"DeepGemm not supported on the current platform.")
@@ -725,7 +720,7 @@ def process_weights_after_loading(self, layer: Module) -> None:
725720

726721
# DeepGemm scales need to be transposed and aligned. We try to do
727722
# it ahead of time for performance reasons.
728-
if self.allow_deep_gemm and not is_blackwell_deep_gemm_used():
723+
if self.allow_deep_gemm and not is_blackwell_deep_gemm_e8m0_used():
729724
# Lazy import to avoid CUDA initialization problems.
730725
if _is_col_major(layer.w13_weight_scale_inv):
731726
layer.w13_weight_scale_inv = \
@@ -851,7 +846,7 @@ def process_weights_after_loading(self, layer: Module) -> None:
851846
del layer.w13_input_scale
852847
del layer.w2_input_scale
853848

854-
if is_blackwell_deep_gemm_used():
849+
if is_blackwell_deep_gemm_e8m0_used():
855850
assert layer.weight_block_size is not None
856851
# Re-quantise the expert weights so their scales are UE8M0.
857852
block_sz = tuple(layer.weight_block_size)

vllm/model_executor/layers/quantization/utils/fp8_utils.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@
2020
from vllm.platforms import current_platform
2121
from vllm.triton_utils import tl, triton
2222
from vllm.utils import cdiv, direct_register_custom_op, has_deep_gemm
23-
from vllm.utils.deep_gemm import is_blackwell_deep_gemm_used
23+
from vllm.utils.deep_gemm import is_blackwell_deep_gemm_e8m0_used
2424

2525
logger = init_logger(__name__)
2626

@@ -394,10 +394,8 @@ def per_token_group_quant_fp8(
394394
tuple[torch.Tensor, torch.Tensor]: The quantized tensor and the
395395
scaling factor.
396396
"""
397-
# TODO(wentao): refactor this
398-
# use_ue8m0 should be a global flag that could be set by user
399397
if use_ue8m0 is None:
400-
use_ue8m0 = is_blackwell_deep_gemm_used()
398+
use_ue8m0 = is_blackwell_deep_gemm_e8m0_used()
401399
dtype = current_platform.fp8_dtype() if dtype is None else dtype
402400
assert (x.shape[-1] % group_size == 0), (
403401
f"the last dimension of `x` {x.shape[-1]} must be divisible "

vllm/utils/deep_gemm.py

Lines changed: 37 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -31,19 +31,37 @@ def is_deep_gemm_supported() -> bool:
3131

3232

3333
@functools.cache
34-
def is_blackwell_deep_gemm_used() -> bool:
35-
"""Return ``True`` if vLLM is configured to use DeepGEMM on a
36-
Blackwell-class GPU.
34+
def is_blackwell_deep_gemm_e8m0_used() -> bool:
35+
"""Return ``True`` if vLLM is configured to use DeepGEMM "
36+
"E8M0 scale on a Blackwell-class GPU.
3737
"""
38-
if not (envs.VLLM_USE_DEEP_GEMM and has_deep_gemm()):
38+
if not (envs.VLLM_USE_DEEP_GEMM):
39+
logger.debug_once("DeepGEMM E8M0 disabled: VLLM_USE_DEEP_GEMM=0.")
40+
return False
41+
42+
if not has_deep_gemm():
43+
logger.debug_once("DeepGEMM E8M0 disabled: DeepGEMM backend missing.")
44+
return False
45+
46+
if not envs.VLLM_USE_DEEP_GEMM_E8M0:
47+
logger.debug_once("DeepGEMM E8M0 disabled: VLLM_USE_DEEP_GEMM_E8M0=0.")
3948
return False
4049

4150
_lazy_init()
51+
4252
if _fp8_gemm_nt_impl is None:
53+
logger.debug_once(
54+
"DeepGEMM E8M0 disabled: _fp8_gemm_nt_impl not found")
4355
return False
4456

45-
return (current_platform.is_cuda()
46-
and current_platform.is_device_capability(100))
57+
enabled = (current_platform.is_cuda()
58+
and current_platform.has_device_capability(100))
59+
if enabled:
60+
logger.debug_once("DeepGEMM E8M0 enabled on Blackwell GPU.")
61+
else:
62+
logger.debug_once(
63+
"DeepGEMM E8M0 disabled: not running on Blackwell GPU.")
64+
return enabled
4765

4866

4967
def _missing(*_: Any, **__: Any) -> NoReturn:
@@ -109,21 +127,30 @@ def fp8_gemm_nt(*args, **kwargs):
109127
_lazy_init()
110128
if _fp8_gemm_nt_impl is None:
111129
return _missing(*args, **kwargs)
112-
return _fp8_gemm_nt_impl(*args, **kwargs)
130+
return _fp8_gemm_nt_impl(
131+
*args,
132+
disable_ue8m0_cast=not is_blackwell_deep_gemm_e8m0_used(),
133+
**kwargs)
113134

114135

115136
def m_grouped_fp8_gemm_nt_contiguous(*args, **kwargs):
116137
_lazy_init()
117138
if _grouped_impl is None:
118139
return _missing(*args, **kwargs)
119-
return _grouped_impl(*args, **kwargs)
140+
return _grouped_impl(
141+
*args,
142+
disable_ue8m0_cast=not is_blackwell_deep_gemm_e8m0_used(),
143+
**kwargs)
120144

121145

122146
def fp8_m_grouped_gemm_nt_masked(*args, **kwargs):
123147
_lazy_init()
124148
if _grouped_masked_impl is None:
125149
return _missing(*args, **kwargs)
126-
return _grouped_masked_impl(*args, **kwargs)
150+
return _grouped_masked_impl(
151+
*args,
152+
disable_ue8m0_cast=not is_blackwell_deep_gemm_e8m0_used(),
153+
**kwargs)
127154

128155

129156
def _ceil_to_ue8m0(x: torch.Tensor):
@@ -181,6 +208,6 @@ def calc_diff(x: torch.Tensor, y: torch.Tensor):
181208
"m_grouped_fp8_gemm_nt_contiguous",
182209
"fp8_m_grouped_gemm_nt_masked",
183210
"per_block_cast_to_fp8",
184-
"is_blackwell_deep_gemm_used",
211+
"is_blackwell_deep_gemm_e8m0_used",
185212
"is_deep_gemm_supported",
186213
]

0 commit comments

Comments
 (0)