@@ -174,6 +174,37 @@ def prepare_args_kwargs(
174174
175175@dataclass
176176class RunAlignedRecord :
177+ """
178+ The side-by-side ran by function :func:`run_aligned
179+ <onnx_diagnostic.torch_onnx.sbs.run_aligned>`
180+ yields instances of this type. If both `ep_name`
181+ and `onnx_name` are specified, then both results
182+ appear in the exported program (torch) and the onnx model.
183+
184+ :param ep_id_node: node index in the exported program
185+ :param onnx_id_node: node index in the onnx model, -1 for an initializer
186+ :param ep_name: result name in the exported program
187+ :param onnx_name: result name in the onnx model, usually same as `ep_name`
188+ except for initializer
189+ :param ep_target: target name in the exported program producing the result
190+ :param onnx_op_type: operator type in the onnx model producing the result
191+ :param onnx_id_output: usually 0 unless this node has multiple output,
192+ in that case, it is the output index
193+ :param ep_shape_type: shape and type of the results in the exported program
194+ :param onnx_shape_type: shape and type of the results in the onnx mode,
195+ it should be the same as `ep_shape_type`, anything different probably
196+ means a bug
197+ :param err_abs: maximum absolute error for the considered result
198+ between the exported program and the onnx model
199+ :param err_rel: maximum relative error
200+ :param err_dev: 0 if the device is the same, 1 if not
201+ :param err_nan: number of nan values disagreeing
202+ :param err_h01: number of values for which the discrepancy is above 0.1
203+ :param ep_time_run: execution time for the exported program
204+ :param onnx_time_run: execution time for the onnx model, that includes
205+ the creation of the onnx model so that's probably not very usable
206+ """
207+
177208 ep_id_node : Optional [int ] = None
178209 onnx_id_node : Optional [int ] = None
179210 ep_name : Optional [str ] = None
@@ -208,7 +239,14 @@ def set_diff(self, diff: Dict[str, Any]):
208239
209240@dataclass
210241class StatusRunAligned :
211- "Information to display while running the side-by-side"
242+ """
243+ Information to display while running the side-by-side
244+
245+ :param max_abs: maximum absolute seen so far
246+ :param n_inf: number of infinite values seen so far
247+ :param n_nan: number of nan values seen so for
248+ :param yielded_nodes: number of yielded pair of nodes seen so far
249+ """
212250
213251 max_abs : float = 0.0
214252 n_inf : int = 0
@@ -223,6 +261,7 @@ def to_str(self) -> str:
223261 )
224262
225263 def update (self , err_abs : float ):
264+ "Updates all attributes with the latest measure."
226265 if np .isinf (err_abs ) or np .isnan (err_abs ):
227266 self .n_inf += 1
228267 elif err_abs > 1e6 :
@@ -253,6 +292,7 @@ def run_aligned(
253292 gemmlinear : bool = False ,
254293 verbose : int = 0 ,
255294 exc : bool = True ,
295+ reset_names : Optional [List [str ]] = None ,
256296) -> Iterator [RunAlignedRecord ]:
257297 """
258298 Runs in parallel both the exported program
@@ -274,6 +314,8 @@ def run_aligned(
274314 ``torch.nn.functional.linear(A,X,B)`` on onnx side
275315 :param verbose: verbosity level
276316 :param exc: stops if an exception
317+ :param reset_names: list of names, the onnx execution takes the torch outputs instead
318+ of its own result if the names falls into that set
277319 :return: a list of :class:`RunAlignedRecord`
278320
279321 Example:
@@ -408,6 +450,7 @@ def forward(self, x):
408450 -v 1 --atol=0.1 --rtol=1
409451 """
410452 assert callable (run_cls ), f"run_cls={ run_cls } not a callable"
453+ reset_names = set (reset_names ) if reset_names else set () # type: ignore[assignment]
411454 str_kws = dict (with_shape = True , with_device = True )
412455 has_cuda = any (
413456 (isinstance (t , torch .Tensor ) and t .is_cuda )
@@ -618,6 +661,31 @@ def _loop_onnx_node(
618661 if tmp .err_abs is not None :
619662 status .update (tmp .err_abs )
620663 yield tmp
664+ if reset_names and tmp .ep_name in reset_names :
665+ assert (
666+ tmp .ep_name in torch_results
667+ ), f"name { tmp .ep_name !r} set to be reset is missing in torch_results."
668+ assert (
669+ tmp .onnx_name in onnx_results
670+ ), f"name { tmp .onnx_name !r} set to be reset is missing in onnx_results."
671+ onnx_results [tmp .onnx_name ] = torch_results [tmp .ep_name ]
672+ tmp = _loop_cmp (
673+ mapping_onnx_to_torch ,
674+ torch_results ,
675+ onnx_results ,
676+ o ,
677+ r ,
678+ verbose ,
679+ atol ,
680+ rtol ,
681+ i ,
682+ i_onnx ,
683+ )
684+ if tmp is not None :
685+ tmp .onnx_op_type = "reset"
686+ tmp .onnx_id_output = list_node_output .index (o )
687+ status .yielded_nodes += 1
688+ yield tmp
621689 already_run .add (i_onnx )
622690
623691 def _duplicated_values (d ):
@@ -799,13 +867,13 @@ def _gemm_linear(node, feeds, sess):
799867 t = torch_results [init .name ]
800868 torch_names_to_onnx_names [init .name ] = init .name
801869 elif init .name not in skip_onnx_name and init .name in rev_init_aliases :
802- new_names = [
870+ new_names = [ # type: ignore[assignment]
803871 k
804872 for k in rev_init_aliases [init .name ]
805873 if k in torch_results and k not in skip_mapping_torch_onnx
806874 ]
807875 if new_names and len (new_names ) == 1 :
808- new_name = new_names [0 ]
876+ new_name = new_names [0 ] # type: ignore[assignment, index]
809877 t = torch_results [new_name ]
810878 if (
811879 t .shape == tuple (init .dims )
0 commit comments