4
4
from typing import TYPE_CHECKING
5
5
6
6
import vllm .envs as envs
7
+ from vllm .config .compilation import CUDAGraphMode
7
8
from vllm .logger import init_logger
8
9
from vllm .model_executor .models import ModelRegistry
9
10
from vllm .utils import STR_DTYPE_TO_TORCH_DTYPE , cdiv
@@ -275,6 +276,42 @@ def verify_and_update_config(vllm_config: "VllmConfig") -> None:
275
276
"%d for performance." , 1024 )
276
277
277
278
279
+ class MambaModelConfig (VerifyAndUpdateConfig ):
280
+
281
+ @classmethod
282
+ def verify_and_update_config (cls , vllm_config : "VllmConfig" ) -> None :
283
+ """
284
+ Enable FULL_AND_PIECEWISE cuda graph mode by default (required
285
+ to get good performance for mamba layers in V1).
286
+
287
+ Args:
288
+ vllm_config: vLLM Config
289
+ """
290
+
291
+ if not envs .VLLM_USE_V1 :
292
+ return
293
+
294
+ model_config = vllm_config .model_config
295
+ compilation_config = vllm_config .compilation_config
296
+
297
+ model_cls , _ = ModelRegistry .resolve_model_cls (
298
+ model_config .architecture ,
299
+ model_config = model_config ,
300
+ )
301
+
302
+ # TODO(tdoublep): remove as full cuda graph support is added
303
+ FCG_NOT_SUPPORTED_MODELS = [
304
+ "Lfm2ForCausalLM" , "MiniMaxText01ForCausalLM"
305
+ ]
306
+
307
+ if (model_config .architecture not in FCG_NOT_SUPPORTED_MODELS
308
+ and compilation_config .cudagraph_mode is None ):
309
+ logger .info (
310
+ "Hybrid or mamba-based model detected: setting cudagraph mode "
311
+ "to FULL_AND_PIECEWISE in order to optimize performance." )
312
+ compilation_config .cudagraph_mode = CUDAGraphMode .FULL_AND_PIECEWISE
313
+
314
+
278
315
class HybridAttentionMambaModelConfig (VerifyAndUpdateConfig ):
279
316
280
317
@classmethod
@@ -293,6 +330,9 @@ def verify_and_update_config(cls, vllm_config: "VllmConfig") -> None:
293
330
if not envs .VLLM_USE_V1 :
294
331
return
295
332
333
+ # Enable FULL_AND_PIECEWISE by default
334
+ MambaModelConfig .verify_and_update_config (vllm_config )
335
+
296
336
cache_config = vllm_config .cache_config
297
337
model_config = vllm_config .model_config
298
338
parallel_config = vllm_config .parallel_config
@@ -374,4 +414,6 @@ def verify_and_update_config(cls, vllm_config: "VllmConfig") -> None:
374
414
"JambaForSequenceClassification" : JambaForSequenceClassificationConfig ,
375
415
"GraniteMoeHybridForCausalLM" : GraniteMoeHybridModelConfig ,
376
416
"GptOssForCausalLM" : GptOssForCausalLMConfig ,
417
+ "MambaForCausalLM" : MambaModelConfig ,
418
+ "Mamba2ForCausalLM" : MambaModelConfig ,
377
419
}
0 commit comments