@@ -1486,19 +1486,27 @@ def max_diff(
14861486 dev = dev ,
14871487 )
14881488 if hist :
1489- if isinstance (hist , bool ):
1490- hist = torch .tensor (
1491- [0 , 0.0001 , 0.001 , 0.01 , 0.1 , 1 , 10 , 100 ], dtype = diff .dtype
1492- )
1493- hist = hist .to (diff .device )
1494- ind = torch .bucketize (diff .reshape ((- 1 ,)), hist , right = False )
1495- cou = torch .bincount (ind , minlength = ind .shape [0 ] + 1 )
1496- res ["rep" ] = dict (
1497- zip (
1498- [f">{ x } " for x in hist ],
1499- [int (i ) for i in (cou .sum () - torch .cumsum (cou , 0 ))],
1489+ if isinstance (hist , list ) and len (hist ) == 1 :
1490+ res ["rep" ] = {f">{ hist [0 ]} " : (diff > hist [0 ]).sum ().item ()}
1491+ elif isinstance (hist , list ) and len (hist ) == 2 :
1492+ res ["rep" ] = {
1493+ f">{ hist [0 ]} " : (diff > hist [0 ]).sum ().item (),
1494+ f">{ hist [1 ]} " : (diff > hist [1 ]).sum ().item (),
1495+ }
1496+ else :
1497+ if isinstance (hist , bool ):
1498+ hist = torch .tensor (
1499+ [0 , 0.0001 , 0.001 , 0.01 , 0.1 , 1 , 10 , 100 ], dtype = diff .dtype
1500+ )
1501+ hist = torch .tensor (hist ).to (diff .device )
1502+ ind = torch .bucketize (diff .reshape ((- 1 ,)), hist , right = False )
1503+ cou = torch .bincount (ind , minlength = ind .shape [0 ] + 1 )
1504+ res ["rep" ] = dict (
1505+ zip (
1506+ [f">{ x } " for x in hist ],
1507+ [int (i ) for i in (cou .sum () - torch .cumsum (cou , 0 ))],
1508+ )
15001509 )
1501- )
15021510 return res # type: ignore
15031511
15041512 if isinstance (expected , int ) and isinstance (got , torch .Tensor ):
0 commit comments