1+ import os
12import unittest
23import torch
34from onnx_diagnostic .ext_test_case import ExtTestCase , never_test
45from onnx_diagnostic .helpers import string_type
56from onnx_diagnostic .helpers .cache_helper import make_dynamic_cache , make_encoder_decoder_cache
67from onnx_diagnostic .helpers .torch_helper import steal_forward
78from onnx_diagnostic .torch_models .hghub .model_inputs import get_untrained_model_with_inputs
9+ from onnx_diagnostic .torch_export_patches import torch_export_patches
810
911
1012class TestHuggingFaceHubModel (ExtTestCase ):
@@ -137,8 +139,9 @@ def test_text_generation_phi4_mini(self):
137139 import torch
138140 from transformers import RobertaTokenizer , T5ForConditionalGeneration
139141
140- tokenizer = RobertaTokenizer .from_pretrained ("microsoft/Phi-4-mini-instruct" )
141- model = T5ForConditionalGeneration .from_pretrained ("microsoft/Phi-4-mini-instruct" )
142+ model_id = "microsoft/Phi-4-mini-instruct"
143+ tokenizer = RobertaTokenizer .from_pretrained (model_id )
144+ model = T5ForConditionalGeneration .from_pretrained (model_id )
142145
143146 text = "def greet(user): print(f'hello <extra_id_0>!')"
144147 input_ids = tokenizer (text , return_tensors = "pt" ).input_ids
@@ -156,6 +159,41 @@ def test_text_generation_phi4_mini(self):
156159 )
157160 print (tokenizer .decode (generated_ids [0 ], skip_special_tokens = True ))
158161
162+ @never_test ()
163+ def test_text_generation_phi3_mini (self ):
164+ # clear&&NEVERTEST=1 python _unittests/ut_tasks/try_tasks.py -k phi3_mini
165+
166+ from transformers import Phi3ForCausalLM , AutoTokenizer
167+
168+ model_id = "microsoft/Phi-3-mini-4k-instruct"
169+ tokenizer = AutoTokenizer .from_pretrained (model_id )
170+ model = Phi3ForCausalLM .from_pretrained (model_id )
171+
172+ messages = [
173+ {
174+ "role" : "system" ,
175+ "content" : (
176+ "You are a helpful digital assistant. Please provide safe, "
177+ "ethical and accurate information to the user."
178+ ),
179+ },
180+ {
181+ "role" : "user" ,
182+ "content" : (
183+ "Can you provide ways to eat combinations of bananas and dragonfruits?"
184+ ),
185+ },
186+ ]
187+ inputs = tokenizer .apply_chat_template (
188+ messages , add_generation_prompt = True , return_tensors = "pt"
189+ )
190+
191+ # simply generate a single sequence
192+ print ()
193+ with steal_forward (model ):
194+ generated_ids = model .generate (inputs , max_length = 100 )
195+ print (tokenizer .decode (generated_ids [0 ], skip_special_tokens = True ))
196+
159197 @never_test ()
160198 @unittest .skip (
161199 reason = "AttributeError: 'Phi4MMModel' object has no attribute "
@@ -791,6 +829,119 @@ def test_sentence_similary_alibaba_nlp_gte(self):
791829 scores = (embeddings [:1 ] @ embeddings [1 :].T ) * 100
792830 print (scores .tolist ())
793831
832+ @never_test ()
833+ def test_imagetext2text_generation_gemma3_4b_it (self ):
834+ """
835+ clear&&NEVERTEST=1 python _unittests/ut_tasks/try_tasks.py -k gemma3_4b_it
836+ """
837+ from transformers import AutoProcessor , Gemma3ForConditionalGeneration
838+
839+ model_id = "google/gemma-3-4b-it"
840+ if os .environ .get ("PRETRAINED" , "" ):
841+ model = Gemma3ForConditionalGeneration .from_pretrained (
842+ model_id , device_map = "cpu"
843+ ).eval ()
844+ else :
845+ data = get_untrained_model_with_inputs (
846+ model_id ,
847+ verbose = 1 ,
848+ add_second_input = False ,
849+ # same_as_pretrained=True, #use_pretrained=True
850+ inputs_kwargs = {
851+ "sequence_length" : 281 ,
852+ "batch_size" : 1 ,
853+ "max_sequence_length" : 580 ,
854+ "n_images" : 1 ,
855+ },
856+ )
857+ model = data ["model" ]
858+
859+ print (f"-- model.device={ model .device } " )
860+ processor = AutoProcessor .from_pretrained (model_id , use_fast = True )
861+ print (f"-- processor={ type (processor )} " )
862+
863+ messages = messages = [
864+ {
865+ "role" : "system" ,
866+ "content" : [{"type" : "text" , "text" : "You are a helpful assistant." }],
867+ },
868+ {
869+ "role" : "user" ,
870+ "content" : [
871+ {
872+ "type" : "image" ,
873+ "image" : "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/bee.jpg" ,
874+ },
875+ {"type" : "text" , "text" : "Describe this image in detail." },
876+ ],
877+ },
878+ ]
879+ inputs = processor .apply_chat_template (
880+ messages ,
881+ tokenize = True ,
882+ add_generation_prompt = True ,
883+ return_dict = True ,
884+ return_tensors = "pt" ,
885+ ).to (model .device , dtype = torch .bfloat16 )
886+ # if "token_type_ids" in inputs:
887+ # print(
888+ # f"-- remove token_type_ids: "
889+ # f"{self.string_type(inputs['token_type_ids'], with_shape=True)}"
890+ # )
891+ # inputs.pop("token_type_ids", None)
892+ print (f"-- inputs={ self .string_type (inputs )} " )
893+
894+ # iteration merge = sequence > 1, cache not empty
895+ # iteration 1 = sequence > 1, no cache
896+ # cache_position:T7s281,
897+ # past_key_values:StaticCache(key_cache=#0[], value_cache=#0[]),
898+ # input_ids:T7s1x281,
899+ # inputs_embeds:None,
900+ # token_type_ids:T7s1x281,
901+ # attention_mask:dict(sliding_attention:T9s1x1x281x580,
902+ # full_attention:T9s1x1x281x580),
903+ # position_ids:None,
904+ # use_cache:bool,
905+ # logits_to_keep:None,
906+ # pixel_values:T16s1x3x896x896,
907+ # return_dict:bool)
908+ # iteration 2 = sequence = 1, cache not empty
909+ # cache_position:T7s1,
910+ # past_key_values:StaticCache(key_cache=#34[T1s1x4x580x256,...],
911+ # value_cache=#34[T1s1x4x580x256,...]),
912+ # input_ids:T7s1x1,
913+ # inputs_embeds:None,
914+ # token_type_ids:T7s1x1,
915+ # attention_mask:dict(sliding_attention:T9s1x1x1x580,full_attention:T9s1x1x1x580),
916+ # position_ids:None,
917+ # use_cache:bool,logits_to_keep:None,return_dict:bool)
918+
919+ print ()
920+ with (
921+ torch_export_patches (
922+ patch_torch = False , patch_sympy = False , patch_transformers = True
923+ ),
924+ steal_forward (
925+ model ,
926+ dump_file = self .get_dump_file (
927+ "test_imagetext2text_generation_gemma3_4b_it.onnx"
928+ ),
929+ dump_drop = {"attention_mask" , "past_key_values" , "pixel_values" },
930+ save_as_external_data = False ,
931+ ),
932+ ):
933+ generated_ids = model .generate (
934+ ** inputs ,
935+ # 282 = value high enough to trigger multiple iterations of the model
936+ max_new_tokens = 282 ,
937+ do_sample = False ,
938+ cache_implementation = "static" ,
939+ )
940+ output_text = processor .decode (
941+ generated_ids [0 ][inputs ["input_ids" ].shape [1 ] :], skip_special_tokens = False
942+ )
943+ print (output_text )
944+
794945
795946if __name__ == "__main__" :
796947 unittest .main (verbosity = 2 )
0 commit comments