Skip to content

Commit 917676f

Browse files
committed
tiny modif
1 parent 82657aa commit 917676f

File tree

1 file changed

+3
-3
lines changed

1 file changed

+3
-3
lines changed

onnx_diagnostic/helpers/helper.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1260,10 +1260,10 @@ def max_diff(
12601260
return dict(abs=np.inf, rel=np.inf, sum=np.inf, n=np.inf, dnan=np.inf)
12611261
# nan are replace by 1e10, any discrepancies in that order of magnitude
12621262
# is likely caused by nans
1263-
exp_cpu = expected.to(torch.float64).cpu().nan_to_num(1e10)
1264-
got_cpu = got.to(torch.float64).cpu().nan_to_num(1e10)
1263+
exp_cpu = expected.to(torch.float64).nan_to_num(1e10)
1264+
got_cpu = got.to(torch.float64).nan_to_num(1e10)
12651265
diff = (got_cpu - exp_cpu).abs()
1266-
ndiff = (expected.isnan().cpu().to(int) - got.isnan().cpu().to(int)).abs()
1266+
ndiff = (expected.isnan().cpu().to(int) - got.isnan().to(int)).abs()
12671267
rdiff = diff / (exp_cpu.abs() + 1e-3)
12681268
if diff.numel() > 0:
12691269
abs_diff, rel_diff, sum_diff, n_diff, nan_diff = (

0 commit comments

Comments
 (0)