Skip to content

Commit bad4c58

Browse files
committed
harden things a bit
Signed-off-by: Thomas Parnell <[email protected]>
1 parent f70e398 commit bad4c58

File tree

11 files changed

+88
-27
lines changed

11 files changed

+88
-27
lines changed

vllm/config/__init__.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1678,6 +1678,10 @@ def is_attention_free(self) -> bool:
16781678
def is_hybrid(self) -> bool:
16791679
return self._model_info.is_hybrid
16801680

1681+
@property
1682+
def has_mamba2(self) -> bool:
1683+
return self._model_info.has_mamba2
1684+
16811685
@property
16821686
def has_noops(self) -> bool:
16831687
return self._model_info.has_noops
@@ -4215,14 +4219,18 @@ def try_verify_and_update_config(self):
42154219
return
42164220

42174221
from vllm.model_executor.models.config import (
4218-
MODELS_CONFIG_MAP, HybridAttentionMambaModelConfig)
4222+
MODELS_CONFIG_MAP, HybridAttentionMambaModelConfig,
4223+
Mamba2ModelConfig)
42194224
cls = MODELS_CONFIG_MAP.get(architecture, None)
42204225
if cls is not None:
42214226
cls.verify_and_update_config(self)
42224227

42234228
if self.model_config.is_hybrid:
42244229
HybridAttentionMambaModelConfig.verify_and_update_config(self)
42254230

