11
11
from vllm .config import CompilationConfig , CompilationLevel
12
12
from vllm .distributed import cleanup_dist_env_and_memory
13
13
from vllm .forward_context import get_forward_context
14
- from vllm .model_executor .models .gemma3n import Gemma3nForConditionalGeneration
14
+ from vllm .model_executor .models .gemma3n_mm import (
15
+ Gemma3nForConditionalGeneration )
15
16
from vllm .model_executor .models .registry import ModelRegistry
16
17
from vllm .model_executor .models .utils import extract_layer_index
17
18
from vllm .sequence import IntermediateTensors
@@ -32,12 +33,13 @@ def forward(
32
33
inputs_embeds : Optional [torch .Tensor ] = None ,
33
34
** kwargs ,
34
35
) -> Union [torch .Tensor , IntermediateTensors ]:
35
- hidden_states = self .model (input_ids , positions , intermediate_tensors ,
36
- inputs_embeds , ** kwargs )
36
+ hidden_states = super ().forward (input_ids , positions ,
37
+ intermediate_tensors , inputs_embeds ,
38
+ ** kwargs )
37
39
attn_metadata = get_forward_context ().attn_metadata
38
40
# attn_metadata is None during dummy runs
39
41
if (attn_metadata is not None
40
- and self .cache_config .kv_sharing_fast_prefill ):
42
+ and self .language_model . cache_config .kv_sharing_fast_prefill ):
41
43
assert isinstance (attn_metadata , dict ) # true in V1
42
44
# Gemma3n-E2B has 30 layers, with last 20 layers being
43
45
# cross-decoder layers. Check attention metadata is correct
@@ -52,7 +54,7 @@ def forward(
52
54
53
55
# Last layer will be a KV sharing layer
54
56
layer_attn_metadata = attn_metadata [
55
- self .model . language_model .layers [- 1 ].self_attn .attn .layer_name ]
57
+ self .language_model . model .layers [- 1 ].self_attn .attn .layer_name ]
56
58
logits_indices_padded = (layer_attn_metadata .logits_indices_padded )
57
59
assert logits_indices_padded is not None
58
60
num_logits_indices = layer_attn_metadata .num_logits_indices
0 commit comments