Skip to content

Commit 4ca6c9d

Browse files
authored
Implements option reset_names in side-by-side (#317)
* implements option reset_names in side-by-side * fix * fix * fix
1 parent 2419114 commit 4ca6c9d

File tree

4 files changed

+133
-4
lines changed

4 files changed

+133
-4
lines changed

CHANGELOGS.rst

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@ Change Logs
88
* :pr:`311`: use custom and local function to use PackedMultiHeadAttention from onnxruntime
99
* :pr:`310`: splits patches into multiple files
1010
* :pr:`308`: add option --save_ep to dump the exported program as well as torch input
11-
* :pr:`304`, :pr:`306`, :pr:`316`: improves side-by-side comparison, creates command line sbs
11+
* :pr:`304`, :pr:`306`, :pr:`316`, :pr:`317`: improves side-by-side comparison, creates command line sbs
1212

1313
0.8.2
1414
+++++

_unittests/ut_torch_onnx/test_sbs.py

Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -464,6 +464,54 @@ def forward(self, x):
464464
)
465465
self.assertEqual(len(results), 5)
466466

467+
@hide_stdout()
468+
@ignore_warnings((DeprecationWarning, FutureWarning, UserWarning))
469+
def test_sbs_model_with_weights_custom_reset(self):
470+
torch = self.torch
471+
472+
class Model(self.torch.nn.Module):
473+
def __init__(self):
474+
super(Model, self).__init__()
475+
self.fc1 = torch.nn.Linear(10, 3200) # input size 10 → hidden size 32
476+
self.relu = torch.nn.ReLU()
477+
self.fc2 = torch.nn.Linear(3200, 1) # hidden → output
478+
with torch.no_grad():
479+
self.fc2.bias += 1999
480+
self.fc1.bias += 999
481+
482+
def forward(self, x):
483+
x = self.relu(self.fc1(x))
484+
x = self.fc2(x)
485+
return x
486+
487+
inputs = dict(x=self.torch.randn((5, 10), dtype=torch.float16))
488+
ds = dict(x={0: "batch"})
489+
model = Model()
490+
model = model.to(torch.float16)
491+
model(**inputs)
492+
ep = self.torch.export.export(
493+
model, (), kwargs=inputs, dynamic_shapes=use_dyn_not_str(ds)
494+
)
495+
filename = self.get_dump_file("test_sbs_model_with_weights_custom_reset.onnx")
496+
to_onnx(ep, exporter="custom", filename=filename)
497+
onx = onnx.load(filename)
498+
results = list(
499+
run_aligned(
500+
ep,
501+
onx,
502+
kwargs=inputs,
503+
run_cls=OnnxruntimeEvaluator,
504+
verbose=11,
505+
use_tensor=True,
506+
reset_names=["linear"],
507+
),
508+
)
509+
df = pandas.DataFrame(list(results))
510+
df.to_excel(self.get_dump_file("test_sbs_model_with_weights_custom_reset.xlsx"))
511+
onnx_op_type = df["onnx_op_type"].tolist()
512+
self.assertEqual(onnx_op_type.count("reset"), 1)
513+
self.clean_dump()
514+
467515

468516
if __name__ == "__main__":
469517
unittest.main(verbosity=2)

onnx_diagnostic/_command_lines_parser.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1209,6 +1209,18 @@ def get_parser_sbs() -> ArgumentParser:
12091209
default=False,
12101210
help="First runs the whole model.",
12111211
)
1212+
parser.add_argument(
1213+
"--reset",
1214+
required=False,
1215+
default="",
1216+
help=textwrap.dedent(
1217+
"""
1218+
List of result names separated by a comma. For those results,
1219+
the side-by-side will take torch results instead of onnx results
1220+
to compute the rest of the onnx model.
1221+
"""
1222+
),
1223+
)
12121224
parser.add_argument(
12131225
"--gemmlinear",
12141226
action=BooleanOptionalAction,
@@ -1308,6 +1320,7 @@ def _size(name):
13081320
kwargs=mkwargs,
13091321
use_tensor=True,
13101322
gemmlinear=args.gemmlinear,
1323+
reset_names=args.reset.split(","),
13111324
exc=False,
13121325
):
13131326
data.append(obs)

onnx_diagnostic/torch_onnx/sbs.py

