From 7ac2e2d28cfdc9c5d342252fd93f5baddf3baf24 Mon Sep 17 00:00:00 2001 From: Luciano Martins Date: Fri, 21 Nov 2025 19:19:36 +0000 Subject: [PATCH] [Model] Restore Gemma3 GGUF multimodal support with GGUF-only guards Restores custom attention mask generation for Gemma3 GGUF multimodal models that was partially reverted in #28995. Implements robust GGUF-only guards to ensure the feature only applies to GGUF models and does not affect HF models. Changes: - Add uses_custom_attention_masks() utility with GGUF file format check - Add uses_custom_attention_masks property to ModelConfig - Initialize uses_custom_attention_masks in GPUModelRunner - Restore generate_attention_masks() method to Gemma3ForConditionalGeneration - Implement 3-layer defense-in-depth guard mechanism The implementation uses check_gguf_file() to guarantee that custom attention mask logic only triggers for GGUF files, preventing the issue that caused the original revert where HF models incorrectly triggered the custom logic. Tested with GGUF models (1B, 4B, 270M) for both text-only and multimodal inference. HF model compatibility verified via pytest multimodal test suite. Signed-off-by: Luciano Martins --- vllm/config/model.py | 5 ++ vllm/model_executor/models/gemma3_mm.py | 91 +++++++++++++++++++++++++ vllm/transformers_utils/config.py | 26 +++++++ vllm/v1/worker/gpu_model_runner.py | 21 ++++++ 4 files changed, 143 insertions(+) diff --git a/vllm/config/model.py b/vllm/config/model.py index 8f59673f4e1c..fdc04ec1f6a6 100644 --- a/vllm/config/model.py +++ b/vllm/config/model.py @@ -32,6 +32,7 @@ try_get_generation_config, try_get_safetensors_metadata, try_get_tokenizer_config, + uses_custom_attention_masks, uses_mrope, ) from vllm.transformers_utils.gguf_utils import ( @@ -1604,6 +1605,10 @@ def uses_alibi(self) -> bool: def uses_mrope(self) -> bool: return uses_mrope(self.hf_config) + @property + def uses_custom_attention_masks(self) -> bool: + return uses_custom_attention_masks(self.hf_config, self.model) + @property def is_multimodal_model(self) -> bool: return self.multimodal_config is not None diff --git a/vllm/model_executor/models/gemma3_mm.py b/vllm/model_executor/models/gemma3_mm.py index 43c69e5e1399..3feb6a494de9 100644 --- a/vllm/model_executor/models/gemma3_mm.py +++ b/vllm/model_executor/models/gemma3_mm.py @@ -644,6 +644,97 @@ def forward( return hidden_states + def generate_attention_masks( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + mask_dtype: torch.dtype, + ) -> dict[str, Any]: + """Generate custom attention masks for Gemma3 GGUF multimodal inputs. + + This is called by V1 engine's gpu_model_runner during preprocessing + to generate attention masks that allow bidirectional attention between + image tokens while maintaining causal attention for text. + + NOTE: This method is ONLY called for GGUF models due to the guard in + gpu_model_runner.py. HF models handle attention masks internally. + + Args: + input_ids: Input token IDs + positions: Position IDs + mask_dtype: Data type for the attention mask tensors + + Returns: + Dictionary containing: + - has_images: Always True (method is only called for multimodal) + - seq_lens: List of sequence lengths + - global_attn_masks: Global causal masks with bidirectional image attention + - local_attn_masks: Local sliding window masks (if applicable) + """ + # NOTE(woosuk): Here, we distinguish the sequences by the position id 0. + # This is a HACK. Fix this. + start_indices = (positions == 0).cpu().nonzero() + num_seqs = len(start_indices) + seq_lens = [] + for i in range(num_seqs): + start_idx = start_indices[i].item() + end_idx = ( + start_indices[i + 1].item() if i < num_seqs - 1 else len(input_ids) + ) + seq_lens.append(end_idx - start_idx) + + global_attn_masks = [] + local_attn_masks = [] + start_idx = 0 + for seq_idx, seq_len in enumerate(seq_lens): + end_idx = start_idx + seq_len + input_token_ids = input_ids[start_idx:end_idx] + + # Find image token positions + img_pos = input_token_ids == self.config.image_token_index + + start_idx = end_idx + + # Create a global causal mask + global_attn_mask = torch.empty( + 1, + 1, + seq_len, + seq_len, + dtype=mask_dtype, + device=input_ids.device, + ) + global_attn_mask.fill_(float("-inf")) + # Fill the lower triangle with 0 (causal attention) + global_attn_mask = global_attn_mask.triu(diagonal=1) + + # Enable bidirectional attention between image tokens + # Use advanced indexing for better performance + img_indices = torch.where(img_pos)[0] + global_attn_mask[:, :, img_indices[:, None], img_indices] = 0 + global_attn_masks.append(global_attn_mask) + + # GGUF compatibility: config might be Gemma3TextConfig directly + text_config = getattr(self.config, "text_config", self.config) + sliding_window = text_config.sliding_window + if sliding_window is not None: + # Create a mask for tokens outside the sliding window + outside_window_mask = torch.ones_like( + global_attn_mask, dtype=torch.bool + ).tril(diagonal=-sliding_window) + + # Start with the global mask and apply the sliding window constraint + local_attn_mask = global_attn_mask.clone() + local_attn_mask[outside_window_mask] = float("-inf") + local_attn_masks.append(local_attn_mask) + + return { + "has_images": True, + "seq_lens": seq_lens, + "global_attn_masks": global_attn_masks, + "local_attn_masks": local_attn_masks, + } + def compute_logits( self, hidden_states: torch.Tensor, diff --git a/vllm/transformers_utils/config.py b/vllm/transformers_utils/config.py index df24738477e7..2306f6dbe04d 100644 --- a/vllm/transformers_utils/config.py +++ b/vllm/transformers_utils/config.py @@ -520,6 +520,32 @@ def is_interleaved(config: PretrainedConfig) -> bool: return False +def uses_custom_attention_masks(config: PretrainedConfig, model_path: str) -> bool: + """Detect if model uses custom attention mask generation for multimodal. + + Some multimodal models require custom attention masks that enable + bidirectional attention between image tokens while maintaining causal + attention for text tokens. Currently applies ONLY to Gemma3 GGUF models. + + Args: + config: HuggingFace model config + model_path: Path to the model (used to check if it's a GGUF file) + + Returns: + True if the model is a GGUF Gemma3 multimodal model + """ + from vllm.transformers_utils.utils import check_gguf_file + + # Check architecture + architectures = getattr(config, "architectures", []) + is_gemma3 = "Gemma3ForConditionalGeneration" in architectures + + # CRITICAL: Only return True for GGUF models + is_gguf = check_gguf_file(model_path) + + return is_gemma3 and is_gguf + + def _maybe_update_auto_config_kwargs(kwargs: dict[str, Any], model_type: str): """ Update kwargs for AutoConfig initialization based on model_type diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index e786cd8bc7c9..d6f9336d50e2 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -324,6 +324,7 @@ def __init__( # Multi-modal data support self.mm_registry = MULTIMODAL_REGISTRY self.uses_mrope = model_config.uses_mrope + self.uses_custom_attention_masks = model_config.uses_custom_attention_masks self.supports_mm_inputs = self.mm_registry.supports_multimodal_inputs( model_config ) @@ -2336,6 +2337,26 @@ def _preprocess( **self._init_model_kwargs(num_scheduled_tokens), **self._extract_mm_kwargs(scheduler_output), } + + # Generate custom attention masks for models that require them. + # V1 pre-generates embeddings, so forward() skips prepare_attn_masks(). + # Check mm_features (mm_embeds is empty during decode). + has_mm_features = any( + req_state.mm_features for req_state in self.requests.values() + ) + # Defense-in-depth: Check flag (GGUF-only), multimodal presence, + # and method existence before generating custom attention masks. + if ( + self.uses_custom_attention_masks + and has_mm_features + and hasattr(self.model, "generate_attention_masks") + ): + mask_kwargs = self.model.generate_attention_masks( + self.input_ids.gpu[:num_scheduled_tokens], + self.positions.gpu[:num_scheduled_tokens], + mask_dtype=self.model.dtype, + ) + model_kwargs.update(mask_kwargs) elif self.enable_prompt_embeds and is_first_rank: # Get the input embeddings for the tokens that are not input embeds, # then put them into the appropriate positions.