Skip to content

Commit 7ac2e2d

Browse files
[Model] Restore Gemma3 GGUF multimodal support with GGUF-only guards
Restores custom attention mask generation for Gemma3 GGUF multimodal models that was partially reverted in #28995. Implements robust GGUF-only guards to ensure the feature only applies to GGUF models and does not affect HF models. Changes: - Add uses_custom_attention_masks() utility with GGUF file format check - Add uses_custom_attention_masks property to ModelConfig - Initialize uses_custom_attention_masks in GPUModelRunner - Restore generate_attention_masks() method to Gemma3ForConditionalGeneration - Implement 3-layer defense-in-depth guard mechanism The implementation uses check_gguf_file() to guarantee that custom attention mask logic only triggers for GGUF files, preventing the issue that caused the original revert where HF models incorrectly triggered the custom logic. Tested with GGUF models (1B, 4B, 270M) for both text-only and multimodal inference. HF model compatibility verified via pytest multimodal test suite. Signed-off-by: Luciano Martins <[email protected]>
1 parent 1840c5c commit 7ac2e2d

File tree

4 files changed

+143
-0
lines changed

4 files changed

+143
-0
lines changed

vllm/config/model.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@
3232
try_get_generation_config,
3333
try_get_safetensors_metadata,
3434
try_get_tokenizer_config,
35+
uses_custom_attention_masks,
3536
uses_mrope,
3637
)
3738
from vllm.transformers_utils.gguf_utils import (
@@ -1604,6 +1605,10 @@ def uses_alibi(self) -> bool:
16041605
def uses_mrope(self) -> bool:
16051606
return uses_mrope(self.hf_config)
16061607

1608+
@property
1609+
def uses_custom_attention_masks(self) -> bool:
1610+
return uses_custom_attention_masks(self.hf_config, self.model)
1611+
16071612
@property
16081613
def is_multimodal_model(self) -> bool:
16091614
return self.multimodal_config is not None

vllm/model_executor/models/gemma3_mm.py

Lines changed: 91 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -644,6 +644,97 @@ def forward(
644644

645645
return hidden_states
646646

647+
def generate_attention_masks(
648+
self,
649+
input_ids: torch.Tensor,
650+
positions: torch.Tensor,
651+
mask_dtype: torch.dtype,
652+
) -> dict[str, Any]:
653+
"""Generate custom attention masks for Gemma3 GGUF multimodal inputs.
654+
655+
This is called by V1 engine's gpu_model_runner during preprocessing
656+
to generate attention masks that allow bidirectional attention between
657+
image tokens while maintaining causal attention for text.
658+
659+
NOTE: This method is ONLY called for GGUF models due to the guard in
660+
gpu_model_runner.py. HF models handle attention masks internally.
661+
662+
Args:
663+
input_ids: Input token IDs
664+
positions: Position IDs
665+
mask_dtype: Data type for the attention mask tensors
666+
667+
Returns:
668+
Dictionary containing:
669+
- has_images: Always True (method is only called for multimodal)
670+
- seq_lens: List of sequence lengths
671+
- global_attn_masks: Global causal masks with bidirectional image attention
672+
- local_attn_masks: Local sliding window masks (if applicable)
673+
"""
674+
# NOTE(woosuk): Here, we distinguish the sequences by the position id 0.
675+
# This is a HACK. Fix this.
676+
start_indices = (positions == 0).cpu().nonzero()
677+
num_seqs = len(start_indices)
678+
seq_lens = []
679+
for i in range(num_seqs):
680+
start_idx = start_indices[i].item()
681+
end_idx = (
682+
start_indices[i + 1].item() if i < num_seqs - 1 else len(input_ids)
683+
)
684+
seq_lens.append(end_idx - start_idx)
685+
686+
global_attn_masks = []
687+
local_attn_masks = []
688+
start_idx = 0
689+
for seq_idx, seq_len in enumerate(seq_lens):
690+
end_idx = start_idx + seq_len
691+
input_token_ids = input_ids[start_idx:end_idx]
692+
693+
# Find image token positions
694+
img_pos = input_token_ids == self.config.image_token_index
695+
696+
start_idx = end_idx
697+
698+
# Create a global causal mask
699+
global_attn_mask = torch.empty(
700+
1,
701+
1,
702+
seq_len,
703+
seq_len,
704+
dtype=mask_dtype,
705+
device=input_ids.device,
706+
)
707+
global_attn_mask.fill_(float("-inf"))
708+
# Fill the lower triangle with 0 (causal attention)
709+
global_attn_mask = global_attn_mask.triu(diagonal=1)
710+
711+
# Enable bidirectional attention between image tokens
712+
# Use advanced indexing for better performance
713+
img_indices = torch.where(img_pos)[0]
714+
global_attn_mask[:, :, img_indices[:, None], img_indices] = 0
715+
global_attn_masks.append(global_attn_mask)
716+
717+
# GGUF compatibility: config might be Gemma3TextConfig directly
718+
text_config = getattr(self.config, "text_config", self.config)
719+
sliding_window = text_config.sliding_window
720+
if sliding_window is not None:
721+
# Create a mask for tokens outside the sliding window
722+
outside_window_mask = torch.ones_like(
723+
global_attn_mask, dtype=torch.bool
724+
).tril(diagonal=-sliding_window)
725+
726+
# Start with the global mask and apply the sliding window constraint
727+
local_attn_mask = global_attn_mask.clone()
728+
local_attn_mask[outside_window_mask] = float("-inf")
729+
local_attn_masks.append(local_attn_mask)
730+
731+
return {
732+
"has_images": True,
733+
"seq_lens": seq_lens,
734+
"global_attn_masks": global_attn_masks,
735+
"local_attn_masks": local_attn_masks,
736+
}
737+
647738
def compute_logits(
648739
self,
649740
hidden_states: torch.Tensor,

vllm/transformers_utils/config.py

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -520,6 +520,32 @@ def is_interleaved(config: PretrainedConfig) -> bool:
520520
return False
521521

522522

523+
def uses_custom_attention_masks(config: PretrainedConfig, model_path: str) -> bool:
524+
"""Detect if model uses custom attention mask generation for multimodal.
525+
526+
Some multimodal models require custom attention masks that enable
527+
bidirectional attention between image tokens while maintaining causal
528+
attention for text tokens. Currently applies ONLY to Gemma3 GGUF models.
529+
530+
Args:
531+
config: HuggingFace model config
532+
model_path: Path to the model (used to check if it's a GGUF file)
533+
534+
Returns:
535+
True if the model is a GGUF Gemma3 multimodal model
536+
"""
537+
from vllm.transformers_utils.utils import check_gguf_file
538+
539+
# Check architecture
540+
architectures = getattr(config, "architectures", [])
541+
is_gemma3 = "Gemma3ForConditionalGeneration" in architectures
542+
543+
# CRITICAL: Only return True for GGUF models
544+
is_gguf = check_gguf_file(model_path)
545+
546+
return is_gemma3 and is_gguf
547+
548+
523549
def _maybe_update_auto_config_kwargs(kwargs: dict[str, Any], model_type: str):
524550
"""
525551
Update kwargs for AutoConfig initialization based on model_type

vllm/v1/worker/gpu_model_runner.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -324,6 +324,7 @@ def __init__(
324324
# Multi-modal data support
325325
self.mm_registry = MULTIMODAL_REGISTRY
326326
self.uses_mrope = model_config.uses_mrope
327+
self.uses_custom_attention_masks = model_config.uses_custom_attention_masks
327328
self.supports_mm_inputs = self.mm_registry.supports_multimodal_inputs(
328329
model_config
329330
)
@@ -2336,6 +2337,26 @@ def _preprocess(
23362337
**self._init_model_kwargs(num_scheduled_tokens),
23372338
**self._extract_mm_kwargs(scheduler_output),
23382339
}
2340+
2341+
# Generate custom attention masks for models that require them.
2342+
# V1 pre-generates embeddings, so forward() skips prepare_attn_masks().
2343+
# Check mm_features (mm_embeds is empty during decode).
2344+
has_mm_features = any(
2345+
req_state.mm_features for req_state in self.requests.values()
2346+
)
2347+
# Defense-in-depth: Check flag (GGUF-only), multimodal presence,
2348+
# and method existence before generating custom attention masks.
2349+
if (
2350+
self.uses_custom_attention_masks
2351+
and has_mm_features
2352+
and hasattr(self.model, "generate_attention_masks")
2353+
):
2354+
mask_kwargs = self.model.generate_attention_masks(
2355+
self.input_ids.gpu[:num_scheduled_tokens],
2356+
self.positions.gpu[:num_scheduled_tokens],
2357+
mask_dtype=self.model.dtype,
2358+
)
2359+
model_kwargs.update(mask_kwargs)
23392360
elif self.enable_prompt_embeds and is_first_rank:
23402361
# Get the input embeddings for the tokens that are not input embeds,
23412362
# then put them into the appropriate positions.

0 commit comments

Comments
 (0)