Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CHANGELOGS.rst
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ Change Logs
0.8.3
+++++

* :pr:`331`: adds a helper to convert an onnx model into dot
* :pr:`330`: fixes access rope_parameters for ``transformers>=5``
* :pr:`329`: supports lists with OnnxruntimeEvaluator
* :pr:`326`: use ConcatFromSequence in LoopMHA with the loop
Expand Down
7 changes: 7 additions & 0 deletions _doc/api/helpers/dot_helper.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@

onnx_diagnostic.helpers.dot_helper
==================================

.. automodule:: onnx_diagnostic.helpers.dot_helper
:members:
:no-undoc-members:
1 change: 1 addition & 0 deletions _doc/api/helpers/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ onnx_diagnostic.helpers
cache_helper
config_helper
doc_helper
dot_helper
fake_tensor_helper
graph_helper
helper
Expand Down
2 changes: 1 addition & 1 deletion _doc/recipes/plot_dynamic_shapes_json.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,7 @@ def flatten_unflatten_like_dynamic_shapes(obj):
start = 0
end = 0
subtrees = []
for subspec in spec.children_specs:
for subspec in (spec.children() if hasattr(spec, "children") else spec.children_specs):
end += subspec.num_leaves
value = subspec.unflatten(flat[start:end])
value = flatten_unflatten_like_dynamic_shapes(value)
Expand Down
80 changes: 80 additions & 0 deletions _unittests/ut_helpers/test_dot_helper.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,80 @@
import textwrap
import unittest
import onnx
import onnx.helper as oh
from onnx_diagnostic.ext_test_case import ExtTestCase, requires_transformers
from onnx_diagnostic.helpers.dot_helper import to_dot
from onnx_diagnostic.export.api import to_onnx
from onnx_diagnostic.torch_export_patches import torch_export_patches
from onnx_diagnostic.torch_models.hghub import get_untrained_model_with_inputs


class TestDotHelper(ExtTestCase):
def test_custom_doc_kernels_layer_normalization(self):
TFLOAT16 = onnx.TensorProto.FLOAT16
model = oh.make_model(
oh.make_graph(
[
oh.make_node(
"LayerNormalization",
["X", "W", "B"],
["ln"],
axis=-1,
epsilon=9.999999974752427e-7,
),
oh.make_node(
"Add", ["ln", "W"], ["Z"], axis=-1, epsilon=9.999999974752427e-7
),
],
"dummy",
[
oh.make_tensor_value_info("X", TFLOAT16, ["b", "c", "d"]),
oh.make_tensor_value_info("W", TFLOAT16, ["d"]),
oh.make_tensor_value_info("B", TFLOAT16, ["d"]),
],
[oh.make_tensor_value_info("Z", TFLOAT16, ["b", "c", "d"])],
),
ir_version=9,
opset_imports=[oh.make_opsetid("", 18)],
)
dot = to_dot(model)
expected = textwrap.dedent(
"""
digraph {
graph [rankdir=TB, splines=true, overlap=false, nodesep=0.2, ranksep=0.2, fontsize=8];
node [style="rounded,filled", color="#888888", fontcolor="#222222", shape=box];
edge [arrowhead=vee, fontsize=7, labeldistance=-5, labelangle=0];
I_0 [label="X\\nFLOAT16(b,c,d)", fillcolor="#aaeeaa"];
I_1 [label="W\\nFLOAT16(d)", fillcolor="#aaeeaa"];
I_2 [label="B\\nFLOAT16(d)", fillcolor="#aaeeaa"];
LayerNormalization_3 [label="LayerNormalization(., ., ., axis=-1)", fillcolor="#cccccc"];
Add_4 [label="Add(., ., axis=-1)", fillcolor="#cccccc"];
I_0 -> LayerNormalization_3 [label="FLOAT16(b,c,d)"];
I_1 -> LayerNormalization_3 [label="FLOAT16(d)"];
I_2 -> LayerNormalization_3 [label="FLOAT16(d)"];
LayerNormalization_3 -> Add_4 [label="FLOAT16(b,c,d)"];
I_1 -> Add_4 [label="FLOAT16(d)"];
O_5 [label="Z\\nFLOAT16(d)", fillcolor="#aaaaee"];
Add_4 -> O_5;
}
"""
)
self.maxDiff = None
self.assertEqual(expected.strip("\n "), dot.strip("\n "))

@requires_transformers("4.57")
def test_dot_plot_tiny(self):
data = get_untrained_model_with_inputs("arnir0/Tiny-LLM")
model, inputs, ds = data["model"], data["inputs"], data["dynamic_shapes"]
with torch_export_patches(patch_transformers=True):
em = to_onnx(model, inputs, dynamic_shapes=ds, exporter="custom")
dot = to_dot(em.model_proto)
name = self.get_dump_file("test_dot_plot_tiny.dot")
with open(name, "w") as f:
f.write(dot)
# dot -Tpng dump_test/test_dot_plot_tiny.dot -o dump_test/test_dot_plot_tiny.png
self.assertIn("-> Add", dot)


if __name__ == "__main__":
unittest.main(verbosity=2)
9 changes: 9 additions & 0 deletions _unittests/ut_xrun_doc/test_command_lines.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
get_main_parser,
get_parser_agg,
get_parser_config,
get_parser_dot,
get_parser_find,
get_parser_lighten,
get_parser_print,
Expand All @@ -23,6 +24,7 @@ def test_main_parser(self):
get_main_parser().print_help()
text = st.getvalue()
self.assertIn("lighten", text)
self.assertIn("dot", text)

