From e704df9be9a5163f1cce9aa729ef29704e9f4ec3 Mon Sep 17 00:00:00 2001 From: Jonas Kuebler Date: Mon, 15 Sep 2025 07:16:17 +0000 Subject: [PATCH] enable skipping of SW layers for attn quant Signed-off-by: Jonas Kuebler --- vllm/attention/layer.py | 15 ++- vllm/config/cache.py | 4 + vllm/engine/arg_utils.py | 53 ++++++++++ vllm/model_executor/models/gpt_oss.py | 34 +++++- vllm/v1/attention/backends/flash_attn.py | 48 +++++++-- vllm/v1/core/kv_cache_utils.py | 128 +++++++++++++++++++++-- vllm/v1/worker/gpu_model_runner.py | 21 +++- 7 files changed, 278 insertions(+), 25 deletions(-) diff --git a/vllm/attention/layer.py b/vllm/attention/layer.py index 237802afccde..075e59dcdd3d 100644 --- a/vllm/attention/layer.py +++ b/vllm/attention/layer.py @@ -83,6 +83,7 @@ def __init__( attn_type: str = AttentionType.DECODER, kv_sharing_target_layer_name: Optional[str] = None, attn_backend: Optional[type[AttentionBackend]] = None, + per_layer_kv_cache_dtype: Optional[str] = None, **extra_impl_args, ) -> None: """ @@ -100,10 +101,20 @@ def __init__( sliding_window = None if cache_config is not None: - kv_cache_dtype = cache_config.cache_dtype + # Use per-layer dtype if provided + if per_layer_kv_cache_dtype is not None: + kv_cache_dtype = per_layer_kv_cache_dtype + # For mixed-dtype: only calculate scales for FP8 layers + calculate_kv_scales = per_layer_kv_cache_dtype.startswith( + 'fp8') and cache_config.calculate_kv_scales + else: + # This is a skip layer in mixed-dtype scenario + # Use "auto" to resolve to model dtype + kv_cache_dtype = "auto" + calculate_kv_scales = False + block_size = cache_config.block_size is_attention_free = cache_config.is_attention_free - calculate_kv_scales = cache_config.calculate_kv_scales else: kv_cache_dtype = "auto" block_size = 16 diff --git a/vllm/config/cache.py b/vllm/config/cache.py index bf85aad452d0..7665a54babd5 100644 --- a/vllm/config/cache.py +++ b/vllm/config/cache.py @@ -96,6 +96,10 @@ class CacheConfig: """The data type to use for the Mamba cache (ssm state only, conv state will still be controlled by mamba_cache_dtype). If set to 'auto', the data type for the ssm state will be determined by mamba_cache_dtype.""" + skip_kv_quantization_layers: Optional[list[int | str]] = None + """List of layer IDs to skip KV cache quantization for, or 'sliding_window' + to skip all sliding window layers. These layers will use the model's + default dtype instead of the quantized cache_dtype (e.g., fp8).""" # Will be set after profiling. num_gpu_blocks: Optional[int] = field(default=None, init=False) diff --git a/vllm/engine/arg_utils.py b/vllm/engine/arg_utils.py index 94c984116131..bec4ed04ab6c 100644 --- a/vllm/engine/arg_utils.py +++ b/vllm/engine/arg_utils.py @@ -293,6 +293,7 @@ class EngineArgs: config_format: str = ModelConfig.config_format dtype: ModelDType = ModelConfig.dtype kv_cache_dtype: CacheDType = CacheConfig.cache_dtype + skip_kv_quantization_layers: Optional[str] = None seed: Optional[int] = ModelConfig.seed max_model_len: Optional[int] = ModelConfig.max_model_len cuda_graph_sizes: list[int] = get_field(SchedulerConfig, @@ -751,6 +752,14 @@ def add_cli_args(parser: FlexibleArgumentParser) -> FlexibleArgumentParser: **cache_kwargs["mamba_cache_dtype"]) cache_group.add_argument("--mamba-ssm-cache-dtype", **cache_kwargs["mamba_ssm_cache_dtype"]) + cache_group.add_argument( + "--skip-kv-quantization-layers", + type=str, + default=None, + help=("Comma-separated or dash-separated list of layer indices to " + "skip KV quantization, or 'sliding_window' to skip all " + "sliding window layers. Example: '0,31', '1-6-32', or " + "'sliding_window'")) # Multimodal related configs multimodal_kwargs = get_kwargs(MultiModalConfig) @@ -1081,6 +1090,45 @@ def create_speculative_config( }) return SpeculativeConfig(**self.speculative_config) + def _parse_skip_kv_quantization_layers( + self, layers_str: Optional[str]) -> Optional[list[int]]: + """Parse skip_kv_quantization_layers string to list of integers. + + Supports formats: + - "0,31" (comma-separated) + - "0-31" (dash-separated) + - "0-1-30-31" (multiple dash-separated) + - "sliding_window" (skip all sliding window layers) + """ + if layers_str is None: + return None + + layers_str = layers_str.strip() + if not layers_str: + return None + + # Handle special keyword + if layers_str.lower() == "sliding_window": + # Return special marker - will be resolved later with model info + return ["sliding_window"] + else: + raise NotImplementedError("Currently we only support skipping " + "sliding window layers.") + + # Try comma-separated first, then dash-separated + if ',' in layers_str: + parts = layers_str.split(',') + else: + parts = layers_str.split('-') + + try: + return [int(part.strip()) for part in parts if part.strip()] + except ValueError as e: + raise ValueError( + f"Invalid skip_kv_quantization_layers format: '{layers_str}'. " + f"Expected comma-separated (0,31), dash-separated (0-31) " + f"integers, or 'sliding_window'. Error: {e}") from e + def create_engine_config( self, usage_context: Optional[UsageContext] = None, @@ -1167,6 +1215,10 @@ def create_engine_config( f"dcp_size={self.decode_context_parallel_size}." ) + # Parse skip_kv_quantization_layers string to list of integers + parsed_skip_layers = self._parse_skip_kv_quantization_layers( + self.skip_kv_quantization_layers) + cache_config = CacheConfig( block_size=self.block_size, gpu_memory_utilization=self.gpu_memory_utilization, @@ -1180,6 +1232,7 @@ def create_engine_config( cpu_offload_gb=self.cpu_offload_gb, calculate_kv_scales=self.calculate_kv_scales, kv_sharing_fast_prefill=self.kv_sharing_fast_prefill, + skip_kv_quantization_layers=parsed_skip_layers, mamba_cache_dtype=self.mamba_cache_dtype, mamba_ssm_cache_dtype=self.mamba_ssm_cache_dtype, ) diff --git a/vllm/model_executor/models/gpt_oss.py b/vllm/model_executor/models/gpt_oss.py index e0b4df772875..7a137c1b34b4 100644 --- a/vllm/model_executor/models/gpt_oss.py +++ b/vllm/model_executor/models/gpt_oss.py @@ -43,6 +43,7 @@ def __init__( quant_config: Optional[QuantizationConfig] = None, cache_config: Optional[CacheConfig] = None, prefix: str = "", + per_layer_kv_cache_dtype: Optional[str] = None, ): super().__init__() self.layer_idx = extract_layer_index(prefix) @@ -117,6 +118,7 @@ def __init__( attn_type=AttentionType.DECODER, prefix=f"{prefix}.attn", sinks=self.sinks, + per_layer_kv_cache_dtype=per_layer_kv_cache_dtype, ) def forward(self, hidden_states: torch.Tensor, @@ -177,9 +179,23 @@ def __init__( ): super().__init__() self.layer_idx = extract_layer_index(prefix) - self.attn = OAIAttention(config, - prefix=f"{prefix}.attn", - cache_config=cache_config) + + # Determine per-layer KV cache dtype for mixed-dtype support + per_layer_kv_cache_dtype = None + if cache_config is not None and self.layer_idx is not None: + skip_layers = cache_config.skip_kv_quantization_layers + + # Simple check - sliding_window is already resolved to layer indices + if skip_layers and self.layer_idx in skip_layers: + per_layer_kv_cache_dtype = None + else: + per_layer_kv_cache_dtype = cache_config.cache_dtype + + self.attn = OAIAttention( + config, + prefix=f"{prefix}.attn", + cache_config=cache_config, + per_layer_kv_cache_dtype=per_layer_kv_cache_dtype) self.mlp = MLPBlock(config, self.layer_idx, quant_config=quant_config, @@ -227,6 +243,18 @@ def __init__( self.config.vocab_size, self.config.hidden_size, ) + + # Resolve "sliding_window" keyword for GPT-OSS + if (self.cache_config.skip_kv_quantization_layers + and len(self.cache_config.skip_kv_quantization_layers) == 1 + and self.cache_config.skip_kv_quantization_layers[0] + == "sliding_window"): + # Replace with actual sliding window layer indices (even layers) + sliding_window_layers = [ + i for i in range(self.config.num_hidden_layers) if i % 2 == 0 + ] + self.cache_config.skip_kv_quantization_layers = ( + sliding_window_layers) self.start_layer, self.end_layer, self.layers = make_layers( self.config.num_hidden_layers, lambda prefix: TransformerBlock( diff --git a/vllm/v1/attention/backends/flash_attn.py b/vllm/v1/attention/backends/flash_attn.py index 3cc67acd04c6..25813fe84886 100755 --- a/vllm/v1/attention/backends/flash_attn.py +++ b/vllm/v1/attention/backends/flash_attn.py @@ -529,6 +529,16 @@ def forward( descale_shape = (cu_seqlens_q.shape[0] - 1, self.num_kv_heads) + # Only pass scale parameters for FP8 layers + if self.kv_cache_dtype.startswith("fp8"): + q_descale = layer._q_scale.expand(descale_shape) + k_descale = layer._k_scale.expand(descale_shape) + v_descale = layer._v_scale.expand(descale_shape) + else: + q_descale = None + k_descale = None + v_descale = None + flash_attn_varlen_func( q=query[:num_actual_tokens], k=key_cache, @@ -546,15 +556,25 @@ def forward( softcap=self.logits_soft_cap, scheduler_metadata=scheduler_metadata, fa_version=self.vllm_flash_attn_version, - q_descale=layer._q_scale.expand(descale_shape), - k_descale=layer._k_scale.expand(descale_shape), - v_descale=layer._v_scale.expand(descale_shape), + q_descale=q_descale, + k_descale=k_descale, + v_descale=v_descale, num_splits=attn_metadata.max_num_splits, s_aux=self.sinks, ) return output # Cascade attention (rare case). + # Only pass scale parameters for FP8 layers + if self.kv_cache_dtype.startswith("fp8"): + q_descale = layer._q_scale + k_descale = layer._k_scale + v_descale = layer._v_scale + else: + q_descale = None + k_descale = None + v_descale = None + cascade_attention( output[:num_actual_tokens], query[:num_actual_tokens], @@ -575,9 +595,9 @@ def forward( fa_version=self.vllm_flash_attn_version, prefix_scheduler_metadata=attn_metadata.prefix_scheduler_metadata, suffix_scheduler_metadata=attn_metadata.scheduler_metadata, - q_descale=layer._q_scale, - k_descale=layer._k_scale, - v_descale=layer._v_scale, + q_descale=q_descale, + k_descale=k_descale, + v_descale=v_descale, ) return output @@ -615,6 +635,16 @@ def _forward_encoder_attention( cu_seqlens_q.shape[0] - 1, # type: ignore[union-attr] self.num_kv_heads) + # Only pass scale parameters for FP8 layers + if self.kv_cache_dtype.startswith("fp8"): + q_descale = layer._q_scale.expand(descale_shape) + k_descale = layer._k_scale.expand(descale_shape) + v_descale = layer._v_scale.expand(descale_shape) + else: + q_descale = None + k_descale = None + v_descale = None + # Call flash attention directly on Q, K, V tensors flash_attn_varlen_func( q=query, @@ -631,9 +661,9 @@ def _forward_encoder_attention( window_size=self.sliding_window, softcap=self.logits_soft_cap, fa_version=self.vllm_flash_attn_version, - q_descale=layer._q_scale.expand(descale_shape), - k_descale=layer._k_scale.expand(descale_shape), - v_descale=layer._v_scale.expand(descale_shape), + q_descale=q_descale, + k_descale=k_descale, + v_descale=v_descale, ) return output diff --git a/vllm/v1/core/kv_cache_utils.py b/vllm/v1/core/kv_cache_utils.py index 2c0eac3ddd79..ec4bd8cab857 100644 --- a/vllm/v1/core/kv_cache_utils.py +++ b/vllm/v1/core/kv_cache_utils.py @@ -8,6 +8,8 @@ from dataclasses import astuple, dataclass from typing import Any, Callable, NewType, Optional, Union +import torch + from vllm import envs from vllm.config import VllmConfig from vllm.logger import init_logger @@ -744,6 +746,89 @@ def create_kv_cache_group_specs( return kv_cache_groups +def is_kv_cache_dtype_uniform(kv_cache_spec: dict[str, KVCacheSpec]) -> bool: + """ + Check if all layers use the same dtype for KV cache. + + Args: + kv_cache_spec: The kv cache spec of each attention layer in the model + + Returns: + True if all layers use the same dtype, False otherwise + """ + dtypes = set() + for spec in kv_cache_spec.values(): + dtype = getattr(spec, 'dtype', None) + dtypes.add(dtype) + + return len(dtypes) <= 1 + + +def _get_kv_cache_config_mixed_dtype(vllm_config: VllmConfig, + kv_cache_spec: dict[str, KVCacheSpec], + available_memory: int) -> KVCacheConfig: + """ + Handle mixed-dtype KV cache for GPT-OSS with sliding window skipping. + + Sliding window layers: BF16 (skipped quantization) + Full attention layers: FP8 (quantized) + + Adjusts block_size to achieve uniform page sizes, then uses standard + hybrid attention framework. + """ + # Calculate max kv_hidden_size (BF16 will be larger than FP8) + max_kv_hidden_size = 0 + for spec in kv_cache_spec.values(): + if isinstance(spec, (FullAttentionSpec, SlidingWindowSpec)): + dtype_size = 2 if spec.dtype == torch.bfloat16 else 1 + kv_hidden_size = 2 * spec.num_kv_heads * spec.head_size * dtype_size + max_kv_hidden_size = max(max_kv_hidden_size, kv_hidden_size) + + # Adjust block_size for FP8 layers to match BF16 page size + adjusted_kv_cache_spec: dict[str, KVCacheSpec] = {} + for layer_name, spec in kv_cache_spec.items(): + if isinstance(spec, (FullAttentionSpec, SlidingWindowSpec)): + dtype_size = 2 if spec.dtype == torch.bfloat16 else 1 + current_kv_hidden_size = (2 * spec.num_kv_heads * spec.head_size * + dtype_size) + + # Calculate multiplier (will be 1 for BF16, >1 for FP8) + block_size_multiplier = max_kv_hidden_size // current_kv_hidden_size + adjusted_block_size = spec.block_size * block_size_multiplier + + # Create adjusted spec with new block_size + adjusted_spec: KVCacheSpec + if isinstance(spec, FullAttentionSpec): + adjusted_spec = FullAttentionSpec( + block_size=adjusted_block_size, + num_kv_heads=spec.num_kv_heads, + head_size=spec.head_size, + dtype=spec.dtype, + use_mla=spec.use_mla, + sliding_window=spec.sliding_window, + ) + else: # SlidingWindowSpec + adjusted_spec = SlidingWindowSpec( + block_size=adjusted_block_size, + num_kv_heads=spec.num_kv_heads, + head_size=spec.head_size, + dtype=spec.dtype, + use_mla=spec.use_mla, + sliding_window=spec.sliding_window, + ) + + adjusted_kv_cache_spec[layer_name] = adjusted_spec + else: + raise NotImplementedError( + f"Mixed-dtype KV cache does not support spec type: {type(spec)}" + ) + + # Use standard hybrid attention with adjusted specs + return _get_kv_cache_config_uniform_page_size(vllm_config, + adjusted_kv_cache_spec, + available_memory) + + def is_kv_cache_type_uniform(kv_cache_spec: dict[str, KVCacheSpec]) -> bool: """ Whether all layers in the given KVCacheSpec have the same KV cache spec. @@ -1116,11 +1201,15 @@ def get_kv_cache_config( # to allow for the KVCache manager to handle attention free models. return _get_kv_cache_config_attention_free() elif is_kv_cache_type_uniform(kv_cache_spec): - # KV cache of all layers are the same, which is true for - # most models. Allocate the same amount of memory for - # each layer. - return _get_kv_cache_config_uniform_type(vllm_config, kv_cache_spec, - available_memory) + # Check if it's uniform due to mixed-dtype support + if not is_kv_cache_dtype_uniform(kv_cache_spec): + # Mixed-dtype scenario - use specialized handler + return _get_kv_cache_config_mixed_dtype(vllm_config, kv_cache_spec, + available_memory) + else: + return _get_kv_cache_config_uniform_type(vllm_config, + kv_cache_spec, + available_memory) elif is_kv_cache_page_size_uniform(kv_cache_spec): # Model contains multiple attention types, but KV cache of all layers # have the same physical memory per block per layer. Split the layers @@ -1129,10 +1218,32 @@ def get_kv_cache_config( return _get_kv_cache_config_uniform_page_size(vllm_config, kv_cache_spec, available_memory) + elif not is_kv_cache_dtype_uniform(kv_cache_spec): + # Mixed-dtype scenario with different attention types + # Fall back to mixed-dtype handler + return _get_kv_cache_config_mixed_dtype(vllm_config, kv_cache_spec, + available_memory) raise NotImplementedError +def _get_sortable_spec_key(kv_cache_spec): + """ + Create a sortable key from a KVCacheSpec, handling torch dtypes properly. + """ + spec_tuple = astuple(kv_cache_spec) + sortable_tuple = [] + + for item in spec_tuple: + if isinstance(item, torch.dtype): + # Convert torch dtype to string for sorting + sortable_tuple.append(str(item)) + else: + sortable_tuple.append(item) + + return tuple(sortable_tuple) + + def unify_kv_cache_configs(kv_cache_configs: list[KVCacheConfig]): """ Make the KV cache configurations for each worker consistent, so that all @@ -1144,12 +1255,15 @@ def unify_kv_cache_configs(kv_cache_configs: list[KVCacheConfig]): kv_cache_configs: The KV cache configurations for each worker. Will be in-place modified to make them consistent. """ + if len(kv_cache_configs) == 0: + return # Sort the kv cache groups by their KV cache spec. # This can avoid the inconsistency caused by the order of groups. for kv_cache_config in kv_cache_configs: - kv_cache_config.kv_cache_groups.sort(key=lambda x: (type( - x.kv_cache_spec).__name__, astuple(x.kv_cache_spec))) + kv_cache_config.kv_cache_groups.sort( + key=lambda x: (type(x.kv_cache_spec).__name__, + _get_sortable_spec_key(x.kv_cache_spec))) # Verify that the groups of each rank are the same. for kv_cache_config in kv_cache_configs[1:]: diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index 549c5dd2bbb2..e33c009a0b62 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -3452,7 +3452,8 @@ def get_kv_cache_spec(self) -> dict[str, KVCacheSpec]: use_mla = self.vllm_config.model_config.use_mla kv_cache_spec: dict[str, KVCacheSpec] = {} attn_layers = get_layers_from_vllm_config(self.vllm_config, Attention) - for layer_name, attn_module in attn_layers.items(): + for layer_idx, (layer_name, + attn_module) in enumerate(attn_layers.items()): if (kv_tgt_layer := attn_module.kv_sharing_target_layer_name) is not None: # The layer doesn't need its own KV cache and will use that of @@ -3474,7 +3475,11 @@ def get_kv_cache_spec(self) -> dict[str, KVCacheSpec]: block_size=block_size, num_kv_heads=attn_module.num_kv_heads, head_size=attn_module.head_size, - dtype=self.kv_cache_dtype, + dtype=(self.dtype + if self.cache_config.skip_kv_quantization_layers + and layer_idx + in self.cache_config.skip_kv_quantization_layers + else self.kv_cache_dtype), sliding_window=attn_module.sliding_window, use_mla=use_mla) elif self.attention_chunk_size is not None \ @@ -3483,7 +3488,11 @@ def get_kv_cache_spec(self) -> dict[str, KVCacheSpec]: block_size=block_size, num_kv_heads=attn_module.num_kv_heads, head_size=attn_module.head_size, - dtype=self.kv_cache_dtype, + dtype=(self.dtype + if self.cache_config.skip_kv_quantization_layers + and layer_idx + in self.cache_config.skip_kv_quantization_layers + else self.kv_cache_dtype), attention_chunk_size=self.attention_chunk_size, use_mla=use_mla) else: @@ -3491,7 +3500,11 @@ def get_kv_cache_spec(self) -> dict[str, KVCacheSpec]: block_size=block_size, num_kv_heads=attn_module.num_kv_heads, head_size=attn_module.head_size, - dtype=self.kv_cache_dtype, + dtype=(self.dtype + if self.cache_config.skip_kv_quantization_layers + and layer_idx + in self.cache_config.skip_kv_quantization_layers + else self.kv_cache_dtype), use_mla=use_mla) elif attn_module.attn_type in (AttentionType.ENCODER, AttentionType.ENCODER_ONLY):