Skip to content

Commit 694ccb0

Browse files
authored
Improve command line find (#140)
* Improve command line find * version * spell * style * fix atol * add shadowing * shadow * post-shadow
1 parent 3ff0c54 commit 694ccb0

File tree

9 files changed

+546
-19
lines changed

9 files changed

+546
-19
lines changed

CHANGELOGS.rst

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,11 @@
11
Change Logs
22
===========
33

4+
0.6.3
5+
+++++
6+
7+
* :pr:`140`: improves command line find
8+
49
0.6.2
510
+++++
611

_doc/index.rst

Lines changed: 1 addition & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -206,19 +206,14 @@ The function replaces dynamic dimensions defined as strings by
206206
Older versions
207207
++++++++++++++
208208

209+
* `0.6.3 <../v0.6.3/index.html>`_
209210
* `0.6.2 <../v0.6.2/index.html>`_
210211
* `0.6.1 <../v0.6.1/index.html>`_
211212
* `0.6.0 <../v0.6.0/index.html>`_
212213
* `0.5.0 <../v0.5.0/index.html>`_
213214
* `0.4.4 <../v0.4.4/index.html>`_
214-
* `0.4.3 <../v0.4.3/index.html>`_
215-
* `0.4.2 <../v0.4.2/index.html>`_
216-
* `0.4.1 <../v0.4.1/index.html>`_
217-
* `0.4.0 <../v0.4.0/index.html>`_
218215
* `0.3.0 <../v0.3.0/index.html>`_
219216
* `0.2.2 <../v0.2.2/index.html>`_
220-
* `0.2.1 <../v0.2.1/index.html>`_
221-
* `0.2.0 <../v0.2.0/index.html>`_
222217
* `0.1.0 <../v0.1.0/index.html>`_
223218

224219
The documentation was updated on:

_unittests/ut_helpers/test_onnx_helper.py

Lines changed: 216 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,10 +16,13 @@
1616
iterator_initializer_constant,
1717
from_array_extended,
1818
tensor_statistics,
19+
enumerate_results,
20+
shadowing_names,
1921
)
2022

2123

2224
TFLOAT = TensorProto.FLOAT
25+
TINT64 = TensorProto.INT64
2326

2427

2528
class TestOnnxHelper(ExtTestCase):
@@ -251,6 +254,219 @@ def test_statistics(self):
251254
stat = tensor_statistics(rnd)
252255
self.assertEqual(stat["stype"], "FLOAT")
253256