def test_parser_lighten(self):
st = StringIO()
Expand Down Expand Up @@ -87,6 +89,13 @@ def test_parser_sbs(self):
text = st.getvalue()
self.assertIn("--onnx", text)

def test_parser_dot(self):
st = StringIO()
with redirect_stdout(st):
get_parser_dot().print_help()
text = st.getvalue()
self.assertIn("--run", text)


if __name__ == "__main__":
unittest.main(verbosity=2)
39 changes: 39 additions & 0 deletions _unittests/ut_xrun_doc/test_command_lines_exe.py
Original file line number Diff line number Diff line change
Expand Up @@ -162,6 +162,45 @@ def forward(self, x):
sdf = df[(df.ep_target == "placeholder") & (df.onnx_op_type == "initializer")]
self.assertEqual(sdf.shape[0], 4)

@ignore_warnings(UserWarning)
@requires_transformers("4.53")
def test_i_parser_dot(self):
import torch

class Model(torch.nn.Module):
def __init__(self):
super(Model, self).__init__()
self.fc1 = torch.nn.Linear(10, 32) # input size 10 → hidden size 32
self.relu = torch.nn.ReLU()
self.fc2 = torch.nn.Linear(32, 1) # hidden → output

def forward(self, x):
x = self.relu(self.fc1(x))
x = self.fc2(x)
return x

inputs = dict(x=torch.randn((5, 10)))
ds = dict(x={0: "batch"})
onnx_file = self.get_dump_file("test_i_parser_dot.model.onnx")
to_onnx(
Model(),
kwargs=inputs,
dynamic_shapes=ds,
exporter="custom",
filename=onnx_file,
)

output = self.get_dump_file("test_i_parser_dot.dot")
args = ["dot", onnx_file, "-v", "1", "-o", output]
if not self.unit_test_going():
args.extend(["--run", "svg"])

st = StringIO()
with redirect_stdout(st):
main(args)
text = st.getvalue()
print(text)


if __name__ == "__main__":
unittest.main(verbosity=2)
75 changes: 75 additions & 0 deletions onnx_diagnostic/_command_lines_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,77 @@
from argparse import ArgumentParser, RawTextHelpFormatter, BooleanOptionalAction


def get_parser_dot() -> ArgumentParser:
parser = ArgumentParser(
prog="dot",
description=textwrap.dedent(
"""
Converts a model into a dot file dot can draw into a graph.
"""
),
)
parser.add_argument("input", type=str, help="onnx model to lighten")
parser.add_argument(
"-o",
"--output",
default="",
type=str,
required=False,
help="dot model to output or empty to print out the result",
)
parser.add_argument(
"-v",
"--verbose",
type=int,
default=0,
required=False,
help="verbosity",
)
parser.add_argument(
"-r",
"--run",
default="",
required=False,
help="run dot, in that case, format must be given (svg, png)",
)
return parser


def _cmd_dot(argv: List[Any]):
import subprocess
from .helpers.dot_helper import to_dot

parser = get_parser_dot()
args = parser.parse_args(argv[1:])
if args.verbose:
print(f"-- loads {args.input!r}")
onx = onnx.load(args.input, load_external_data=False)
if args.verbose:
print("-- converts into dot")
dot = to_dot(onx)
if args.output:
if args.verbose:
print(f"-- saves into {args.output}")
with open(args.output, "w") as f:
f.write(dot)
else:
print(dot)
if args.run:
assert args.output, "Cannot run dot without an output file."
cmds = ["dot", f"-T{args.run}", args.output, "-o", f"{args.output}.{args.run}"]
if args.verbose:
print(f"-- run {' '.join(cmds)}")
p = subprocess.Popen(cmds, stdout=subprocess.PIPE, stderr=subprocess.PIPE)
res = p.communicate()
out, err = res
if out:
print("--")
print(out)
if err:
print("--")
print(err)


def get_parser_lighten() -> ArgumentParser:
parser = ArgumentParser(
prog="lighten",
Expand Down Expand Up @@ -1412,6 +1483,7 @@ def get_main_parser() -> ArgumentParser:

agg - aggregates statistics from multiple files
config - prints a configuration for a model id
dot - converts an onnx model into dot format
exportsample - produces a code to export a model
find - find node consuming or producing a result
lighten - makes an onnx model lighter by removing the weights,
Expand All @@ -1428,6 +1500,7 @@ def get_main_parser() -> ArgumentParser:
choices=[
"agg",
"config",
"dot",
"exportsample",
"find",
"lighten",
Expand All @@ -1446,6 +1519,7 @@ def main(argv: Optional[List[Any]] = None):
fcts = dict(
agg=_cmd_agg,
config=_cmd_config,
dot=_cmd_dot,
exportsample=_cmd_export_sample,
find=_cmd_find,
lighten=_cmd_lighten,
Expand All @@ -1470,6 +1544,7 @@ def main(argv: Optional[List[Any]] = None):
parsers = dict(
agg=get_parser_agg,
config=get_parser_config,
dot=get_parser_dot,
exportsample=lambda: get_parser_validate("exportsample"), # type: ignore[operator]
find=get_parser_find,
lighten=get_parser_lighten,
Expand Down
2 changes: 1 addition & 1 deletion onnx_diagnostic/helpers/cache_helper.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,7 @@ def flatten_unflatten_for_dynamic_shapes(
start = 0
end = 0
subtrees = []
for subspec in spec.children_specs:
for subspec in (spec.children() if hasattr(spec, "children") else spec.children_specs):
end += subspec.num_leaves
value = subspec.unflatten(flat[start:end])
value = flatten_unflatten_for_dynamic_shapes(
Expand Down
Loading
Loading