@@ -1181,6 +1181,7 @@ def max_diff(
11811181 if exp_cpu .size == got_cpu .size
11821182 else (np .inf , np .inf , np .inf , 0 , np .inf )
11831183 )
1184+ argm = None
11841185 else :
11851186 abs_diff , rel_diff , sum_diff , n_diff , nan_diff = (
11861187 float (diff .max ()),
@@ -1189,6 +1190,7 @@ def max_diff(
11891190 float (diff .size ),
11901191 float (ndiff .sum ()),
11911192 )
1193+ argm = tuple (map (int , np .unravel_index (diff .argmax (), diff .shape )))
11921194 if verbose >= 10 and (abs_diff >= 10 or rel_diff >= 10 ):
11931195 # To understand the value it comes from.
11941196 if debug_info :
@@ -1219,7 +1221,9 @@ def max_diff(
12191221 f"_index={ _index } "
12201222 )
12211223
1222- res = dict (abs = abs_diff , rel = rel_diff , sum = sum_diff , n = n_diff , dnan = nan_diff )
1224+ res = dict (
1225+ abs = abs_diff , rel = rel_diff , sum = sum_diff , n = n_diff , dnan = nan_diff , argm = argm
1226+ )
12231227 if hist :
12241228 if isinstance (hist , bool ):
12251229 hist = np .array ([0 , 0.0001 , 0.001 , 0.01 , 0.1 , 1 , 10 , 100 ], dtype = diff .dtype )
@@ -1284,8 +1288,10 @@ def max_diff(
12841288 float (diff .numel ()),
12851289 float (ndiff .sum ()),
12861290 )
1291+ argm = tuple (map (int , torch .unravel_index (diff .argmax (), diff .shape )))
12871292 elif got_cpu .numel () == exp_cpu .numel ():
12881293 abs_diff , rel_diff , sum_diff , n_diff , nan_diff = (0.0 , 0.0 , 0.0 , 0.0 , 0.0 )
1294+ argm = None
12891295 else :
12901296 abs_diff , rel_diff , sum_diff , n_diff , nan_diff = (
12911297 np .inf ,
@@ -1294,6 +1300,7 @@ def max_diff(
12941300 np .inf ,
12951301 np .inf ,
12961302 )
1303+ argm = None
12971304
12981305 if verbose >= 10 and (abs_diff >= 10 or rel_diff >= 10 ):
12991306 # To understand the value it comes from.
@@ -1325,7 +1332,9 @@ def max_diff(
13251332 f"_index={ _index } "
13261333 )
13271334
1328- res = dict (abs = abs_diff , rel = rel_diff , sum = sum_diff , n = n_diff , dnan = nan_diff )
1335+ res = dict (
1336+ abs = abs_diff , rel = rel_diff , sum = sum_diff , n = n_diff , dnan = nan_diff , argm = argm
1337+ )
13291338 if hist :
13301339 if isinstance (hist , bool ):
13311340 hist = torch .tensor (
@@ -1478,6 +1487,8 @@ def string_diff(diff: Dict[str, Any]) -> str:
14781487 rows .append (f"#{ v } { k } " )
14791488 suffix = "-" .join (rows )
14801489 suffix = f"/{ suffix } "
1490+ if "argm" in diff :
1491+ suffix += f", argmax={ diff ['argm' ]} "
14811492 if diff .get ("dnan" , None ):
14821493 if diff ["abs" ] == 0 or diff ["rel" ] == 0 :
14831494 return f"abs={ diff ['abs' ]} , rel={ diff ['rel' ]} , dnan={ diff ['dnan' ]} { suffix } "
0 commit comments