257+
@hide_stdout()
258+
def test_enumerate_results(self):
259+
model = oh.make_model(
260+
oh.make_graph(
261+
[
262+
oh.make_node("Unsqueeze", ["X", "zero"], ["xu1"]),
263+
oh.make_node("Unsqueeze", ["xu1", "un"], ["xu2"]),
264+
oh.make_node("Reshape", ["xu2", "shape1"], ["xm1"]),
265+
oh.make_node("Reshape", ["Y", "shape2"], ["xm2c"]),
266+
oh.make_node("Cast", ["xm2c"], ["xm2"], to=1),
267+
oh.make_node("MatMul", ["xm1", "xm2"], ["xm"]),
268+
oh.make_node("Reshape", ["xm", "shape3"], ["Z"]),
269+
],
270+
"dummy",
271+
[oh.make_tensor_value_info("X", TFLOAT, [320, 1280])],
272+
[oh.make_tensor_value_info("Z", TFLOAT, [3, 5, 320, 640])],
273+
[
274+
onh.from_array(
275+
np.random.rand(3, 5, 1280, 640).astype(np.float32), name="Y"
276+
),
277+
onh.from_array(np.array([0], dtype=np.int64), name="zero"),
278+
onh.from_array(np.array([1], dtype=np.int64), name="un"),
279+
onh.from_array(np.array([1, 320, 1280], dtype=np.int64), name="shape1"),
280+
onh.from_array(np.array([15, 1280, 640], dtype=np.int64), name="shape2"),
281+
onh.from_array(np.array([3, 5, 320, 640], dtype=np.int64), name="shape3"),
282+
],
283+
),
284+
opset_imports=[oh.make_opsetid("", 18)],
285+
ir_version=9,
286+
)
287+
res = list(enumerate_results(model, "xu1", verbose=2))
288+
ress = ";".join(str(r) for r in res)
289+
self.assertEqual(
290+
"<< xu1 - (0:Unsqueeze:) :: Unsqueeze(X, zero) -> xu1;"
291+
">> xu1 - (1:Unsqueeze:) :: Unsqueeze(xu1, un) -> xu2",
292+
ress,
293+
)
294+
self.assertEqual(2, len(list(enumerate_results(model, "shape1", verbose=2))))
295+
self.assertEqual(2, len(list(enumerate_results(model, "X", verbose=2))))
296+
self.assertEqual(2, len(list(enumerate_results(model, "Z", verbose=2))))
297+
298+
def test_enumerate_results_loop(self):
299+
x = np.array([1, 2, 3, 4, 5]).astype(np.float32)
300+
301+
model = oh.make_model(
302+
graph=oh.make_graph(
303+
name="loop_test",
304+
inputs=[
305+
oh.make_tensor_value_info("trip_count", TINT64, ["a"]),
306+
oh.make_tensor_value_info("cond", TensorProto.BOOL, [1]),
307+
],
308+
outputs=[oh.make_tensor_value_info("res", TFLOAT, [])],
309+
nodes=[
310+
oh.make_node("SequenceEmpty", [], ["seq_empty"], dtype=TFLOAT),
311+
oh.make_node(
312+
"Loop",
313+
inputs=["trip_count", "cond", "seq_empty"],
314+
outputs=["seq_res"],
315+
body=oh.make_graph(
316+
[
317+
oh.make_node(
318+
"Identity", inputs=["cond_in"], outputs=["cond_out"]
319+
),
320+
oh.make_node(
321+
"Constant",
322+
inputs=[],
323+
outputs=["x"],
324+
value=oh.make_tensor(
325+
name="const_tensor_x",
326+
data_type=TFLOAT,
327+
dims=x.shape,
328+
vals=x.flatten().astype(float),
329+
),
330+
),
331+
oh.make_node(
332+
"Constant",
333+
inputs=[],
334+
outputs=["one"],
335+
value=oh.make_tensor(
336+
name="const_tensor_one",
337+
data_type=TINT64,
338+
dims=(),
339+
vals=[1],
340+
),
341+
),
342+
oh.make_node(
343+
"Constant",
344+
inputs=[],
345+
outputs=["slice_start"],
346+
value=oh.make_tensor(
347+
name="const_tensor_zero",
348+
data_type=TINT64,
349+
dims=(1,),
350+
vals=[0],
351+
),
352+
),
353+
oh.make_node(
354+
"Add", inputs=["iter_count", "one"], outputs=["end"]
355+
),
356+
oh.make_node(
357+
"Constant",
358+
inputs=[],
359+
outputs=["axes"],
360+
value=oh.make_tensor(
361+
name="const_tensor_axes",
362+
data_type=TINT64,
363+
dims=(1,),
364+
vals=[0],
365+
),
366+
),
367+
oh.make_node(
368+
"Unsqueeze", inputs=["end", "axes"], outputs=["slice_end"]
369+
),
370+
oh.make_node(
371+
"Slice",
372+
inputs=["x", "slice_start", "slice_end"],
373+
outputs=["slice_out"],
374+
),
375+
oh.make_node(
376+
"SequenceInsert",
377+
inputs=["seq_in", "slice_out"],
378+
outputs=["seq_out"],
379+
),
380+
],
381+
"loop_body",
382+
[
383+
oh.make_tensor_value_info("iter_count", TINT64, []),
384+
oh.make_tensor_value_info("cond_in", TensorProto.BOOL, []),
385+
oh.make_tensor_sequence_value_info("seq_in", TFLOAT, None),
386+
],
387+
[
388+
oh.make_tensor_value_info("cond_out", TensorProto.BOOL, []),
389+
oh.make_tensor_sequence_value_info("seq_out", TFLOAT, None),
390+
],
391+
),
392+
),
393+
oh.make_node(
394+
"ConcatFromSequence",
395+
inputs=["seq_res"],
396+
outputs=["res"],
397+
axis=0,
398+
new_axis=0,
399+
),
400+
],
401+
)
402+
)
403+
res = list(enumerate_results(model, "slice_start", verbose=2))
404+
self.assertEqual(len(res), 2)
405+
406+
def test_shadowing_names(self):
407+
def _mkv_(name):
408+
value_info_proto = ValueInfoProto()
409+
value_info_proto.name = name
410+
return value_info_proto
411+
412+
model = oh.make_model(
413+
oh.make_graph(
414+
[
415+
oh.make_node("ReduceSum", ["X"], ["Xred"]),
416+
oh.make_node("Add", ["X", "two"], ["X0"]),
417+
oh.make_node("Add", ["X0", "zero"], ["X00"]),
418+
oh.make_node("CastLike", ["one", "Xred"], ["one_c"]),
419+
oh.make_node("Greater", ["Xred", "one_c"], ["cond"]),
420+
oh.make_node("Identity", ["two"], ["three"]),
421+
oh.make_node(
422+
"If",
423+
["cond"],
424+
["Z_c"],
425+
then_branch=oh.make_graph(
426+
[
427+
# shadowing
428+
oh.make_node("Constant", [], ["three"], value_floats=[2.1]),
429+
oh.make_node("Add", ["X00", "three"], ["Y"]),
430+
],
431+
"then",
432+
[],
433+
[_mkv_("Y")],
434+
),
435+
else_branch=oh.make_graph(
436+
[
437+
# not shadowing
438+
oh.make_node("Sub", ["X0", "three"], ["Y"]),
439+
],
440+
"else",
441+
[],
442+
[_mkv_("Y")],
443+
),
444+
),
445+
oh.make_node("CastLike", ["Z_c", "X"], ["Z"]),
446+
],
447+
"test",
448+
[
449+
oh.make_tensor_value_info("X", TensorProto.FLOAT, ["N"]),
450+
oh.make_tensor_value_info("one", TensorProto.FLOAT, ["N"]),
451+
],
452+
[oh.make_tensor_value_info("Z", TensorProto.UNDEFINED, ["N"])],
453+
[
454+
onh.from_array(np.array([0], dtype=np.float32), name="zero"),
455+
onh.from_array(np.array([2], dtype=np.float32), name="two"),
456+
],
457+
),
458+
opset_imports=[oh.make_operatorsetid("", 18)],
459+
ir_version=10,
460+
)
461+
self.assertEqual(
462+
(
463+
{"three"},
464+
set(),
465+
{"cond", "Z", "X0", "Z_c", "three", "one_c", "Xred", "X00", "Y"},
466+
),
467+
shadowing_names(model),
468+
)
469+
254470

