Skip to content

Commit 389578a

Browse files
committed
fix
1 parent 6af4693 commit 389578a

File tree

3 files changed

+34
-15
lines changed

3 files changed

+34
-15
lines changed

CHANGELOGS.rst

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@ Change Logs
88
* :pr:`311`: use custom and local function to use PackedMultiHeadAttention from onnxruntime
99
* :pr:`310`: splits patches into multiple files
1010
* :pr:`308`: add option --save_ep to dump the exported program as well as torch input
11-
* :pr:`304`, :pr:`306`, :pr:`316`, :pr:`317`, :pr:`318`: improves side-by-side comparison, creates command line sbs
11+
* :pr:`304`, :pr:`306`, :pr:`316`, :pr:`317`, :pr:`318`, :pr:`319`: improves side-by-side comparison, creates command line sbs
1212

1313
0.8.2
1414
+++++

onnx_diagnostic/_command_lines_parser.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1234,7 +1234,7 @@ def get_parser_sbs() -> ArgumentParser:
12341234
"--replay-threshold",
12351235
type=float,
12361236
required=False,
1237-
default=1e6,
1237+
default=1e9,
12381238
help="Triggers the replay if the discrepancies are higher than this value.",
12391239
)
12401240
parser.add_argument(

onnx_diagnostic/torch_onnx/sbs.py

Lines changed: 32 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -214,13 +214,11 @@ def select(
214214
:param err_abs: measured discrepancy
215215
:return: True if this should be dumped
216216
"""
217-
if name and self.selected_names:
218-
if name in self.selected_names:
219-
return True
220-
if op_type and self.selected_op_types:
221-
if op_type in self.selected_op_types:
222-
return True
223-
if err_abs is not None and err_abs >= self.threshold:
217+
if name and self.selected_names and name in self.selected_names:
218+
return True
219+
if op_type and self.selected_op_types and op_type in self.selected_op_types:
220+
return True
221+
if err_abs is not None and self.threshold is not None and err_abs >= self.threshold:
224222
return True
225223
return False
226224

@@ -286,15 +284,23 @@ def dump(
286284
if n in onnx_name_to_ep_name:
287285
torch_inputs[n] = torch_results[onnx_name_to_ep_name[n]]
288286
else:
289-
raise AssertionError(f"n={n!r}, onnx_name_to_ep_name={onnx_name_to_ep_name}")
287+
# It is possible that this result only exists in the onnx worlds.
288+
pass
290289
onnx_inputs = {n: onnx_results[n] for n in input_names}
291290
assert (
292291
name in onnx_name_to_ep_name
293292
), f"Unable to find {name!r} in {onnx_name_to_ep_name}"
294-
expected_outputs = (torch_results[onnx_name_to_ep_name[name]],)
293+
expected_outputs_and_mapping = dict(
294+
expected=(torch_results[onnx_name_to_ep_name[name]],),
295+
mapping={
296+
k: onnx_name_to_ep_name[k] for k in input_names if k in onnx_name_to_ep_name
297+
},
298+
)
295299
torch.save(torch_inputs, os.path.join(folder, "torch_inputs.pt"))
296300
torch.save(onnx_inputs, os.path.join(folder, "onnx_inputs.pt"))
297-
torch.save(expected_outputs, os.path.join(folder, "torch_outputs.pt"))
301+
torch.save(
302+
expected_outputs_and_mapping, os.path.join(folder, "torch_outputs_and_mapping.pt")
303+
)
298304
if verbose:
299305
print(f"[ReplayConfiguration.dump] done {folder!r}")
300306
return folder
@@ -374,7 +380,9 @@ def set_diff(self, diff: Dict[str, Any]):
374380
return self
375381

376382
@property
377-
def key(self) -> Tuple[int, int, int, str, str]:
383+
def key(
384+
self,
385+
) -> Tuple[Optional[int], Optional[int], Optional[int], Optional[str], Optional[str]]:
378386
"Creates a unique identifier."
379387
return (
380388
self.ep_id_node,
@@ -384,8 +392,17 @@ def key(self) -> Tuple[int, int, int, str, str]:
384392
self.onnx_name,
385393
)
386394

387-
def check(self, already_yielded: Dict[Tuple[int, int, int, str, str], int]) -> Self:
395+
def check(
396+
self,
397+
already_yielded: Dict[
398+
Tuple[Optional[int], Optional[int], Optional[int], Optional[str], Optional[str]],
399+
int,
400+
],
401+
) -> Self:
388402
"Checks a record was not already yielded."
403+
if self.onnx_op_type == "reset":
404+
# no record for this one
405+
return self
389406
key = self.key
390407
assert key not in already_yielded, (
391408
f"Record with key={key} was already yielded, "
@@ -616,7 +633,7 @@ def forward(self, x):
616633
-v 1 --atol=0.1 --rtol=1
617634
"""
618635
assert callable(run_cls), f"run_cls={run_cls} not a callable"
619-
already_yielded = {}
636+
already_yielded = {} # type: ignore[var-annotated]
620637
reset_names = set(reset_names) if reset_names else set() # type: ignore[assignment]
621638
str_kws = dict(with_shape=True, with_device=True)
622639
has_cuda = any(
@@ -647,6 +664,8 @@ def forward(self, x):
647664
if verbose:
648665
print(f"[run_aligned] run_cls={run_cls}")
649666
print(f"[run_aligned] run_cls_kwargs={run_cls_kwargs}")
667+
if replay_configuration:
668+
print(f"[run_aligned] replay={replay_configuration}")
650669

651670
def _check_tensor_(name, obj, flip_type=False):
652671
if flip_type:

0 commit comments

Comments
 (0)