Skip to content

Commit ec6b435

Browse files
authored
[Kernels] Force mxfp4->bf16 conversion to use mul.bf16x2 for scaling (#8967)
LLVM doesn't auto-vectorize this very well, and ends up with a mix of vector and scalar muls. I think the cost heuristics gets tripped up by the scale broadcasting which requires unpacking and duplicating the scales, for which we generate ptx like ``` mov.b32 {%rs0, %rs1}, %packed_scales mov.b32 %r1, {%rs0, %rs0} mov.b32 %r2, {%rs1, %rs1} ``` However, ptxas can fuse this into the multiply e.g. ``` HMUL2.BF16_V2 R90, R90, R100.H0_H0 HMUL2.BF16_V2 R91, R91, R100.H1_H1 ``` where the movs have become the register modifier in the instruction. This gives a modest 1% speedup on non-persistent bf16xmxfp4 MoE.
1 parent 81526ff commit ec6b435

File tree

1 file changed

+17
-1
lines changed
  • python/triton_kernels/triton_kernels/tensor_details/layout_details

1 file changed

+17
-1
lines changed

python/triton_kernels/triton_kernels/tensor_details/layout_details/hopper_value.py

Lines changed: 17 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -291,6 +291,22 @@ def _unpack_fp4_to_bf16_triton(x):
291291
return x
292292

293293

294+
@triton.jit
295+
def mul_bf16x2(a, b):
296+
use_mul: tl.constexpr = cuda_capability_geq(9)
297+
op_instr: tl.constexpr = "mul.bf16x2" if use_mul else "fma.rn.bf16x2"
298+
op_suffix: tl.constexpr = "" if use_mul else ", z"
299+
300+
return tl.inline_asm_elementwise(
301+
asm=f"{op_instr} $0, $1, $2{op_suffix};",
302+
constraints="=r,r,r",
303+
args=[a, b],
304+
dtype=tl.bfloat16,
305+
is_pure=True,
306+
pack=2,
307+
)
308+
309+
294310
@triton.jit
295311
def mxfp4_to_bf16_triton(x, scale, mx_axis: tl.constexpr):
296312
"""
@@ -345,5 +361,5 @@ def mxfp4_to_bf16_triton(x, scale, mx_axis: tl.constexpr):
345361
scale = scale.reshape(x.shape)
346362

347363
# Combine scale and x
348-
x = x * scale
364+
x = mul_bf16x2(x, scale)
349365
return x

0 commit comments

Comments
 (0)