@@ -1840,3 +1840,55 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
18401840 batch_size , sequence_length , hidden_dim
18411841 )
18421842 return final_hidden_states , router_logits
1843+
1844+
1845+ try :
1846+ import transformers .models .gemma3 .modeling_gemma3
1847+
1848+ patch_gemma3 = True
1849+ except ImportError :
1850+ patch_gemma3 = False
1851+
1852+
1853+ if patch_gemma3 :
1854+
1855+ class patched_Gemma3Model (torch .nn .Module ):
1856+ _PATCHES_ = ["get_placeholder_mask" ]
1857+ _PATCHED_CLASS_ = transformers .models .gemma3 .modeling_gemma3 .Gemma3Model
1858+
1859+ def get_placeholder_mask (
1860+ self ,
1861+ input_ids : torch .LongTensor ,
1862+ inputs_embeds : torch .FloatTensor ,
1863+ image_features : torch .FloatTensor ,
1864+ ):
1865+ if input_ids is None :
1866+ special_image_mask = inputs_embeds == self .get_input_embeddings ()(
1867+ torch .tensor (
1868+ self .config .image_token_id ,
1869+ dtype = torch .long ,
1870+ device = inputs_embeds .device ,
1871+ )
1872+ )
1873+ special_image_mask = special_image_mask .all (- 1 )
1874+ else :
1875+ special_image_mask = input_ids == self .config .image_token_id
1876+
1877+ n_image_tokens = special_image_mask .sum ()
1878+ special_image_mask = (
1879+ special_image_mask .unsqueeze (- 1 )
1880+ .expand_as (inputs_embeds )
1881+ .to (inputs_embeds .device )
1882+ )
1883+ n_image_features = image_features .shape [0 ] * image_features .shape [1 ]
1884+ # PATCHED: torch._check
1885+ # if inputs_embeds[special_image_mask].numel() != image_features.numel():
1886+ # raise ValueError( ... )
1887+ torch ._check (
1888+ inputs_embeds [special_image_mask ].numel () == image_features .numel (),
1889+ lambda : (
1890+ f"Image features and image tokens do not match: tokens: "
1891+ f"{ n_image_tokens } , features { n_image_features } "
1892+ ),
1893+ )
1894+ return special_image_mask
0 commit comments