Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 13 additions & 2 deletions vllm/attention/layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,7 @@ def __init__(
attn_type: str = AttentionType.DECODER,
kv_sharing_target_layer_name: Optional[str] = None,
attn_backend: Optional[type[AttentionBackend]] = None,
per_layer_kv_cache_dtype: Optional[str] = None,
**extra_impl_args,
) -> None:
"""
Expand All @@ -109,10 +110,20 @@ def __init__(
sliding_window = None

if cache_config is not None:
kv_cache_dtype = cache_config.cache_dtype
# Use per-layer dtype if provided
if per_layer_kv_cache_dtype is not None:
kv_cache_dtype = per_layer_kv_cache_dtype
# For mixed-dtype: only calculate scales for FP8 layers
calculate_kv_scales = per_layer_kv_cache_dtype.startswith(
'fp8') and cache_config.calculate_kv_scales
else:
# This is a skip layer in mixed-dtype scenario
# Use "auto" to resolve to model dtype
kv_cache_dtype = "auto"
calculate_kv_scales = False

block_size = cache_config.block_size
is_attention_free = cache_config.is_attention_free
calculate_kv_scales = cache_config.calculate_kv_scales
else:
kv_cache_dtype = "auto"
block_size = 16
Expand Down
4 changes: 4 additions & 0 deletions vllm/config/cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,6 +96,10 @@ class CacheConfig:
"""The data type to use for the Mamba cache (ssm state only, conv state will
still be controlled by mamba_cache_dtype). If set to 'auto', the data type
for the ssm state will be determined by mamba_cache_dtype."""
skip_kv_quantization_layers: Optional[list[int | str]] = None
"""List of layer IDs to skip KV cache quantization for, or 'sliding_window'
to skip all sliding window layers. These layers will use the model's
default dtype instead of the quantized cache_dtype (e.g., fp8)."""

# Will be set after profiling.
num_gpu_blocks: Optional[int] = field(default=None, init=False)
Expand Down
53 changes: 53 additions & 0 deletions vllm/engine/arg_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -303,6 +303,7 @@
config_format: str = ModelConfig.config_format
dtype: ModelDType = ModelConfig.dtype
kv_cache_dtype: CacheDType = CacheConfig.cache_dtype
skip_kv_quantization_layers: Optional[str] = None
seed: Optional[int] = ModelConfig.seed
max_model_len: Optional[int] = ModelConfig.max_model_len
cuda_graph_sizes: list[int] = get_field(SchedulerConfig,
Expand Down Expand Up @@ -782,6 +783,14 @@
**cache_kwargs["mamba_cache_dtype"])
cache_group.add_argument("--mamba-ssm-cache-dtype",
**cache_kwargs["mamba_ssm_cache_dtype"])
cache_group.add_argument(
"--skip-kv-quantization-layers",
type=str,
default=None,
help=("Comma-separated or dash-separated list of layer indices to "
"skip KV quantization, or 'sliding_window' to skip all "
"sliding window layers. Example: '0,31', '1-6-32', or "
"'sliding_window'"))

# Multimodal related configs
multimodal_kwargs = get_kwargs(MultiModalConfig)
Expand Down Expand Up @@ -1123,6 +1132,45 @@
})
return SpeculativeConfig(**self.speculative_config)

def _parse_skip_kv_quantization_layers(
self, layers_str: Optional[str]) -> Optional[list[int]]:
"""Parse skip_kv_quantization_layers string to list of integers.

