@@ -214,13 +214,11 @@ def select(
214214 :param err_abs: measured discrepancy
215215 :return: True if this should be dumped
216216 """
217- if name and self .selected_names :
218- if name in self .selected_names :
219- return True
220- if op_type and self .selected_op_types :
221- if op_type in self .selected_op_types :
222- return True
223- if err_abs is not None and err_abs >= self .threshold :
217+ if name and self .selected_names and name in self .selected_names :
218+ return True
219+ if op_type and self .selected_op_types and op_type in self .selected_op_types :
220+ return True
221+ if err_abs is not None and self .threshold is not None and err_abs >= self .threshold :
224222 return True
225223 return False
226224
@@ -286,15 +284,23 @@ def dump(
286284 if n in onnx_name_to_ep_name :
287285 torch_inputs [n ] = torch_results [onnx_name_to_ep_name [n ]]
288286 else :
289- raise AssertionError (f"n={ n !r} , onnx_name_to_ep_name={ onnx_name_to_ep_name } " )
287+ # It is possible that this result only exists in the onnx worlds.
288+ pass
290289 onnx_inputs = {n : onnx_results [n ] for n in input_names }
291290 assert (
292291 name in onnx_name_to_ep_name
293292 ), f"Unable to find { name !r} in { onnx_name_to_ep_name } "
294- expected_outputs = (torch_results [onnx_name_to_ep_name [name ]],)
293+ expected_outputs_and_mapping = dict (
294+ expected = (torch_results [onnx_name_to_ep_name [name ]],),
295+ mapping = {
296+ k : onnx_name_to_ep_name [k ] for k in input_names if k in onnx_name_to_ep_name
297+ },
298+ )
295299 torch .save (torch_inputs , os .path .join (folder , "torch_inputs.pt" ))
296300 torch .save (onnx_inputs , os .path .join (folder , "onnx_inputs.pt" ))
297- torch .save (expected_outputs , os .path .join (folder , "torch_outputs.pt" ))
301+ torch .save (
302+ expected_outputs_and_mapping , os .path .join (folder , "torch_outputs_and_mapping.pt" )
303+ )
298304 if verbose :
299305 print (f"[ReplayConfiguration.dump] done { folder !r} " )
300306 return folder
@@ -374,7 +380,9 @@ def set_diff(self, diff: Dict[str, Any]):
374380 return self
375381
376382 @property
377- def key (self ) -> Tuple [int , int , int , str , str ]:
383+ def key (
384+ self ,
385+ ) -> Tuple [Optional [int ], Optional [int ], Optional [int ], Optional [str ], Optional [str ]]:
378386 "Creates a unique identifier."
379387 return (
380388 self .ep_id_node ,
@@ -384,8 +392,17 @@ def key(self) -> Tuple[int, int, int, str, str]:
384392 self .onnx_name ,
385393 )
386394
387- def check (self , already_yielded : Dict [Tuple [int , int , int , str , str ], int ]) -> Self :
395+ def check (
396+ self ,
397+ already_yielded : Dict [
398+ Tuple [Optional [int ], Optional [int ], Optional [int ], Optional [str ], Optional [str ]],
399+ int ,
400+ ],
401+ ) -> Self :
388402 "Checks a record was not already yielded."
403+ if self .onnx_op_type == "reset" :
404+ # no record for this one
405+ return self
389406 key = self .key
390407 assert key not in already_yielded , (
391408 f"Record with key={ key } was already yielded, "
@@ -616,7 +633,7 @@ def forward(self, x):
616633 -v 1 --atol=0.1 --rtol=1
617634 """
618635 assert callable (run_cls ), f"run_cls={ run_cls } not a callable"
619- already_yielded = {}
636+ already_yielded = {} # type: ignore[var-annotated]
620637 reset_names = set (reset_names ) if reset_names else set () # type: ignore[assignment]
621638 str_kws = dict (with_shape = True , with_device = True )
622639 has_cuda = any (
@@ -647,6 +664,8 @@ def forward(self, x):
647664 if verbose :
648665 print (f"[run_aligned] run_cls={ run_cls } " )
649666 print (f"[run_aligned] run_cls_kwargs={ run_cls_kwargs } " )
667+ if replay_configuration :
668+ print (f"[run_aligned] replay={ replay_configuration } " )
650669
651670 def _check_tensor_ (name , obj , flip_type = False ):
652671 if flip_type :
0 commit comments