|
2 | 2 | # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
3 | 3 |
|
4 | 4 | import random
|
5 |
| -from typing import Optional, Union |
6 | 5 |
|
7 | 6 | import pytest
|
8 | 7 | import torch
|
9 | 8 |
|
10 | 9 | from vllm import LLM, SamplingParams
|
11 | 10 | from vllm.config import CompilationConfig, CompilationLevel
|
12 | 11 | from vllm.distributed import cleanup_dist_env_and_memory
|
13 |
| -from vllm.forward_context import get_forward_context |
14 |
| -from vllm.model_executor.models.gemma3n_mm import ( |
15 |
| - Gemma3nForConditionalGeneration) |
16 |
| -from vllm.model_executor.models.registry import ModelRegistry |
17 |
| -from vllm.model_executor.models.utils import extract_layer_index |
18 |
| -from vllm.sequence import IntermediateTensors |
19 | 12 |
|
20 | 13 | from ...utils import fork_new_process_for_each_test
|
21 | 14 |
|
22 | 15 | # global seed
|
23 | 16 | SEED = 42
|
24 | 17 |
|
25 | 18 |
|
26 |
| -class TestGemma3nForConditionalGeneration(Gemma3nForConditionalGeneration): |
27 |
| - |
28 |
| - def forward( |
29 |
| - self, |
30 |
| - input_ids: torch.Tensor, |
31 |
| - positions: torch.Tensor, |
32 |
| - intermediate_tensors: Optional[IntermediateTensors] = None, |
33 |
| - inputs_embeds: Optional[torch.Tensor] = None, |
34 |
| - **kwargs, |
35 |
| - ) -> Union[torch.Tensor, IntermediateTensors]: |
36 |
| - hidden_states = super().forward(input_ids, positions, |
37 |
| - intermediate_tensors, inputs_embeds, |
38 |
| - **kwargs) |
39 |
| - attn_metadata = get_forward_context().attn_metadata |
40 |
| - # attn_metadata is None during dummy runs |
41 |
| - if (attn_metadata is not None |
42 |
| - and self.language_model.cache_config.kv_sharing_fast_prefill): |
43 |
| - assert isinstance(attn_metadata, dict) # true in V1 |
44 |
| - # Gemma3n-E2B has 30 layers, with last 20 layers being |
45 |
| - # cross-decoder layers. Check attention metadata is correct |
46 |
| - for layer_name, metadata in attn_metadata.items(): |
47 |
| - layer_idx = extract_layer_index(layer_name) |
48 |
| - if layer_idx >= 20: |
49 |
| - assert hasattr(metadata, 'logits_indices_padded') |
50 |
| - assert hasattr(metadata, 'num_logits_indices') |
51 |
| - else: |
52 |
| - assert not hasattr(metadata, 'logits_indices_padded') |
53 |
| - assert not hasattr(metadata, 'num_logits_indices') |
54 |
| - |
55 |
| - # Last layer will be a KV sharing layer |
56 |
| - layer_attn_metadata = attn_metadata[ |
57 |
| - self.language_model.model.layers[-1].self_attn.attn.layer_name] |
58 |
| - logits_indices_padded = (layer_attn_metadata.logits_indices_padded) |
59 |
| - assert logits_indices_padded is not None |
60 |
| - num_logits_indices = layer_attn_metadata.num_logits_indices |
61 |
| - assert num_logits_indices > 0 |
62 |
| - # Reset hidden states to random values and |
63 |
| - # only set logits at logits_indices to valid values |
64 |
| - # Because logits_indices are the only positions that are used |
65 |
| - # for output token sampling, this still produces same outputs |
66 |
| - logits_hs = hidden_states[logits_indices_padded] |
67 |
| - hidden_states = torch.randn_like(hidden_states) |
68 |
| - gen_indices = logits_indices_padded[:num_logits_indices] |
69 |
| - hidden_states[gen_indices] = logits_hs[:num_logits_indices] |
70 |
| - |
71 |
| - return hidden_states |
72 |
| - |
73 |
| - |
74 | 19 | @pytest.fixture
|
75 | 20 | def test_prompts():
|
76 | 21 | """
|
@@ -124,8 +69,6 @@ def test_kv_sharing_fast_prefill(
|
124 | 69 | enforce_eager: bool,
|
125 | 70 | test_prompts: list[str],
|
126 | 71 | ):
|
127 |
| - ModelRegistry.register_model("Gemma3nForConditionalGeneration", |
128 |
| - TestGemma3nForConditionalGeneration) |
129 | 72 | sampling_params = SamplingParams(temperature=0.0, max_tokens=100)
|
130 | 73 | compilation_config = CompilationConfig(
|
131 | 74 | # This allows vLLM compilation backend to handle allocating and
|
|
0 commit comments