Skip to content

Commit b7d8e27

Browse files
committed
Merge branch 'main' of https://github.com/sdpython/onnx-diagnostic into latetr
2 parents 5a01286 + 54c0b00 commit b7d8e27

File tree

4 files changed

+814
-5
lines changed

4 files changed

+814
-5
lines changed

CHANGELOGS.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@ Change Logs
44
0.8.2
55
+++++
66

7+
* :pr:`302`: adds helpers to analyse onnxruntime profiling
78
* :pr:`297`: experiment around a higher ops ``loop_for``
89
* :pr:`292`, :pr:`293`, :pr:`294`, :pr:`295`: new patches for Qwen models
910

_unittests/ut_helpers/test_rt_helper.py

Lines changed: 258 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,19 +1,28 @@
11
import os
22
import unittest
3+
import numpy as np
34
import onnx
5+
import onnx.helper as oh
6+
import onnx.numpy_helper as onh
47
import torch
58
from onnx_diagnostic.ext_test_case import (
69
ExtTestCase,
710
has_onnxruntime_genai,
811
hide_stdout,
9-
requires_transformers,
12+
ignore_warnings,
1013
requires_torch,
14+
requires_transformers,
15+
skipif_ci_windows,
1116
)
1217
from onnx_diagnostic.helpers.rt_helper import (
1318
onnx_generate,
1419
generate_and_validate,
1520
onnx_generate_with_genai,
1621
name_type_to_onnx_dtype,
22+
js_profile_to_dataframe,
23+
plot_ort_profile_timeline,
24+
plot_ort_profile,
25+
_process_shape,
1726
)
1827
from onnx_diagnostic.torch_models.hghub import get_untrained_model_with_inputs
1928
from onnx_diagnostic.torch_export_patches import torch_export_patches
@@ -24,6 +33,7 @@ class TestRtSession(ExtTestCase):
2433
@requires_transformers("4.55")
2534
@requires_torch("2.9")
2635
@hide_stdout()
36+
@ignore_warnings(FutureWarning)
2737
def test_onnx_generate(self):
2838
mid = "arnir0/Tiny-LLM"
2939
print("-- test_onnx_generate: get model")
@@ -84,6 +94,253 @@ def test_name_type_to_onnx_dtype(self):
8494
expected = getattr(onnx.TensorProto, name.upper())
8595
self.assertEqual(expected, name_type_to_onnx_dtype(look))
8696

