Skip to content

[V1] Enable prefill optimization for Gemma3n #22628

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
55 changes: 0 additions & 55 deletions tests/v1/e2e/test_kv_sharing_fast_prefill.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,73 +2,20 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project

import random
from typing import Optional, Union

import pytest
import torch

from vllm import LLM, SamplingParams
from vllm.config import CompilationConfig, CompilationLevel
from vllm.distributed import cleanup_dist_env_and_memory
from vllm.forward_context import get_forward_context
from vllm.model_executor.models.gemma3n import Gemma3nForConditionalGeneration
from vllm.model_executor.models.registry import ModelRegistry
from vllm.model_executor.models.utils import extract_layer_index
from vllm.sequence import IntermediateTensors

from ...utils import fork_new_process_for_each_test

# global seed
SEED = 42


class TestGemma3nForConditionalGeneration(Gemma3nForConditionalGeneration):

def forward(
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
intermediate_tensors: Optional[IntermediateTensors] = None,
inputs_embeds: Optional[torch.Tensor] = None,
**kwargs,
) -> Union[torch.Tensor, IntermediateTensors]:
hidden_states = self.model(input_ids, positions, intermediate_tensors,
inputs_embeds, **kwargs)
attn_metadata = get_forward_context().attn_metadata
# attn_metadata is None during dummy runs
if (attn_metadata is not None
and self.cache_config.kv_sharing_fast_prefill):
assert isinstance(attn_metadata, dict) # true in V1
# Gemma3n-E2B has 30 layers, with last 20 layers being
# cross-decoder layers. Check attention metadata is correct
for layer_name, metadata in attn_metadata.items():
layer_idx = extract_layer_index(layer_name)
if layer_idx >= 20:
assert hasattr(metadata, 'logits_indices_padded')
assert hasattr(metadata, 'num_logits_indices')
else:
assert not hasattr(metadata, 'logits_indices_padded')
assert not hasattr(metadata, 'num_logits_indices')

# Last layer will be a KV sharing layer
layer_attn_metadata = attn_metadata[
self.model.language_model.layers[-1].self_attn.attn.layer_name]
logits_indices_padded = (layer_attn_metadata.logits_indices_padded)
assert logits_indices_padded is not None
num_logits_indices = layer_attn_metadata.num_logits_indices
assert num_logits_indices > 0
# Reset hidden states to random values and
# only set logits at logits_indices to valid values
# Because logits_indices are the only positions that are used
# for output token sampling, this still produces same outputs
logits_hs = hidden_states[logits_indices_padded]
hidden_states = torch.randn_like(hidden_states)
gen_indices = logits_indices_padded[:num_logits_indices]
hidden_states[gen_indices] = logits_hs[:num_logits_indices]

return hidden_states


