|
| 1 | +import os |
1 | 2 | import unittest |
2 | 3 | import torch |
3 | 4 | from onnx_diagnostic.ext_test_case import ExtTestCase, never_test |
@@ -799,17 +800,19 @@ def test_imagetext2text_generation_gemma3_4b_it(self): |
799 | 800 | from transformers import AutoProcessor, Gemma3ForConditionalGeneration |
800 | 801 |
|
801 | 802 | model_id = "google/gemma-3-4b-it" |
802 | | - # model_id = "google/gemma-3n-e4b-it" |
803 | | - # model_id = "qnaug/gemma-3-4b-med" |
804 | | - # model_id = "hf-internal-testing/tiny-random-Gemma3ForCausalLM" |
805 | | - # data = get_untrained_model_with_inputs( |
806 | | - # model_id, verbose=1, add_second_input=True, |
807 | | - # same_as_pretrained=True, use_pretrained=True |
808 | | - # ) |
809 | | - # model = data["model"] |
810 | | - model = Gemma3ForConditionalGeneration.from_pretrained( |
811 | | - model_id, device_map="cpu" |
812 | | - ).eval() |
| 803 | + if os.environ.get("PRETRAINED", ""): |
| 804 | + model = Gemma3ForConditionalGeneration.from_pretrained( |
| 805 | + model_id, device_map="cpu" |
| 806 | + ).eval() |
| 807 | + else: |
| 808 | + data = get_untrained_model_with_inputs( |
| 809 | + model_id, |
| 810 | + verbose=1, |
| 811 | + add_second_input=True, |
| 812 | + # same_as_pretrained=True, #use_pretrained=True |
| 813 | + ) |
| 814 | + model = data["model"] |
| 815 | + |
813 | 816 | print(f"-- model.device={model.device}") |
814 | 817 | processor = AutoProcessor.from_pretrained(model_id, use_fast=True) |
815 | 818 | print(f"-- processor={type(processor)}") |
@@ -845,11 +848,39 @@ def test_imagetext2text_generation_gemma3_4b_it(self): |
845 | 848 | # inputs.pop("token_type_ids", None) |
846 | 849 | print(f"-- inputs={self.string_type(inputs)}") |
847 | 850 |
|
| 851 | + # iteration 1 |
| 852 | + # cache_position:T7s281, |
| 853 | + # past_key_values:StaticCache(key_cache=#0[], value_cache=#0[]), |
| 854 | + # input_ids:T7s1x281, |
| 855 | + # inputs_embeds:None, |
| 856 | + # token_type_ids:T7s1x281, |
| 857 | + # attention_mask:dict(sliding_attention:T9s1x1x281x580, |
| 858 | + # full_attention:T9s1x1x281x580), |
| 859 | + # position_ids:None, |
| 860 | + # use_cache:bool, |
| 861 | + # logits_to_keep:None, |
| 862 | + # pixel_values:T16s1x3x896x896, |
| 863 | + # return_dict:bool) |
| 864 | + # iteration 3 |
| 865 | + # cache_position:T7s1, |
| 866 | + # past_key_values:StaticCache(key_cache=#34[T1s1x4x580x256,...], |
| 867 | + # value_cache=#34[T1s1x4x580x256,...]), |
| 868 | + # input_ids:T7s1x1, |
| 869 | + # inputs_embeds:None, |
| 870 | + # token_type_ids:T7s1x1, |
| 871 | + # attention_mask:dict(sliding_attention:T9s1x1x1x580,full_attention:T9s1x1x1x580), |
| 872 | + # position_ids:None, |
| 873 | + # use_cache:bool,logits_to_keep:None,return_dict:bool) |
| 874 | + |
848 | 875 | print() |
849 | 876 | # steal forward creates a bug... |
850 | | - with steal_forward(model): # , torch.inference_mode(): |
| 877 | + with steal_forward( |
| 878 | + model, |
| 879 | + dump_file=self.get_dump_file("test_imagetext2text_generation_gemma3_4b_it.onnx"), |
| 880 | + dump_drop={"attention_mask", "past_key_values", "pixel_values"}, |
| 881 | + ): |
851 | 882 | generated_ids = model.generate( |
852 | | - **inputs, max_new_tokens=300, do_sample=False, cache_implementation="hybrid" |
| 883 | + **inputs, max_new_tokens=282, do_sample=False, cache_implementation="static" |
853 | 884 | ) |
854 | 885 | output_text = processor.decode( |
855 | 886 | generated_ids[0][inputs["input_ids"].shape[1] :], skip_special_tokens=False |
|
0 commit comments