Skip to content

Commit 6af4693

Browse files
committed
fix a few inefficiencies in sbs
1 parent 1d7735f commit 6af4693

File tree

1 file changed

+36
-12
lines changed
  • onnx_diagnostic/torch_onnx

1 file changed

+36
-12
lines changed

onnx_diagnostic/torch_onnx/sbs.py

Lines changed: 36 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
import os
33
import time
44
from dataclasses import dataclass
5-
from typing import Any, Callable, Dict, Iterator, List, Optional, Set, Tuple, Union
5+
from typing import Any, Callable, Dict, Iterator, List, Optional, Self, Set, Tuple, Union
66
import onnx
77
import onnx.helper as oh
88
import numpy as np
@@ -371,6 +371,29 @@ def set_diff(self, diff: Dict[str, Any]):
371371
self.err_nan = diff["nan"]
372372
if "rep" in diff:
373373
self.err_h01 = diff["rep"][">0.1"]
374+
return self
375+
376+
@property
377+
def key(self) -> Tuple[int, int, int, str, str]:
378+
"Creates a unique identifier."
379+
return (
380+
self.ep_id_node,
381+
self.onnx_id_node,
382+
self.onnx_id_output,
383+
self.ep_name,
384+
self.onnx_name,
385+
)
386+
387+
def check(self, already_yielded: Dict[Tuple[int, int, int, str, str], int]) -> Self:
388+
"Checks a record was not already yielded."
389+
key = self.key
390+
assert key not in already_yielded, (
391+
f"Record with key={key} was already yielded, "
392+
f"number of records={len(already_yielded)} and previous "
393+
f"record at position {already_yielded[key]} (self={self})"
394+
)
395+
already_yielded[key] = len(already_yielded)
396+
return self
374397

375398

376399
@dataclass
@@ -451,8 +474,6 @@ def run_aligned(
451474
for the onnx runtime
452475
:param atol: absolute tolerance
453476
:param rtol: relative tolerance
454-
:param gemmlinear: if True, replaces ``Gemm(A,X.T,B)`` by
455-
``torch.nn.functional.linear(A,X,B)`` on onnx side
456477
:param verbose: verbosity level
457478
:param exc: stops if an exception
458479
:param reset_names: list of names, the onnx execution takes the torch outputs instead
@@ -595,6 +616,7 @@ def forward(self, x):
595616
-v 1 --atol=0.1 --rtol=1
596617
"""
597618
assert callable(run_cls), f"run_cls={run_cls} not a callable"
619+
already_yielded = {}
598620
reset_names = set(reset_names) if reset_names else set() # type: ignore[assignment]
599621
str_kws = dict(with_shape=True, with_device=True)
600622
has_cuda = any(
@@ -774,7 +796,7 @@ def _loop_onnx_node(
774796
list_node_output = list(node.output)
775797
node_output = [o for o in list_node_output if o]
776798
for o, r in zip(node_output, res):
777-
if r is None or o is None:
799+
if r is None or not o:
778800
continue
779801
tmp = _loop_cmp(
780802
mapping_onnx_to_torch,
@@ -1033,7 +1055,7 @@ def _duplicated_values(d):
10331055
onnx_name=init.name,
10341056
onnx_op_type="initializer",
10351057
onnx_shape_type=string_type(t, **str_kws),
1036-
)
1058+
).check(already_yielded)
10371059

10381060
size = t.element_size() * t.numel()
10391061
if t.is_cuda:
@@ -1115,7 +1137,7 @@ def _duplicated_values(d):
11151137
onnx_results[torch_names_to_onnx_names[node.name]], **str_kws
11161138
),
11171139
)
1118-
yield record
1140+
yield record.check(already_yielded)
11191141
else:
11201142
assert node.name in placeholders_to_state_dict, (
11211143
f"Unable to find placeholder {node.name!r} (node.op={node.op!r}), "
@@ -1155,7 +1177,7 @@ def _duplicated_values(d):
11551177
hist=[0.1],
11561178
)
11571179
)
1158-
yield record
1180+
yield record.check(already_yielded)
11591181
else:
11601182
if verbose > 1:
11611183
print(
@@ -1166,7 +1188,7 @@ def _duplicated_values(d):
11661188
ep_name=node.name,
11671189
ep_target="placeholder",
11681190
ep_shape_type=string_type(t, **str_kws),
1169-
)
1191+
).check(already_yielded)
11701192
continue
11711193

11721194
outputs = [node.name] if isinstance(node.name, str) else list(node.name)
@@ -1197,6 +1219,8 @@ def _duplicated_values(d):
11971219
continue
11981220

11991221
for i_onnx in range(last_position, max_pos + 1):
1222+
if i_onnx in already_run:
1223+
continue
12001224
for r in _loop_onnx_node(
12011225
onx,
12021226
ep_graph_nodes,
@@ -1216,7 +1240,7 @@ def _duplicated_values(d):
12161240
verbose,
12171241
):
12181242
if r:
1219-
yield r
1243+
yield r.check(already_yielded)
12201244

12211245
last_position = max_pos + 1
12221246

@@ -1227,6 +1251,8 @@ def _duplicated_values(d):
12271251
f"to {len(onx.graph.node)}"
12281252
)
12291253
for i_onnx in range(last_position, len(onx.graph.node)):
1254+
if i_onnx in already_run:
1255+
continue
12301256
for r in _loop_onnx_node(
12311257
onx,
12321258
ep_graph_nodes,
@@ -1246,9 +1272,7 @@ def _duplicated_values(d):
12461272
verbose,
12471273
):
12481274
if r:
1249-
yield r
1250-
1251-
already_run.add(i_onnx)
1275+
yield r.check(already_yielded)
12521276

12531277
if verbose:
12541278
print(f"[run_aligned] done with status={status.to_str()}")

0 commit comments

Comments
 (0)