@@ -3330,7 +3330,7 @@ def kernel(X, stride_xm, stride_xk, Y, stride_yk, stride_yn, W, stride_wn, strid
33303330 for M , N , K in itertools .product ([32 , 64 , 128 ], [32 , 64 , 128 ], [64 , 128 ])
33313331 for col_a , col_b in itertools .product ([True , False ], repeat = 2 )
33323332 for type_a in ["e2m1" , "e4m3" , "e5m2" ]
3333- for type_b in ["e4m3" , "e5m2" ]
3333+ for type_b in ["e4m3" , "e5m2" , "bf16" ]
33343334 for mma in ([32 , 16 ] if is_hip () else [16 ])
33353335 for kpack in ([1 , 2 ] if is_hip () else [1 ])])
33363336def test_scaled_dot (M , N , K , col_a , col_b , type_a , type_b , num_warps , mma , kpack , device ):
@@ -3351,7 +3351,7 @@ def test_scaled_dot(M, N, K, col_a, col_b, type_a, type_b, num_warps, mma, kpack
33513351 def dot_scale_kernel (a_base , stride_a0 , stride_a1 , a_scale , b_base , stride_b0 , stride_b1 , out ,
33523352 BLOCK_M : tl .constexpr , BLOCK_N : tl .constexpr , BLOCK_K : tl .constexpr , type_a : tl .constexpr ,
33533353 type_b : tl .constexpr ):
3354- tl .static_assert (type_b == "e4m3" or type_b == "e5m2" , "type_b must be fp8" )
3354+ tl .static_assert (( type_b == "e4m3" or type_b == "e5m2" ) or type_b == "bf16" , "type_b must be fp8 or bf16 " )
33553355 IS_FP8 : tl .constexpr = type_a == "e4m3" or type_a == "e5m2"
33563356 DIV_FACTOR : tl .constexpr = 1 if IS_FP8 else 2
33573357 PACKED_BLOCK_K_A : tl .constexpr = BLOCK_K // DIV_FACTOR
@@ -3442,7 +3442,7 @@ def mxfp_to_bf16_kernel(
34423442
34433443 def dot_scale_ref (x , scale , y , type_x , type_y ):
34443444 e_bits , m_bits = {"e2m1" : (2 , 1 ), "e4m3" : (4 , 3 ), "e5m2" : (5 , 2 )}[type_x ]
3445- type_fp8_y = {"e4m3" : torch .float8_e4m3fn , "e5m2" : torch .float8_e5m2 }[type_y ]
3445+ type_y = {"e4m3" : torch .float8_e4m3fn , "e5m2" : torch .float8_e5m2 , "bf16" : torch . bfloat16 }[type_y ]
34463446
34473447 comp_dtype = torch .bfloat16
34483448
@@ -3455,7 +3455,7 @@ def dot_scale_ref(x, scale, y, type_x, type_y):
34553455 mxfp_to_bf16_kernel [grid ](x , scale , x_upcast , scale .numel (), e_bits , m_bits , BLOCK_SIZE , num_warps = num_warps )
34563456 assert x_upcast .isfinite ().all ()
34573457
3458- y_upcast = y .view (type_fp8_y ).to (comp_dtype )
3458+ y_upcast = y .view (type_y ).to (comp_dtype )
34593459
34603460 class AccumulateInFp32 :
34613461
@@ -3467,28 +3467,30 @@ def __exit__(self, exc_type, exc_val, exc_tb):
34673467 torch .backends .cuda .matmul .allow_bf16_reduced_precision_reduction = self .prev_value
34683468
34693469 with AccumulateInFp32 ():
3470- return torch .matmul (x_upcast . to ( comp_dtype ) , y_upcast . to ( comp_dtype ) )
3470+ return torch .matmul (x_upcast , y_upcast )
34713471
34723472 torch .manual_seed (0 )
34733473
3474- def create_uint8 (shape , col_major = False , max_val = 255 ):
3474+ def make_arg (shape , ty , col_major = False , max_val = 255 ):
34753475 if col_major :
34763476 shape = shape [:- 2 ] + (shape [- 1 ], shape [- 2 ])
3477- ret = torch .randint (max_val + 1 , shape , dtype = torch .uint8 , device = device )
3477+ if ty == "bf16" :
3478+ ret = torch .randn (shape , dtype = torch .bfloat16 , device = device )
3479+ # Clamp to avoid relative error issues
3480+ ret .clamp_ (- 2 ** 15 , 2 ** 15 - 1 )
3481+ else :
3482+ ret = torch .randint (max_val + 1 , shape , dtype = torch .uint8 , device = device )
34783483 if col_major :
34793484 ret = ret .mT
34803485 return ret
34813486
34823487 DIV_FACTOR = 2 if type_a == "e2m1" else 1
3483- x = create_uint8 ((M , K // DIV_FACTOR ), col_major = col_a )
3484- y = create_uint8 ((K , N ), col_major = col_b )
3488+ x = make_arg ((M , K // DIV_FACTOR ), type_a , col_major = col_a )
3489+ y = make_arg ((K , N ), type_b , col_major = col_b )
34853490
34863491 # sample scales that don't overflow as otherwise it's implementation defined (underflowing is alright)
3487- # We substract a reasonably high number (64) so that the sum of all the mxfp elements does not overflow
3488- m_bytes = int (type_a [1 ])
3489- bias_type_a = 1 << (m_bytes - 1 ) - 1
3490- max_exponent_type_a = (1 << m_bytes ) - 1 - bias_type_a
3491- scale_x = create_uint8 ((M , K // 32 ), max_val = 255 - max_exponent_type_a - 64 )
3492+ # Max scale= 2**15
3493+ scale_x = make_arg ((M , K // 32 ), "e8m0" , max_val = 127 + 15 )
34923494
34933495 def make_finite (x , dtype ):
34943496 # e5m2 has too many non-finite values when sampled uniformly (1 / 32) and
@@ -3513,7 +3515,6 @@ def make_finite(x, dtype):
35133515
35143516 z_ref = dot_scale_ref (x , scale_x , y , type_a , type_b )
35153517
3516- # generous rtol as we are sampling the whole range of floats
35173518 torch .testing .assert_close (z , z_ref , atol = 1e-5 , rtol = 1e-2 )
35183519
35193520 # make sure ld/st are vectorized
0 commit comments