diff --git a/vllm/config/cache.py b/vllm/config/cache.py index 3d2496b7f21d..64be80ce3188 100644 --- a/vllm/config/cache.py +++ b/vllm/config/cache.py @@ -23,7 +23,7 @@ BlockSize = Literal[1, 8, 16, 32, 64, 128] CacheDType = Literal["auto", "bfloat16", "fp8", "fp8_e4m3", "fp8_e5m2", "fp8_inc"] -MambaDType = Literal["auto", "float32"] +MambaDType = Literal["auto", "float32", "fp8", "fp8_e4m3"] PrefixCachingHashAlgo = Literal["sha256", "sha256_cbor"] diff --git a/vllm/model_executor/layers/mamba/mamba_mixer2.py b/vllm/model_executor/layers/mamba/mamba_mixer2.py index 7589905ac927..d06195dab3ce 100644 --- a/vllm/model_executor/layers/mamba/mamba_mixer2.py +++ b/vllm/model_executor/layers/mamba/mamba_mixer2.py @@ -46,6 +46,7 @@ sharded_weight_loader, ) from vllm.model_executor.utils import set_weight_attrs +from vllm.platforms import current_platform from vllm.utils import direct_register_custom_op from vllm.v1.attention.backends.mamba2_attn import Mamba2AttentionMetadata @@ -465,7 +466,7 @@ def __init__( compilation_config.static_forward_context[prefix] = self # The tuple is (conv_state, ssm_state) self.kv_cache = (torch.tensor([]), torch.tensor([])) - + self.fp8_dtype = current_platform.fp8_dtype() self.model_config = model_config self.cache_config = cache_config self.prefix = prefix @@ -514,7 +515,10 @@ def forward_cuda( self_kv_cache = self.kv_cache[forward_context.virtual_engine] # conv_state = (..., dim, width-1) yet contiguous along 'dim' conv_state = self_kv_cache[0].transpose(-1, -2) - ssm_state = self_kv_cache[1] + if self.cache_config.mamba_ssm_cache_dtype.startswith("fp8"): + ssm_state = self_kv_cache[1].view(self.fp8_dtype) + else: + ssm_state = self_kv_cache[1] state_indices_tensor = attn_metadata.state_indices_tensor has_initial_states_p = attn_metadata.has_initial_states_p prep_initial_states = attn_metadata.prep_initial_states @@ -689,6 +693,9 @@ def forward_cuda( 0, ) + # TODO: add fp8 dequantization logic here when loading + # ssm state. Should load scales tensors if available + # NOTE: final output is an in-place update of out tensor varlen_states = mamba_chunk_scan_combined_varlen( hidden_states_p.view( @@ -791,6 +798,9 @@ def forward_cuda( # tensor ssm_state[state_indices_tensor_p] = varlen_states + # TODO: Add fp8 quantization logic here for storing back + # ssm state. Should also update scales if dynamic + # Process decode requests if has_decode: if prefix_caching_enabled: diff --git a/vllm/model_executor/layers/mamba/mamba_utils.py b/vllm/model_executor/layers/mamba/mamba_utils.py index 21c36617a872..5e323f4d6fa6 100644 --- a/vllm/model_executor/layers/mamba/mamba_utils.py +++ b/vllm/model_executor/layers/mamba/mamba_utils.py @@ -51,6 +51,8 @@ def _mamba_state_dtype( mamba_cache_dtype: MambaDType, mamba_ssm_cache_dtype: MambaDType, ) -> tuple[torch.dtype, ...]: + if mamba_cache_dtype.startswith("fp8"): + raise ValueError("fp8 mamba conv state is not supported") conv_state_dtype = get_kv_cache_torch_dtype(mamba_cache_dtype, model_dtype) if mamba_ssm_cache_dtype == "auto": temporal_state_dtype = conv_state_dtype