Skip to content

Commit fb75628

Browse files
committed
fix device
1 parent 917676f commit fb75628

File tree

1 file changed

+13
-1
lines changed

1 file changed

+13
-1
lines changed

onnx_diagnostic/helpers/helper.py

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1262,8 +1262,19 @@ def max_diff(
12621262
# is likely caused by nans
12631263
exp_cpu = expected.to(torch.float64).nan_to_num(1e10)
12641264
got_cpu = got.to(torch.float64).nan_to_num(1e10)
1265+
if got_cpu.device != exp_cpu.device:
1266+
if torch.device("cuda:0") in {got_cpu.device, exp_cpu.device}:
1267+
got_cpu = got_cpu.to("cuda:0")
1268+
exp_cpu = exp_cpu.to("cuda:0")
1269+
expected = expected.to("cuda:0")
1270+
got = got.to("cuda:0")
1271+
else:
1272+
got_cpu = got_cpu.detach().to("cpu")
1273+
exp_cpu = exp_cpu.detach().to("cpu")
1274+
expected = expected.to("cpu")
1275+
got = got.to("cpu")
12651276
diff = (got_cpu - exp_cpu).abs()
1266-
ndiff = (expected.isnan().cpu().to(int) - got.isnan().to(int)).abs()
1277+
ndiff = (expected.isnan().to(int) - got.isnan().to(int)).abs()
12671278
rdiff = diff / (exp_cpu.abs() + 1e-3)
12681279
if diff.numel() > 0:
12691280
abs_diff, rel_diff, sum_diff, n_diff, nan_diff = (
@@ -1320,6 +1331,7 @@ def max_diff(
13201331
hist = torch.tensor(
13211332
[0, 0.0001, 0.001, 0.01, 0.1, 1, 10, 100], dtype=diff.dtype
13221333
)
1334+
hist = hist.to(diff.device)
13231335
ind = torch.bucketize(diff.reshape((-1,)), hist, right=False)
13241336
cou = torch.bincount(ind, minlength=ind.shape[0] + 1)
13251337
res["rep"] = dict(

0 commit comments

Comments
 (0)