From 7c1933e494b1d8c0e117e9282e244748ae15aca7 Mon Sep 17 00:00:00 2001 From: Yong Hoon Shin Date: Thu, 7 Aug 2025 11:49:21 -0700 Subject: [PATCH 1/2] Enable fast prefill for Gemma3n Signed-off-by: Yong Hoon Shin --- tests/v1/e2e/test_kv_sharing_fast_prefill.py | 55 --- .../layers/chunked_local_attention.py | 44 +- vllm/compilation/decorators.py | 22 +- vllm/config/cache.py | 7 + vllm/model_executor/models/gemma3n.py | 416 +++++++++++++++--- vllm/v1/attention/backends/utils.py | 141 +++++- vllm/v1/spec_decode/eagle.py | 1 + vllm/v1/worker/gpu_model_runner.py | 141 ++++-- vllm/v1/worker/tpu_model_runner.py | 39 +- vllm/v1/worker/utils.py | 30 +- 10 files changed, 657 insertions(+), 239 deletions(-) diff --git a/tests/v1/e2e/test_kv_sharing_fast_prefill.py b/tests/v1/e2e/test_kv_sharing_fast_prefill.py index f5a7b9cc276b..7bc7f44dd7ab 100644 --- a/tests/v1/e2e/test_kv_sharing_fast_prefill.py +++ b/tests/v1/e2e/test_kv_sharing_fast_prefill.py @@ -2,7 +2,6 @@ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project import random -from typing import Optional, Union import pytest import torch @@ -10,11 +9,6 @@ from vllm import LLM, SamplingParams from vllm.config import CompilationConfig, CompilationLevel from vllm.distributed import cleanup_dist_env_and_memory -from vllm.forward_context import get_forward_context -from vllm.model_executor.models.gemma3n import Gemma3nForConditionalGeneration -from vllm.model_executor.models.registry import ModelRegistry -from vllm.model_executor.models.utils import extract_layer_index -from vllm.sequence import IntermediateTensors from ...utils import fork_new_process_for_each_test @@ -22,53 +16,6 @@ SEED = 42 -class TestGemma3nForConditionalGeneration(Gemma3nForConditionalGeneration): - - def forward( - self, - input_ids: torch.Tensor, - positions: torch.Tensor, - intermediate_tensors: Optional[IntermediateTensors] = None, - inputs_embeds: Optional[torch.Tensor] = None, - **kwargs, - ) -> Union[torch.Tensor, IntermediateTensors]: - hidden_states = self.model(input_ids, positions, intermediate_tensors, - inputs_embeds, **kwargs) - attn_metadata = get_forward_context().attn_metadata - # attn_metadata is None during dummy runs - if (attn_metadata is not None - and self.cache_config.kv_sharing_fast_prefill): - assert isinstance(attn_metadata, dict) # true in V1 - # Gemma3n-E2B has 30 layers, with last 20 layers being - # cross-decoder layers. Check attention metadata is correct - for layer_name, metadata in attn_metadata.items(): - layer_idx = extract_layer_index(layer_name) - if layer_idx >= 20: - assert hasattr(metadata, 'logits_indices_padded') - assert hasattr(metadata, 'num_logits_indices') - else: - assert not hasattr(metadata, 'logits_indices_padded') - assert not hasattr(metadata, 'num_logits_indices') - - # Last layer will be a KV sharing layer - layer_attn_metadata = attn_metadata[ - self.model.language_model.layers[-1].self_attn.attn.layer_name] - logits_indices_padded = (layer_attn_metadata.logits_indices_padded) - assert logits_indices_padded is not None - num_logits_indices = layer_attn_metadata.num_logits_indices - assert num_logits_indices > 0 - # Reset hidden states to random values and - # only set logits at logits_indices to valid values - # Because logits_indices are the only positions that are used - # for output token sampling, this still produces same outputs - logits_hs = hidden_states[logits_indices_padded] - hidden_states = torch.randn_like(hidden_states) - gen_indices = logits_indices_padded[:num_logits_indices] - hidden_states[gen_indices] = logits_hs[:num_logits_indices] - - return hidden_states - - @pytest.fixture def test_prompts(): """ @@ -122,8 +69,6 @@ def test_kv_sharing_fast_prefill( enforce_eager: bool, test_prompts: list[str], ): - ModelRegistry.register_model("Gemma3nForConditionalGeneration", - TestGemma3nForConditionalGeneration) sampling_params = SamplingParams(temperature=0.0, max_tokens=100) compilation_config = CompilationConfig( # This allows vLLM compilation backend to handle allocating and diff --git a/vllm/attention/layers/chunked_local_attention.py b/vllm/attention/layers/chunked_local_attention.py index 892077ba91e0..9c9a9c24c52e 100644 --- a/vllm/attention/layers/chunked_local_attention.py +++ b/vllm/attention/layers/chunked_local_attention.py @@ -1,48 +1,19 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project -import functools from typing import List, Optional import torch from vllm import envs -from vllm.attention.backends.abstract import AttentionBackend from vllm.attention.selector import get_attn_backend from vllm.config import CacheConfig, QuantizationConfig from vllm.v1.attention.backends.utils import ( - CommonAttentionMetadata, make_local_attention_virtual_batches, - subclass_attention_backend, subclass_attention_metadata_builder) + CommonAttentionMetadata, create_custom_attention_backend, + make_local_attention_virtual_batches) from ..layer import Attention -@functools.lru_cache -def create_chunked_local_attention_backend( - underlying_attn_backend: AttentionBackend, - attention_chunk_size: int, - block_size: int, -) -> type[AttentionBackend]: - prefix = f"ChunkedLocalAttention_{attention_chunk_size}_{block_size}_" - - def build_preprocess_fn(cm: CommonAttentionMetadata): - return make_local_attention_virtual_batches(attention_chunk_size, cm, - block_size) - - # Dynamically create a new attention backend that wraps the - # underlying attention backend but applies - # `make_local_attention_virtual_batches` before calling `build(...)` - builder_cls = subclass_attention_metadata_builder( - name_prefix=prefix, - builder_cls=underlying_attn_backend.get_builder_cls(), - build_preprocess_fn=build_preprocess_fn) - attn_backend = subclass_attention_backend( - name_prefix=prefix, - attention_backend_cls=underlying_attn_backend, - builder_cls=builder_cls) - - return attn_backend - - class ChunkedLocalAttention(Attention): def __init__(self, @@ -69,8 +40,15 @@ def __init__(self, kv_cache_dtype, block_size) - attn_backend = create_chunked_local_attention_backend( - underlying_attn_backend, attention_chunk_size, block_size) + prefix = \ + f"ChunkedLocalAttention_{attention_chunk_size}_{block_size}_" + + def build_preprocess_fn(cm: CommonAttentionMetadata): + return make_local_attention_virtual_batches( + attention_chunk_size, cm, block_size) + + attn_backend = create_custom_attention_backend( + prefix, underlying_attn_backend, build_preprocess_fn) else: # in v0 the local attention is handled inside the backends attn_backend = None diff --git a/vllm/compilation/decorators.py b/vllm/compilation/decorators.py index 1370862d580a..10e12bcb5130 100644 --- a/vllm/compilation/decorators.py +++ b/vllm/compilation/decorators.py @@ -60,6 +60,14 @@ def support_torch_compile( ... +@overload +def support_torch_compile( + *, + compile_cond: Optional[Callable[[VllmConfig], bool]] = None, +) -> Callable[[_T], _T]: + ... + + @overload def support_torch_compile(cls: _T) -> _T: ... @@ -69,6 +77,7 @@ def support_torch_compile( cls: Optional[_T] = None, *, dynamic_arg_dims: Optional[dict[str, Union[int, list[int]]]] = None, + compile_cond: Optional[Callable[[VllmConfig], bool]] = None, ) -> Union[Callable[[_T], _T], _T]: """ A decorator to add support for compiling the forward method of a class. @@ -118,6 +127,11 @@ def forward(self, x: torch.Tensor, y: Optional[torch.Tensor]): NOTE: if an argument is `None`, it should always be passed as `None` during the lifetime of the model, otherwise, it cannot be captured as a single computation graph. + + `compile_cond` is a function that takes a `VllmConfig` object as input and + returns a boolean value indicating whether to compile the model or not. + This is useful if you want to compile the model only when certain + conditions are met. """ def cls_decorator_helper(cls: _T) -> _T: @@ -149,7 +163,8 @@ def cls_decorator_helper(cls: _T) -> _T: if k not in sig.parameters: raise ValueError( f"Argument {k} not found in the forward method of {cls}") - return _support_torch_compile(cls, inferred_dynamic_arg_dims) + return _support_torch_compile(cls, inferred_dynamic_arg_dims, + compile_cond) if cls is not None: # use `support_torch_compile` as a decorator without arguments @@ -162,6 +177,7 @@ def cls_decorator_helper(cls: _T) -> _T: def _support_torch_compile( cls: _T, dynamic_arg_dims: dict[str, Union[int, list[int]]], + compile_cond: Optional[Callable[[VllmConfig], bool]] = None, ) -> _T: """ A decorator to add support for compiling the forward method of a class. @@ -182,13 +198,15 @@ def _support_torch_compile( def __init__(self, *, vllm_config: VllmConfig, prefix: str = '', **kwargs): old_init(self, vllm_config=vllm_config, prefix=prefix, **kwargs) self.vllm_config = vllm_config + compile_cond_satisfied = compile_cond is None or compile_cond( + vllm_config) # for CompilationLevel.DYNAMO_AS_IS , the upper level model runner # will handle the compilation, so we don't need to do anything here. self.do_not_compile = \ vllm_config.compilation_config.level in [ CompilationLevel.NO_COMPILATION, CompilationLevel.DYNAMO_AS_IS ] or not supports_dynamo() or _should_ignore_torch_compile( - self.__class__) + self.__class__) or not compile_cond_satisfied if self.do_not_compile: return diff --git a/vllm/config/cache.py b/vllm/config/cache.py index 69cb0d9732fa..3104f5b620a8 100644 --- a/vllm/config/cache.py +++ b/vllm/config/cache.py @@ -133,11 +133,18 @@ def __post_init__(self) -> None: self._verify_cache_dtype() self._verify_prefix_caching() + self._verify_kv_sharing_fast_prefill() def metrics_info(self): # convert cache_config to dict(key: str, value: str) for prometheus # metrics info return {key: str(value) for key, value in self.__dict__.items()} + + def _verify_kv_sharing_fast_prefill(self) -> None: + if self.kv_sharing_fast_prefill and not envs.VLLM_USE_V1: + raise NotImplementedError( + "Fast prefill optimization for KV sharing is not supported " + "in V0 currently.") @model_validator(mode='after') def _verify_args(self) -> Self: diff --git a/vllm/model_executor/models/gemma3n.py b/vllm/model_executor/models/gemma3n.py index ffec3408702c..c9f8e404ebf7 100644 --- a/vllm/model_executor/models/gemma3n.py +++ b/vllm/model_executor/models/gemma3n.py @@ -23,9 +23,11 @@ from transformers.models.gemma3n.configuration_gemma3n import Gemma3nTextConfig from vllm.attention import Attention +from vllm.compilation.backends import set_model_tag from vllm.compilation.decorators import support_torch_compile from vllm.config import CacheConfig, VllmConfig from vllm.distributed import get_tensor_model_parallel_world_size +from vllm.forward_context import get_forward_context from vllm.logger import init_logger from vllm.model_executor.layers.activation import (_ACTIVATION_REGISTRY, GeluAndMul, @@ -45,6 +47,8 @@ default_weight_loader, maybe_remap_kv_scale_name) from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.sequence import IntermediateTensors +from vllm.v1.attention.backends.utils import ( + KVSharingFastPrefillAttentionMetadata) from .interfaces import SupportsQuant from .utils import (AutoWeightsLoader, extract_layer_index, @@ -533,7 +537,177 @@ def forward( return corrected_predictions -@support_torch_compile +# This enables torch.compile if --kv-sharing-fast-prefill passed +@support_torch_compile(compile_cond=lambda vllm_config: vllm_config. + cache_config.kv_sharing_fast_prefill) +class Gemma3nSelfDecoder(nn.Module): + """ + Includes altup embedding and self decoder layers + """ + + def __init__( + self, + *, + vllm_config: VllmConfig, + prefix: str = "", + decoder_layers: list[Gemma3nDecoderLayer], + layer_idx_start: int, + per_layer_model_projection: ColumnParallelLinear, + embed_scale_per_layer: torch.Tensor, + embed_tokens_per_layer: VocabParallelEmbedding, + per_layer_projection_norm: RMSNorm, + per_layer_input_scale: torch.Tensor, + altup_projections: nn.ModuleList, + eps: torch.Tensor, + embed_tokens: VocabParallelEmbedding, + embed_scale: torch.Tensor, + ): + super().__init__() + self.decoder_layers = decoder_layers + self.layer_idx_start = layer_idx_start + self.per_layer_model_projection = per_layer_model_projection + self.config = vllm_config.model_config.hf_config.text_config + self.embed_scale_per_layer = embed_scale_per_layer + self.embed_tokens_per_layer = embed_tokens_per_layer + self.per_layer_projection_norm = per_layer_projection_norm + self.per_layer_input_scale = per_layer_input_scale + self.altup_projections = altup_projections + self.eps = eps + self.embed_tokens = embed_tokens + self.embed_scale = embed_scale + + def get_per_layer_input_embeddings( + self, input_ids: torch.Tensor) -> torch.Tensor: + # Deal with the fact that vocab_size_per_layer_input < vocab_size + # which causes us to have some out of vocab tokens by setting + # those token ids to 0. This matches the HF implementation. + per_layer_inputs_mask = torch.logical_and( + input_ids >= 0, input_ids < self.config.vocab_size_per_layer_input) + per_layer_inputs_tokens = torch.where(per_layer_inputs_mask, input_ids, + torch.zeros_like(input_ids)) + return self.embed_tokens_per_layer( + per_layer_inputs_tokens) * self.embed_scale_per_layer + + def get_per_layer_inputs( + self, + input_ids: torch.Tensor, + hidden_states_0: torch.Tensor, + ) -> torch.Tensor: + per_layer_inputs = self.get_per_layer_input_embeddings(input_ids) + per_layer_inputs = per_layer_inputs.reshape( + -1, self.config.num_hidden_layers, + self.config.hidden_size_per_layer_input) + per_layer_projection = self.per_layer_model_projection(hidden_states_0) + per_layer_projection = per_layer_projection.reshape( + *hidden_states_0.shape[:-1], + self.config.num_hidden_layers, + self.config.hidden_size_per_layer_input, + ) + per_layer_projection = self.per_layer_projection_norm( + per_layer_projection) + per_layer_inputs = per_layer_projection + per_layer_inputs + per_layer_inputs *= self.per_layer_input_scale + return per_layer_inputs + + def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: + return self.embed_tokens(input_ids) * self.embed_scale + + def altup_embed(self, hidden_states_0: torch.Tensor) -> torch.Tensor: + # Altup embed. + hidden_states = [hidden_states_0] * self.config.altup_num_inputs + target_magnitude = torch.mean(hidden_states_0**2, dim=-1, + keepdim=True)**0.5 + for i in range(1, self.config.altup_num_inputs): + hidden_states[i] = self.altup_projections[i - 1](hidden_states[i]) + new_magnitude = torch.mean(hidden_states[i]**2, + dim=-1, + keepdim=True)**0.5 + hidden_states[i] *= target_magnitude / torch.maximum( + new_magnitude, self.eps) + hidden_states = torch.stack(hidden_states, dim=-1) + return hidden_states + + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + inputs_embeds: Optional[torch.Tensor] = None, + **kwargs, + ) -> tuple[torch.Tensor, torch.Tensor]: + if inputs_embeds is not None: + hidden_states_0 = inputs_embeds + else: + hidden_states_0 = self.get_input_embeddings(input_ids) + + per_layer_inputs = self.get_per_layer_inputs(input_ids, + hidden_states_0) + hidden_states = self.altup_embed(hidden_states_0) + + # [altnum_inputs, num_tokens, hidden_size] + hidden_states = hidden_states.permute(2, 0, 1) + + for idx, layer in enumerate(self.decoder_layers): + layer_idx = idx + self.layer_idx_start + # [altup_num_inputs, num_tokens, hidden_size] + hidden_states = layer( + positions=positions, + hidden_states=hidden_states, + per_layer_input=per_layer_inputs[:, layer_idx, :], + **kwargs, + ) + + # [num_tokens, hidden_size, altnum_inputs] + hidden_states = hidden_states.permute(1, 2, 0) + + return hidden_states, per_layer_inputs + + +# This enables torch.compile if --kv-sharing-fast-prefill passed +@support_torch_compile(compile_cond=lambda vllm_config: vllm_config. + cache_config.kv_sharing_fast_prefill) +class Gemma3nCrossDecoder(nn.Module): + """ + Cross-decoder layers + """ + + def __init__( + self, + *, + vllm_config: VllmConfig, + prefix: str = "", + decoder_layers: list[Gemma3nDecoderLayer], + layer_idx_start: int, + ): + super().__init__() + self.decoder_layers = decoder_layers + self.layer_idx_start = layer_idx_start + + def forward( + self, + positions: torch.Tensor, + hidden_states: torch.Tensor, + per_layer_inputs: torch.Tensor, + **kwargs, + ) -> torch.Tensor: + # [altnum_inputs, num_tokens, hidden_size] + hidden_states = hidden_states.permute(2, 0, 1) + for idx, layer in enumerate(self.decoder_layers): + layer_idx = idx + self.layer_idx_start + # [altup_num_inputs, num_tokens, hidden_size] + hidden_states = layer( + positions=positions, + hidden_states=hidden_states, + per_layer_input=per_layer_inputs[:, layer_idx, :], + **kwargs, + ) + # [num_tokens, hidden_size, altnum_inputs] + hidden_states = hidden_states.permute(1, 2, 0) + return hidden_states + + +# This disables torch.compile if --kv-sharing-fast-prefill passed +@support_torch_compile(compile_cond=lambda vllm_config: not vllm_config. + cache_config.kv_sharing_fast_prefill) class Gemma3nTextModel(nn.Module, SupportsQuant): def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): @@ -543,7 +717,6 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): quant_config = vllm_config.quant_config self.config = config self.quant_config = quant_config - self.embed_tokens = VocabParallelEmbedding( config.vocab_size, config.hidden_size, @@ -613,95 +786,208 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): lambda prefix: Gemma3nDecoderLayer( config, cache_config, quant_config, prefix=prefix), prefix=f"{prefix}.layers") + + self.eps = torch.tensor(torch.finfo().min) + + first_kv_shared_layer_idx = (config.num_hidden_layers - + config.num_kv_shared_layers) + # Layer idx 0-19 are self-decoder layers in You Only Cache Once (YOCO) + with set_model_tag("self_decoder"): + self.self_decoder = Gemma3nSelfDecoder( + vllm_config=vllm_config, + prefix=f"{prefix}.self_decoder", + decoder_layers=self.layers[:first_kv_shared_layer_idx], + layer_idx_start=0, + per_layer_model_projection=self.per_layer_model_projection, + embed_scale_per_layer=self.embed_scale_per_layer, + embed_tokens_per_layer=self.embed_tokens_per_layer, + per_layer_projection_norm=self.per_layer_projection_norm, + per_layer_input_scale=self.per_layer_input_scale, + altup_projections=self.altup_projections, + eps=self.eps, + embed_tokens=self.embed_tokens, + embed_scale=self.embed_scale, + ) + # Layer idx 20-30 are cross-decoder layers in YOCO + with set_model_tag("cross_decoder"): + self.cross_decoder = Gemma3nCrossDecoder( + vllm_config=vllm_config, + prefix=f"{prefix}.cross_decoder", + decoder_layers=self.layers[first_kv_shared_layer_idx:], + layer_idx_start=first_kv_shared_layer_idx, + ) + self.norm = RMSNorm( config.hidden_size, eps=config.rms_norm_eps, ) - self.eps = torch.tensor(torch.finfo().min) - def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: - return self.embed_tokens(input_ids) * self.embed_scale + self.fast_prefill_enabled = cache_config.kv_sharing_fast_prefill + + if self.fast_prefill_enabled: + # Allocate static buffers for CUDAGraph + # TODO(sarckk): Extract this functionality to interface + max_num_tokens = vllm_config.scheduler_config.max_num_batched_tokens + device = next(self.parameters()).device + self.positions = torch.zeros(max_num_tokens, + dtype=torch.int64, + device=device) + self.hidden_states = torch.zeros( + (max_num_tokens, config.hidden_size, + self.config.altup_num_inputs), + dtype=self.embed_tokens.weight.dtype, + device=device, + ) + self.per_layer_inputs = torch.zeros( + (max_num_tokens, self.config.num_hidden_layers, + self.config.hidden_size_per_layer_input), + dtype=self.embed_tokens.weight.dtype, + device=device, + ) - def get_per_layer_input_embeddings( - self, input_ids: torch.Tensor) -> torch.Tensor: - # Deal with the fact that vocab_size_per_layer_input < vocab_size - # which causes us to have some out of vocab tokens by setting - # those token ids to 0. This matches the HF implementation. - per_layer_inputs_mask = torch.logical_and( - input_ids >= 0, input_ids < self.config.vocab_size_per_layer_input) - per_layer_inputs_tokens = torch.where(per_layer_inputs_mask, input_ids, - torch.zeros_like(input_ids)) - return self.embed_tokens_per_layer( - per_layer_inputs_tokens) * self.embed_scale_per_layer + def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: + return self.self_decoder.get_input_embeddings(input_ids) - def forward( + def fast_prefill_forward( self, - input_ids: Optional[torch.Tensor], + input_ids: torch.Tensor, positions: torch.Tensor, - per_layer_inputs: torch.Tensor, - intermediate_tensors: Optional[IntermediateTensors] = None, inputs_embeds: Optional[torch.Tensor] = None, **kwargs, - ) -> Union[torch.Tensor, IntermediateTensors]: - if inputs_embeds is not None: - hidden_states_0 = inputs_embeds - else: - hidden_states_0 = self.get_input_embeddings(input_ids) + ) -> torch.Tensor: + logits_indices_padded, num_logits_indices = None, None + attn_metadata = get_forward_context().attn_metadata + + # attn_metadata is None during dummy runs + if (self.fast_prefill_enabled and attn_metadata is not None): + assert isinstance(attn_metadata, dict) + # Last layer is a KV sharing layer + layer_attn_metadata = attn_metadata[ + self.layers[-1].self_attn.attn.layer_name] + if (isinstance(layer_attn_metadata, + KVSharingFastPrefillAttentionMetadata)): + logits_indices_padded = ( + layer_attn_metadata.logits_indices_padded) + num_logits_indices = layer_attn_metadata.num_logits_indices + + # Copy inputs for cudagraph + batch_size = positions.size(0) + self.positions[:batch_size].copy_(positions) + # input_ids and inputs_embeds are allocated in model runner + self_decoder_hidden_states, per_layer_inputs = self.self_decoder( + input_ids=input_ids, + positions=self.positions[:batch_size], + inputs_embeds=inputs_embeds, + **kwargs, + ) - per_layer_projection = self.per_layer_model_projection(hidden_states_0) - per_layer_projection = per_layer_projection.reshape( - *hidden_states_0.shape[:-1], - self.config.num_hidden_layers, - self.config.hidden_size_per_layer_input, + if logits_indices_padded is None: + logits_indices_padded = torch.arange( + positions.size(0), + dtype=positions.dtype, + device=positions.device, + ) + + # NOTE(sarckk): There is currently a bug caused by + # vLLM converting output of last piecewise CUDA graph + # to weakref, causing memory to be prematurely freed + # when there are multiple compilation units + # Keep .clone() until fix in + # https://github.com/vllm-project/vllm/pull/22282 + hidden_states = self_decoder_hidden_states.clone() + + # Copy inputs for cudagraph + num_padded_logits_indices = logits_indices_padded.size(0) + self.positions[:num_padded_logits_indices].copy_( + positions[logits_indices_padded]) + self.hidden_states[:num_padded_logits_indices].copy_( + self_decoder_hidden_states[logits_indices_padded]) + self.per_layer_inputs[:num_padded_logits_indices].copy_( + per_layer_inputs[logits_indices_padded]) + cross_decoder_hidden_states = self.cross_decoder( + positions=self.positions[:num_padded_logits_indices], + hidden_states=self.hidden_states[:num_padded_logits_indices], + per_layer_inputs=self.per_layer_inputs[:num_padded_logits_indices], + **kwargs, ) - per_layer_projection = self.per_layer_projection_norm( - per_layer_projection) - if per_layer_inputs is not None: - # Profiling run does not compute per_layer_inputs - per_layer_inputs = per_layer_projection + per_layer_inputs - per_layer_inputs *= self.per_layer_input_scale + if num_logits_indices is not None: + assert num_logits_indices > 0 + # Merge cross-decoder and self-decoder hidden states + hidden_states[logits_indices_padded[:num_logits_indices]] = ( + cross_decoder_hidden_states[:num_logits_indices]) else: - per_layer_inputs = per_layer_projection + hidden_states = cross_decoder_hidden_states - # Altup embed. - hidden_states = [hidden_states_0] * self.config.altup_num_inputs - target_magnitude = torch.mean(hidden_states_0**2, dim=-1, - keepdim=True)**0.5 - for i in range(1, self.config.altup_num_inputs): - hidden_states[i] = self.altup_projections[i - 1](hidden_states[i]) - new_magnitude = torch.mean(hidden_states[i]**2, - dim=-1, - keepdim=True)**0.5 - hidden_states[i] *= target_magnitude / torch.maximum( - new_magnitude, self.eps) - hidden_states = torch.stack(hidden_states, dim=0) + return hidden_states - # Transformer blocks. - for layer_idx, layer in enumerate(self.layers): - # [altup_num_inputs, num_tokens, hidden_size] - hidden_states = layer( - positions=positions, - hidden_states=hidden_states, - per_layer_input=per_layer_inputs[:, layer_idx, :], - **kwargs, - ) + def normal_forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + inputs_embeds: Optional[torch.Tensor] = None, + **kwargs, + ) -> torch.Tensor: + hidden_states, per_layer_inputs = self.self_decoder( + input_ids=input_ids, + positions=positions, + inputs_embeds=inputs_embeds, + **kwargs, + ) + hidden_states = self.cross_decoder( + positions=positions, + hidden_states=hidden_states, + per_layer_inputs=per_layer_inputs, + **kwargs, + ) + return hidden_states + def altup_unembed( + self, + hidden_states: torch.Tensor, + ) -> torch.Tensor: # Altup unembed. - target_magnitude = torch.mean(hidden_states[0]**2, + target_magnitude = torch.mean(hidden_states[..., 0]**2, dim=-1, keepdim=True)**0.5 for i in range(1, self.config.altup_num_inputs): - hidden_states[i] = self.altup_unembed_projections[i - 1]( - hidden_states[i]) - new_magnitude = torch.mean(hidden_states[i]**2, + hidden_states[..., i] = self.altup_unembed_projections[i - 1]( + hidden_states[..., i]) + new_magnitude = torch.mean(hidden_states[..., i]**2, dim=-1, keepdim=True)**0.5 - hidden_states[i] *= target_magnitude / torch.maximum( + hidden_states[..., i] *= target_magnitude / torch.maximum( new_magnitude, self.eps) - # [altup_num_inputs,num_tokens,hidden_size] -> [num_tokens,hidden_size] - hidden_states = torch.mean(hidden_states, dim=0) + # [num_tokens,hidden_size, altup_num_inputs] -> [num_tokens,hidden_size] + hidden_states = torch.mean(hidden_states, dim=-1) + return hidden_states + def forward( + self, + input_ids: Optional[torch.Tensor], + positions: torch.Tensor, + inputs_embeds: Optional[torch.Tensor] = None, + **kwargs, + ) -> Union[torch.Tensor, IntermediateTensors]: + # Per layer inputs. + if input_ids is None: + raise ValueError("Passing None for input ids is not supported.") + + if self.fast_prefill_enabled: + hidden_states = self.fast_prefill_forward( + input_ids, + positions, + inputs_embeds, + **kwargs, + ) + else: + hidden_states = self.normal_forward( + input_ids, + positions, + inputs_embeds, + **kwargs, + ) + hidden_states = self.altup_unembed(hidden_states) return self.norm(hidden_states) def load_weights(self, weights: Iterable[tuple[str, diff --git a/vllm/v1/attention/backends/utils.py b/vllm/v1/attention/backends/utils.py index e23dd8bc5bbb..498870497287 100644 --- a/vllm/v1/attention/backends/utils.py +++ b/vllm/v1/attention/backends/utils.py @@ -4,12 +4,14 @@ import enum import functools from abc import abstractmethod -from dataclasses import dataclass, make_dataclass +from collections.abc import Hashable +from dataclasses import dataclass, fields, make_dataclass from typing import (TYPE_CHECKING, Any, Callable, ClassVar, Generic, Optional, - TypeVar) + Protocol, TypeVar) import numpy as np import torch +from typing_extensions import runtime_checkable from vllm.config import VllmConfig, get_layers_from_vllm_config from vllm.utils import cdiv @@ -62,6 +64,8 @@ class CommonAttentionMetadata: block_table_tensor: torch.Tensor slot_mapping: torch.Tensor + logits_indices: Optional[torch.Tensor] = None + causal: bool = True @@ -530,8 +534,73 @@ def make_local_attention_virtual_batches( max_query_len=seqlens_q_local.max(), block_table_tensor=block_table_local, slot_mapping=common_attn_metadata.slot_mapping, + logits_indices=common_attn_metadata.logits_indices, + causal=True, + ) + + +def make_kv_sharing_fast_prefill_common_attn_metadata( + common_attn_metadata: CommonAttentionMetadata, +) -> CommonAttentionMetadata: + if common_attn_metadata.max_query_len == 1: + # All requests are decode (assume 1 token for now) + # Skip computing fast prefill path + return common_attn_metadata + + if common_attn_metadata.logits_indices is None: + # Logits_indices can be None if prompt_logprobs is + # set for at least one request in the current iteration + # fast prefill is not compatible with prompt_logprobs + # so skip computing fast prefill path + return common_attn_metadata + + logits_indices = common_attn_metadata.logits_indices + num_reqs = common_attn_metadata.num_reqs + query_start_loc = common_attn_metadata.query_start_loc + seq_lens = common_attn_metadata.seq_lens + # Example inputs + # num_reqs: 3 + # generation_indices: [14, 18, 19, 27] + # query_start_loc: [0, 15, 20, 28] + # seq_lens: [41, 31, 40] + + # Find how many decode indices belong to each request + # request_ids: [0, 1, 1, 2] + request_ids = torch.bucketize(logits_indices, + query_start_loc[1:], + right=True) + + # Figure out how many tokens are in each request + # num_decode_tokens: [1, 2, 1] + num_decode_tokens = torch.bincount(request_ids, minlength=num_reqs) + + # Calculate new query_start_loc with tokens in generation_indices + # decode_query_start_loc: [0, 1, 3, 4] + decode_query_start_loc = torch.empty(num_reqs + 1, + device=query_start_loc.device, + dtype=query_start_loc.dtype) + + decode_query_start_loc[0] = 0 + decode_query_start_loc[1:] = torch.cumsum(num_decode_tokens, dim=0) + decode_max_query_len = int(num_decode_tokens.max().item()) + total_num_decode_tokens = int(num_decode_tokens.sum().item()) + + common_attn_metadata = CommonAttentionMetadata( + query_start_loc=decode_query_start_loc, + # TODO: optimize + query_start_loc_cpu=decode_query_start_loc.cpu(), + seq_lens=seq_lens, + seq_lens_cpu=seq_lens.cpu(), + num_computed_tokens_cpu=common_attn_metadata.num_computed_tokens_cpu, + num_reqs=num_reqs, + num_actual_tokens=total_num_decode_tokens, + max_query_len=decode_max_query_len, + block_table_tensor=common_attn_metadata.block_table_tensor, + slot_mapping=common_attn_metadata.slot_mapping, + logits_indices=logits_indices, causal=True, ) + return common_attn_metadata def subclass_attention_metadata_builder( @@ -700,13 +769,77 @@ def subclass_attention_metadata( return Wrapped +@functools.lru_cache def make_kv_sharing_fast_prefill_attention_metadata( - metadata_cls: Any, ) -> Any: + metadata_cls: Hashable, ) -> Any: """ Return a new subclass of `metadata_cls` for fast prefill """ - return subclass_attention_metadata( + attn_metadata_dataclass = subclass_attention_metadata( name_prefix="KVSharingFastPrefill", metadata_cls=metadata_cls, fields=KV_SHARING_FAST_PREFILL_METADATA_FIELDS, ) + # Make attention metadata type inherit + # KVSharingFastPrefillAttentionMetadata type + fast_prefill_metadata_type = type( + attn_metadata_dataclass.__name__, + ( + attn_metadata_dataclass, + KVSharingFastPrefillAttentionMetadata, + ), + {}, + ) + return fast_prefill_metadata_type + + +@runtime_checkable +class KVSharingFastPrefillAttentionMetadata(Protocol): + logits_indices_padded: torch.Tensor + num_logits_indices: int + + +def create_kv_sharing_fast_prefill_attn_metadata_subclass( + attn_metadata_i: Any, + logits_indices_padded: torch.Tensor, + num_logits_indices: int, +): + # Dynamically create a a dataclass type that inherits + # from attention metadata type but includes additional + # fields logits_indices_padded and num_logits_indices + # which are required for prefill truncation + fast_prefill_metadata_type = ( + make_kv_sharing_fast_prefill_attention_metadata( + metadata_cls=type(attn_metadata_i), )) # type: ignore + # Avoid deepcopy caused by dict.asdict + attn_metadata_fields = {} + for field in fields(attn_metadata_i.__class__): + attn_metadata_fields[field.name] = getattr(attn_metadata_i, field.name) + attn_metadata_i = fast_prefill_metadata_type( + **attn_metadata_fields, + logits_indices_padded=logits_indices_padded, + num_logits_indices=num_logits_indices, + ) + return attn_metadata_i + + +@functools.lru_cache +def create_custom_attention_backend( + prefix: str, + underlying_attn_backend: AttentionBackend, + build_preprocess_fn: Callable[[CommonAttentionMetadata], + CommonAttentionMetadata], +) -> type[AttentionBackend]: + # Dynamically create a new attention backend that wraps the + # underlying attention backend but applies + # `build_preproces_fn` before calling `build(...)` + builder_cls = subclass_attention_metadata_builder( + name_prefix=prefix, + builder_cls=underlying_attn_backend.get_builder_cls(), + build_preprocess_fn=build_preprocess_fn) + attn_backend = subclass_attention_backend( + name_prefix=prefix, + attention_backend_cls=underlying_attn_backend, + builder_cls=builder_cls) + + return attn_backend diff --git a/vllm/v1/spec_decode/eagle.py b/vllm/v1/spec_decode/eagle.py index f75d76dd978f..c805499aa25a 100644 --- a/vllm/v1/spec_decode/eagle.py +++ b/vllm/v1/spec_decode/eagle.py @@ -609,6 +609,7 @@ def prepare_inputs( max_query_len=new_query_len_per_req.max().item(), block_table_tensor=common_attn_metadata.block_table_tensor, slot_mapping=common_attn_metadata.slot_mapping[token_indices], + logits_indices=common_attn_metadata.logits_indices, causal=True, ) diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index 48ff50fd6bd8..c6b035f14c03 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -1,7 +1,6 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project -import dataclasses import gc import itertools import time @@ -52,7 +51,9 @@ from vllm.v1.attention.backends.mamba_selectors import get_mamba_attn_backend from vllm.v1.attention.backends.utils import ( AttentionCGSupport, AttentionMetadataBuilder, CommonAttentionMetadata, - make_kv_sharing_fast_prefill_attention_metadata, + create_custom_attention_backend, + create_kv_sharing_fast_prefill_attn_metadata_subclass, + make_kv_sharing_fast_prefill_common_attn_metadata, reorder_batch_to_split_decodes_and_prefills) from vllm.v1.kv_cache_interface import (AttentionSpec, ChunkedLocalAttentionSpec, @@ -75,9 +76,10 @@ from vllm.v1.worker.lora_model_runner_mixin import LoRAModelRunnerMixin from ..sample.logits_processor import LogitsProcessorManager -from .utils import (AttentionGroup, MultiModalBudget, bind_kv_cache, - gather_mm_placeholders, initialize_kv_cache_for_kv_sharing, - sanity_check_mm_encoder_outputs, scatter_mm_placeholders) +from .utils import (AttentionGroup, MultiModalBudget, + add_kv_sharing_layers_to_kv_cache_groups, bind_kv_cache, + gather_mm_placeholders, sanity_check_mm_encoder_outputs, + scatter_mm_placeholders) if TYPE_CHECKING: import xgrammar as xgr @@ -802,6 +804,14 @@ def _prepare_inputs( if attn_module.attn_type == AttentionType.ENCODER_ONLY: attn_metadata[layer_name] = encoder_attn_metadata + if (self.cache_config.kv_sharing_fast_prefill + and self.input_batch.num_prompt_logprobs): + logger.warning_once( + "Encountered at least one request with prompt_logprobs set " + "with --kv-sharing-fast-prefill enabled. Fast prefill doesn't " + "produce correct logits for prompt tokens, so fast prefill " + "will be disabled for scheduling rounds with prompt_logprobs.") + # Prepare the attention metadata for each KV cache group and make layers # in the same group share the same metadata. for kv_cache_group_id, kv_cache_group_spec in enumerate( @@ -847,33 +857,35 @@ def _prepare_inputs( builder, ) + # If there is at least one request with prompt_logprobs set, + # we cannot enable this optimization as the logits of prompt + # tokens will no longer be valid when doing fast prefill. + is_fast_prefill = ( + attn_group.layer_names[0] + in self.kv_sharing_fast_prefill_eligible_layers + and not self.input_batch.num_prompt_logprobs) + if is_fast_prefill: + # If logits_indices is set, builder.build(...) will + # preprocess the common metadata to skip prefill tokens + common_attn_metadata.logits_indices = logits_indices + # TODO(sarckk): Enable cascade attention for fast prefill + common_prefix_len = 0 + attn_metadata_i = (builder.build( common_prefix_len=common_prefix_len, common_attn_metadata=common_attn_metadata, )) - fast_prefill_metadata = attn_metadata_i - if (self.cache_config.kv_sharing_fast_prefill - and self.kv_sharing_fast_prefill_eligible_layers): - # Dynamically create a a dataclass type that inherits - # from attention metadata type but includes additional - # fields logits_indices_padded and num_logits_indices - # which are required for prefill truncation - fast_prefill_metadata_type = ( - make_kv_sharing_fast_prefill_attention_metadata( - metadata_cls=type(attn_metadata_i), )) - fast_prefill_metadata = fast_prefill_metadata_type( - **dataclasses.asdict(attn_metadata_i), - logits_indices_padded=logits_indices_padded, - num_logits_indices=logits_indices.size(0), - ) + if is_fast_prefill: + # Eligible layers need extra metadata for use in the model. + attn_metadata_i = \ + create_kv_sharing_fast_prefill_attn_metadata_subclass( + attn_metadata_i, + logits_indices_padded, + logits_indices.size(0), + ) for layer_name in attn_group.layer_names: - if (self.cache_config.kv_sharing_fast_prefill - and layer_name - in self.kv_sharing_fast_prefill_eligible_layers): - attn_metadata[layer_name] = fast_prefill_metadata - continue attn_metadata[layer_name] = attn_metadata_i attention_cuda_graphs = all( @@ -2559,6 +2571,14 @@ def get_attn_backends_for_layers( # layer. for layer_name in layer_names: attn_backend = attn_layers[layer_name].get_attn_backend() + + if layer_name in self.kv_sharing_fast_prefill_eligible_layers: + attn_backend = create_custom_attention_backend( + "FastPrefill", + attn_backend, + make_kv_sharing_fast_prefill_common_attn_metadata, + ) + key = attn_backend.full_cls_name() attn_backends[key] = attn_backend attn_backend_layers[key].append(layer_name) @@ -2735,7 +2755,10 @@ def _allocate_kv_cache_tensors( layer_names = set() for group in kv_cache_config.kv_cache_groups: layer_names.update(group.layer_names) - assert layer_names == set(kv_cache_raw_tensors.keys( + + kv_allocating_layers = layer_names - set( + self.shared_kv_cache_layers.keys()) + assert kv_allocating_layers == set(kv_cache_raw_tensors.keys( )), "Some layers are not correctly initialized" return kv_cache_raw_tensors @@ -2772,6 +2795,9 @@ def _reshape_kv_cache_tensors( for kv_cache_spec, group in self._kv_cache_spec_attn_group_iterator(): attn_backend = group.backend for layer_name in group.layer_names: + if layer_name in self.shared_kv_cache_layers: + # Skip layers without KV cache + continue raw_tensor = kv_cache_raw_tensors[layer_name] assert raw_tensor.numel() % kv_cache_spec.page_size_bytes == 0 num_blocks = (raw_tensor.numel() // @@ -2883,31 +2909,33 @@ def initialize_kv_cache_tensors( kv_caches = self._reshape_kv_cache_tensors(kv_cache_config, kv_cache_raw_tensors) - # Setup `kv_cache_config` and `kv_caches` for models - # with cross-layer KV sharing - if self.shared_kv_cache_layers: - initialize_kv_cache_for_kv_sharing( - self.shared_kv_cache_layers, - kv_cache_config.kv_cache_groups, - kv_caches, - self.attn_groups, - ) - attn_layers = get_layers_from_vllm_config(self.vllm_config, - Attention) - # Iterate in reversed order and add layers that re-use KV cache - # e.g. in YOCO-like KV sharing setups (e.g. Gemma3n) - for layer_name in reversed(attn_layers): - if layer_name in self.shared_kv_cache_layers: - self.kv_sharing_fast_prefill_eligible_layers.add( - layer_name) - else: - break + # Set up cross-layer KV cache sharing + for layer_name, target_layer_name in self.shared_kv_cache_layers.items( + ): + logger.debug("%s reuses KV cache of %s", layer_name, + target_layer_name) + kv_caches[layer_name] = kv_caches[target_layer_name] bind_kv_cache(kv_caches, self.compilation_config.static_forward_context, self.kv_caches) return kv_caches + def maybe_add_kv_sharing_layers_to_kv_cache_groups( + self, kv_cache_config: KVCacheConfig) -> None: + """ + Add layers that re-use KV cache to KV cache group of its target layer. + Mapping of KV cache tensors happens in `initialize_kv_cache_tensors()` + """ + if not self.shared_kv_cache_layers: + # No cross-layer KV sharing, return + return + + add_kv_sharing_layers_to_kv_cache_groups( + self.shared_kv_cache_layers, + kv_cache_config.kv_cache_groups, + ) + def initialize_kv_cache(self, kv_cache_config: KVCacheConfig) -> None: """ Initialize KV cache based on `kv_cache_config`. @@ -2917,6 +2945,7 @@ def initialize_kv_cache(self, kv_cache_config: KVCacheConfig) -> None: """ self.kv_cache_config = kv_cache_config self.may_reinitialize_input_batch(kv_cache_config) + self.maybe_add_kv_sharing_layers_to_kv_cache_groups(kv_cache_config) self.initialize_attn_backend(kv_cache_config) kv_caches = self.initialize_kv_cache_tensors(kv_cache_config) @@ -2929,6 +2958,26 @@ def initialize_kv_cache(self, kv_cache_config: KVCacheConfig) -> None: if has_kv_transfer_group(): get_kv_transfer_group().register_kv_caches(kv_caches) + def maybe_add_kv_sharing_fast_prefill_layers(self, + attn_layers: dict[str, + Attention]): + """ + In You Only Cache Once (https://arxiv.org/abs/2405.05254), or other + similar KV sharing setups, the layers that re-use the shared KV cache + (cross-decoder layers) can skip prefill, as only the earlier layers + that generate KV caches are involved in the prefill phase. + """ + if not self.cache_config.kv_sharing_fast_prefill: + # Optimization disabled, return + return + + # Iterate in reversed order and add layers that re-use KV cache + for layer_name in reversed(attn_layers): + if layer_name in self.shared_kv_cache_layers: + self.kv_sharing_fast_prefill_eligible_layers.add(layer_name) + else: + break + def get_kv_cache_spec(self) -> dict[str, KVCacheSpec]: """ Generates the KVCacheSpec by parsing the kv cache format from each @@ -3016,6 +3065,8 @@ def get_kv_cache_spec(self) -> dict[str, KVCacheSpec]: page_size_padded=page_size_padded, mamba_type=mamba_module.mamba_type) + self.maybe_add_kv_sharing_fast_prefill_layers(attn_layers) + return kv_cache_spec def _build_encoder_only_attn_metadata( diff --git a/vllm/v1/worker/tpu_model_runner.py b/vllm/v1/worker/tpu_model_runner.py index 915869726fbf..0dc3050eb15d 100644 --- a/vllm/v1/worker/tpu_model_runner.py +++ b/vllm/v1/worker/tpu_model_runner.py @@ -55,9 +55,8 @@ from vllm.v1.worker.lora_model_runner_mixin import LoRAModelRunnerMixin from vllm.v1.worker.tpu_input_batch import CachedRequestState, InputBatch -from .utils import (MultiModalBudget, bind_kv_cache, - initialize_kv_cache_for_kv_sharing, - sanity_check_mm_encoder_outputs) +from .utils import (MultiModalBudget, add_kv_sharing_layers_to_kv_cache_groups, + bind_kv_cache, sanity_check_mm_encoder_outputs) if TYPE_CHECKING: from vllm.v1.core.sched.output import SchedulerOutput @@ -1595,6 +1594,30 @@ def profile_run( self.encoder_cache.clear() gc.collect() + def maybe_setup_cross_layer_kv_sharing( + self, + kv_caches: dict[str, torch.Tensor], + kv_cache_config: KVCacheConfig, + ) -> None: + """ + Add layers that re-use KV cache to KV cache group of its target layer. + Mapping of KV cache tensors happens in `initialize_kv_cache_tensors()` + """ + if not self.shared_kv_cache_layers: + # No cross-layer KV sharing, return + return + + add_kv_sharing_layers_to_kv_cache_groups( + self.shared_kv_cache_layers, + kv_cache_config.kv_cache_groups, + ) + + for layer_name, target_layer_name in self.shared_kv_cache_layers.items( + ): + logger.debug("%s reuses KV cache of %s", layer_name, + target_layer_name) + kv_caches[layer_name] = kv_caches[target_layer_name] + def initialize_kv_cache(self, kv_cache_config: KVCacheConfig) -> None: """ Initialize KV cache based on `kv_cache_config`. @@ -1660,14 +1683,8 @@ def initialize_kv_cache(self, kv_cache_config: KVCacheConfig) -> None: else: raise NotImplementedError - # Setup `kv_cache_config` and `kv_caches` for models - # with cross-layer KV sharing - if self.shared_kv_cache_layers: - initialize_kv_cache_for_kv_sharing( - self.shared_kv_cache_layers, - kv_cache_config.kv_cache_groups, - kv_caches, - ) + # Set up cross-layer KV cache sharing if needed + self.maybe_setup_cross_layer_kv_sharing(kv_caches, kv_cache_config) bind_kv_cache( kv_caches, diff --git a/vllm/v1/worker/utils.py b/vllm/v1/worker/utils.py index e7079235d651..504fa51a8a3e 100644 --- a/vllm/v1/worker/utils.py +++ b/vllm/v1/worker/utils.py @@ -202,12 +202,9 @@ def gather_mm_placeholders( return placeholders[is_embed] -def initialize_kv_cache_for_kv_sharing( +def add_kv_sharing_layers_to_kv_cache_groups( shared_kv_cache_layers: dict[str, str], kv_cache_groups: list[KVCacheGroupSpec], - kv_caches: dict[str, torch.Tensor], - # Optional for now to avoid breaking TPU - attn_groups: Optional[list[list[AttentionGroup]]] = None, ) -> None: """ Sets up KV cache sharing by reusing the allocated KV caches in `kv_caches` @@ -221,30 +218,15 @@ def initialize_kv_cache_for_kv_sharing( means this layer will perform attention using the keys and values from the KV cache of `shared_kv_cache_layers[layer_name]`. kv_cache_groups: The KV cache groups of the model. - kv_caches: The allocated kv_caches with layer names as keys. - Note that layers in shared_kv_cache_layers.keys() are not - originally included as it only contains layers which have its own - KV cache allocation. """ - # Record index of KV cache group for each layer that allocates a KV cache. - layer_to_kv_cache_group_idx: dict[str, int] = {} - for i, kv_cache_group in enumerate(kv_cache_groups): + layer_to_kv_cache_group: dict[str, KVCacheGroupSpec] = {} + for kv_cache_group in kv_cache_groups: for layer_name in kv_cache_group.layer_names: - layer_to_kv_cache_group_idx[layer_name] = i + layer_to_kv_cache_group[layer_name] = kv_cache_group for layer_name, target_layer_name in shared_kv_cache_layers.items(): - kv_caches[layer_name] = kv_caches[target_layer_name] - group_idx = layer_to_kv_cache_group_idx[target_layer_name] - kv_cache_groups[group_idx].layer_names.append(layer_name) - - if attn_groups is not None: - assert len(attn_groups[group_idx]) == 1, ( - "Only one attention group per KV cache group is supported " - "for KV-cache sharing for now.") - # TODO(lucas): I think in the future the layers that re-use a - # KV cache will be in a different attention group so we can - # remove this code from here. - attn_groups[group_idx][0].layer_names.append(layer_name) + tgt_kv_cache_group = layer_to_kv_cache_group[target_layer_name] + tgt_kv_cache_group.layer_names.append(layer_name) def bind_kv_cache( From 1853ce195aed128901dca2bf33806240178dba3b Mon Sep 17 00:00:00 2001 From: Yong Hoon Shin Date: Sun, 10 Aug 2025 17:18:57 -0700 Subject: [PATCH 2/2] Fix rebase Signed-off-by: Yong Hoon Shin --- .../layers/chunked_local_attention.py | 4 +- vllm/config/__init__.py | 6 +++ vllm/config/cache.py | 7 +-- vllm/model_executor/models/gemma3n.py | 53 ++++++++++--------- vllm/model_executor/models/gemma3n_mm.py | 2 +- 5 files changed, 39 insertions(+), 33 deletions(-) diff --git a/vllm/attention/layers/chunked_local_attention.py b/vllm/attention/layers/chunked_local_attention.py index 9c9a9c24c52e..5c2deda76cc9 100644 --- a/vllm/attention/layers/chunked_local_attention.py +++ b/vllm/attention/layers/chunked_local_attention.py @@ -40,7 +40,7 @@ def __init__(self, kv_cache_dtype, block_size) - prefix = \ + backend_prefix = \ f"ChunkedLocalAttention_{attention_chunk_size}_{block_size}_" def build_preprocess_fn(cm: CommonAttentionMetadata): @@ -48,7 +48,7 @@ def build_preprocess_fn(cm: CommonAttentionMetadata): attention_chunk_size, cm, block_size) attn_backend = create_custom_attention_backend( - prefix, underlying_attn_backend, build_preprocess_fn) + backend_prefix, underlying_attn_backend, build_preprocess_fn) else: # in v0 the local attention is handled inside the backends attn_backend = None diff --git a/vllm/config/__init__.py b/vllm/config/__init__.py index 700d29f956a8..c848794ae1fb 100644 --- a/vllm/config/__init__.py +++ b/vllm/config/__init__.py @@ -3908,6 +3908,12 @@ def __post_init__(self): # local attention. self.scheduler_config.disable_hybrid_kv_cache_manager = True + if self.cache_config.kv_sharing_fast_prefill: + # There is an IMA issue currently when using fast prefill with + # hybrid kv cache manager (e.g. interleaved sliding window) + # TODO(sarckk): investigate and fix + self.scheduler_config.disable_hybrid_kv_cache_manager = True + def update_sizes_for_sequence_parallelism(self, possible_sizes: list) -> list: # remove the sizes that not multiple of tp_size when diff --git a/vllm/config/cache.py b/vllm/config/cache.py index 3104f5b620a8..3636289f2e41 100644 --- a/vllm/config/cache.py +++ b/vllm/config/cache.py @@ -139,7 +139,7 @@ def metrics_info(self): # convert cache_config to dict(key: str, value: str) for prometheus # metrics info return {key: str(value) for key, value in self.__dict__.items()} - + def _verify_kv_sharing_fast_prefill(self) -> None: if self.kv_sharing_fast_prefill and not envs.VLLM_USE_V1: raise NotImplementedError( @@ -157,11 +157,6 @@ def _verify_args(self) -> Self: "GPU memory utilization must be less than 1.0. Got " f"{self.gpu_memory_utilization}.") - if self.kv_sharing_fast_prefill: - logger.warning_once( - "--kv-sharing-fast-prefill is currently work in progress " - "and not functional yet (i.e. no prefill savings)") - return self def _verify_cache_dtype(self) -> None: diff --git a/vllm/model_executor/models/gemma3n.py b/vllm/model_executor/models/gemma3n.py index c9f8e404ebf7..42d1fd2e1528 100644 --- a/vllm/model_executor/models/gemma3n.py +++ b/vllm/model_executor/models/gemma3n.py @@ -566,7 +566,7 @@ def __init__( self.decoder_layers = decoder_layers self.layer_idx_start = layer_idx_start self.per_layer_model_projection = per_layer_model_projection - self.config = vllm_config.model_config.hf_config.text_config + self.config = vllm_config.model_config.hf_config self.embed_scale_per_layer = embed_scale_per_layer self.embed_tokens_per_layer = embed_tokens_per_layer self.per_layer_projection_norm = per_layer_projection_norm @@ -590,13 +590,9 @@ def get_per_layer_input_embeddings( def get_per_layer_inputs( self, - input_ids: torch.Tensor, hidden_states_0: torch.Tensor, + per_layer_inputs: Optional[torch.Tensor], ) -> torch.Tensor: - per_layer_inputs = self.get_per_layer_input_embeddings(input_ids) - per_layer_inputs = per_layer_inputs.reshape( - -1, self.config.num_hidden_layers, - self.config.hidden_size_per_layer_input) per_layer_projection = self.per_layer_model_projection(hidden_states_0) per_layer_projection = per_layer_projection.reshape( *hidden_states_0.shape[:-1], @@ -605,8 +601,12 @@ def get_per_layer_inputs( ) per_layer_projection = self.per_layer_projection_norm( per_layer_projection) - per_layer_inputs = per_layer_projection + per_layer_inputs - per_layer_inputs *= self.per_layer_input_scale + if per_layer_inputs is not None: + # Profiling run does not compute per_layer_inputs + per_layer_inputs = per_layer_projection + per_layer_inputs + per_layer_inputs *= self.per_layer_input_scale + else: + per_layer_inputs = per_layer_projection return per_layer_inputs def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: @@ -632,6 +632,7 @@ def forward( input_ids: torch.Tensor, positions: torch.Tensor, inputs_embeds: Optional[torch.Tensor] = None, + per_layer_inputs: Optional[torch.Tensor] = None, **kwargs, ) -> tuple[torch.Tensor, torch.Tensor]: if inputs_embeds is not None: @@ -639,8 +640,8 @@ def forward( else: hidden_states_0 = self.get_input_embeddings(input_ids) - per_layer_inputs = self.get_per_layer_inputs(input_ids, - hidden_states_0) + adjusted_per_layer_inputs = self.get_per_layer_inputs( + hidden_states_0, per_layer_inputs) hidden_states = self.altup_embed(hidden_states_0) # [altnum_inputs, num_tokens, hidden_size] @@ -652,14 +653,14 @@ def forward( hidden_states = layer( positions=positions, hidden_states=hidden_states, - per_layer_input=per_layer_inputs[:, layer_idx, :], + per_layer_input=adjusted_per_layer_inputs[:, layer_idx, :], **kwargs, ) # [num_tokens, hidden_size, altnum_inputs] hidden_states = hidden_states.permute(1, 2, 0) - return hidden_states, per_layer_inputs + return hidden_states, adjusted_per_layer_inputs # This enables torch.compile if --kv-sharing-fast-prefill passed @@ -853,6 +854,7 @@ def fast_prefill_forward( input_ids: torch.Tensor, positions: torch.Tensor, inputs_embeds: Optional[torch.Tensor] = None, + per_layer_inputs: Optional[torch.Tensor] = None, **kwargs, ) -> torch.Tensor: logits_indices_padded, num_logits_indices = None, None @@ -873,13 +875,14 @@ def fast_prefill_forward( # Copy inputs for cudagraph batch_size = positions.size(0) self.positions[:batch_size].copy_(positions) - # input_ids and inputs_embeds are allocated in model runner - self_decoder_hidden_states, per_layer_inputs = self.self_decoder( - input_ids=input_ids, - positions=self.positions[:batch_size], - inputs_embeds=inputs_embeds, - **kwargs, - ) + self_decoder_hidden_states, per_layer_inputs_adjusted = \ + self.self_decoder( + input_ids=input_ids, + positions=self.positions[:batch_size], + inputs_embeds=inputs_embeds, + per_layer_inputs=per_layer_inputs, + **kwargs, + ) if logits_indices_padded is None: logits_indices_padded = torch.arange( @@ -903,7 +906,7 @@ def fast_prefill_forward( self.hidden_states[:num_padded_logits_indices].copy_( self_decoder_hidden_states[logits_indices_padded]) self.per_layer_inputs[:num_padded_logits_indices].copy_( - per_layer_inputs[logits_indices_padded]) + per_layer_inputs_adjusted[logits_indices_padded]) cross_decoder_hidden_states = self.cross_decoder( positions=self.positions[:num_padded_logits_indices], hidden_states=self.hidden_states[:num_padded_logits_indices], @@ -926,12 +929,14 @@ def normal_forward( input_ids: torch.Tensor, positions: torch.Tensor, inputs_embeds: Optional[torch.Tensor] = None, + per_layer_inputs: Optional[torch.Tensor] = None, **kwargs, ) -> torch.Tensor: hidden_states, per_layer_inputs = self.self_decoder( input_ids=input_ids, positions=positions, inputs_embeds=inputs_embeds, + per_layer_inputs=per_layer_inputs, **kwargs, ) hidden_states = self.cross_decoder( @@ -966,18 +971,17 @@ def forward( self, input_ids: Optional[torch.Tensor], positions: torch.Tensor, + per_layer_inputs: Optional[torch.Tensor] = None, + intermediate_tensors: Optional[IntermediateTensors] = None, inputs_embeds: Optional[torch.Tensor] = None, **kwargs, ) -> Union[torch.Tensor, IntermediateTensors]: - # Per layer inputs. - if input_ids is None: - raise ValueError("Passing None for input ids is not supported.") - if self.fast_prefill_enabled: hidden_states = self.fast_prefill_forward( input_ids, positions, inputs_embeds, + per_layer_inputs, **kwargs, ) else: @@ -985,6 +989,7 @@ def forward( input_ids, positions, inputs_embeds, + per_layer_inputs, **kwargs, ) hidden_states = self.altup_unembed(hidden_states) diff --git a/vllm/model_executor/models/gemma3n_mm.py b/vllm/model_executor/models/gemma3n_mm.py index a0c3bb50070b..4361cd5a1f40 100644 --- a/vllm/model_executor/models/gemma3n_mm.py +++ b/vllm/model_executor/models/gemma3n_mm.py @@ -624,7 +624,7 @@ def get_input_embeddings( # NOTE (NickLucche) Each pass needs tokens to compute PLE so we cache # them here, as the model forward has only access to the input_embeds. if input_ids is not None: - per_layer_inputs = self.language_model.model.get_per_layer_input_embeddings( + per_layer_inputs = self.language_model.model.self_decoder.get_per_layer_input_embeddings( input_ids) per_layer_inputs = per_layer_inputs.reshape( -1, self.config.text_config.num_hidden_layers,