Skip to content

Commit 358159a

Browse files
committed
Merge branch 'main' into titaiwang/fix_modelbuilder_discrepancy
2 parents 21355b5 + d6ea09a commit 358159a

File tree

13 files changed

+318
-129
lines changed

13 files changed

+318
-129
lines changed

CHANGELOGS.rst

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,14 @@
11
Change Logs
22
===========
33

4+
0.7.13
5+
++++++
6+
47
0.7.12
58
++++++
69

10+
* :pr:`232`: fixes ``--patch`` argument so that ``--patch=0`` works
11+
* :pr:`231`: better statistics about fusions
712
* :pr:`227`: better support for ``model_id//pretrained``, adds speed up when running command validate
813
* :pr:`226`: fix input order for models created with modelbuilder
914

_doc/index.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -239,6 +239,7 @@ The function replaces dynamic dimensions defined as strings by
239239
Older versions
240240
==============
241241

242+
* `0.7.13 <../v0.7.13/index.html>`_
242243
* `0.7.12 <../v0.7.12/index.html>`_
243244
* `0.7.11 <../v0.7.11/index.html>`_
244245
* `0.6.3 <../v0.6.3/index.html>`_

_unittests/ut_helpers/test_log_helper.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -268,7 +268,7 @@ def test_cube_logs_performance_cube_time(self):
268268
cube = CubeLogsPerformance(dfs, keep_last_date=True)
269269
cube.load()
270270
ct = cube.clone()
271-
self.assertEqual((52, 106), ct.shape)
271+
self.assertEqual((52, 111), ct.shape)
272272

273273
def test_duplicate(self):
274274
df = pandas.DataFrame(

_unittests/ut_torch_models/test_hghub_api.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,6 @@ def test_enumerate_model_list(self):
4040
verbose=1,
4141
dump="test_enumerate_model_list.csv",
4242
filter="image-classification",
43-
library="transformers",
4443
)
4544
)
4645
self.assertEqual(len(models), 2)

_unittests/ut_xrun_doc/test_check_ort_float16.py

Lines changed: 92 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
ExtTestCase,
1111
ignore_warnings,
1212
requires_cuda,
13+
requires_onnxruntime,
1314
)
1415

1516

@@ -18,7 +19,9 @@ def common_scatter(self, opset, providers, dtype, reduction, expected_names):
1819
import onnxruntime
1920
from onnxruntime import InferenceSession, SessionOptions
2021

21-
op_type = "ScatterElements" if "ScatterElements" in expected_names else "ScatterND"
22+
op_type = (
23+
"ScatterElements" if "ScatterElements" in str(expected_names) else "ScatterND"
24+
)
2225
ndim = 2 if op_type == "ScatterElements" else 3
2326

2427
assert dtype in (np.float16, np.float32)
@@ -61,7 +64,10 @@ def common_scatter(self, opset, providers, dtype, reduction, expected_names):
6164
# onnxruntime might introduces some intermediate cast.
6265
if pv.Version(onnxruntime.__version__) <= pv.Version("1.17.1"):
6366
raise unittest.SkipTest("float16 not supported on cpu")
64-
self.assertEqual(expected_names, names)
67+
if isinstance(expected_names, list):
68+
self.assertEqual(names, expected_names)
69+
else:
70+
self.assertIn(names, expected_names)
6571

6672
sonx = str(onx).replace(" ", "").replace("\n", "|")
6773
sexp = 'op_type:"Cast"|attribute{|name:"to"|type:INT|i:%d|}' % itype
@@ -126,24 +132,47 @@ def common_scatter(self, opset, providers, dtype, reduction, expected_names):
126132
(row.get("args_provider", None), row.get("args_op_name", None))
127133
)
128134
short_list = [(a, b) for a, b in exe_providers if a is not None and b is not None]
129-
self.assertEqual(short_list, [("CUDAExecutionProvider", o) for o in expected_names])
135+
if isinstance(expected_names, list):
136+
self.assertEqual(
137+
short_list, [("CUDAExecutionProvider", o) for o in expected_names]
138+
)
139+
else:
140+
self.assertIn(
141+
short_list,
142+
tuple([("CUDAExecutionProvider", o) for o in en] for en in expected_names),
143+
)
130144

