@@ -38,27 +38,26 @@ def wmma_scaled(a, a_scale, a_format, b, b_scale, b_format, acc, _semantic=None)
3838 Args:
3939 a (tensor): The operand A to be multiplied.
4040 a_scale (tensor): Scale factor for operand A.
41- a_format (str): Format of the operand A. Available formats: `e2m1' .
41+ a_format (str): Format of the operand A. Available formats: `e2m1`, `e4m3`, `e5m2` .
4242 b (tensor): The operand B to be multiplied.
4343 b_scale (tensor): Scale factor for operand B.
44- b_format (str): Format of the operand B. Available formats: `e2m1' .
44+ b_format (str): Format of the operand B. Available formats: `e2m1`, `e4m3`, `e5m2` .
4545 acc (tensor): Accumulator tensor.
4646 """
4747 _verify_wmma (3 , a , b , acc )
4848 if a_format .value == "e2m1" :
4949 wmma_layout = a .type .layout .parent
50- assert isinstance (wmma_layout , AMDWMMALayout ) and wmma_layout .instr_shape == ( 16 , 16 , 64 ) , \
51- "e2m1 format expects instr_shape to be ( 16, 16, 64) "
50+ assert isinstance (wmma_layout , AMDWMMALayout ) and wmma_layout .instr_shape == [ 16 , 16 , 64 ] , \
51+ "e2m1 format expects instr_shape to be [ 16, 16, 64] "
5252 if b_format .value == "e2m1" :
5353 wmma_layout = b .type .layout .parent
54- assert isinstance (wmma_layout , AMDWMMALayout ) and wmma_layout .instr_shape == ( 16 , 16 , 64 ) , \
55- "e2m1 format expects instr_shape to be ( 16, 16, 64) "
54+ assert isinstance (wmma_layout , AMDWMMALayout ) and wmma_layout .instr_shape == [ 16 , 16 , 64 ] , \
55+ "e2m1 format expects instr_shape to be [ 16, 16, 64] "
5656
5757 acc_layout = acc .type .layout
58- assert isinstance (acc_layout , AMDWMMALayout ) and acc_layout .instr_shape == ( 16 , 16 , 128 ) , \
59- "accumulator tensor's layout must be ( 16, 16, 128) "
58+ assert isinstance (acc_layout , AMDWMMALayout ) and acc_layout .instr_shape == [ 16 , 16 , 128 ] , \
59+ "accumulator tensor's layout must be [ 16, 16, 128] "
6060
61- # TODO: Add more formats
6261 assert a_format .value in {"e2m1" , "e4m3" , "e5m2" }, f"Unsupported lhs_format: { a_format .value } "
6362 assert b_format .value in {"e2m1" , "e4m3" , "e5m2" }, f"Unsupported rhs_format: { b_format .value } "
6463
0 commit comments