Skip to content

Commit b80e7b6

Browse files
committed
fix
1 parent f4fea2c commit b80e7b6

File tree

4 files changed

+124
-4
lines changed

4 files changed

+124
-4
lines changed

_unittests/ut_xrun_doc/test_command_lines.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
get_main_parser,
77
get_parser_agg,
88
get_parser_config,
9+
get_parser_dot,
910
get_parser_find,
1011
get_parser_lighten,
1112
get_parser_print,
@@ -23,6 +24,7 @@ def test_main_parser(self):
2324
get_main_parser().print_help()
2425
text = st.getvalue()
2526
self.assertIn("lighten", text)
27+
self.assertIn("dot", text)
2628

2729
def test_parser_lighten(self):
2830
st = StringIO()
@@ -87,6 +89,13 @@ def test_parser_sbs(self):
8789
text = st.getvalue()
8890
self.assertIn("--onnx", text)
8991

92+
def test_parser_dot(self):
93+
st = StringIO()
94+
with redirect_stdout(st):
95+
get_parser_dot().print_help()
96+
text = st.getvalue()
97+
self.assertIn("--run", text)
98+
9099

91100
if __name__ == "__main__":
92101
unittest.main(verbosity=2)

_unittests/ut_xrun_doc/test_command_lines_exe.py

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -162,6 +162,45 @@ def forward(self, x):
162162
sdf = df[(df.ep_target == "placeholder") & (df.onnx_op_type == "initializer")]
163163
self.assertEqual(sdf.shape[0], 4)
164164

165+
@ignore_warnings(UserWarning)
166+
@requires_transformers("4.53")
167+
def test_i_parser_dot(self):
168+
import torch
169+
170+
class Model(torch.nn.Module):
171+
def __init__(self):
172+
super(Model, self).__init__()
173+
self.fc1 = torch.nn.Linear(10, 32) # input size 10 → hidden size 32
174+
self.relu = torch.nn.ReLU()
175+
self.fc2 = torch.nn.Linear(32, 1) # hidden → output
176+
177+
def forward(self, x):
178+
x = self.relu(self.fc1(x))
179+
x = self.fc2(x)
180+
return x
181+
182+
inputs = dict(x=torch.randn((5, 10)))
183+
ds = dict(x={0: "batch"})
184+
onnx_file = self.get_dump_file("test_i_parser_dot.model.onnx")
185+
to_onnx(
186+
Model(),
187+
kwargs=inputs,
188+
dynamic_shapes=ds,
189+
exporter="custom",
190+
filename=onnx_file,
191+
)
192+
193+
output = self.get_dump_file("test_i_parser_dot.dot")
194+
args = ["dot", onnx_file, "-v", "1", "-o", output]
195+
if not self.unit_test_going():
196+
args.extend(["--run", "svg"])
197+
198+
st = StringIO()
199+
with redirect_stdout(st):
200+
main(args)
201+
text = st.getvalue()
202+
print(text)
203+
165204

166205
if __name__ == "__main__":
167206
unittest.main(verbosity=2)

onnx_diagnostic/_command_lines_parser.py

Lines changed: 75 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,77 @@
1111
from argparse import ArgumentParser, RawTextHelpFormatter, BooleanOptionalAction
1212

1313