4231+
if self.model_config.has_mamba2:
4232+
Mamba2ModelConfig.verify_and_update_config(self)
4233+
42264234
if self.model_config.convert_type == "classify":
42274235
# Maybe convert ForCausalLM into ForSequenceClassification model.
42284236
from vllm.model_executor.models.adapters import (

vllm/config/compilation.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -214,7 +214,7 @@ class CompilationConfig:
214214
are always used, it can set this to False. Otherwise, it should
215215
set this to True, and the compiler will copy the input to an
216216
internally managed buffer. Default is False."""
217-
full_cuda_graph: bool = False
217+
full_cuda_graph: Optional[bool] = None
218218
"""whether to use a full cuda graph for the entire forward pass rather than
219219
splitting certain operations such as attention into subgraphs. Thus this
220220
flag cannot be used together with splitting_ops. This may provide
@@ -344,7 +344,8 @@ def __post_init__(self) -> None:
344344
def init_backend(self, vllm_config: VllmConfig) -> Union[str, Callable]:
345345
if self.level == CompilationLevel.NO_COMPILATION:
346346
raise ValueError("No compilation level is set.")
347-
347+
if self.full_cuda_graph is None:
348+
self.full_cuda_graph = False
348349
from torch._dynamo.backends.registry import list_backends
349350
torch_backends = list_backends(exclude_tags=tuple())
350351
if self.level in [

vllm/model_executor/models/bamba.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -38,8 +38,8 @@
3838
from vllm.sequence import IntermediateTensors
3939
from vllm.utils import LayerBlockType
4040

41-
from .interfaces import (HasInnerState, IsHybrid, SupportsLoRA, SupportsPP,
42-
SupportsQuant)
41+
from .interfaces import (HasInnerState, HasMamba2, IsHybrid, SupportsLoRA,
42+
SupportsPP, SupportsQuant)
4343
from .utils import (AutoWeightsLoader, is_pp_missing_parameter,
4444
make_empty_intermediate_tensors_factory, make_layers,
4545
maybe_prefix)
@@ -420,7 +420,7 @@ def load_weights(self, weights: Iterable[tuple[str,
420420

421421

422422
class BambaForCausalLM(nn.Module, HasInnerState, SupportsLoRA, SupportsPP,
423-
IsHybrid, SupportsQuant):
423+
IsHybrid, SupportsQuant, HasMamba2):
424424
packed_modules_mapping = {
425425
"qkv_proj": [
426426
"q_proj",

vllm/model_executor/models/config.py

Lines changed: 19 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -275,6 +275,25 @@ def verify_and_update_config(vllm_config: "VllmConfig") -> None:
275275
"%d for performance.", 1024)
276276

277277

278+
class Mamba2ModelConfig(VerifyAndUpdateConfig):
279+
280+
@classmethod
281+
def verify_and_update_config(cls, vllm_config: "VllmConfig") -> None:
282+
"""
283+
Enable full cuda graphs for decode-only batches to ensure that
284+
V1 performance matches that of V0.
285+
286+
Args:
287+
vllm_config: vLLM Config
288+
"""
289+
if not envs.VLLM_USE_V1:
290+
return
291+
292+
compilation_config = vllm_config.compilation_config
293+
if compilation_config.full_cuda_graph is None:
294+
compilation_config.full_cuda_graph = True
295+
296+
278297
class HybridAttentionMambaModelConfig(VerifyAndUpdateConfig):
279298

280299
@classmethod
@@ -296,7 +315,6 @@ def verify_and_update_config(cls, vllm_config: "VllmConfig") -> None:
296315
cache_config = vllm_config.cache_config
297316
model_config = vllm_config.model_config
298317
parallel_config = vllm_config.parallel_config
299-
compilation_config = vllm_config.compilation_config
300318

301319
if cache_config.cache_dtype == "auto":
302320
kv_cache_dtype = model_config.dtype
@@ -362,11 +380,6 @@ def verify_and_update_config(cls, vllm_config: "VllmConfig") -> None:
362380
"that mamba page size and attention page size are "
363381
"exactly equal.", mamba_padding_pct)
364382

365-
# enable full cuda graphs for decode-only batches
366-
# note (tdoublep): this is currently necessary to
367-
# match V0 performance
368-
compilation_config.full_cuda_graph = True
369-
370383

371384
MODELS_CONFIG_MAP: dict[str, type[VerifyAndUpdateConfig]] = {
372385
"GteModel": SnowflakeGteNewModelConfig,

vllm/model_executor/models/falcon_h1.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,8 @@
3636
from vllm.model_executor.sampling_metadata import SamplingMetadata
3737
from vllm.sequence import IntermediateTensors
3838

39-
from .interfaces import HasInnerState, IsHybrid, SupportsLoRA, SupportsPP
39+
from .interfaces import (HasInnerState, HasMamba2, IsHybrid, SupportsLoRA,
40+
SupportsPP)
4041
from .utils import (PPMissingLayer, is_pp_missing_parameter,
4142
make_empty_intermediate_tensors_factory, make_layers,
4243
maybe_prefix)
@@ -507,7 +508,7 @@ def forward(
507508

508509

509510
class FalconH1ForCausalLM(nn.Module, HasInnerState, SupportsLoRA, SupportsPP,
510-
IsHybrid):
511+
IsHybrid, HasMamba2):
511512
packed_modules_mapping = {
512513
"qkv_proj": ["q_proj", "k_proj", "v_proj"],
513514
"gate_up_proj": ["gate_proj", "up_proj"],

vllm/model_executor/models/granitemoehybrid.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -38,8 +38,8 @@
3838

3939
from .granitemoe import GraniteMoeMoE
4040
from .granitemoeshared import GraniteMoeSharedMLP
41-
from .interfaces import (HasInnerState, IsHybrid, SupportsLoRA, SupportsPP,
42-
SupportsQuant)
41+
from .interfaces import (HasInnerState, HasMamba2, IsHybrid, SupportsLoRA,
42+
SupportsPP, SupportsQuant)
4343
from .utils import (AutoWeightsLoader, is_pp_missing_parameter,
4444
make_empty_intermediate_tensors_factory, make_layers,
4545
maybe_prefix)
@@ -513,7 +513,8 @@ def _load_expert(n, p, name, shard_id, expert_id):
513513

514514

515515
class GraniteMoeHybridForCausalLM(nn.Module, HasInnerState, SupportsLoRA,
516-
SupportsPP, IsHybrid, SupportsQuant):
516+
SupportsPP, IsHybrid, SupportsQuant,
517+
HasMamba2):
517518
packed_modules_mapping = {
518519
"qkv_proj": [
519520
"q_proj",

vllm/model_executor/models/interfaces.py

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -468,6 +468,33 @@ def is_attention_free(
468468
return getattr(model, "is_attention_free", False)
469469

470470

471+
@runtime_checkable
472+
class HasMamba2(Protocol):
473+
"""The interface required for all models like mamba2, bamba, zamba2,
474+
etc., that have mamba2 blocks"""
475+
476+
has_mamba2: ClassVar[Literal[True]] = True
477+
"""
478+
A flag that indicates if the model has mamba2 blocks.
479+
"""
480+
481+
482+
@overload
483+
def has_mamba2(model: object) -> TypeIs[HasMamba2]:
484+
...
485+
486+
487+
@overload
488+
def has_mamba2(model: type[object]) -> TypeIs[type[HasMamba2]]:
489+
...
490+
491+
492+
def has_mamba2(
493+
model: Union[type[object], object]
494+
) -> Union[TypeIs[type[HasMamba2]], TypeIs[HasMamba2]]:
495+
return getattr(model, "has_mamba2", False)
496+
497+
471498
@runtime_checkable
472499
class IsHybrid(Protocol):
473500
"""The interface required for all models like Jamba that have both

vllm/model_executor/models/mamba2.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@
2626
from vllm.model_executor.layers.vocab_parallel_embedding import (
2727
DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding)
2828
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
29-
from vllm.model_executor.models.interfaces import (HasInnerState,
29+
from vllm.model_executor.models.interfaces import (HasInnerState, HasMamba2,
3030
IsAttentionFree)
3131
from vllm.model_executor.models.mamba_cache import (MambaCacheManager,
3232
MambaCacheParams)
@@ -198,7 +198,7 @@ def load_weights(self, weights: Iterable[tuple[str,
198198
return loaded_params
199199

200200

201-
class Mamba2ForCausalLM(nn.Module, HasInnerState, IsAttentionFree):
201+
class Mamba2ForCausalLM(nn.Module, HasInnerState, IsAttentionFree, HasMamba2):
202202

203203
@classmethod
204204
def get_mamba_state_shape_from_config(

vllm/model_executor/models/nemotron_h.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -45,9 +45,9 @@
4545
from vllm.model_executor.layers.vocab_parallel_embedding import (
4646
DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding)
4747
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
48-
from vllm.model_executor.models.interfaces import (HasInnerState, IsHybrid,
49-
SupportsLoRA, SupportsPP,
50-
SupportsQuant)
48+
from vllm.model_executor.models.interfaces import (HasInnerState, HasMamba2,
49+
IsHybrid, SupportsLoRA,
50+
SupportsPP, SupportsQuant)
5151
from vllm.model_executor.models.mamba_cache import (MambaCacheManager,
5252
MambaCacheParams)
5353
from vllm.model_executor.models.utils import (
@@ -446,7 +446,7 @@ def load_weights(self, weights: Iterable[tuple[str,
446446

447447

448448
class NemotronHForCausalLM(nn.Module, HasInnerState, SupportsLoRA, SupportsPP,
449-
IsHybrid, SupportsQuant):
449+
IsHybrid, SupportsQuant, HasMamba2):
450450
packed_modules_mapping = {
451451
"qkv_proj": [
452452
"q_proj",

vllm/model_executor/models/registry.py

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -25,8 +25,8 @@
2525
from vllm.transformers_utils.dynamic_module import (
2626
try_get_class_from_dynamic_module)
2727

28-
from .interfaces import (has_inner_state, has_noops, is_attention_free,
29-
is_hybrid, supports_cross_encoding,
28+
from .interfaces import (has_inner_state, has_mamba2, has_noops,
29+
is_attention_free, is_hybrid, supports_cross_encoding,
3030
supports_multimodal, supports_multimodal_raw_input,
3131
supports_pp, supports_transcription, supports_v0_only)
3232
from .interfaces_base import is_pooling_model, is_text_generation_model
@@ -312,6 +312,7 @@ class _ModelInfo:
312312
has_inner_state: bool
313313
is_attention_free: bool
314314
is_hybrid: bool
315+
has_mamba2: bool
315316
has_noops: bool
316317
supports_transcription: bool
317318
supports_transcription_only: bool
@@ -329,6 +330,7 @@ def from_model_cls(model: type[nn.Module]) -> "_ModelInfo":
329330
supports_pp=supports_pp(model),
330331
has_inner_state=has_inner_state(model),
331332
is_attention_free=is_attention_free(model),
333+
has_mamba2=has_mamba2(model),
332334
is_hybrid=is_hybrid(model),
333335
supports_transcription=supports_transcription(model),
334336
supports_transcription_only=(supports_transcription(model) and
@@ -760,6 +762,14 @@ def is_hybrid_model(
760762
model_cls, _ = self.inspect_model_cls(architectures, model_config)
761763
return model_cls.is_hybrid
762764

765+
def model_has_mamba2(
766+
self,
767+
architectures: Union[str, list[str]],
768+
model_config: ModelConfig,
769+
) -> bool:
770+
model_cls, _ = self.inspect_model_cls(architectures, model_config)
771+
return model_cls.has_mamba2
772+
763773
def is_noops_model(
764774
self,
765775
architectures: Union[str, list[str]],

0 commit comments

Comments
 (0)