1+ import os
12import unittest
23import torch
34from onnx_diagnostic .ext_test_case import ExtTestCase , never_test
@@ -163,9 +164,12 @@ def test_text_generation_tiny_llm(self):
163164
164165 # simply generate a single sequence
165166 print ()
166- with torch_export_patches (
167- patch_transformers = True , patch_torch = False , patch_sympy = False
168- ), steal_forward (model ):
167+ with (
168+ torch_export_patches (
169+ patch_transformers = True , patch_torch = False , patch_sympy = False
170+ ),
171+ steal_forward (model ),
172+ ):
169173 generated_ids = model .generate (
170174 input_ids = input_ids ,
171175 max_length = 100 ,
@@ -181,8 +185,9 @@ def test_text_generation_phi4_mini(self):
181185 import torch
182186 from transformers import RobertaTokenizer , T5ForConditionalGeneration
183187
184- tokenizer = RobertaTokenizer .from_pretrained ("microsoft/Phi-4-mini-instruct" )
185- model = T5ForConditionalGeneration .from_pretrained ("microsoft/Phi-4-mini-instruct" )
188+ model_id = "microsoft/Phi-4-mini-instruct"
189+ tokenizer = RobertaTokenizer .from_pretrained (model_id )
190+ model = T5ForConditionalGeneration .from_pretrained (model_id )
186191
187192 text = "def greet(user): print(f'hello <extra_id_0>!')"
188193 input_ids = tokenizer (text , return_tensors = "pt" ).input_ids
@@ -200,6 +205,41 @@ def test_text_generation_phi4_mini(self):
200205 )
201206 print (tokenizer .decode (generated_ids [0 ], skip_special_tokens = True ))
202207
208+ @never_test ()
209+ def test_text_generation_phi3_mini (self ):
210+ # clear&&NEVERTEST=1 python _unittests/ut_tasks/try_tasks.py -k phi3_mini
211+
212+ from transformers import Phi3ForCausalLM , AutoTokenizer
213+
214+ model_id = "microsoft/Phi-3-mini-4k-instruct"
215+ tokenizer = AutoTokenizer .from_pretrained (model_id )
216+ model = Phi3ForCausalLM .from_pretrained (model_id )
217+
218+ messages = [
219+ {
220+ "role" : "system" ,
221+ "content" : (
222+ "You are a helpful digital assistant. Please provide safe, "
223+ "ethical and accurate information to the user."
224+ ),
225+ },
226+ {
227+ "role" : "user" ,
228+ "content" : (
229+ "Can you provide ways to eat combinations of bananas and dragonfruits?"
230+ ),
231+ },
232+ ]
233+ inputs = tokenizer .apply_chat_template (
234+ messages , add_generation_prompt = True , return_tensors = "pt"
235+ )
236+
237+ # simply generate a single sequence
238+ print ()
239+ with steal_forward (model ):
240+ generated_ids = model .generate (inputs , max_length = 100 )
241+ print (tokenizer .decode (generated_ids [0 ], skip_special_tokens = True ))
242+
203243 @never_test ()
204244 @unittest .skip (
205245 reason = "AttributeError: 'Phi4MMModel' object has no attribute "
@@ -835,6 +875,119 @@ def test_sentence_similary_alibaba_nlp_gte(self):
835875 scores = (embeddings [:1 ] @ embeddings [1 :].T ) * 100
836876 print (scores .tolist ())
837877
878+ @never_test ()
879+ def test_imagetext2text_generation_gemma3_4b_it (self ):
880+ """
881+ clear&&NEVERTEST=1 python _unittests/ut_tasks/try_tasks.py -k gemma3_4b_it
882+ """
883+ from transformers import AutoProcessor , Gemma3ForConditionalGeneration
884+
885+ model_id = "google/gemma-3-4b-it"
886+ if os .environ .get ("PRETRAINED" , "" ):
887+ model = Gemma3ForConditionalGeneration .from_pretrained (
888+ model_id , device_map = "cpu"
889+ ).eval ()
890+ else :
891+ data = get_untrained_model_with_inputs (
892+ model_id ,
893+ verbose = 1 ,
894+ add_second_input = False ,
895+ # same_as_pretrained=True, #use_pretrained=True
896+ inputs_kwargs = {
897+ "sequence_length" : 281 ,
898+ "batch_size" : 1 ,
899+ "max_sequence_length" : 580 ,
900+ "n_images" : 1 ,
901+ },
902+ )
903+ model = data ["model" ]
904+
905+ print (f"-- model.device={ model .device } " )
906+ processor = AutoProcessor .from_pretrained (model_id , use_fast = True )
907+ print (f"-- processor={ type (processor )} " )
908+
909+ messages = messages = [
910+ {
911+ "role" : "system" ,
912+ "content" : [{"type" : "text" , "text" : "You are a helpful assistant." }],
913+ },
914+ {
915+ "role" : "user" ,
916+ "content" : [
917+ {
918+ "type" : "image" ,
919+ "image" : "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/bee.jpg" ,
920+ },
921+ {"type" : "text" , "text" : "Describe this image in detail." },
922+ ],
923+ },
924+ ]
925+ inputs = processor .apply_chat_template (
926+ messages ,
927+ tokenize = True ,
928+ add_generation_prompt = True ,
929+ return_dict = True ,
930+ return_tensors = "pt" ,
931+ ).to (model .device , dtype = torch .bfloat16 )
932+ # if "token_type_ids" in inputs:
933+ # print(
934+ # f"-- remove token_type_ids: "
935+ # f"{self.string_type(inputs['token_type_ids'], with_shape=True)}"
936+ # )
937+ # inputs.pop("token_type_ids", None)
938+ print (f"-- inputs={ self .string_type (inputs )} " )
939+
940+ # iteration merge = sequence > 1, cache not empty
941+ # iteration 1 = sequence > 1, no cache
942+ # cache_position:T7s281,
943+ # past_key_values:StaticCache(key_cache=#0[], value_cache=#0[]),
944+ # input_ids:T7s1x281,
945+ # inputs_embeds:None,
946+ # token_type_ids:T7s1x281,
947+ # attention_mask:dict(sliding_attention:T9s1x1x281x580,
948+ # full_attention:T9s1x1x281x580),
949+ # position_ids:None,
950+ # use_cache:bool,
951+ # logits_to_keep:None,
952+ # pixel_values:T16s1x3x896x896,
953+ # return_dict:bool)
954+ # iteration 2 = sequence = 1, cache not empty
955+ # cache_position:T7s1,
956+ # past_key_values:StaticCache(key_cache=#34[T1s1x4x580x256,...],
957+ # value_cache=#34[T1s1x4x580x256,...]),
958+ # input_ids:T7s1x1,
959+ # inputs_embeds:None,
960+ # token_type_ids:T7s1x1,
961+ # attention_mask:dict(sliding_attention:T9s1x1x1x580,full_attention:T9s1x1x1x580),
962+ # position_ids:None,
963+ # use_cache:bool,logits_to_keep:None,return_dict:bool)
964+
965+ print ()
966+ with (
967+ torch_export_patches (
968+ patch_torch = False , patch_sympy = False , patch_transformers = True
969+ ),
970+ steal_forward (
971+ model ,
972+ dump_file = self .get_dump_file (
973+ "test_imagetext2text_generation_gemma3_4b_it.onnx"
974+ ),
975+ dump_drop = {"attention_mask" , "past_key_values" , "pixel_values" },
976+ save_as_external_data = False ,
977+ ),
978+ ):
979+ generated_ids = model .generate (
980+ ** inputs ,
981+ # 282 = value high enough to trigger multiple iterations of the model
982+ max_new_tokens = 282 ,
983+ do_sample = False ,
984+ cache_implementation = "static" ,
985+ )
986+ output_text = processor .decode (
987+ generated_ids [0 ][inputs ["input_ids" ].shape [1 ] :], skip_special_tokens = False
988+ )
989+ print (output_text )
990+
838991
839992if __name__ == "__main__" :
840993 unittest .main (verbosity = 2 )
0 commit comments