22import os
33import time
44from dataclasses import dataclass
5- from typing import Any , Callable , Dict , Iterator , List , Optional , Set , Tuple , Union
5+ from typing import Any , Callable , Dict , Iterator , List , Optional , Self , Set , Tuple , Union
66import onnx
77import onnx .helper as oh
88import numpy as np
@@ -371,6 +371,29 @@ def set_diff(self, diff: Dict[str, Any]):
371371 self .err_nan = diff ["nan" ]
372372 if "rep" in diff :
373373 self .err_h01 = diff ["rep" ][">0.1" ]
374+ return self
375+
376+ @property
377+ def key (self ) -> Tuple [int , int , int , str , str ]:
378+ "Creates a unique identifier."
379+ return (
380+ self .ep_id_node ,
381+ self .onnx_id_node ,
382+ self .onnx_id_output ,
383+ self .ep_name ,
384+ self .onnx_name ,
385+ )
386+
387+ def check (self , already_yielded : Dict [Tuple [int , int , int , str , str ], int ]) -> Self :
388+ "Checks a record was not already yielded."
389+ key = self .key
390+ assert key not in already_yielded , (
391+ f"Record with key={ key } was already yielded, "
392+ f"number of records={ len (already_yielded )} and previous "
393+ f"record at position { already_yielded [key ]} (self={ self } )"
394+ )
395+ already_yielded [key ] = len (already_yielded )
396+ return self
374397
375398
376399@dataclass
@@ -451,8 +474,6 @@ def run_aligned(
451474 for the onnx runtime
452475 :param atol: absolute tolerance
453476 :param rtol: relative tolerance
454- :param gemmlinear: if True, replaces ``Gemm(A,X.T,B)`` by
455- ``torch.nn.functional.linear(A,X,B)`` on onnx side
456477 :param verbose: verbosity level
457478 :param exc: stops if an exception
458479 :param reset_names: list of names, the onnx execution takes the torch outputs instead
@@ -595,6 +616,7 @@ def forward(self, x):
595616 -v 1 --atol=0.1 --rtol=1
596617 """
597618 assert callable (run_cls ), f"run_cls={ run_cls } not a callable"
619+ already_yielded = {}
598620 reset_names = set (reset_names ) if reset_names else set () # type: ignore[assignment]
599621 str_kws = dict (with_shape = True , with_device = True )
600622 has_cuda = any (
@@ -774,7 +796,7 @@ def _loop_onnx_node(
774796 list_node_output = list (node .output )
775797 node_output = [o for o in list_node_output if o ]
776798 for o , r in zip (node_output , res ):
777- if r is None or o is None :
799+ if r is None or not o :
778800 continue
779801 tmp = _loop_cmp (
780802 mapping_onnx_to_torch ,
@@ -1033,7 +1055,7 @@ def _duplicated_values(d):
10331055 onnx_name = init .name ,
10341056 onnx_op_type = "initializer" ,
10351057 onnx_shape_type = string_type (t , ** str_kws ),
1036- )
1058+ ). check ( already_yielded )
10371059
10381060 size = t .element_size () * t .numel ()
10391061 if t .is_cuda :
@@ -1115,7 +1137,7 @@ def _duplicated_values(d):
11151137 onnx_results [torch_names_to_onnx_names [node .name ]], ** str_kws
11161138 ),
11171139 )
1118- yield record
1140+ yield record . check ( already_yielded )
11191141 else :
11201142 assert node .name in placeholders_to_state_dict , (
11211143 f"Unable to find placeholder { node .name !r} (node.op={ node .op !r} ), "
@@ -1155,7 +1177,7 @@ def _duplicated_values(d):
11551177 hist = [0.1 ],
11561178 )
11571179 )
1158- yield record
1180+ yield record . check ( already_yielded )
11591181 else :
11601182 if verbose > 1 :
11611183 print (
@@ -1166,7 +1188,7 @@ def _duplicated_values(d):
11661188 ep_name = node .name ,
11671189 ep_target = "placeholder" ,
11681190 ep_shape_type = string_type (t , ** str_kws ),
1169- )
1191+ ). check ( already_yielded )
11701192 continue
11711193
11721194 outputs = [node .name ] if isinstance (node .name , str ) else list (node .name )
@@ -1197,6 +1219,8 @@ def _duplicated_values(d):
11971219 continue
11981220
11991221 for i_onnx in range (last_position , max_pos + 1 ):
1222+ if i_onnx in already_run :
1223+ continue
12001224 for r in _loop_onnx_node (
12011225 onx ,
12021226 ep_graph_nodes ,
@@ -1216,7 +1240,7 @@ def _duplicated_values(d):
12161240 verbose ,
12171241 ):
12181242 if r :
1219- yield r
1243+ yield r . check ( already_yielded )
12201244
12211245 last_position = max_pos + 1
12221246
@@ -1227,6 +1251,8 @@ def _duplicated_values(d):
12271251 f"to { len (onx .graph .node )} "
12281252 )
12291253 for i_onnx in range (last_position , len (onx .graph .node )):
1254+ if i_onnx in already_run :
1255+ continue
12301256 for r in _loop_onnx_node (
12311257 onx ,
12321258 ep_graph_nodes ,
@@ -1246,9 +1272,7 @@ def _duplicated_values(d):
12461272 verbose ,
12471273 ):
12481274 if r :
1249- yield r
1250-
1251- already_run .add (i_onnx )
1275+ yield r .check (already_yielded )
12521276
12531277 if verbose :
12541278 print (f"[run_aligned] done with status={ status .to_str ()} " )
0 commit comments