@pytest.fixture
def test_prompts():
"""
Expand Down Expand Up @@ -122,8 +69,6 @@ def test_kv_sharing_fast_prefill(
enforce_eager: bool,
test_prompts: list[str],
):
ModelRegistry.register_model("Gemma3nForConditionalGeneration",
TestGemma3nForConditionalGeneration)
sampling_params = SamplingParams(temperature=0.0, max_tokens=100)
compilation_config = CompilationConfig(
# This allows vLLM compilation backend to handle allocating and
Expand Down
44 changes: 11 additions & 33 deletions vllm/attention/layers/chunked_local_attention.py
Original file line number Diff line number Diff line change
@@ -1,48 +1,19 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import functools
from typing import List, Optional

import torch

from vllm import envs
from vllm.attention.backends.abstract import AttentionBackend
from vllm.attention.selector import get_attn_backend
from vllm.config import CacheConfig, QuantizationConfig
from vllm.v1.attention.backends.utils import (
CommonAttentionMetadata, make_local_attention_virtual_batches,
subclass_attention_backend, subclass_attention_metadata_builder)
CommonAttentionMetadata, create_custom_attention_backend,
make_local_attention_virtual_batches)

from ..layer import Attention


@functools.lru_cache
def create_chunked_local_attention_backend(
underlying_attn_backend: AttentionBackend,
attention_chunk_size: int,
block_size: int,
) -> type[AttentionBackend]:
prefix = f"ChunkedLocalAttention_{attention_chunk_size}_{block_size}_"

def build_preprocess_fn(cm: CommonAttentionMetadata):
return make_local_attention_virtual_batches(attention_chunk_size, cm,
block_size)

# Dynamically create a new attention backend that wraps the
# underlying attention backend but applies
# `make_local_attention_virtual_batches` before calling `build(...)`
builder_cls = subclass_attention_metadata_builder(
name_prefix=prefix,
builder_cls=underlying_attn_backend.get_builder_cls(),
build_preprocess_fn=build_preprocess_fn)
attn_backend = subclass_attention_backend(
name_prefix=prefix,
attention_backend_cls=underlying_attn_backend,
builder_cls=builder_cls)

return attn_backend


class ChunkedLocalAttention(Attention):

def __init__(self,
Expand All @@ -69,8 +40,15 @@ def __init__(self,
kv_cache_dtype,
block_size)

attn_backend = create_chunked_local_attention_backend(
underlying_attn_backend, attention_chunk_size, block_size)
backend_prefix = \
f"ChunkedLocalAttention_{attention_chunk_size}_{block_size}_"

def build_preprocess_fn(cm: CommonAttentionMetadata):
return make_local_attention_virtual_batches(
attention_chunk_size, cm, block_size)

attn_backend = create_custom_attention_backend(
backend_prefix, underlying_attn_backend, build_preprocess_fn)
else:
# in v0 the local attention is handled inside the backends
attn_backend = None
Expand Down
22 changes: 20 additions & 2 deletions vllm/compilation/decorators.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,14 @@ def support_torch_compile(
...


@overload
def support_torch_compile(
*,
compile_cond: Optional[Callable[[VllmConfig], bool]] = None,
) -> Callable[[_T], _T]:
...


@overload
def support_torch_compile(cls: _T) -> _T:
...
Expand All @@ -69,6 +77,7 @@ def support_torch_compile(
cls: Optional[_T] = None,
*,
dynamic_arg_dims: Optional[dict[str, Union[int, list[int]]]] = None,
compile_cond: Optional[Callable[[VllmConfig], bool]] = None,
) -> Union[Callable[[_T], _T], _T]:
"""
A decorator to add support for compiling the forward method of a class.
Expand Down Expand Up @@ -118,6 +127,11 @@ def forward(self, x: torch.Tensor, y: Optional[torch.Tensor]):
NOTE: if an argument is `None`, it should always be passed as `None` during
the lifetime of the model, otherwise, it cannot be captured as a single
computation graph.
`compile_cond` is a function that takes a `VllmConfig` object as input and
returns a boolean value indicating whether to compile the model or not.
This is useful if you want to compile the model only when certain
conditions are met.
"""

def cls_decorator_helper(cls: _T) -> _T:
Expand Down Expand Up @@ -149,7 +163,8 @@ def cls_decorator_helper(cls: _T) -> _T:
if k not in sig.parameters:
raise ValueError(
f"Argument {k} not found in the forward method of {cls}")
return _support_torch_compile(cls, inferred_dynamic_arg_dims)
return _support_torch_compile(cls, inferred_dynamic_arg_dims,
compile_cond)

if cls is not None:
# use `support_torch_compile` as a decorator without arguments
Expand All @@ -162,6 +177,7 @@ def cls_decorator_helper(cls: _T) -> _T:
def _support_torch_compile(
cls: _T,
dynamic_arg_dims: dict[str, Union[int, list[int]]],
compile_cond: Optional[Callable[[VllmConfig], bool]] = None,
) -> _T:
"""
A decorator to add support for compiling the forward method of a class.
Expand All @@ -182,13 +198,15 @@ def _support_torch_compile(
def __init__(self, *, vllm_config: VllmConfig, prefix: str = '', **kwargs):
old_init(self, vllm_config=vllm_config, prefix=prefix, **kwargs)
self.vllm_config = vllm_config
compile_cond_satisfied = compile_cond is None or compile_cond(
vllm_config)
# for CompilationLevel.DYNAMO_AS_IS , the upper level model runner
# will handle the compilation, so we don't need to do anything here.
self.do_not_compile = \
vllm_config.compilation_config.level in [
CompilationLevel.NO_COMPILATION, CompilationLevel.DYNAMO_AS_IS
] or not supports_dynamo() or _should_ignore_torch_compile(
self.__class__)
self.__class__) or not compile_cond_satisfied
if self.do_not_compile:
return

Expand Down
6 changes: 6 additions & 0 deletions vllm/config/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3908,6 +3908,12 @@ def __post_init__(self):
# local attention.
self.scheduler_config.disable_hybrid_kv_cache_manager = True

if self.cache_config.kv_sharing_fast_prefill:
# There is an IMA issue currently when using fast prefill with
# hybrid kv cache manager (e.g. interleaved sliding window)
# TODO(sarckk): investigate and fix
self.scheduler_config.disable_hybrid_kv_cache_manager = True

def update_sizes_for_sequence_parallelism(self,
possible_sizes: list) -> list:
# remove the sizes that not multiple of tp_size when
Expand Down
12 changes: 7 additions & 5 deletions vllm/config/cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,12 +133,19 @@ def __post_init__(self) -> None:

self._verify_cache_dtype()
self._verify_prefix_caching()
self._verify_kv_sharing_fast_prefill()

def metrics_info(self):
# convert cache_config to dict(key: str, value: str) for prometheus
# metrics info
return {key: str(value) for key, value in self.__dict__.items()}

def _verify_kv_sharing_fast_prefill(self) -> None:
if self.kv_sharing_fast_prefill and not envs.VLLM_USE_V1:
raise NotImplementedError(
"Fast prefill optimization for KV sharing is not supported "
"in V0 currently.")

@model_validator(mode='after')
def _verify_args(self) -> Self:
if self.cpu_offload_gb < 0:
Expand All @@ -150,11 +157,6 @@ def _verify_args(self) -> Self:
"GPU memory utilization must be less than 1.0. Got "
f"{self.gpu_memory_utilization}.")

if self.kv_sharing_fast_prefill:
logger.warning_once(
"--kv-sharing-fast-prefill is currently work in progress "
"and not functional yet (i.e. no prefill savings)")

return self

def _verify_cache_dtype(self) -> None:
Expand Down
Loading