11import os
2+ import time
23import unittest
34import torch
45from onnx_diagnostic .ext_test_case import ExtTestCase , never_test , ignore_warnings
@@ -45,6 +46,7 @@ def test_imagetext2text_qwen_2_5_vl_instruct_visual(self):
4546 EXPORTER=custom \\
4647 python _unittests/ut_tasks/try_export.py -k qwen_2_5_vl_instruct_visual
4748 """
49+ begin = time .perf_counter ()
4850 device = os .environ .get ("TESTDEVICE" , "cpu" )
4951 dtype = os .environ .get ("TESTDTYPE" , "float32" )
5052 torch_dtype = {
@@ -87,13 +89,18 @@ def _config_reduction(config, task):
8789 )
8890 model = data ["model" ]
8991
92+ print (f"-- MODEL LOADED IN { time .perf_counter () - begin } " )
93+ begin = time .perf_counter ()
9094 model = model .to (device ).to (getattr (torch , dtype ))
95+ print (f"-- MODEL MOVED IN { time .perf_counter () - begin } " )
9196
9297 print (f"-- config._attn_implementation={ model .config ._attn_implementation } " )
9398 print (f"-- model.dtype={ model .dtype } " )
9499 print (f"-- model.device={ model .device } " )
100+ begin = time .perf_counter ()
95101 processor = AutoProcessor .from_pretrained (model_id , use_fast = True )
96102 print (f"-- processor={ type (processor )} " )
103+ print (f"-- PROCESSOR LOADED IN { time .perf_counter () - begin } " )
97104
98105 big_inputs = dict (
99106 hidden_states = torch .rand ((14308 , 1176 ), dtype = torch_dtype ).to (device ),
@@ -104,14 +111,19 @@ def _config_reduction(config, task):
104111 hidden_states = torch .rand ((1292 , 1176 ), dtype = torch_dtype ).to (device ),
105112 grid_thw = torch .tensor ([[1 , 34 , 38 ]], dtype = torch .int64 ).to (device ),
106113 )
107- print ("-- save inputs" )
108- torch .save (big_inputs , self .get_dump_file ("qwen_2_5_vl_instruct_visual.inputs.big.pt" ))
109- torch .save (inputs , self .get_dump_file ("qwen_2_5_vl_instruct_visual.inputs.pt" ))
114+ if not self .unit_test_going ():
115+ print ("-- save inputs" )
116+ torch .save (
117+ big_inputs , self .get_dump_file ("qwen_2_5_vl_instruct_visual.inputs.big.pt" )
118+ )
119+ torch .save (inputs , self .get_dump_file ("qwen_2_5_vl_instruct_visual.inputs.pt" ))
110120
111121 print (f"-- inputs: { self .string_type (inputs , with_shape = True )} " )
112122 # this is too long
113123 model_to_export = model .visual if hasattr (model , "visual" ) else model .model .visual
124+ begin = time .perf_counter ()
114125 expected = model_to_export (** inputs )
126+ print (f"-- MODEL RUN IN { time .perf_counter () - begin } " )
115127 print (f"-- expected: { self .string_type (expected , with_shape = True )} " )
116128
117129 filename = self .get_dump_file (
@@ -126,6 +138,7 @@ def _config_reduction(config, task):
126138 )
127139
128140 # fake_inputs = make_fake_with_dynamic_dimensions(inputs, dynamic_shapes)[0]
141+ begin = time .perf_counter ()
129142 export_inputs = inputs
130143 print ()
131144 with torch_export_patches (
@@ -148,14 +161,21 @@ def _config_reduction(config, task):
148161 onnx_plugs = PLUGS ,
149162 )
150163
164+ print (f"-- MODEL CONVERTED IN { time .perf_counter () - begin } " )
165+
151166 pt2_files = [f"{ fileep } .backup.pt2" , f"{ fileep } .ep.pt2" , f"{ fileep } .pt2" ]
152- pt2_file = [f for f in pt2_files if os .path .exists (f )]
153- assert pt2_file , f"Unable to find an existing file among { pt2_files } "
154- pt2_file = pt2_file [0 ]
167+ pt2_files = [f for f in pt2_files if os .path .exists (f )]
168+ assert (
169+ self .unit_test_going () or pt2_files
170+ ), f"Unable to find an existing file among { pt2_files !r} "
171+ pt2_file = (
172+ (pt2_files [0 ] if pt2_files else None ) if not self .unit_test_going () else None
173+ )
155174 # self.assertExists(pt2_file)
156175 # ep = torch.export.load(pt2_file)
157176 # diff = self.max_diff(ep.module()(**export_inputs), model.visual(**export_inputs))
158177 # print("----------- diff", diff)
178+ begin = time .perf_counter ()
159179 self .assert_onnx_disc (
160180 f"test_imagetext2text_qwen_2_5_vl_instruct_visual.{ device } .{ dtype } .{ exporter } " ,
161181 filename ,
@@ -171,9 +191,10 @@ def _config_reduction(config, task):
171191 atol = 0.02 ,
172192 rtol = 10 ,
173193 ort_optimized_graph = False ,
174- # ep=pt2_file,
194+ ep = pt2_file ,
175195 expected = expected ,
176196 )
197+ print (f"-- MODEL VERIFIED IN { time .perf_counter () - begin } " )
177198 if self .unit_test_going ():
178199 self .clean_dump ()
179200
0 commit comments