@@ -45,30 +45,36 @@ def _resolve_symbol(module, new: str, old: str) -> Callable[..., Any] | None:
45
45
return None
46
46
47
47
48
- if not has_deep_gemm ():
49
- _fp8_gemm_nt_impl : Callable [..., Any ] | None = None
50
- _grouped_impl : Callable [..., Any ] | None = None
51
- _grouped_masked_impl : Callable [..., Any ] | None = None
52
- _per_block_cast_impl : Callable [..., Any ] | None = None
53
- else :
54
- _dg = importlib .import_module ("deep_gemm" ) # type: ignore
55
-
56
- _fp8_gemm_nt_impl = _resolve_symbol (
57
- _dg ,
58
- "fp8_gemm_nt" ,
59
- "gemm_fp8_fp8_bf16_nt" ,
60
- )
48
+ _fp8_gemm_nt_impl : Callable [..., Any ] | None = None
49
+ _grouped_impl : Callable [..., Any ] | None = None
50
+ _grouped_masked_impl : Callable [..., Any ] | None = None
51
+ _per_block_cast_impl : Callable [..., Any ] | None = None
52
+
53
+
54
+ def _lazy_init () -> None :
55
+ """Import deep_gemm and resolve symbols on first use."""
56
+ global _fp8_gemm_nt_impl , _grouped_impl , _grouped_masked_impl , \
57
+ _per_block_cast_impl
58
+
59
+ # fast path
60
+ if (_fp8_gemm_nt_impl is not None or _grouped_impl is not None
61
+ or _grouped_masked_impl is not None
62
+ or _per_block_cast_impl is not None ):
63
+ return
64
+
65
+ if not has_deep_gemm ():
66
+ return
67
+
68
+ _dg = importlib .import_module ("deep_gemm" )
69
+
70
+ _fp8_gemm_nt_impl = _resolve_symbol (_dg , "fp8_gemm_nt" ,
71
+ "gemm_fp8_fp8_bf16_nt" )
61
72
_grouped_impl = _resolve_symbol (
62
- _dg ,
63
- "m_grouped_fp8_gemm_nt_contiguous" ,
64
- "m_grouped_gemm_fp8_fp8_bf16_nt_contiguous" ,
65
- )
73
+ _dg , "m_grouped_fp8_gemm_nt_contiguous" ,
74
+ "m_grouped_gemm_fp8_fp8_bf16_nt_contiguous" )
66
75
_grouped_masked_impl = _resolve_symbol (
67
- _dg ,
68
- "fp8_m_grouped_gemm_nt_masked" ,
69
- "m_grouped_gemm_fp8_fp8_bf16_nt_masked" ,
70
- )
71
-
76
+ _dg , "fp8_m_grouped_gemm_nt_masked" ,
77
+ "m_grouped_gemm_fp8_fp8_bf16_nt_masked" )
72
78
# Try to get per_token_cast_to_fp8 from DeepGEMM math utils.
73
79
try :
74
80
_math_mod = importlib .import_module (
@@ -80,24 +86,28 @@ def _resolve_symbol(module, new: str, old: str) -> Callable[..., Any] | None:
80
86
81
87
82
88
def fp8_gemm_nt (* args , ** kwargs ):
89
+ _lazy_init ()
83
90
if _fp8_gemm_nt_impl is None :
84
91
return _missing (* args , ** kwargs )
85
92
return _fp8_gemm_nt_impl (* args , ** kwargs )
86
93
87
94
88
95
def m_grouped_fp8_gemm_nt_contiguous (* args , ** kwargs ):
96
+ _lazy_init ()
89
97
if _grouped_impl is None :
90
98
return _missing (* args , ** kwargs )
91
99
return _grouped_impl (* args , ** kwargs )
92
100
93
101
94
102
def fp8_m_grouped_gemm_nt_masked (* args , ** kwargs ):
103
+ _lazy_init ()
95
104
if _grouped_masked_impl is None :
96
105
return _missing (* args , ** kwargs )
97
106
return _grouped_masked_impl (* args , ** kwargs )
98
107
99
108
100
109
def per_block_cast_to_fp8 (x , * args , ** kwargs ):
110
+ _lazy_init ()
101
111
if _per_block_cast_impl is not None and is_blackwell_deep_gemm_used ():
102
112
return _per_block_cast_impl (x , use_ue8m0 = True )
103
113
# TODO: refactor the `per_block_cast_to_fp8` from tests to vllm utils
0 commit comments