14+
def get_parser_dot() -> ArgumentParser:
15+
parser = ArgumentParser(
16+
prog="dot",
17+
description=textwrap.dedent(
18+
"""
19+
Converts a model into a dot file dot can draw into a graph.
20+
"""
21+
),
22+
)
23+
parser.add_argument("input", type=str, help="onnx model to lighten")
24+
parser.add_argument(
25+
"-o",
26+
"--output",
27+
default="",
28+
type=str,
29+
required=False,
30+
help="dot model to output or empty to print out the result",
31+
)
32+
parser.add_argument(
33+
"-v",
34+
"--verbose",
35+
type=int,
36+
default=0,
37+
required=False,
38+
help="verbosity",
39+
)
40+
parser.add_argument(
41+
"-r",
42+
"--run",
43+
default="",
44+
required=False,
45+
help="run dot, in that case, format must be given (svg, png)",
46+
)
47+
return parser
48+
49+
50+
def _cmd_dot(argv: List[Any]):
51+
import subprocess
52+
from .helpers.dot_helper import to_dot
53+
54+
parser = get_parser_dot()
55+
args = parser.parse_args(argv[1:])
56+
if args.verbose:
57+
print(f"-- loads {args.input!r}")
58+
onx = onnx.load(args.input, load_external_data=False)
59+
if args.verbose:
60+
print("-- converts into dot")
61+
dot = to_dot(onx)
62+
if args.output:
63+
if args.verbose:
64+
print(f"-- saves into {args.output}")
65+
with open(args.output, "w") as f:
66+
f.write(dot)
67+
else:
68+
print(dot)
69+
if args.run:
70+
assert args.output, "Cannot run dot without an output file."
71+
cmds = ["dot", f"-T{args.run}", args.output, "-o", f"{args.output}.{args.run}"]
72+
if args.verbose:
73+
print(f"-- run {' '.join(cmds)}")
74+
p = subprocess.Popen(cmds, stdout=subprocess.PIPE, stderr=subprocess.PIPE)
75+
res = p.communicate()
76+
out, err = res
77+
if out:
78+
print("--")
79+
print(out)
80+
if err:
81+
print("--")
82+
print(err)
83+
84+
1485
def get_parser_lighten() -> ArgumentParser:
1586
parser = ArgumentParser(
1687
prog="lighten",
@@ -1412,6 +1483,7 @@ def get_main_parser() -> ArgumentParser:
14121483
14131484
agg - aggregates statistics from multiple files
14141485
config - prints a configuration for a model id
1486+
dot - converts an onnx model into dot format
14151487
exportsample - produces a code to export a model
14161488
find - find node consuming or producing a result
14171489
lighten - makes an onnx model lighter by removing the weights,
@@ -1428,6 +1500,7 @@ def get_main_parser() -> ArgumentParser:
14281500
choices=[
14291501
"agg",
14301502
"config",
1503+
"dot",
14311504
"exportsample",
14321505
"find",
14331506
"lighten",
@@ -1446,6 +1519,7 @@ def main(argv: Optional[List[Any]] = None):
14461519
fcts = dict(
14471520
agg=_cmd_agg,
14481521
config=_cmd_config,
1522+
dot=_cmd_dot,
14491523
exportsample=_cmd_export_sample,
14501524
find=_cmd_find,
14511525
lighten=_cmd_lighten,
@@ -1470,6 +1544,7 @@ def main(argv: Optional[List[Any]] = None):
14701544
parsers = dict(
14711545
agg=get_parser_agg,
14721546
config=get_parser_config,
1547+
dot=get_parser_dot,
14731548
exportsample=lambda: get_parser_validate("exportsample"), # type: ignore[operator]
14741549
find=get_parser_find,
14751550
lighten=get_parser_lighten,

onnx_diagnostic/helpers/dot_helper.py

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -153,15 +153,12 @@ def _mkn(obj: object) -> int:
153153
shape = tuple(init.dims)
154154
if len(shape) == 0 or (len(shape) == 1 and shape[0] < 10):
155155
a = onh.to_array(init)
156-
vals = f" = {a}" if len(shape) == 0 else f"\\n=[{', '.join([str(i) for i in a])}]"
157156
tiny_inits[init.name] = (
158157
str(a) if len(shape) == 0 else f"[{', '.join([str(i) for i in a])}]"
159158
)
160159
else:
161160
ls = f"{onnx_dtype_name(init.data_type)}({', '.join(map(str,shape))})"
162-
rows.append(
163-
f' i_{_mkn(init)} [label="{init.name}\\n{ls}{vals}", fillcolor="#cccc00"];'
164-
)
161+
rows.append(f' i_{_mkn(init)} [label="{init.name}\\n{ls}", fillcolor="#cccc00"];')
165162
name_to_ids[init.name] = f"i_{_mkn(init)}"
166163
edge_label[init.name] = ls
167164
for node in nodes:

0 commit comments

Comments
 (0)