11import inspect
22import os
3+ import textwrap
34import time
45from dataclasses import dataclass
5- from typing import Any , Callable , Dict , Iterator , List , Optional , Self , Set , Tuple , Union
6+ from typing import Any , Callable , Dict , Iterator , List , Optional , Set , Tuple , Union
67import onnx
78import onnx .helper as oh
89import numpy as np
910import torch
1011from ..helpers import string_type , string_diff , max_diff , flatten_object
11- from ..helpers .onnx_helper import pretty_onnx , extract_subset_of_nodes , make_submodel
12+ from ..helpers .onnx_helper import (
13+ pretty_onnx ,
14+ extract_subset_of_nodes ,
15+ make_submodel ,
16+ from_array_extended ,
17+ )
1218from ..helpers .torch_helper import to_numpy , from_numpy , to_tensor , torch_dtype_to_onnx_dtype
1319
1420
@@ -222,6 +228,62 @@ def select(
222228 return True
223229 return False
224230
231+ def get_replay_code (self ) -> str :
232+ """
233+ Returns a code letting the user replay the onnx model.
234+ It looks like the following.
235+
236+ .. runpython::
237+ :showcode:
238+
239+ from onnx_diagnostic.torch_onnx.sbs import ReplayConfiguration
240+
241+ rc = ReplayConfiguration(dump_folder="unsued")
242+ print(rc.get_replay_code())
243+ """
244+ return textwrap .dedent (
245+ """
246+ import onnx
247+ import torch
248+ from onnx_diagnostic.helpers import max_diff, string_diff, string_type
249+ from onnx_diagnostic.helpers.onnx_helper import pretty_onnx
250+ from onnx_diagnostic.reference import OnnxruntimeEvaluator
251+
252+ skws = dict(with_shape=True, with_device=True)
253+
254+ torch_inputs = torch.load("torch_inputs.pt")
255+ onnx_inputs = torch.load("onnx_inputs.pt")
256+ expected_outputs_and_mapping = torch.load("torch_outputs_and_mapping.pt")
257+ expected = expected_outputs_and_mapping["expected"]
258+ mapping = expected_outputs_and_mapping["mapping"]
259+
260+ print(f"-- torch_inputs={string_type(torch_inputs, **skws)}")
261+ print(f"-- onnx_inputs={string_type(onnx_inputs, **skws)}")
262+ print(f"-- expected={string_type(expected, **skws)}")
263+ print(f"-- mapping={mapping}")
264+
265+ model = onnx.load("model.onnx")
266+ print("-- model.onnx")
267+ print(pretty_onnx(model))
268+ print("--")
269+
270+ print("-- run with onnx_inputs")
271+ sess = OnnxruntimeEvaluator(model, whole=True)
272+ feeds = onnx_inputs
273+ obtained = sess.run(None, feeds)
274+ print(f"-- obtained={string_type(obtained, **skws)}")
275+ diff = max_diff(expected, tuple(obtained))
276+ print(f"-- diff: {string_diff(diff)}")
277+
278+ print("-- run with torch_inputs")
279+ feeds = {k: torch_inputs[mapping[k]] for k in feeds}
280+ obtained = sess.run(None, feeds)
281+ print(f"-- obtained={string_type(obtained, **skws)}")
282+ diff = max_diff(expected, tuple(obtained))
283+ print(f"-- diff: {string_diff(diff)}")
284+ """
285+ )
286+
225287 def dump (
226288 self ,
227289 name : str ,
@@ -278,14 +340,27 @@ def dump(
278340 os .makedirs (folder , exist_ok = True )
279341 if verbose :
280342 print (f"[ReplayConfiguration.dump] dumps into folder { folder !r} " )
281- onnx . save ( submodel , os . path . join ( folder , "model.onnx" ))
343+
282344 torch_inputs = {}
345+ removed_inputs = set ()
283346 for n in input_names :
284347 if n in onnx_name_to_ep_name :
285348 torch_inputs [n ] = torch_results [onnx_name_to_ep_name [n ]]
286349 else :
287- # It is possible that this result only exists in the onnx worlds.
288- pass
350+ # We add that input as an initializer because it is probably a constant.
351+ submodel .graph .initializer .append (from_array_extended (onnx_results [n ], name = n ))
352+ removed_inputs .add (n )
353+
354+ if removed_inputs :
355+ input_names = [i for i in input_names if i not in removed_inputs ]
356+ new_inputs = [i for i in submodel .graph .input if i .name not in removed_inputs ]
357+ del submodel .graph .input [:]
358+ submodel .graph .input .extend (new_inputs )
359+ if verbose :
360+ print (f"[ReplayConfiguration.dump] removed input { removed_inputs } " )
361+ print (f"[ReplayConfiguration.dump] final model inputs { input_names } " )
362+
363+ onnx .save (submodel , os .path .join (folder , "model.onnx" ))
289364 onnx_inputs = {n : onnx_results [n ] for n in input_names }
290365 assert (
291366 name in onnx_name_to_ep_name
@@ -301,6 +376,8 @@ def dump(
301376 torch .save (
302377 expected_outputs_and_mapping , os .path .join (folder , "torch_outputs_and_mapping.pt" )
303378 )
379+ with open (os .path .join (folder , "replay.py" ), "w" ) as f :
380+ f .write (self .get_replay_code ())
304381 if verbose :
305382 print (f"[ReplayConfiguration.dump] done { folder !r} " )
306383 return folder
@@ -363,7 +440,7 @@ def __post_init__(self):
363440 f"ep_id_node={ self .ep_id_node } "
364441 )
365442
366- def set_diff (self , diff : Dict [str , Any ]):
443+ def set_diff (self , diff : Dict [str , Any ]) -> "Self" : # noqa: F821
367444 """Sets error."""
368445 if diff is None :
369446 return
@@ -398,7 +475,7 @@ def check(
398475 Tuple [Optional [int ], Optional [int ], Optional [int ], Optional [str ], Optional [str ]],
399476 int ,
400477 ],
401- ) -> Self :
478+ ) -> " Self" : # noqa: F821
402479 "Checks a record was not already yielded."
403480 if self .onnx_op_type == "reset" :
404481 # no record for this one
0 commit comments