Skip to content

Commit d7f7515

Browse files
committed
Improve command line find
1 parent 3ff0c54 commit d7f7515

File tree

4 files changed

+387
-9
lines changed

4 files changed

+387
-9
lines changed

_unittests/ut_helpers/test_onnx_helper.py

Lines changed: 151 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,10 +16,12 @@
1616
iterator_initializer_constant,
1717
from_array_extended,
1818
tensor_statistics,
19+
enumerate_results,
1920
)
2021

2122

2223
TFLOAT = TensorProto.FLOAT
24+
TINT64 = TensorProto.INT64
2325

2426

2527
class TestOnnxHelper(ExtTestCase):
@@ -251,6 +253,155 @@ def test_statistics(self):
251253
stat = tensor_statistics(rnd)
252254
self.assertEqual(stat["stype"], "FLOAT")
253255

256+
@hide_stdout()
257+
def test_enumerate_results(self):
258+
model = oh.make_model(
259+
oh.make_graph(
260+
[
261+
oh.make_node("Unsqueeze", ["X", "zero"], ["xu1"]),
262+
oh.make_node("Unsqueeze", ["xu1", "un"], ["xu2"]),
263+
oh.make_node("Reshape", ["xu2", "shape1"], ["xm1"]),
264+
oh.make_node("Reshape", ["Y", "shape2"], ["xm2c"]),
265+
oh.make_node("Cast", ["xm2c"], ["xm2"], to=1),
266+
oh.make_node("MatMul", ["xm1", "xm2"], ["xm"]),
267+
oh.make_node("Reshape", ["xm", "shape3"], ["Z"]),
268+
],
269+
"dummy",
270+
[oh.make_tensor_value_info("X", TFLOAT, [320, 1280])],
271+
[oh.make_tensor_value_info("Z", TFLOAT, [3, 5, 320, 640])],
272+
[
273+
onh.from_array(
274+
np.random.rand(3, 5, 1280, 640).astype(np.float32), name="Y"
275+
),
276+
onh.from_array(np.array([0], dtype=np.int64), name="zero"),
277+
onh.from_array(np.array([1], dtype=np.int64), name="un"),
278+
onh.from_array(np.array([1, 320, 1280], dtype=np.int64), name="shape1"),
279+
onh.from_array(np.array([15, 1280, 640], dtype=np.int64), name="shape2"),
280+
onh.from_array(np.array([3, 5, 320, 640], dtype=np.int64), name="shape3"),
281+
],
282+
),
283+
opset_imports=[oh.make_opsetid("", 18)],
284+
ir_version=9,
285+
)
286+
res = list(enumerate_results(model, "xu1", verbose=2))
287+
ress = ";".join(str(r) for r in res)
288+
self.assertEqual(
289+
"<< xu1 - (0:Unsqueeze:) :: Unsqueeze(X, zero) -> xu1;"
290+
">> xu1 - (1:Unsqueeze:) :: Unsqueeze(xu1, un) -> xu2",
291+
ress,
292+
)
293+
self.assertEqual(2, len(list(enumerate_results(model, "shape1", verbose=2))))
294+
self.assertEqual(2, len(list(enumerate_results(model, "X", verbose=2))))
295+
self.assertEqual(2, len(list(enumerate_results(model, "Z", verbose=2))))
296+
297+
def test_enumerate_results_loop(self):
298+
x = np.array([1, 2, 3, 4, 5]).astype(np.float32)
299+
300+
model = oh.make_model(
301+
graph=oh.make_graph(
302+
name="loop_test",
303+
inputs=[
304+
oh.make_tensor_value_info("trip_count", TINT64, ["a"]),
305+
oh.make_tensor_value_info("cond", TensorProto.BOOL, [1]),
306+
],
307+
outputs=[oh.make_tensor_value_info("res", TFLOAT, [])],
308+
nodes=[
309+
oh.make_node("SequenceEmpty", [], ["seq_empty"], dtype=TFLOAT),
310+
oh.make_node(
311+
"Loop",
312+
inputs=["trip_count", "cond", "seq_empty"],
313+
outputs=["seq_res"],
314+
body=oh.make_graph(
315+
[
316+
oh.make_node(
317+
"Identity", inputs=["cond_in"], outputs=["cond_out"]
318+
),
319+
oh.make_node(
320+
"Constant",
321+
inputs=[],
322+
outputs=["x"],
323+
value=oh.make_tensor(
324+
name="const_tensor_x",
325+
data_type=TFLOAT,
326+
dims=x.shape,
327+
vals=x.flatten().astype(float),
328+
),
329+
),
330+
oh.make_node(
331+
"Constant",
332+
inputs=[],
333+
outputs=["one"],
334+
value=oh.make_tensor(
335+
name="const_tensor_one",
336+
data_type=TINT64,
337+
dims=(),
338+
vals=[1],
339+
),
340+
),
341+
oh.make_node(
342+
"Constant",
343+
inputs=[],
344+
outputs=["slice_start"],
345+
value=oh.make_tensor(
346+
name="const_tensor_zero",
347+
data_type=TINT64,
348+
dims=(1,),
349+
vals=[0],
350+
),
351+
),
352+
oh.make_node(
353+
"Add", inputs=["iter_count", "one"], outputs=["end"]
354+
),
355+
oh.make_node(
356+
"Constant",
357+
inputs=[],
358+
outputs=["axes"],
359+
value=oh.make_tensor(
360+
name="const_tensor_axes",
361+
data_type=TINT64,
362+
dims=(1,),
363+
vals=[0],
364+
),
365+
),
366+
oh.make_node(
367+
"Unsqueeze", inputs=["end", "axes"], outputs=["slice_end"]
368+
),
369+
oh.make_node(
370+
"Slice",
371+
inputs=["x", "slice_start", "slice_end"],
372+
outputs=["slice_out"],
373+
),
374+
oh.make_node(
375+
"SequenceInsert",
376+
inputs=["seq_in", "slice_out"],
377+
outputs=["seq_out"],
378+
),
379+
],
380+
"loop_body",
381+
[
382+
oh.make_tensor_value_info("iter_count", TINT64, []),
383+
oh.make_tensor_value_info("cond_in", TensorProto.BOOL, []),
384+
oh.make_tensor_sequence_value_info("seq_in", TFLOAT, None),
385+
],
386+
[
387+
oh.make_tensor_value_info("cond_out", TensorProto.BOOL, []),
388+
oh.make_tensor_sequence_value_info("seq_out", TFLOAT, None),
389+
],
390+
),
391+
),
392+
oh.make_node(
393+
"ConcatFromSequence",
394+
inputs=["seq_res"],
395+
outputs=["res"],
396+
axis=0,
397+
new_axis=0,
398+
),
399+
],
400+
)
401+
)
402+
res = list(enumerate_results(model, "slice_start", verbose=2))
403+
print(res)
404+
254405

