Skip to content

Commit da3d437

Browse files
authored
[Kernels] Support A100 upcasting for mxfp4 (#8428)
1 parent 4e0d041 commit da3d437

File tree

1 file changed

+13
-11
lines changed
  • python/triton_kernels/triton_kernels/tensor_details/layout_details

1 file changed

+13
-11
lines changed

python/triton_kernels/triton_kernels/tensor_details/layout_details/hopper_value.py

Lines changed: 13 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
import triton
33
import triton.language as tl
44
from .base import Layout
5+
from triton_kernels.target_info import cuda_capability_geq
56

67

78
def right_shift_unsigned(x, shift):
@@ -228,23 +229,25 @@ def _unshuffle_triton(x, mma_version: tl.constexpr):
228229

229230
@triton.jit
230231
def _unpack_fp4_to_bf16_triton(x):
231-
# For now we implement just H100 support (mul.bf16x2)
232-
# A100 support is possible via fma
232+
# Use fma on a100 as there is no mul.bf16x2.
233+
use_mul: tl.constexpr = cuda_capability_geq(9)
234+
op_instr: tl.constexpr = "mul.bf16x2" if use_mul else "fma.rn.bf16x2"
235+
op_suffix: tl.constexpr = "" if use_mul else ", z"
233236
r0, r1 = tl.inline_asm_elementwise(
234-
r"""
235-
{
236-
.reg .b32 b, c, d<7>, scale;
237+
asm=f"""{{
238+
.reg .b32 b, c, z, d<7>, scale;
237239
.reg .b32 bias;
240+
mov.b32 z, 0;
238241
mov.b32 bias, 0x7e807e80; // 2 ** 126 == 2 ** (bias_bf16 - bias_fp2)
239242
// We add the missing bias to the scale directly
240243
and.b32 $0, $4, 0b10000001110000001000000111000000;
241-
mul.bf16x2 $0, $0, bias;
244+
{op_instr} $0, $0, bias{op_suffix};
242245
shl.b32 b, $4, 3;
243246
and.b32 $1, b, 0b10000001110000001000000111000000;
244-
mul.bf16x2 $1, $1, bias;
247+
{op_instr} $1, $1, bias{op_suffix};
245248
shl.b32 c, $4, 6;
246249
and.b32 $2, c, 0b10000001110000001000000111000000;
247-
mul.bf16x2 $2, $2, bias;
250+
{op_instr} $2, $2, bias{op_suffix};
248251
// Unpack last two elements
249252
shl.b32 d0, $4, 1;
250253
and.b32 d1, d0, 0b10000000000000001000000000000000;
@@ -254,9 +257,8 @@ def _unpack_fp4_to_bf16_triton(x):
254257
shr.b32 d5, $4, 7;
255258
and.b32 d6, d5, 0b00000000010000000000000001000000;
256259
or.b32 $3, d4, d6;
257-
mul.bf16x2 $3, $3, bias;
258-
}
259-
""",
260+
{op_instr} $3, $3, bias{op_suffix};
261+
}}""",
260262
constraints="=r,=r,=r,=r,r",
261263
args=[x],
262264
dtype=(tl.bfloat16, tl.bfloat16),

0 commit comments

Comments
 (0)