Skip to content

Commit e69a92a

Browse files
authored
[Bug] DeepGemm: Fix Cuda Init Error (#21312)
Signed-off-by: yewentao256 <[email protected]>
1 parent 8425f78 commit e69a92a

File tree

1 file changed

+32
-22
lines changed

1 file changed

+32
-22
lines changed

vllm/utils/deep_gemm.py

Lines changed: 32 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -45,30 +45,36 @@ def _resolve_symbol(module, new: str, old: str) -> Callable[..., Any] | None:
4545
return None
4646

4747

48-
if not has_deep_gemm():
49-
_fp8_gemm_nt_impl: Callable[..., Any] | None = None
50-
_grouped_impl: Callable[..., Any] | None = None
51-
_grouped_masked_impl: Callable[..., Any] | None = None
52-
_per_block_cast_impl: Callable[..., Any] | None = None
53-
else:
54-
_dg = importlib.import_module("deep_gemm") # type: ignore
55-
56-
_fp8_gemm_nt_impl = _resolve_symbol(
57-
_dg,
58-
"fp8_gemm_nt",
59-
"gemm_fp8_fp8_bf16_nt",
60-
)
48+
_fp8_gemm_nt_impl: Callable[..., Any] | None = None
49+
_grouped_impl: Callable[..., Any] | None = None
50+
_grouped_masked_impl: Callable[..., Any] | None = None
51+
_per_block_cast_impl: Callable[..., Any] | None = None
52+
53+
54+
def _lazy_init() -> None:
55+
"""Import deep_gemm and resolve symbols on first use."""
56+
global _fp8_gemm_nt_impl, _grouped_impl, _grouped_masked_impl, \
57+
_per_block_cast_impl
58+
59+
# fast path
60+
if (_fp8_gemm_nt_impl is not None or _grouped_impl is not None
61+
or _grouped_masked_impl is not None
62+
or _per_block_cast_impl is not None):
63+
return
64+
65+
if not has_deep_gemm():
66+
return
67+
68+
_dg = importlib.import_module("deep_gemm")
69+
70+
_fp8_gemm_nt_impl = _resolve_symbol(_dg, "fp8_gemm_nt",
71+
"gemm_fp8_fp8_bf16_nt")
6172
_grouped_impl = _resolve_symbol(
62-
_dg,
63-
"m_grouped_fp8_gemm_nt_contiguous",
64-
"m_grouped_gemm_fp8_fp8_bf16_nt_contiguous",
65-
)
73+
_dg, "m_grouped_fp8_gemm_nt_contiguous",
74+
"m_grouped_gemm_fp8_fp8_bf16_nt_contiguous")
6675
_grouped_masked_impl = _resolve_symbol(
67-
_dg,
68-
"fp8_m_grouped_gemm_nt_masked",
69-
"m_grouped_gemm_fp8_fp8_bf16_nt_masked",
70-
)
71-
76+
_dg, "fp8_m_grouped_gemm_nt_masked",
77+
"m_grouped_gemm_fp8_fp8_bf16_nt_masked")
7278
# Try to get per_token_cast_to_fp8 from DeepGEMM math utils.
7379
try:
7480
_math_mod = importlib.import_module(
@@ -80,24 +86,28 @@ def _resolve_symbol(module, new: str, old: str) -> Callable[..., Any] | None:
8086

8187

8288
def fp8_gemm_nt(*args, **kwargs):
89+
_lazy_init()
8390
if _fp8_gemm_nt_impl is None:
8491
return _missing(*args, **kwargs)
8592
return _fp8_gemm_nt_impl(*args, **kwargs)
8693

8794

8895
def m_grouped_fp8_gemm_nt_contiguous(*args, **kwargs):
96+
_lazy_init()
8997
if _grouped_impl is None:
9098
return _missing(*args, **kwargs)
9199
return _grouped_impl(*args, **kwargs)
92100

93101

94102
def fp8_m_grouped_gemm_nt_masked(*args, **kwargs):
103+
_lazy_init()
95104
if _grouped_masked_impl is None:
96105
return _missing(*args, **kwargs)
97106
return _grouped_masked_impl(*args, **kwargs)
98107

99108

100109
def per_block_cast_to_fp8(x, *args, **kwargs):
110+
_lazy_init()
101111
if _per_block_cast_impl is not None and is_blackwell_deep_gemm_used():
102112
return _per_block_cast_impl(x, use_ue8m0=True)
103113
# TODO: refactor the `per_block_cast_to_fp8` from tests to vllm utils

0 commit comments

Comments
 (0)