255406
if __name__ == "__main__":
256407
unittest.main(verbosity=2)

_unittests/ut_xrun_doc/test_command_lines_exe.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,13 @@ def test_parser_find(self):
3838
text = st.getvalue()
3939
self.assertIsInstance(text, str)
4040

41+
def test_parser_find_v2(self):
42+
st = StringIO()
43+
with redirect_stdout(st):
44+
main(["find", "-i", self.dummy_path, "-n", "node_Add_188", "--v2"])
45+
text = st.getvalue()
46+
self.assertIsInstance(text, str)
47+
4148
def test_parser_config(self):
4249
st = StringIO()
4350
with redirect_stdout(st):

onnx_diagnostic/_command_lines_parser.py

Lines changed: 13 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -197,18 +197,29 @@ def get_parser_find() -> ArgumentParser:
197197
"-v",
198198
"--verbose",
199199
default=0,
200+
type=int,
200201
required=False,
201202
help="verbosity",
202203
)
204+
parser.add_argument(
205+
"--v2",
206+
default=False,
207+
action=BooleanOptionalAction,
208+
help="use enumerate_results instead of onnx_find",
209+
)
203210
return parser
204211

205212

206213
def _cmd_find(argv: List[Any]):
207-
from .helpers.onnx_helper import onnx_find
214+
from .helpers.onnx_helper import onnx_find, enumerate_results
208215

209216
parser = get_parser_find()
210217
args = parser.parse_args(argv[1:])
211-
onnx_find(args.input, verbose=args.verbose, watch=set(args.names.split(",")))
218+
if args.v2:
219+
onx = onnx.load(args.input, load_external_data=False)
220+
list(enumerate_results(onx, name=set(args.names.split(",")), verbose=args.verbose))
221+
else:
222+
onnx_find(args.input, verbose=args.verbose, watch=set(args.names.split(",")))
212223

213224

214225
def get_parser_config() -> ArgumentParser:

0 commit comments

Comments
 (0)