Skip to content

Commit 4556713

Browse files
committed
add patch
1 parent 4bdaa1c commit 4556713

File tree

2 files changed

+53
-1
lines changed

2 files changed

+53
-1
lines changed

_unittests/ut_tasks/test_tasks_image_text_to_text.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -54,7 +54,7 @@ def test_image_text_to_text_tiny_gemma3(self):
5454

5555
@hide_stdout()
5656
@requires_transformers("4.56.2")
57-
@requires_torch("2.7.99")
57+
@requires_torch("2.8.99")
5858
def test_image_text_to_text_gemma3_4b_it(self):
5959
mid = "google/gemma-3-4b-it"
6060
data = get_untrained_model_with_inputs(

onnx_diagnostic/torch_export_patches/patches/patch_transformers.py

Lines changed: 52 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1840,3 +1840,55 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
18401840
batch_size, sequence_length, hidden_dim
18411841
)
18421842
return final_hidden_states, router_logits
1843+
1844+
1845+
try:
1846+
import transformers.models.gemma3.modeling_gemma3
1847+
1848+
patch_gemma3 = True
1849+
except ImportError:
1850+
patch_gemma3 = False
1851+
1852+
1853+
if patch_gemma3:
1854+
1855+
class patched_Gemma3Model(torch.nn.Module):
1856+
_PATCHES_ = ["get_placeholder_mask"]
1857+
_PATCHED_CLASS_ = transformers.models.gemma3.modeling_gemma3.Gemma3Model
1858+
1859+
def get_placeholder_mask(
1860+
self,
1861+
input_ids: torch.LongTensor,
1862+
inputs_embeds: torch.FloatTensor,
1863+
image_features: torch.FloatTensor,
1864+
):
1865+
if input_ids is None:
1866+
special_image_mask = inputs_embeds == self.get_input_embeddings()(
1867+
torch.tensor(
1868+
self.config.image_token_id,
1869+
dtype=torch.long,
1870+
device=inputs_embeds.device,
1871+
)
1872+
)
1873+
special_image_mask = special_image_mask.all(-1)
1874+
else:
1875+
special_image_mask = input_ids == self.config.image_token_id
1876+
1877+
n_image_tokens = special_image_mask.sum()
1878+
special_image_mask = (
1879+
special_image_mask.unsqueeze(-1)
1880+
.expand_as(inputs_embeds)
1881+
.to(inputs_embeds.device)
1882+
)
1883+
n_image_features = image_features.shape[0] * image_features.shape[1]
1884+
# PATCHED: torch._check
1885+
# if inputs_embeds[special_image_mask].numel() != image_features.numel():
1886+
# raise ValueError( ... )
1887+
torch._check(
1888+
inputs_embeds[special_image_mask].numel() == image_features.numel(),
1889+
lambda: (
1890+
f"Image features and image tokens do not match: tokens: "
1891+
f"{n_image_tokens}, features {n_image_features}"
1892+
),
1893+
)
1894+
return special_image_mask

0 commit comments

Comments
 (0)