99 from typing import Self
1010except ImportError :
1111 # python <= 3.10
12- Self = "Self"
12+ Self = "Self" # type: ignore[assignment]
1313import onnx
1414import onnx .helper as oh
1515import numpy as np
@@ -848,7 +848,7 @@ def _loop_cmp(
848848 onnx_shape_type = string_type (r , ** str_kws ),
849849 )
850850 r .set_diff (d )
851- mapping_onnx_to_torch [to ] = o
851+ mapping_onnx_to_torch [o ] = to
852852 return r
853853 return None
854854
@@ -1035,7 +1035,8 @@ def _duplicated_values(d):
10351035
10361036 if verbose :
10371037 print (f"[run_aligned] ep: walks through { len (ep .graph .nodes )} nodes from torch" )
1038- positions : Dict [str , Any ] = {}
1038+ # dictionary mapping result names and their position in both graphs.
1039+ positions : Dict [str , Dict [str , int ]] = {}
10391040 ep_graph_nodes = list (ep .graph .nodes )
10401041 torch_results : Dict [str , Any ] = {}
10411042 last_position = 0
@@ -1087,6 +1088,18 @@ def _duplicated_values(d):
10871088 print (f"[run_aligned] ep: found inputs { torch_input_names } " )
10881089 print (f"[run_aligned] ep: found outputs { torch_output_names } " )
10891090 print (f"[run_aligned] nx: walks through { len (onx .graph .node )} nodes from onnx" )
1091+ for inp in onx .graph .input :
1092+ n = inp .name
1093+ if n in positions :
1094+ positions [n ]["onnx" ] = - 1
1095+ else :
1096+ positions [n ] = dict (onnx = - 1 )
1097+ for inp in onx .graph .initializer :
1098+ n = inp .name
1099+ if n in positions :
1100+ positions [n ]["onnx" ] = - 1
1101+ else :
1102+ positions [n ] = dict (onnx = - 1 )
10901103 for i , node in enumerate (onx .graph .node ):
10911104 for n in node .output :
10921105 if n in positions :
@@ -1156,7 +1169,6 @@ def _duplicated_values(d):
11561169 memory_cpu = 0
11571170 memory_cuda = 0
11581171 for init in onx .graph .initializer : # type: ignore
1159- positions [init .name ] = - 1
11601172 t = None
11611173 if init .name in torch_results :
11621174 if init .name not in skip_mapping_torch_onnx :
@@ -1215,15 +1227,15 @@ def _duplicated_values(d):
12151227 print (f"[run_aligned-nx] +ini: { k } : { string_type (v , ** str_kws )} " )
12161228
12171229 # starts the side-by-side
1230+ if verbose :
1231+ print (f"[run_aligned] ep: starts side-by-side with { len (ep_graph_nodes )} nodes" )
12181232 if verbose == 1 :
12191233 import tqdm
12201234
12211235 loop = tqdm .tqdm (list (enumerate (ep_graph_nodes )))
12221236 else :
12231237 loop = list (enumerate (ep_graph_nodes ))
12241238
1225- if verbose :
1226- print (f"[run_aligned] ep: starts side-by-side with { len (ep_graph_nodes )} nodes" )
12271239 already_run : Set [int ] = set ()
12281240 ep_durations = {}
12291241 status = StatusRunAligned ()
@@ -1345,15 +1357,33 @@ def _duplicated_values(d):
13451357
13461358 max_pos = - 2
13471359 for n in outputs :
1348- if n in positions and "onnx" in positions [n ]:
1349- max_pos = max (max_pos , positions [n ]["onnx" ])
1360+ if n in positions :
1361+ if "onnx" in positions [n ]:
1362+ max_pos = max (max_pos , positions [n ]["onnx" ])
1363+ if "fx" in positions [n ]:
1364+ if positions [n ]["fx" ] > i :
1365+ max_pos = - 2
1366+ break
13501367 if max_pos == - 2 :
13511368 # we skip.
13521369 continue
13531370
1371+ next_to_visit = last_position
13541372 for i_onnx in range (last_position , max_pos + 1 ):
13551373 if i_onnx in already_run :
13561374 continue
1375+ # The onnx node may produce more than one output, in that
1376+ # case, we need to check the exported program is not behind.
1377+ node = onx .graph .node [i_onnx ]
1378+ ep_behind = False
1379+ for iname in node .output :
1380+ if iname in positions and "fx" in positions [iname ]:
1381+ if positions [iname ]["fx" ] > i :
1382+ ep_behind = True
1383+ break
1384+ if ep_behind :
1385+ break
1386+
13571387 for r in _loop_onnx_node (
13581388 onx ,
13591389 ep_graph_nodes ,
@@ -1374,8 +1404,9 @@ def _duplicated_values(d):
13741404 ):
13751405 if r :
13761406 yield r .check (already_yielded )
1407+ next_to_visit = i_onnx + 1
13771408
1378- last_position = max_pos + 1
1409+ last_position = next_to_visit
13791410
13801411 # complete the execution of the onnx graph
13811412 if verbose :
0 commit comments