Supports formats:
- "0,31" (comma-separated)
- "0-31" (dash-separated)
- "0-1-30-31" (multiple dash-separated)
- "sliding_window" (skip all sliding window layers)
"""
if layers_str is None:
return None

layers_str = layers_str.strip()
if not layers_str:
return None

# Handle special keyword
if layers_str.lower() == "sliding_window":
# Return special marker - will be resolved later with model info
return ["sliding_window"]

Check failure on line 1155 in vllm/engine/arg_utils.py

View workflow job for this annotation

GitHub Actions / pre-commit

List item 0 has incompatible type "str"; expected "int" [list-item]

Check failure on line 1155 in vllm/engine/arg_utils.py

View workflow job for this annotation

GitHub Actions / pre-commit

List item 0 has incompatible type "str"; expected "int" [list-item]

Check failure on line 1155 in vllm/engine/arg_utils.py

View workflow job for this annotation

GitHub Actions / pre-commit

List item 0 has incompatible type "str"; expected "int" [list-item]

Check failure on line 1155 in vllm/engine/arg_utils.py

View workflow job for this annotation

GitHub Actions / pre-commit

List item 0 has incompatible type "str"; expected "int" [list-item]

Check failure on line 1155 in vllm/engine/arg_utils.py

View workflow job for this annotation

GitHub Actions / pre-commit

List item 0 has incompatible type "str"; expected "int" [list-item]

Check failure on line 1155 in vllm/engine/arg_utils.py

View workflow job for this annotation

GitHub Actions / pre-commit

List item 0 has incompatible type "str"; expected "int" [list-item]

Check failure on line 1155 in vllm/engine/arg_utils.py

View workflow job for this annotation

GitHub Actions / pre-commit

List item 0 has incompatible type "str"; expected "int" [list-item]

Check failure on line 1155 in vllm/engine/arg_utils.py

View workflow job for this annotation

GitHub Actions / pre-commit

List item 0 has incompatible type "str"; expected "int" [list-item]

Check failure on line 1155 in vllm/engine/arg_utils.py

View workflow job for this annotation

GitHub Actions / pre-commit

List item 0 has incompatible type "str"; expected "int" [list-item]

Check failure on line 1155 in vllm/engine/arg_utils.py

View workflow job for this annotation

GitHub Actions / pre-commit

List item 0 has incompatible type "str"; expected "int" [list-item]
else:
raise NotImplementedError("Currently we only support skipping "
"sliding window layers.")

# Try comma-separated first, then dash-separated
if ',' in layers_str:
parts = layers_str.split(',')
else:
parts = layers_str.split('-')

try:
return [int(part.strip()) for part in parts if part.strip()]
except ValueError as e:
raise ValueError(
f"Invalid skip_kv_quantization_layers format: '{layers_str}'. "
f"Expected comma-separated (0,31), dash-separated (0-31) "
f"integers, or 'sliding_window'. Error: {e}") from e

def create_engine_config(
self,
usage_context: Optional[UsageContext] = None,
Expand Down Expand Up @@ -1209,6 +1257,10 @@
f"dcp_size={self.decode_context_parallel_size}."
)

# Parse skip_kv_quantization_layers string to list of integers
parsed_skip_layers = self._parse_skip_kv_quantization_layers(
self.skip_kv_quantization_layers)

cache_config = CacheConfig(
block_size=self.block_size,
gpu_memory_utilization=self.gpu_memory_utilization,
Expand All @@ -1223,6 +1275,7 @@
cpu_offload_gb=self.cpu_offload_gb,
calculate_kv_scales=self.calculate_kv_scales,
kv_sharing_fast_prefill=self.kv_sharing_fast_prefill,
skip_kv_quantization_layers=parsed_skip_layers,
mamba_cache_dtype=self.mamba_cache_dtype,
mamba_ssm_cache_dtype=self.mamba_ssm_cache_dtype,
)
Expand Down
34 changes: 31 additions & 3 deletions vllm/model_executor/models/gpt_oss.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@ def __init__(
quant_config: Optional[QuantizationConfig] = None,
cache_config: Optional[CacheConfig] = None,
prefix: str = "",
per_layer_kv_cache_dtype: Optional[str] = None,
):
super().__init__()
self.layer_idx = extract_layer_index(prefix)
Expand Down Expand Up @@ -117,6 +118,7 @@ def __init__(
attn_type=AttentionType.DECODER,
prefix=f"{prefix}.attn",
sinks=self.sinks,
per_layer_kv_cache_dtype=per_layer_kv_cache_dtype,
)

def forward(self, hidden_states: torch.Tensor,
Expand Down Expand Up @@ -177,9 +179,23 @@ def __init__(
):
super().__init__()
self.layer_idx = extract_layer_index(prefix)
self.attn = OAIAttention(config,
prefix=f"{prefix}.attn",
cache_config=cache_config)

# Determine per-layer KV cache dtype for mixed-dtype support
per_layer_kv_cache_dtype = None
if cache_config is not None and self.layer_idx is not None:
skip_layers = cache_config.skip_kv_quantization_layers

# Simple check - sliding_window is already resolved to layer indices
if skip_layers and self.layer_idx in skip_layers:
per_layer_kv_cache_dtype = None
else:
per_layer_kv_cache_dtype = cache_config.cache_dtype

self.attn = OAIAttention(
config,
prefix=f"{prefix}.attn",
cache_config=cache_config,
per_layer_kv_cache_dtype=per_layer_kv_cache_dtype)
self.mlp = MLPBlock(config,
self.layer_idx,
quant_config=quant_config,
Expand Down Expand Up @@ -227,6 +243,18 @@ def __init__(
self.config.vocab_size,
self.config.hidden_size,
)

# Resolve "sliding_window" keyword for GPT-OSS
if (self.cache_config.skip_kv_quantization_layers
and len(self.cache_config.skip_kv_quantization_layers) == 1
and self.cache_config.skip_kv_quantization_layers[0]
== "sliding_window"):
# Replace with actual sliding window layer indices (even layers)
sliding_window_layers = [
i for i in range(self.config.num_hidden_layers) if i % 2 == 0
]
self.cache_config.skip_kv_quantization_layers = (
sliding_window_layers)
self.start_layer, self.end_layer, self.layers = make_layers(
self.config.num_hidden_layers,
lambda prefix: TransformerBlock(
Expand Down
48 changes: 39 additions & 9 deletions vllm/v1/attention/backends/flash_attn.py
Original file line number Diff line number Diff line change
Expand Up @@ -528,6 +528,16 @@ def forward(

descale_shape = (cu_seqlens_q.shape[0] - 1, self.num_kv_heads)

# Only pass scale parameters for FP8 layers
if self.kv_cache_dtype.startswith("fp8"):
q_descale = layer._q_scale.expand(descale_shape)
k_descale = layer._k_scale.expand(descale_shape)
v_descale = layer._v_scale.expand(descale_shape)
else:
q_descale = None
k_descale = None
v_descale = None

flash_attn_varlen_func(
q=query[:num_actual_tokens],
k=key_cache,
Expand All @@ -545,15 +555,25 @@ def forward(
softcap=self.logits_soft_cap,
scheduler_metadata=scheduler_metadata,
fa_version=self.vllm_flash_attn_version,
q_descale=layer._q_scale.expand(descale_shape),
k_descale=layer._k_scale.expand(descale_shape),
v_descale=layer._v_scale.expand(descale_shape),
q_descale=q_descale,
k_descale=k_descale,
v_descale=v_descale,
num_splits=attn_metadata.max_num_splits,
s_aux=self.sinks,
)
return output

# Cascade attention (rare case).
# Only pass scale parameters for FP8 layers
if self.kv_cache_dtype.startswith("fp8"):
q_descale = layer._q_scale
k_descale = layer._k_scale
v_descale = layer._v_scale
else:
q_descale = None
k_descale = None
v_descale = None

cascade_attention(
output[:num_actual_tokens],
query[:num_actual_tokens],
Expand All @@ -574,9 +594,9 @@ def forward(
fa_version=self.vllm_flash_attn_version,
prefix_scheduler_metadata=attn_metadata.prefix_scheduler_metadata,
suffix_scheduler_metadata=attn_metadata.scheduler_metadata,
q_descale=layer._q_scale,
k_descale=layer._k_scale,
v_descale=layer._v_scale,
q_descale=q_descale,
k_descale=k_descale,
v_descale=v_descale,
)
return output

Expand Down Expand Up @@ -614,6 +634,16 @@ def _forward_encoder_attention(
cu_seqlens_q.shape[0] - 1, # type: ignore[union-attr]
self.num_kv_heads)

# Only pass scale parameters for FP8 layers
if self.kv_cache_dtype.startswith("fp8"):
q_descale = layer._q_scale.expand(descale_shape)
k_descale = layer._k_scale.expand(descale_shape)
v_descale = layer._v_scale.expand(descale_shape)
else:
q_descale = None
k_descale = None
v_descale = None

# Call flash attention directly on Q, K, V tensors
flash_attn_varlen_func(
q=query,
Expand All @@ -630,9 +660,9 @@ def _forward_encoder_attention(
window_size=self.sliding_window,
softcap=self.logits_soft_cap,
fa_version=self.vllm_flash_attn_version,
q_descale=layer._q_scale.expand(descale_shape),
k_descale=layer._k_scale.expand(descale_shape),
v_descale=layer._v_scale.expand(descale_shape),
q_descale=q_descale,
k_descale=k_descale,
v_descale=v_descale,
)

return output
Expand Down
Loading