Skip to content

Commit 05181cc

Browse files
authored
[Hybrid] Add mamba_block_size to Engine Args (#27289)
Signed-off-by: asafg <[email protected]>
1 parent 259504e commit 05181cc

File tree

3 files changed

+25
-8
lines changed

3 files changed

+25
-8
lines changed

vllm/config/cache.py

Lines changed: 13 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
from dataclasses import field
66
from typing import TYPE_CHECKING, Any, Literal
77

8-
from pydantic import Field, SkipValidation, field_validator
8+
from pydantic import Field, SkipValidation, field_validator, model_validator
99
from pydantic.dataclasses import dataclass
1010

1111
from vllm.config.utils import config
@@ -90,8 +90,10 @@ class CacheConfig:
9090
mamba_page_size_padded: int | None = None
9191
""" Optional override for mamba page size; used by hybrid mamba/attention
9292
models to ensure exact alignment with attention page size."""
93-
mamba_block_size: int | None = None
94-
"""Size of a contiguous cache block in number of tokens for mamba cache."""
93+
mamba_block_size: int | None = Field(default=None, gt=0)
94+
"""Size of a contiguous cache block in number of tokens for mamba cache.
95+
Can be set only when prefix caching is enabled.
96+
Value must be a multiple of 8 to align with causal_conv1d kernel."""
9597
mamba_cache_dtype: MambaDType = "auto"
9698
"""The data type to use for the Mamba cache (both the conv as well as the
9799
ssm state). If set to 'auto', the data type will be inferred from the model
@@ -183,3 +185,11 @@ def verify_with_parallel_config(
183185
raise ValueError("Too large swap space. " + msg)
184186
elif cpu_memory_usage > 0.4 * total_cpu_memory:
185187
logger.warning("Possibly too large swap space. %s", msg)
188+
189+
@model_validator(mode="after")
190+
def validate_mamba_block_size(self) -> "CacheConfig":
191+
if self.mamba_block_size is not None and not self.enable_prefix_caching:
192+
raise ValueError(
193+
"--mamba-block-size can only be set with --enable-prefix-caching"
194+
)
195+
return self

vllm/engine/arg_utils.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -535,6 +535,7 @@ class EngineArgs:
535535
calculate_kv_scales: bool = CacheConfig.calculate_kv_scales
536536
mamba_cache_dtype: MambaDType = CacheConfig.mamba_cache_dtype
537537
mamba_ssm_cache_dtype: MambaDType = CacheConfig.mamba_ssm_cache_dtype
538+
mamba_block_size: int | None = get_field(CacheConfig, "mamba_block_size")
538539

539540
additional_config: dict[str, Any] = get_field(VllmConfig, "additional_config")
540541

@@ -893,6 +894,9 @@ def add_cli_args(parser: FlexibleArgumentParser) -> FlexibleArgumentParser:
893894
cache_group.add_argument(
894895
"--mamba-ssm-cache-dtype", **cache_kwargs["mamba_ssm_cache_dtype"]
895896
)
897+
cache_group.add_argument(
898+
"--mamba-block-size", **cache_kwargs["mamba_block_size"]
899+
)
896900

897901
# Multimodal related configs
898902
multimodal_kwargs = get_kwargs(MultiModalConfig)
@@ -1390,6 +1394,7 @@ def create_engine_config(
13901394
kv_sharing_fast_prefill=self.kv_sharing_fast_prefill,
13911395
mamba_cache_dtype=self.mamba_cache_dtype,
13921396
mamba_ssm_cache_dtype=self.mamba_ssm_cache_dtype,
1397+
mamba_block_size=self.mamba_block_size,
13931398
)
13941399

13951400
ray_runtime_env = None

vllm/model_executor/models/config.py

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -291,9 +291,8 @@ def verify_and_update_config(cls, vllm_config: "VllmConfig") -> None:
291291
model_config = vllm_config.model_config
292292
cache_config = vllm_config.cache_config
293293

294-
# Set mamba block size to max_model_len (this may get
295-
# override by prefix caching logic later)
296-
cache_config.mamba_block_size = model_config.max_model_len
294+
if cache_config.mamba_block_size is None:
295+
cache_config.mamba_block_size = model_config.max_model_len
297296

298297
if cache_config.enable_prefix_caching:
299298
if model_config.supports_mamba_prefix_caching:
@@ -333,6 +332,8 @@ def verify_and_update_config(cls, vllm_config: "VllmConfig") -> None:
333332
if not envs.VLLM_USE_V1:
334333
return
335334

335+
# Save the user input before it gets modified by MambaModelConfig
336+
mamba_block_size = vllm_config.cache_config.mamba_block_size
336337
# Enable FULL_AND_PIECEWISE by default
337338
MambaModelConfig.verify_and_update_config(vllm_config)
338339

@@ -386,7 +387,7 @@ def verify_and_update_config(cls, vllm_config: "VllmConfig") -> None:
386387
# With prefix caching, select attention block size to
387388
# optimize for mamba kernel performance
388389

389-
# mamba SSD kernel uses a chunk_size, e.g. 256
390+
# Mamba2 SSD kernel uses a chunk_size, e.g. 256
390391
# Align the block to the kernel: use lowest multiple of chunk_size
391392
# of attention tokens that would fit mamba_page_size:
392393
# e.g. for mamba page size = 788kB
@@ -404,7 +405,8 @@ def verify_and_update_config(cls, vllm_config: "VllmConfig") -> None:
404405
def lcm(a, b):
405406
return a * b // gcd(a, b)
406407

407-
base_chunk_size = model_config.get_mamba_chunk_size()
408+
base_chunk_size = mamba_block_size or model_config.get_mamba_chunk_size()
409+
408410
attn_tokens_per_mamba_state = cdiv(mamba_page_size, attn_page_size_1_token)
409411

410412
chunk_size = lcm(base_chunk_size, kernel_block_alignment_size)

0 commit comments

Comments
 (0)