Skip to content

Commit 7702ce2

Browse files
tdoublepamd-xiaoyu12
authored andcommitted
[V1] [Hybrid] Enable Full CUDA graph by default for hybrid models in V1 (vllm-project#22594)
Signed-off-by: Thomas Parnell <[email protected]> Signed-off-by: Xiao Yu <[email protected]>
1 parent 05f913f commit 7702ce2

File tree

1 file changed

+42
-0
lines changed

1 file changed

+42
-0
lines changed

vllm/model_executor/models/config.py

Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
from typing import TYPE_CHECKING
55

66
import vllm.envs as envs
7+
from vllm.config.compilation import CUDAGraphMode
78
from vllm.logger import init_logger
89
from vllm.model_executor.models import ModelRegistry
910
from vllm.utils import STR_DTYPE_TO_TORCH_DTYPE, cdiv
@@ -275,6 +276,42 @@ def verify_and_update_config(vllm_config: "VllmConfig") -> None:
275276
"%d for performance.", 1024)
276277

277278

279+
class MambaModelConfig(VerifyAndUpdateConfig):
280+
281+
@classmethod
282+
def verify_and_update_config(cls, vllm_config: "VllmConfig") -> None:
283+
"""
284+
Enable FULL_AND_PIECEWISE cuda graph mode by default (required
285+
to get good performance for mamba layers in V1).
286+
287+
Args:
288+
vllm_config: vLLM Config
289+
"""
290+
291+
if not envs.VLLM_USE_V1:
292+
return
293+
294+
model_config = vllm_config.model_config
295+
compilation_config = vllm_config.compilation_config
296+
297+
model_cls, _ = ModelRegistry.resolve_model_cls(
298+
model_config.architecture,
299+
model_config=model_config,
300+
)
301+
302+
# TODO(tdoublep): remove as full cuda graph support is added
303+
FCG_NOT_SUPPORTED_MODELS = [
304+
"Lfm2ForCausalLM", "MiniMaxText01ForCausalLM"
305+
]
306+
307+
if (model_config.architecture not in FCG_NOT_SUPPORTED_MODELS
308+
and compilation_config.cudagraph_mode is None):
309+
logger.info(
310+
"Hybrid or mamba-based model detected: setting cudagraph mode "
311+
"to FULL_AND_PIECEWISE in order to optimize performance.")
312+
compilation_config.cudagraph_mode = CUDAGraphMode.FULL_AND_PIECEWISE
313+
314+
278315
class HybridAttentionMambaModelConfig(VerifyAndUpdateConfig):
279316

280317
@classmethod
@@ -293,6 +330,9 @@ def verify_and_update_config(cls, vllm_config: "VllmConfig") -> None:
293330
if not envs.VLLM_USE_V1:
294331
return
295332

333+
# Enable FULL_AND_PIECEWISE by default
334+
MambaModelConfig.verify_and_update_config(vllm_config)
335+
296336
cache_config = vllm_config.cache_config
297337
model_config = vllm_config.model_config
298338
parallel_config = vllm_config.parallel_config
@@ -374,4 +414,6 @@ def verify_and_update_config(cls, vllm_config: "VllmConfig") -> None:
374414
"JambaForSequenceClassification": JambaForSequenceClassificationConfig,
375415
"GraniteMoeHybridForCausalLM": GraniteMoeHybridModelConfig,
376416
"GptOssForCausalLM": GptOssForCausalLMConfig,
417+
"MambaForCausalLM": MambaModelConfig,
418+
"Mamba2ForCausalLM": MambaModelConfig,
377419
}

0 commit comments

Comments
 (0)