Skip to content

Commit bbce496

Browse files
authored
add one more model to test (#268)
1 parent e632c4d commit bbce496

File tree

2 files changed

+71
-0
lines changed

2 files changed

+71
-0
lines changed
Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,33 @@
1+
import unittest
2+
import torch
3+
from onnx_diagnostic.ext_test_case import (
4+
ExtTestCase,
5+
hide_stdout,
6+
requires_transformers,
7+
requires_torch,
8+
)
9+
from onnx_diagnostic.helpers.torch_helper import torch_deepcopy
10+
from onnx_diagnostic.torch_models.hghub.model_inputs import get_untrained_model_with_inputs
11+
from onnx_diagnostic.torch_export_patches import torch_export_patches
12+
from onnx_diagnostic.torch_export_patches.patch_inputs import use_dyn_not_str
13+
14+
15+
class TestTasksTextGeneration(ExtTestCase):
16+
@hide_stdout()
17+
@requires_transformers("4.53")
18+
@requires_torch("2.7.99")
19+
def test_image_text_to_text_gemma3_for_causallm(self):
20+
mid = "hf-internal-testing/tiny-random-Gemma3ForCausalLM"
21+
data = get_untrained_model_with_inputs(mid, verbose=1, add_second_input=True)
22+
self.assertEqual(data["task"], "text-generation")
23+
model, inputs, ds = data["model"], data["inputs"], data["dynamic_shapes"]
24+
model(**torch_deepcopy(inputs))
25+
model(**data["inputs2"])
26+
with torch_export_patches(patch_transformers=True, verbose=10, patch_torch=False):
27+
torch.export.export(
28+
model, (), kwargs=inputs, dynamic_shapes=use_dyn_not_str(ds), strict=False
29+
)
30+
31+
32+
if __name__ == "__main__":
33+
unittest.main(verbosity=2)

onnx_diagnostic/torch_models/hghub/hub_data_cached_configs.py

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4865,3 +4865,41 @@ def _ccached_google_gemma_3_4b_it_like():
48654865
},
48664866
}
48674867
)
4868+
4869+
4870+
def _ccached_hf_internal_testing_tiny_random_gemma3_for_causal_lm():
4871+
"hf-internal-testing/tiny-random-Gemma3ForCausalLM"
4872+
return transformers.Gemma3TextConfig(
4873+
**{
4874+
"architectures": ["Gemma3ForCausalLM"],
4875+
"attention_bias": false,
4876+
"attention_dropout": 0.0,
4877+
"attn_logit_softcapping": null,
4878+
"bos_token_id": 2,
4879+
"cache_implementation": "hybrid",
4880+
"eos_token_id": [1, 106],
4881+
"final_logit_softcapping": null,
4882+
"head_dim": 8,
4883+
"hidden_activation": "gelu_pytorch_tanh",
4884+
"hidden_size": 16,
4885+
"initializer_range": 0.02,
4886+
"intermediate_size": 32,
4887+
"max_position_embeddings": 32768,
4888+
"model_type": "gemma3_text",
4889+
"num_attention_heads": 2,
4890+
"num_hidden_layers": 2,
4891+
"num_key_value_heads": 1,
4892+
"pad_token_id": 0,
4893+
"query_pre_attn_scalar": 256,
4894+
"rms_norm_eps": 1e-06,
4895+
"rope_local_base_freq": 10000,
4896+
"rope_scaling": null,
4897+
"rope_theta": 1000000,
4898+
"sliding_window": 512,
4899+
"sliding_window_pattern": 6,
4900+
"torch_dtype": "float32",
4901+
"transformers_version": "4.52.0.dev0",
4902+
"use_cache": true,
4903+
"vocab_size": 262144,
4904+
}
4905+
)

0 commit comments

Comments
 (0)