255471
if __name__ == "__main__":
256472
unittest.main(verbosity=2)

_unittests/ut_reference/test_ort_evaluator.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -261,7 +261,6 @@ def test_init_torch_bfloat16(self):
261261

262262
@hide_stdout()
263263
def test_if(self):
264-
265264
def _mkv_(name):
266265
value_info_proto = onnx.ValueInfoProto()
267266
value_info_proto.name = name

_unittests/ut_reference/test_torch_onnx_evaluator.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1468,7 +1468,7 @@ def run(self, x, scale, bias=None):
14681468
)
14691469
expected = torch_sess.run(None, feeds)
14701470
got = torch_sess_custom.run(None, feeds)
1471-
self.assertEqualAny(expected, got)
1471+
self.assertEqualAny(expected, got, atol=1e-3)
14721472
self.assertEqual([1], LayerNormalizationOrt._shared)
14731473

14741474

_unittests/ut_xrun_doc/test_command_lines_exe.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,13 @@ def test_parser_find(self):
3838
text = st.getvalue()
3939
self.assertIsInstance(text, str)
4040

41+
def test_parser_find_v2(self):
42+
st = StringIO()
43+
with redirect_stdout(st):
44+
main(["find", "-i", self.dummy_path, "-n", "node_Add_188", "--v2"])
45+
text = st.getvalue()
46+
self.assertIsInstance(text, str)
47+
4148
def test_parser_config(self):
4249
st = StringIO()
4350
with redirect_stdout(st):

onnx_diagnostic/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,5 +3,5 @@
33
Functions, classes to dig into a model when this one is right, slow, wrong...
44
"""
55

6-
__version__ = "0.6.2"
6+
__version__ = "0.6.3"
77
__author__ = "Xavier Dupré"

onnx_diagnostic/_command_lines_parser.py

Lines changed: 24 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -191,24 +191,45 @@ def get_parser_find() -> ArgumentParser:
191191
"--names",
192192
type=str,
193193
required=False,
194-
help="names to look at comma separated values",
194+
help="names to look at comma separated values, if 'SHADOW', "
195+
"search for shadowing names",
195196
)
196197
parser.add_argument(
197198
"-v",
198199
"--verbose",
199200
default=0,
201+
type=int,
200202
required=False,
201203
help="verbosity",
202204
)
205+
parser.add_argument(
206+
"--v2",
207+
default=False,
208+
action=BooleanOptionalAction,
209+
help="use enumerate_results instead of onnx_find",
210+
)
203211
return parser
204212

205213

206214
def _cmd_find(argv: List[Any]):
207-
from .helpers.onnx_helper import onnx_find
215+
from .helpers.onnx_helper import onnx_find, enumerate_results, shadowing_names
208216

209217
parser = get_parser_find()
210218
args = parser.parse_args(argv[1:])
211-
onnx_find(args.input, verbose=args.verbose, watch=set(args.names.split(",")))
219+
if args.names == "SHADOW":
220+
onx = onnx.load(args.input, load_external_data=False)
221+
s, ps = shadowing_names(onx)[:2]
222+
print(f"shadowing names: {s}")
223+
print(f"post-shadowing names: {ps}")
224+
elif args.v2:
225+
onx = onnx.load(args.input, load_external_data=False)
226+
res = list(
227+
enumerate_results(onx, name=set(args.names.split(",")), verbose=args.verbose)
228+
)
229+
if not args.verbose:
230+
print("\n".join(map(str, res)))
231+
else:
232+
onnx_find(args.input, verbose=args.verbose, watch=set(args.names.split(",")))
212233

213234

214235
def get_parser_config() -> ArgumentParser:

0 commit comments

Comments
 (0)