97+
def test_shapes(self):
98+
tests = [
99+
(
100+
"U8[1x128x768]+F+U8",
101+
[{"uint8": [1, 128, 768]}, {"float": []}, {"uint8": []}],
102+
),
103+
("F[1x128x768]", [{"float": [1, 128, 768]}]),
104+
]
105+
for expected, shapes in tests:
106+
with self.subTest(shapes=shapes):
107+
out = _process_shape(shapes)
108+
self.assertEqual(expected, out)
109+
110+
def _get_model(self):
111+
model_def0 = oh.make_model(
112+
oh.make_graph(
113+
[
114+
oh.make_node("Add", ["X", "init1"], ["X1"]),
115+
oh.make_node("Abs", ["X"], ["X2"]),
116+
oh.make_node("Add", ["X", "init3"], ["inter"]),
117+
oh.make_node("Mul", ["X1", "inter"], ["Xm"]),
118+
oh.make_node("Sub", ["X2", "Xm"], ["final"]),
119+
],
120+
"test",
121+
[oh.make_tensor_value_info("X", onnx.TensorProto.FLOAT, [None])],
122+
[oh.make_tensor_value_info("final", onnx.TensorProto.FLOAT, [None])],
123+
[
124+
onh.from_array(np.array([1], dtype=np.float32), name="init1"),
125+
onh.from_array(np.array([3], dtype=np.float32), name="init3"),
126+
],
127+
),
128+
opset_imports=[oh.make_opsetid("", 18)],
129+
ir_version=9,
130+
)
131+
return model_def0
132+
133+
def test_js_profile_to_dataframe(self):
134+
import onnxruntime
135+
136+
sess_options = onnxruntime.SessionOptions()
137+
sess_options.enable_profiling = True
138+
sess = onnxruntime.InferenceSession(
139+
self._get_model().SerializeToString(),
140+
sess_options,
141+
providers=["CPUExecutionProvider"],
142+
)
143+
for _ in range(11):
144+
sess.run(None, dict(X=np.arange(10).astype(np.float32)))
145+
prof = sess.end_profiling()
146+
147+
df = js_profile_to_dataframe(prof, first_it_out=True)
148+
self.assertEqual(df.shape, (79, 18))
149+
self.assertEqual(
150+
set(df.columns),
151+
{
152+
"cat",
153+
"pid",
154+
"tid",
155+
"dur",
156+
"ts",
157+
"ph",
158+
"name",
159+
"args_op_name",
160+
"op_name",
161+
"args_thread_scheduling_stats",
162+
"args_output_size",
163+
"args_parameter_size",
164+
"args_activation_size",
165+
"args_node_index",
166+
"args_provider",
167+
"event_name",
168+
"iteration",
169+
"it==0",
170+
},
171+
)
172+
173+
df = js_profile_to_dataframe(prof, agg=True)
174+
self.assertEqual(df.shape, (9, 1))
175+
self.assertEqual(list(df.columns), ["dur"])
176+
177+
df = js_profile_to_dataframe(prof, agg_op_name=True)
178+
self.assertEqual(df.shape, (79, 17))
179+
self.assertEqual(
180+
set(df.columns),
181+
{
182+
"cat",
183+
"pid",
184+
"tid",
185+
"dur",
186+
"ts",
187+
"ph",
188+
"name",
189+
"args_op_name",
190+
"op_name",
191+
"args_thread_scheduling_stats",
192+
"args_output_size",
193+
"args_parameter_size",
194+
"args_activation_size",
195+
"args_node_index",
196+
"args_provider",
197+
"event_name",
198+
"iteration",
199+
},
200+
)
201+
202+
os.remove(prof)
203+
204+
@ignore_warnings(UserWarning)
205+
@skipif_ci_windows("failing because of tkinter?")
206+
def test_plot_profile_2(self):
207+
import matplotlib.pyplot as plt
208+
import onnxruntime
209+
210+
sess_options = onnxruntime.SessionOptions()
211+
sess_options.enable_profiling = True
212+
sess = onnxruntime.InferenceSession(
213+
self._get_model().SerializeToString(),
214+
sess_options,
215+
providers=["CPUExecutionProvider"],
216+
)
217+
for _ in range(11):
218+
sess.run(None, dict(X=np.arange(10).astype(np.float32)))
219+
prof = sess.end_profiling()
220+
221+
df = js_profile_to_dataframe(prof, first_it_out=True)
222+
223+
fig, ax = plt.subplots(1, 2, figsize=(10, 5))
224+
plot_ort_profile(df, ax[0], ax[1], "test_title")
225+
# fig.savefig("graph1.png")
226+
self.assertNotEmpty(fig)
227+
228+
os.remove(prof)
229+
230+
@ignore_warnings(UserWarning)
231+
@skipif_ci_windows("failing because of tkinter?")
232+
def test_plot_profile_2_shape(self):
233+
import matplotlib.pyplot as plt
234+
import onnxruntime
235+
236+
sess_options = onnxruntime.SessionOptions()
237+
sess_options.enable_profiling = True
238+
sess = onnxruntime.InferenceSession(
239+
self._get_model().SerializeToString(),
240+
sess_options,
241+
providers=["CPUExecutionProvider"],
242+
)
243+
for _ in range(11):
244+
sess.run(None, dict(X=np.arange(10).astype(np.float32)))
245+
prof = sess.end_profiling()
246+
247+
df = js_profile_to_dataframe(prof, first_it_out=True, with_shape=True)
248+
249+
fig, ax = plt.subplots(1, 2, figsize=(10, 5))
250+
plot_ort_profile(df, ax[0], ax[1], "test_title")
251+
# fig.savefig("graph1.png")
252+
self.assertNotEmpty(fig)
253+
254+
os.remove(prof)
255+
256+
@ignore_warnings(UserWarning)
257+
@skipif_ci_windows("failing because of tkinter?")
258+
def test_plot_profile_agg(self):
259+
import matplotlib.pyplot as plt
260+
import onnxruntime
261+
262+
sess_options = onnxruntime.SessionOptions()
263+
sess_options.enable_profiling = True
264+
sess = onnxruntime.InferenceSession(
265+
self._get_model().SerializeToString(),
266+
sess_options,
267+
providers=["CPUExecutionProvider"],
268+
)
269+
for _ in range(11):
270+
sess.run(None, dict(X=np.arange(10).astype(np.float32)))
271+
prof = sess.end_profiling()
272+
273+
df = js_profile_to_dataframe(prof, first_it_out=True, agg=True)
274+
275+
fig, ax = plt.subplots(1, 1, figsize=(10, 5))
276+
plot_ort_profile(df, ax, title="test_title")
277+
fig.tight_layout()
278+
# fig.savefig("graph2.png")
279+
self.assertNotEmpty(fig)
280+
281+
os.remove(prof)
282+
283+
def _get_model2(self):
284+
model_def0 = oh.make_model(
285+
oh.make_graph(
286+
[
287+
oh.make_node("Add", ["X", "init1"], ["X1"]),
288+
oh.make_node("Abs", ["X"], ["X2"]),
289+
oh.make_node("Add", ["X", "init3"], ["inter"]),
290+
oh.make_node("Mul", ["X1", "inter"], ["Xm"]),
291+
oh.make_node("MatMul", ["X1", "Xm"], ["Xm2"]),
292+
oh.make_node("Sub", ["X2", "Xm2"], ["final"]),
293+
],
294+
"test",
295+
[oh.make_tensor_value_info("X", onnx.TensorProto.FLOAT, [None, None])],
296+
[oh.make_tensor_value_info("final", onnx.TensorProto.FLOAT, [None, None])],
297+
[
298+
onh.from_array(np.array([1], dtype=np.float32), name="init1"),
299+
onh.from_array(np.array([3], dtype=np.float32), name="init3"),
300+
],
301+
),
302+
opset_imports=[oh.make_opsetid("", 18)],
303+
ir_version=9,
304+
)
305+
return model_def0
306+
307+
@ignore_warnings(UserWarning)
308+
@skipif_ci_windows("failing because of tkinter?")
309+
def test_plot_profile_timeline(self):
310+
import matplotlib.pyplot as plt
311+
import onnxruntime
312+
313+
sess_options = onnxruntime.SessionOptions()
314+
sess_options.enable_profiling = True
315+
sess = onnxruntime.InferenceSession(
316+
self._get_model2().SerializeToString(),
317+
sess_options,
318+
providers=["CPUExecutionProvider"],
319+
)
320+
for _ in range(11):
321+
sess.run(None, dict(X=np.random.rand(2**10, 2**10).astype(np.float32)))
322+
prof = sess.end_profiling()
323+
324+
df = js_profile_to_dataframe(prof, first_it_out=True)
325+
326+
fig, ax = plt.subplots(1, 1, figsize=(5, 10))
327+
plot_ort_profile_timeline(df, ax, title="test_timeline", quantile=0.5)
328+
fig.tight_layout()
329+
fig.savefig("test_plot_profile_timeline.png")
330+
self.assertNotEmpty(fig)
331+
332+
os.remove(prof)
333+
87334

88335
if __name__ == "__main__":
336+
import logging
337+
338+
for name in [
339+
"matplotlib.font_manager",
340+
"PIL.PngImagePlugin",
341+
"matplotlib",
342+
"matplotlib.pyplot",
343+
]:
344+
log = logging.getLogger(name)
345+
log.setLevel(logging.ERROR)
89346
unittest.main(verbosity=2)

0 commit comments

Comments
 (0)