Skip to content

Commit 85894da

Browse files
committed
better doc
1 parent 389578a commit 85894da

File tree

1 file changed

+84
-7
lines changed
  • onnx_diagnostic/torch_onnx

1 file changed

+84
-7
lines changed

onnx_diagnostic/torch_onnx/sbs.py

Lines changed: 84 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,20 @@
11
import inspect
22
import os
3+
import textwrap
34
import time
45
from dataclasses import dataclass
5-
from typing import Any, Callable, Dict, Iterator, List, Optional, Self, Set, Tuple, Union
6+
from typing import Any, Callable, Dict, Iterator, List, Optional, Set, Tuple, Union
67
import onnx
78
import onnx.helper as oh
89
import numpy as np
910
import torch
1011
from ..helpers import string_type, string_diff, max_diff, flatten_object
11-
from ..helpers.onnx_helper import pretty_onnx, extract_subset_of_nodes, make_submodel
12+
from ..helpers.onnx_helper import (
13+
pretty_onnx,
14+
extract_subset_of_nodes,
15+
make_submodel,
16+
from_array_extended,
17+
)
1218
from ..helpers.torch_helper import to_numpy, from_numpy, to_tensor, torch_dtype_to_onnx_dtype
1319

1420

@@ -222,6 +228,62 @@ def select(
222228
return True
223229
return False
224230

231+
def get_replay_code(self) -> str:
232+
"""
233+
Returns a code letting the user replay the onnx model.
234+
It looks like the following.
235+
236+
.. runpython::
237+
:showcode:
238+
239+
from onnx_diagnostic.torch_onnx.sbs import ReplayConfiguration
240+
241+
rc = ReplayConfiguration(dump_folder="unsued")
242+
print(rc.get_replay_code())
243+
"""
244+
return textwrap.dedent(
245+
"""
246+
import onnx
247+
import torch
248+
from onnx_diagnostic.helpers import max_diff, string_diff, string_type
249+
from onnx_diagnostic.helpers.onnx_helper import pretty_onnx
250+
from onnx_diagnostic.reference import OnnxruntimeEvaluator
251+
252+
skws = dict(with_shape=True, with_device=True)
253+
254+
torch_inputs = torch.load("torch_inputs.pt")
255+
onnx_inputs = torch.load("onnx_inputs.pt")
256+
expected_outputs_and_mapping = torch.load("torch_outputs_and_mapping.pt")
257+
expected = expected_outputs_and_mapping["expected"]
258+
mapping = expected_outputs_and_mapping["mapping"]
259+
260+
print(f"-- torch_inputs={string_type(torch_inputs, **skws)}")
261+
print(f"-- onnx_inputs={string_type(onnx_inputs, **skws)}")
262+
print(f"-- expected={string_type(expected, **skws)}")
263+
print(f"-- mapping={mapping}")
264+
265+
model = onnx.load("model.onnx")
266+
print("-- model.onnx")
267+
print(pretty_onnx(model))
268+
print("--")
269+
270+
print("-- run with onnx_inputs")
271+
sess = OnnxruntimeEvaluator(model, whole=True)
272+
feeds = onnx_inputs
273+
obtained = sess.run(None, feeds)
274+
print(f"-- obtained={string_type(obtained, **skws)}")
275+
diff = max_diff(expected, tuple(obtained))
276+
print(f"-- diff: {string_diff(diff)}")
277+
278+
print("-- run with torch_inputs")
279+
feeds = {k: torch_inputs[mapping[k]] for k in feeds}
280+
obtained = sess.run(None, feeds)
281+
print(f"-- obtained={string_type(obtained, **skws)}")
282+
diff = max_diff(expected, tuple(obtained))
283+
print(f"-- diff: {string_diff(diff)}")
284+
"""
285+
)
286+
225287
def dump(
226288
self,
227289
name: str,
@@ -278,14 +340,27 @@ def dump(
278340
os.makedirs(folder, exist_ok=True)
279341
if verbose:
280342
print(f"[ReplayConfiguration.dump] dumps into folder {folder!r}")
281-
onnx.save(submodel, os.path.join(folder, "model.onnx"))
343+
282344
torch_inputs = {}
345+
removed_inputs = set()
283346
for n in input_names:
284347
if n in onnx_name_to_ep_name:
285348
torch_inputs[n] = torch_results[onnx_name_to_ep_name[n]]
286349
else:
287-
# It is possible that this result only exists in the onnx worlds.
288-
pass
350+
# We add that input as an initializer because it is probably a constant.
351+
submodel.graph.initializer.append(from_array_extended(onnx_results[n], name=n))
352+
removed_inputs.add(n)
353+
354+
if removed_inputs:
355+
input_names = [i for i in input_names if i not in removed_inputs]
356+
new_inputs = [i for i in submodel.graph.input if i.name not in removed_inputs]
357+
del submodel.graph.input[:]
358+
submodel.graph.input.extend(new_inputs)
359+
if verbose:
360+
print(f"[ReplayConfiguration.dump] removed input {removed_inputs}")
361+
print(f"[ReplayConfiguration.dump] final model inputs {input_names}")
362+
363+
onnx.save(submodel, os.path.join(folder, "model.onnx"))
289364
onnx_inputs = {n: onnx_results[n] for n in input_names}
290365
assert (
291366
name in onnx_name_to_ep_name
@@ -301,6 +376,8 @@ def dump(
301376
torch.save(
302377
expected_outputs_and_mapping, os.path.join(folder, "torch_outputs_and_mapping.pt")
303378
)
379+
with open(os.path.join(folder, "replay.py"), "w") as f:
380+
f.write(self.get_replay_code())
304381
if verbose:
305382
print(f"[ReplayConfiguration.dump] done {folder!r}")
306383
return folder
@@ -363,7 +440,7 @@ def __post_init__(self):
363440
f"ep_id_node={self.ep_id_node}"
364441
)
365442

366-
def set_diff(self, diff: Dict[str, Any]):
443+
def set_diff(self, diff: Dict[str, Any]) -> "Self": # noqa: F821
367444
"""Sets error."""
368445
if diff is None:
369446
return
@@ -398,7 +475,7 @@ def check(
398475
Tuple[Optional[int], Optional[int], Optional[int], Optional[str], Optional[str]],
399476
int,
400477
],
401-
) -> Self:
478+
) -> "Self": # noqa: F821
402479
"Checks a record was not already yielded."
403480
if self.onnx_op_type == "reset":
404481
# no record for this one

0 commit comments

Comments
 (0)