Lines changed: 71 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -174,6 +174,37 @@ def prepare_args_kwargs(
174174

175175
@dataclass
176176
class RunAlignedRecord:
177+
"""
178+
The side-by-side ran by function :func:`run_aligned
179+
<onnx_diagnostic.torch_onnx.sbs.run_aligned>`
180+
yields instances of this type. If both `ep_name`
181+
and `onnx_name` are specified, then both results
182+
appear in the exported program (torch) and the onnx model.
183+
184+
:param ep_id_node: node index in the exported program
185+
:param onnx_id_node: node index in the onnx model, -1 for an initializer
186+
:param ep_name: result name in the exported program
187+
:param onnx_name: result name in the onnx model, usually same as `ep_name`
188+
except for initializer
189+
:param ep_target: target name in the exported program producing the result
190+
:param onnx_op_type: operator type in the onnx model producing the result
191+
:param onnx_id_output: usually 0 unless this node has multiple output,
192+
in that case, it is the output index
193+
:param ep_shape_type: shape and type of the results in the exported program
194+
:param onnx_shape_type: shape and type of the results in the onnx mode,
195+
it should be the same as `ep_shape_type`, anything different probably
196+
means a bug
197+
:param err_abs: maximum absolute error for the considered result
198+
between the exported program and the onnx model
199+
:param err_rel: maximum relative error
200+
:param err_dev: 0 if the device is the same, 1 if not
201+
:param err_nan: number of nan values disagreeing
202+
:param err_h01: number of values for which the discrepancy is above 0.1
203+
:param ep_time_run: execution time for the exported program
204+
:param onnx_time_run: execution time for the onnx model, that includes
205+
the creation of the onnx model so that's probably not very usable
206+
"""
207+
177208
ep_id_node: Optional[int] = None
178209
onnx_id_node: Optional[int] = None
179210
ep_name: Optional[str] = None
@@ -208,7 +239,14 @@ def set_diff(self, diff: Dict[str, Any]):
208239

209240
@dataclass
210241
class StatusRunAligned:
211-
"Information to display while running the side-by-side"
242+
"""
243+
Information to display while running the side-by-side
244+
245+
:param max_abs: maximum absolute seen so far
246+
:param n_inf: number of infinite values seen so far
247+
:param n_nan: number of nan values seen so for
248+
:param yielded_nodes: number of yielded pair of nodes seen so far
249+
"""
212250

213251
max_abs: float = 0.0
214252
n_inf: int = 0
@@ -223,6 +261,7 @@ def to_str(self) -> str:
223261
)
224262

225263
def update(self, err_abs: float):
264+
"Updates all attributes with the latest measure."
226265
if np.isinf(err_abs) or np.isnan(err_abs):
227266
self.n_inf += 1
228267
elif err_abs > 1e6:
@@ -253,6 +292,7 @@ def run_aligned(
253292
gemmlinear: bool = False,
254293
verbose: int = 0,
255294
exc: bool = True,
295+
reset_names: Optional[List[str]] = None,
256296
) -> Iterator[RunAlignedRecord]:
257297
"""
258298
Runs in parallel both the exported program
@@ -274,6 +314,8 @@ def run_aligned(
274314
``torch.nn.functional.linear(A,X,B)`` on onnx side
275315
:param verbose: verbosity level
276316
:param exc: stops if an exception
317+
:param reset_names: list of names, the onnx execution takes the torch outputs instead
318+
of its own result if the names falls into that set
277319
:return: a list of :class:`RunAlignedRecord`
278320
279321
Example:
@@ -408,6 +450,7 @@ def forward(self, x):
408450
-v 1 --atol=0.1 --rtol=1
409451
"""
410452
assert callable(run_cls), f"run_cls={run_cls} not a callable"
453+
reset_names = set(reset_names) if reset_names else set() # type: ignore[assignment]
411454
str_kws = dict(with_shape=True, with_device=True)
412455
has_cuda = any(
413456
(isinstance(t, torch.Tensor) and t.is_cuda)
@@ -618,6 +661,31 @@ def _loop_onnx_node(
618661
if tmp.err_abs is not None:
619662
status.update(tmp.err_abs)
620663
yield tmp
664+
if reset_names and tmp.ep_name in reset_names:
665+
assert (
666+
tmp.ep_name in torch_results
667+
), f"name {tmp.ep_name!r} set to be reset is missing in torch_results."
668+
assert (
669+
tmp.onnx_name in onnx_results
670+
), f"name {tmp.onnx_name!r} set to be reset is missing in onnx_results."
671+
onnx_results[tmp.onnx_name] = torch_results[tmp.ep_name]
672+
tmp = _loop_cmp(
673+
mapping_onnx_to_torch,
674+
torch_results,
675+
onnx_results,
676+
o,
677+
r,
678+
verbose,
679+
atol,
680+
rtol,
681+
i,
682+
i_onnx,
683+
)
684+
if tmp is not None:
685+
tmp.onnx_op_type = "reset"
686+
tmp.onnx_id_output = list_node_output.index(o)
687+
status.yielded_nodes += 1
688+
yield tmp
621689
already_run.add(i_onnx)
622690

623691
def _duplicated_values(d):
@@ -799,13 +867,13 @@ def _gemm_linear(node, feeds, sess):
799867
t = torch_results[init.name]
800868
torch_names_to_onnx_names[init.name] = init.name
801869
elif init.name not in skip_onnx_name and init.name in rev_init_aliases:
802-
new_names = [
870+
new_names = [ # type: ignore[assignment]
803871
k
804872
for k in rev_init_aliases[init.name]
805873
if k in torch_results and k not in skip_mapping_torch_onnx
806874
]
807875
if new_names and len(new_names) == 1:
808-
new_name = new_names[0]
876+
new_name = new_names[0] # type: ignore[assignment, index]
809877
t = torch_results[new_name]
810878
if (
811879
t.shape == tuple(init.dims)

0 commit comments

Comments
 (0)