131145
@unittest.skip("https://github.com/sdpython/onnx-diagnostic/issues/240")
132146
@requires_cuda()
133147
@ignore_warnings(DeprecationWarning)
148+
@requires_onnxruntime("1.23")
134149
def test_scatterels_cuda(self):
135-
default_value = [
136-
"Cast",
137-
# "MemcpyToHost",
138-
"ScatterElements",
139-
# "MemcpyFromHost",
140-
"Sub",
141-
]
150+
default_value = (
151+
[
152+
"Cast",
153+
# "MemcpyToHost",
154+
"ScatterElements",
155+
# "MemcpyFromHost",
156+
"Sub",
157+
],
158+
[
159+
"Cast",
160+
"Cast",
161+
# "MemcpyToHost",
162+
"ScatterElements",
163+
# "MemcpyFromHost",
164+
"Sub",
165+
],
166+
)
142167
expected = {
143168
(np.float32, "none"): default_value,
144169
(np.float16, "none"): default_value,
145170
(np.float32, "add"): default_value,
146171
(np.float16, "add"): default_value,
172+
(np.float32, "min"): default_value,
173+
(np.float16, "min"): default_value,
174+
(np.float32, "max"): default_value,
175+
(np.float16, "max"): default_value,
147176
}
148177
for opset, dtype, reduction in itertools.product(
149178
[16, 18], [np.float32, np.float16], ["none", "add", "min", "max"]
@@ -161,13 +190,23 @@ def test_scatterels_cuda(self):
161190
@requires_cuda()
162191
@ignore_warnings(DeprecationWarning)
163192
def test_scatternd_cuda(self):
164-
default_value = [
165-
"Cast",
166-
# "MemcpyToHost",
167-
"ScatterND",
168-
# "MemcpyFromHost",
169-
"Sub",
170-
]
193+
default_value = (
194+
[
195+
"Cast",
196+
# "MemcpyToHost",
197+
"ScatterND",
198+
# "MemcpyFromHost",
199+
"Sub",
200+
],
201+
[
202+
"Cast",
203+
"Cast",
204+
# "MemcpyToHost",
205+
"ScatterND",
206+
# "MemcpyFromHost",
207+
"Sub",
208+
],
209+
)
171210
expected = {
172211
(np.float32, "none"): default_value,
173212
(np.float16, "none"): default_value,
@@ -188,20 +227,30 @@ def test_scatternd_cuda(self):
188227

189228
@unittest.skip("https://github.com/sdpython/onnx-diagnostic/issues/240")
190229
@ignore_warnings(DeprecationWarning)
230+
@requires_onnxruntime("1.23")
191231
def test_scatterels_cpu(self):
192232
default_value = [
193233
"Cast",
194234
"ScatterElements",
195235
"Sub",
196236
]
197-
default_value_16 = [
198-
"Cast",
199-
"Cast",
200-
"ScatterElements",
201-
"Cast",
202-
"Sub",
203-
"Cast",
204-
]
237+
default_value_16 = (
238+
[
239+
"Cast",
240+
"ScatterElements",
241+
"Cast",
242+
"Sub",
243+
"Cast",
244+
],
245+
[
246+
"Cast",
247+
"Cast",
248+
"ScatterElements",
249+
"Cast",
250+
"Sub",
251+
"Cast",
252+
],
253+
)
205254
expected = {
206255
(np.float32, "none"): default_value,
207256
(np.float16, "none"): default_value_16,
@@ -222,20 +271,30 @@ def test_scatterels_cpu(self):
222271

223272
@unittest.skip("https://github.com/sdpython/onnx-diagnostic/issues/240")
224273
@ignore_warnings(DeprecationWarning)
274+
@requires_onnxruntime("1.23")
225275
def test_scatternd_cpu(self):
226276
default_value = [
227277
"Cast",
228278
"ScatterND",
229279
"Sub",
230280
]
231-
default_value_16 = [
232-
"Cast",
233-
"Cast",
234-
"ScatterND",
235-
"Cast",
236-
"Sub",
237-
"Cast",
238-
]
281+
default_value_16 = (
282+
[
283+
"Cast",
284+
"ScatterND",
285+
"Cast",
286+
"Sub",
287+
"Cast",
288+
],
289+
[
290+
"Cast",
291+
"Cast",
292+
"ScatterND",
293+
"Cast",
294+
"Sub",
295+
"Cast",
296+
],
297+
)
239298
expected = {
240299
(np.float32, "none"): default_value,
241300
(np.float16, "none"): default_value_16,

onnx_diagnostic/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,5 +3,5 @@
33
Functions, classes to dig into a model when this one is right, slow, wrong...
44
"""
55

6-
__version__ = "0.7.12"
6+
__version__ = "0.7.13"
77
__author__ = "Xavier Dupré"

onnx_diagnostic/_command_lines_parser.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -581,6 +581,7 @@ def _cmd_validate(argv: List[Any]):
581581
):
582582
print(f"validate - unsupported args: export={args.export!r}, opt={args.opt!r}")
583583
return
584+
patch_dict = args.patch if isinstance(args.patch, dict) else {"patch": args.patch}
584585
summary, _data = validate_model(
585586
model_id=args.mid,
586587
task=args.task,
@@ -591,8 +592,8 @@ def _cmd_validate(argv: List[Any]):
591592
use_pretrained=args.trained,
592593
dtype=args.dtype,
593594
device=args.device,
594-
patch=args.patch,
595-
rewrite=args.rewrite,
595+
patch=patch_dict,
596+
rewrite=args.rewrite and patch_dict.get("patch", True),
596597
stop_if_static=args.stop_if_static,
597598
optimization=args.opt,
598599
exporter=args.export,
@@ -827,6 +828,8 @@ def get_parser_agg() -> ArgumentParser:
827828
"n_model_running,n_model_acc01,n_model_acc001,n_model_dynamic,"
828829
"n_model_pass,n_model_faster,"
829830
"n_model_faster2x,n_model_faster3x,n_model_faster4x,n_node_attention,"
831+
"n_node_attention23,n_node_rotary_embedding,n_node_rotary_embedding23,"
832+
"n_node_layer_normalization,n_node_layer_normalization23,"
830833
"peak_gpu_torch,peak_gpu_nvidia,n_node_control_flow,"
831834
"n_node_constant,n_node_shape,n_node_expand,"
832835
"n_node_function,n_node_initializer,n_node_scatter,"

0 commit comments

Comments
 (0)