Skip to content

Commit 1ab2efb

Browse files
committed
changes
1 parent 81bd664 commit 1ab2efb

File tree

1 file changed

+40
-9
lines changed
  • onnx_diagnostic/torch_onnx

1 file changed

+40
-9
lines changed

onnx_diagnostic/torch_onnx/sbs.py

Lines changed: 40 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
from typing import Self
1010
except ImportError:
1111
# python <= 3.10
12-
Self = "Self"
12+
Self = "Self" # type: ignore[assignment]
1313
import onnx
1414
import onnx.helper as oh
1515
import numpy as np
@@ -848,7 +848,7 @@ def _loop_cmp(
848848
onnx_shape_type=string_type(r, **str_kws),
849849
)
850850
r.set_diff(d)
851-
mapping_onnx_to_torch[to] = o
851+
mapping_onnx_to_torch[o] = to
852852
return r
853853
return None
854854

@@ -1035,7 +1035,8 @@ def _duplicated_values(d):
10351035

10361036
if verbose:
10371037
print(f"[run_aligned] ep: walks through {len(ep.graph.nodes)} nodes from torch")
1038-
positions: Dict[str, Any] = {}
1038+
# dictionary mapping result names and their position in both graphs.
1039+
positions: Dict[str, Dict[str, int]] = {}
10391040
ep_graph_nodes = list(ep.graph.nodes)
10401041
torch_results: Dict[str, Any] = {}
10411042
last_position = 0
@@ -1087,6 +1088,18 @@ def _duplicated_values(d):
10871088
print(f"[run_aligned] ep: found inputs {torch_input_names}")
10881089
print(f"[run_aligned] ep: found outputs {torch_output_names}")
10891090
print(f"[run_aligned] nx: walks through {len(onx.graph.node)} nodes from onnx")
1091+
for inp in onx.graph.input:
1092+
n = inp.name
1093+
if n in positions:
1094+
positions[n]["onnx"] = -1
1095+
else:
1096+
positions[n] = dict(onnx=-1)
1097+
for inp in onx.graph.initializer:
1098+
n = inp.name
1099+
if n in positions:
1100+
positions[n]["onnx"] = -1
1101+
else:
1102+
positions[n] = dict(onnx=-1)
10901103
for i, node in enumerate(onx.graph.node):
10911104
for n in node.output:
10921105
if n in positions:
@@ -1156,7 +1169,6 @@ def _duplicated_values(d):
11561169
memory_cpu = 0
11571170
memory_cuda = 0
11581171
for init in onx.graph.initializer: # type: ignore
1159-
positions[init.name] = -1
11601172
t = None
11611173
if init.name in torch_results:
11621174
if init.name not in skip_mapping_torch_onnx:
@@ -1215,15 +1227,15 @@ def _duplicated_values(d):
12151227
print(f"[run_aligned-nx] +ini: {k}: {string_type(v, **str_kws)}")
12161228

12171229
# starts the side-by-side
1230+
if verbose:
1231+
print(f"[run_aligned] ep: starts side-by-side with {len(ep_graph_nodes)} nodes")
12181232
if verbose == 1:
12191233
import tqdm
12201234

12211235
loop = tqdm.tqdm(list(enumerate(ep_graph_nodes)))
12221236
else:
12231237
loop = list(enumerate(ep_graph_nodes))
12241238

1225-
if verbose:
1226-
print(f"[run_aligned] ep: starts side-by-side with {len(ep_graph_nodes)} nodes")
12271239
already_run: Set[int] = set()
12281240
ep_durations = {}
12291241
status = StatusRunAligned()
@@ -1345,15 +1357,33 @@ def _duplicated_values(d):
13451357

13461358
max_pos = -2
13471359
for n in outputs:
1348-
if n in positions and "onnx" in positions[n]:
1349-
max_pos = max(max_pos, positions[n]["onnx"])
1360+
if n in positions:
1361+
if "onnx" in positions[n]:
1362+
max_pos = max(max_pos, positions[n]["onnx"])
1363+
if "fx" in positions[n]:
1364+
if positions[n]["fx"] > i:
1365+
max_pos = -2
1366+
break
13501367
if max_pos == -2:
13511368
# we skip.
13521369
continue
13531370

1371+
next_to_visit = last_position
13541372
for i_onnx in range(last_position, max_pos + 1):
13551373
if i_onnx in already_run:
13561374
continue
1375+
# The onnx node may produce more than one output, in that
1376+
# case, we need to check the exported program is not behind.
1377+
node = onx.graph.node[i_onnx]
1378+
ep_behind = False
1379+
for iname in node.output:
1380+
if iname in positions and "fx" in positions[iname]:
1381+
if positions[iname]["fx"] > i:
1382+
ep_behind = True
1383+
break
1384+
if ep_behind:
1385+
break
1386+
13571387
for r in _loop_onnx_node(
13581388
onx,
13591389
ep_graph_nodes,
@@ -1374,8 +1404,9 @@ def _duplicated_values(d):
13741404
):
13751405
if r:
13761406
yield r.check(already_yielded)
1407+
next_to_visit = i_onnx + 1
13771408

1378-
last_position = max_pos + 1
1409+
last_position = next_to_visit
13791410

13801411
# complete the execution of the onnx graph
13811412
if verbose:

0 commit comments

Comments
 (0)