@@ -1262,8 +1262,19 @@ def max_diff(
12621262 # is likely caused by nans
12631263 exp_cpu = expected .to (torch .float64 ).nan_to_num (1e10 )
12641264 got_cpu = got .to (torch .float64 ).nan_to_num (1e10 )
1265+ if got_cpu .device != exp_cpu .device :
1266+ if torch .device ("cuda:0" ) in {got_cpu .device , exp_cpu .device }:
1267+ got_cpu = got_cpu .to ("cuda:0" )
1268+ exp_cpu = exp_cpu .to ("cuda:0" )
1269+ expected = expected .to ("cuda:0" )
1270+ got = got .to ("cuda:0" )
1271+ else :
1272+ got_cpu = got_cpu .detach ().to ("cpu" )
1273+ exp_cpu = exp_cpu .detach ().to ("cpu" )
1274+ expected = expected .to ("cpu" )
1275+ got = got .to ("cpu" )
12651276 diff = (got_cpu - exp_cpu ).abs ()
1266- ndiff = (expected .isnan ().cpu (). to (int ) - got .isnan ().to (int )).abs ()
1277+ ndiff = (expected .isnan ().to (int ) - got .isnan ().to (int )).abs ()
12671278 rdiff = diff / (exp_cpu .abs () + 1e-3 )
12681279 if diff .numel () > 0 :
12691280 abs_diff , rel_diff , sum_diff , n_diff , nan_diff = (
@@ -1320,6 +1331,7 @@ def max_diff(
13201331 hist = torch .tensor (
13211332 [0 , 0.0001 , 0.001 , 0.01 , 0.1 , 1 , 10 , 100 ], dtype = diff .dtype
13221333 )
1334+ hist = hist .to (diff .device )
13231335 ind = torch .bucketize (diff .reshape ((- 1 ,)), hist , right = False )
13241336 cou = torch .bincount (ind , minlength = ind .shape [0 ] + 1 )
13251337 res ["rep" ] = dict (
0 commit comments