11import os
22import unittest
3+ import numpy as np
34import onnx
5+ import onnx .helper as oh
6+ import onnx .numpy_helper as onh
47import torch
58from onnx_diagnostic .ext_test_case import (
69 ExtTestCase ,
710 has_onnxruntime_genai ,
811 hide_stdout ,
9- requires_transformers ,
12+ ignore_warnings ,
1013 requires_torch ,
14+ requires_transformers ,
15+ skipif_ci_windows ,
1116)
1217from onnx_diagnostic .helpers .rt_helper import (
1318 onnx_generate ,
1419 generate_and_validate ,
1520 onnx_generate_with_genai ,
1621 name_type_to_onnx_dtype ,
22+ js_profile_to_dataframe ,
23+ plot_ort_profile_timeline ,
24+ plot_ort_profile ,
25+ _process_shape ,
1726)
1827from onnx_diagnostic .torch_models .hghub import get_untrained_model_with_inputs
1928from onnx_diagnostic .torch_export_patches import torch_export_patches
@@ -24,6 +33,7 @@ class TestRtSession(ExtTestCase):
2433 @requires_transformers ("4.55" )
2534 @requires_torch ("2.9" )
2635 @hide_stdout ()
36+ @ignore_warnings (FutureWarning )
2737 def test_onnx_generate (self ):
2838 mid = "arnir0/Tiny-LLM"
2939 print ("-- test_onnx_generate: get model" )
@@ -84,6 +94,253 @@ def test_name_type_to_onnx_dtype(self):
8494 expected = getattr (onnx .TensorProto , name .upper ())
8595 self .assertEqual (expected , name_type_to_onnx_dtype (look ))
8696
97+ def test_shapes (self ):
98+ tests = [
99+ (
100+ "U8[1x128x768]+F+U8" ,
101+ [{"uint8" : [1 , 128 , 768 ]}, {"float" : []}, {"uint8" : []}],
102+ ),
103+ ("F[1x128x768]" , [{"float" : [1 , 128 , 768 ]}]),
104+ ]
105+ for expected , shapes in tests :
106+ with self .subTest (shapes = shapes ):
107+ out = _process_shape (shapes )
108+ self .assertEqual (expected , out )
109+
110+ def _get_model (self ):
111+ model_def0 = oh .make_model (
112+ oh .make_graph (
113+ [
114+ oh .make_node ("Add" , ["X" , "init1" ], ["X1" ]),
115+ oh .make_node ("Abs" , ["X" ], ["X2" ]),
116+ oh .make_node ("Add" , ["X" , "init3" ], ["inter" ]),
117+ oh .make_node ("Mul" , ["X1" , "inter" ], ["Xm" ]),
118+ oh .make_node ("Sub" , ["X2" , "Xm" ], ["final" ]),
119+ ],
120+ "test" ,
121+ [oh .make_tensor_value_info ("X" , onnx .TensorProto .FLOAT , [None ])],
122+ [oh .make_tensor_value_info ("final" , onnx .TensorProto .FLOAT , [None ])],
123+ [
124+ onh .from_array (np .array ([1 ], dtype = np .float32 ), name = "init1" ),
125+ onh .from_array (np .array ([3 ], dtype = np .float32 ), name = "init3" ),
126+ ],
127+ ),
128+ opset_imports = [oh .make_opsetid ("" , 18 )],
129+ ir_version = 9 ,
130+ )
131+ return model_def0
132+
133+ def test_js_profile_to_dataframe (self ):
134+ import onnxruntime
135+
136+ sess_options = onnxruntime .SessionOptions ()
137+ sess_options .enable_profiling = True
138+ sess = onnxruntime .InferenceSession (
139+ self ._get_model ().SerializeToString (),
140+ sess_options ,
141+ providers = ["CPUExecutionProvider" ],
142+ )
143+ for _ in range (11 ):
144+ sess .run (None , dict (X = np .arange (10 ).astype (np .float32 )))
145+ prof = sess .end_profiling ()
146+
147+ df = js_profile_to_dataframe (prof , first_it_out = True )
148+ self .assertEqual (df .shape , (79 , 18 ))
149+ self .assertEqual (
150+ set (df .columns ),
151+ {
152+ "cat" ,
153+ "pid" ,
154+ "tid" ,
155+ "dur" ,
156+ "ts" ,
157+ "ph" ,
158+ "name" ,
159+ "args_op_name" ,
160+ "op_name" ,
161+ "args_thread_scheduling_stats" ,
162+ "args_output_size" ,
163+ "args_parameter_size" ,
164+ "args_activation_size" ,
165+ "args_node_index" ,
166+ "args_provider" ,
167+ "event_name" ,
168+ "iteration" ,
169+ "it==0" ,
170+ },
171+ )
172+
173+ df = js_profile_to_dataframe (prof , agg = True )
174+ self .assertEqual (df .shape , (9 , 1 ))
175+ self .assertEqual (list (df .columns ), ["dur" ])
176+
177+ df = js_profile_to_dataframe (prof , agg_op_name = True )
178+ self .assertEqual (df .shape , (79 , 17 ))
179+ self .assertEqual (
180+ set (df .columns ),
181+ {
182+ "cat" ,
183+ "pid" ,
184+ "tid" ,
185+ "dur" ,
186+ "ts" ,
187+ "ph" ,
188+ "name" ,
189+ "args_op_name" ,
190+ "op_name" ,
191+ "args_thread_scheduling_stats" ,
192+ "args_output_size" ,
193+ "args_parameter_size" ,
194+ "args_activation_size" ,
195+ "args_node_index" ,
196+ "args_provider" ,
197+ "event_name" ,
198+ "iteration" ,
199+ },
200+ )
201+
202+ os .remove (prof )
203+
204+ @ignore_warnings (UserWarning )
205+ @skipif_ci_windows ("failing because of tkinter?" )
206+ def test_plot_profile_2 (self ):
207+ import matplotlib .pyplot as plt
208+ import onnxruntime
209+
210+ sess_options = onnxruntime .SessionOptions ()
211+ sess_options .enable_profiling = True
212+ sess = onnxruntime .InferenceSession (
213+ self ._get_model ().SerializeToString (),
214+ sess_options ,
215+ providers = ["CPUExecutionProvider" ],
216+ )
217+ for _ in range (11 ):
218+ sess .run (None , dict (X = np .arange (10 ).astype (np .float32 )))
219+ prof = sess .end_profiling ()
220+
221+ df = js_profile_to_dataframe (prof , first_it_out = True )
222+
223+ fig , ax = plt .subplots (1 , 2 , figsize = (10 , 5 ))
224+ plot_ort_profile (df , ax [0 ], ax [1 ], "test_title" )
225+ # fig.savefig("graph1.png")
226+ self .assertNotEmpty (fig )
227+
228+ os .remove (prof )
229+
230+ @ignore_warnings (UserWarning )
231+ @skipif_ci_windows ("failing because of tkinter?" )
232+ def test_plot_profile_2_shape (self ):
233+ import matplotlib .pyplot as plt
234+ import onnxruntime
235+
236+ sess_options = onnxruntime .SessionOptions ()
237+ sess_options .enable_profiling = True
238+ sess = onnxruntime .InferenceSession (
239+ self ._get_model ().SerializeToString (),
240+ sess_options ,
241+ providers = ["CPUExecutionProvider" ],
242+ )
243+ for _ in range (11 ):
244+ sess .run (None , dict (X = np .arange (10 ).astype (np .float32 )))
245+ prof = sess .end_profiling ()
246+
247+ df = js_profile_to_dataframe (prof , first_it_out = True , with_shape = True )
248+
249+ fig , ax = plt .subplots (1 , 2 , figsize = (10 , 5 ))
250+ plot_ort_profile (df , ax [0 ], ax [1 ], "test_title" )
251+ # fig.savefig("graph1.png")
252+ self .assertNotEmpty (fig )
253+
254+ os .remove (prof )
255+
256+ @ignore_warnings (UserWarning )
257+ @skipif_ci_windows ("failing because of tkinter?" )
258+ def test_plot_profile_agg (self ):
259+ import matplotlib .pyplot as plt
260+ import onnxruntime
261+
262+ sess_options = onnxruntime .SessionOptions ()
263+ sess_options .enable_profiling = True
264+ sess = onnxruntime .InferenceSession (
265+ self ._get_model ().SerializeToString (),
266+ sess_options ,
267+ providers = ["CPUExecutionProvider" ],
268+ )
269+ for _ in range (11 ):
270+ sess .run (None , dict (X = np .arange (10 ).astype (np .float32 )))
271+ prof = sess .end_profiling ()
272+
273+ df = js_profile_to_dataframe (prof , first_it_out = True , agg = True )
274+
275+ fig , ax = plt .subplots (1 , 1 , figsize = (10 , 5 ))
276+ plot_ort_profile (df , ax , title = "test_title" )
277+ fig .tight_layout ()
278+ # fig.savefig("graph2.png")
279+ self .assertNotEmpty (fig )
280+
281+ os .remove (prof )
282+
283+ def _get_model2 (self ):
284+ model_def0 = oh .make_model (
285+ oh .make_graph (
286+ [
287+ oh .make_node ("Add" , ["X" , "init1" ], ["X1" ]),
288+ oh .make_node ("Abs" , ["X" ], ["X2" ]),
289+ oh .make_node ("Add" , ["X" , "init3" ], ["inter" ]),
290+ oh .make_node ("Mul" , ["X1" , "inter" ], ["Xm" ]),
291+ oh .make_node ("MatMul" , ["X1" , "Xm" ], ["Xm2" ]),
292+ oh .make_node ("Sub" , ["X2" , "Xm2" ], ["final" ]),
293+ ],
294+ "test" ,
295+ [oh .make_tensor_value_info ("X" , onnx .TensorProto .FLOAT , [None , None ])],
296+ [oh .make_tensor_value_info ("final" , onnx .TensorProto .FLOAT , [None , None ])],
297+ [
298+ onh .from_array (np .array ([1 ], dtype = np .float32 ), name = "init1" ),
299+ onh .from_array (np .array ([3 ], dtype = np .float32 ), name = "init3" ),
300+ ],
301+ ),
302+ opset_imports = [oh .make_opsetid ("" , 18 )],
303+ ir_version = 9 ,
304+ )
305+ return model_def0
306+
307+ @ignore_warnings (UserWarning )
308+ @skipif_ci_windows ("failing because of tkinter?" )
309+ def test_plot_profile_timeline (self ):
310+ import matplotlib .pyplot as plt
311+ import onnxruntime
312+
313+ sess_options = onnxruntime .SessionOptions ()
314+ sess_options .enable_profiling = True
315+ sess = onnxruntime .InferenceSession (
316+ self ._get_model2 ().SerializeToString (),
317+ sess_options ,
318+ providers = ["CPUExecutionProvider" ],
319+ )
320+ for _ in range (11 ):
321+ sess .run (None , dict (X = np .random .rand (2 ** 10 , 2 ** 10 ).astype (np .float32 )))
322+ prof = sess .end_profiling ()
323+
324+ df = js_profile_to_dataframe (prof , first_it_out = True )
325+
326+ fig , ax = plt .subplots (1 , 1 , figsize = (5 , 10 ))
327+ plot_ort_profile_timeline (df , ax , title = "test_timeline" , quantile = 0.5 )
328+ fig .tight_layout ()
329+ fig .savefig ("test_plot_profile_timeline.png" )
330+ self .assertNotEmpty (fig )
331+
332+ os .remove (prof )
333+
87334
88335if __name__ == "__main__" :
336+ import logging
337+
338+ for name in [
339+ "matplotlib.font_manager" ,
340+ "PIL.PngImagePlugin" ,
341+ "matplotlib" ,
342+ "matplotlib.pyplot" ,
343+ ]:
344+ log = logging .getLogger (name )
345+ log .setLevel (logging .ERROR )
89346 unittest .main (verbosity = 2 )
0 commit comments