Skip to content

Commit 0a4d2c1

Browse files
authored
Adds input/output to the text rendering (#109)
* Adds input/output to the text rendering * mypy
1 parent 770ef84 commit 0a4d2c1

File tree

3 files changed

+68
-21
lines changed

3 files changed

+68
-21
lines changed

_unittests/ut_helpers/test_graph_helper.py

Lines changed: 19 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -126,21 +126,31 @@ def test_text_rendering(self):
126126
)
127127
graph = GraphRendering(proto)
128128
text = textwrap.dedent(graph.text_rendering(prefix="|")).strip("\n")
129+
print()
130+
print(text)
129131
expected = textwrap.dedent(
130132
"""
131133
|
132134
|
133135
|
134136
|
135-
| Add Neg
136-
| | |
137-
| +-------+-------+
138-
| |
139-
| Mul
140-
| |
141-
| +-------+
142-
| |
143-
| Mul
137+
| X Y
138+
| | |
139+
| +------+-----+-----------+
140+
| | |
141+
| Add Neg
142+
| | |
143+
| +-----+-----------+
144+
| |
145+
| Mul
146+
| |
147+
| +-----------+
148+
| |
149+
| Mul
150+
| |
151+
| +----------------------+
152+
| |
153+
| Z
144154
|
145155
|
146156
"""

_unittests/ut_torch_models/test_tiny_llms_onnx.py

Lines changed: 9 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -28,14 +28,15 @@ def test_onnx_export_tiny_llm_official(self):
2828
self.assertEqual(
2929
{"attention_mask", "past_key_values", "input_ids", "position_ids"}, set(inputs)
3030
)
31-
ep = torch.onnx.export(
32-
model,
33-
(),
34-
kwargs=inputs,
35-
dynamic_shapes=data["dynamic_shapes"],
36-
dynamo=True,
37-
optimize=True,
38-
)
31+
with torch_export_patches(patch_transformers=True):
32+
ep = torch.onnx.export(
33+
model,
34+
(),
35+
kwargs=inputs,
36+
dynamic_shapes=data["dynamic_shapes"],
37+
dynamo=True,
38+
optimize=True,
39+
)
3940
# There are some discrepancies with torch==2.6
4041
if not has_torch("2.7"):
4142
raise unittest.SkipTest("discrepancies observed with torch<2.7")

onnx_diagnostic/helpers/graph_helper.py

Lines changed: 40 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
from typing import Dict, List, Optional, Sequence, Set, Tuple, Union
22
import onnx
3+
import onnx.helper as oh
34

45

56
class GraphRendering:
@@ -146,6 +147,34 @@ def start_names(self) -> List[onnx.NodeProto]:
146147
)
147148
return [*input_names, *init_names]
148149

150+
@property
151+
def input_names(self) -> List[str]:
152+
"Returns the list of input names."
153+
return (
154+
self.proto.input
155+
if isinstance(self.proto, onnx.FunctionProto)
156+
else [
157+
i.name
158+
for i in (
159+
self.proto if isinstance(self.proto, onnx.GraphProto) else self.proto.graph
160+
).input
161+
]
162+
)
163+
164+
@property
165+
def output_names(self) -> List[str]:
166+
"Returns the list of output names."
167+
return (
168+
self.proto.output
169+
if isinstance(self.proto, onnx.FunctionProto)
170+
else [
171+
i.name
172+
for i in (
173+
self.proto if isinstance(self.proto, onnx.GraphProto) else self.proto.graph
174+
).output
175+
]
176+
)
177+
149178
@classmethod
150179
def build_node_edges(cls, nodes: Sequence[onnx.NodeProto]) -> Set[Tuple[int, int]]:
151180
"""Builds the list of edges between nodes."""
@@ -210,7 +239,7 @@ def text_edge(
210239
else:
211240
middle = (p1[1] + p2[1]) // 2
212241
a, b = (p1[1] + 1, middle - 1) if p1[1] < middle else (middle + 1, p1[1] - 1)
213-
grid[p1[0] + 2][a : b + 1] = ["-" * (b - a + 1)]
242+
grid[p1[0] + 2][a : b + 1] = ["-"] * (b - a + 1)
214243
a, b = (p1[1] + 1, middle - 1) if p1[1] < middle else (middle + 1, p1[1] - 1)
215244
grid[p1[0] + 2][a : b + 1] = ["-"] * (b - a + 1)
216245

@@ -223,7 +252,6 @@ def text_edge(
223252
def text_rendering(self, prefix="") -> str:
224253
"""
225254
Renders a model in text.
226-
It only renders nodes.
227255
228256
.. runpython::
229257
:showcode:
@@ -257,8 +285,14 @@ def text_rendering(self, prefix="") -> str:
257285
text = textwrap.dedent(graph.text_rendering()).strip("\\n")
258286
print(text)
259287
"""
260-
nodes = self.nodes
261-
existing = self.start_names
288+
nodes = [
289+
*[oh.make_node(i, ["BEGIN"], [i]) for i in self.input_names],
290+
*self.nodes,
291+
*[oh.make_node(i, [i], ["END"]) for i in self.output_names],
292+
]
293+
exist = set(self.start_names) - set(self.input_names)
294+
exist |= {"BEGIN"}
295+
existing = sorted(exist)
262296
order = self.computation_order(nodes, existing)
263297
positions = self.graph_positions(nodes, order, existing)
264298
text_pos = self.text_positions(nodes, positions)
@@ -269,8 +303,10 @@ def text_rendering(self, prefix="") -> str:
269303

270304
for n1, n2 in edges:
271305
self.text_edge(grid, text_pos[n1], text_pos[n2])
306+
assert len(set(len(g) for g in grid)) == 1, f"lengths={[len(g) for g in grid]}"
272307
for node, pos in zip(nodes, text_pos):
273308
self.text_grid(grid, pos, node.op_type)
309+
assert len(set(len(g) for g in grid)) == 1, f"lengths={[len(g) for g in grid]}"
274310

275311
return "\n".join(
276312
f"{prefix}{line.rstrip()}" for line in ["".join(line) for line in grid]

0 commit comments

Comments
 (0)