22import triton
33import triton .language as tl
44from .base import Layout
5+ from triton_kernels .target_info import cuda_capability_geq
56
67
78def right_shift_unsigned (x , shift ):
@@ -228,23 +229,25 @@ def _unshuffle_triton(x, mma_version: tl.constexpr):
228229
229230@triton .jit
230231def _unpack_fp4_to_bf16_triton (x ):
231- # For now we implement just H100 support (mul.bf16x2)
232- # A100 support is possible via fma
232+ # Use fma on a100 as there is no mul.bf16x2.
233+ use_mul : tl .constexpr = cuda_capability_geq (9 )
234+ op_instr : tl .constexpr = "mul.bf16x2" if use_mul else "fma.rn.bf16x2"
235+ op_suffix : tl .constexpr = "" if use_mul else ", z"
233236 r0 , r1 = tl .inline_asm_elementwise (
234- r"""
235- {
236- .reg .b32 b, c, d<7>, scale;
237+ asm = f"""{{
238+ .reg .b32 b, c, z, d<7>, scale;
237239 .reg .b32 bias;
240+ mov.b32 z, 0;
238241 mov.b32 bias, 0x7e807e80; // 2 ** 126 == 2 ** (bias_bf16 - bias_fp2)
239242 // We add the missing bias to the scale directly
240243 and.b32 $0, $4, 0b10000001110000001000000111000000;
241- mul.bf16x2 $0, $0, bias;
244+ { op_instr } $0, $0, bias{ op_suffix } ;
242245 shl.b32 b, $4, 3;
243246 and.b32 $1, b, 0b10000001110000001000000111000000;
244- mul.bf16x2 $1, $1, bias;
247+ { op_instr } $1, $1, bias{ op_suffix } ;
245248 shl.b32 c, $4, 6;
246249 and.b32 $2, c, 0b10000001110000001000000111000000;
247- mul.bf16x2 $2, $2, bias;
250+ { op_instr } $2, $2, bias{ op_suffix } ;
248251 // Unpack last two elements
249252 shl.b32 d0, $4, 1;
250253 and.b32 d1, d0, 0b10000000000000001000000000000000;
@@ -254,9 +257,8 @@ def _unpack_fp4_to_bf16_triton(x):
254257 shr.b32 d5, $4, 7;
255258 and.b32 d6, d5, 0b00000000010000000000000001000000;
256259 or.b32 $3, d4, d6;
257- mul.bf16x2 $3, $3, bias;
258- }
259- """ ,
260+ { op_instr } $3, $3, bias{ op_suffix } ;
261+ }}""" ,
260262 constraints = "=r,=r,=r,=r,r" ,
261263 args = [x ],
262264 dtype = (tl .bfloat16 , tl .bfloat16 ),
0 commit comments