Skip to content

Commit 14373ae

Browse files
authored
[Kernels] Support hopper hbm swizzling in persistent matmul (#8917)
1 parent f3a2261 commit 14373ae

File tree

7 files changed

+87
-28
lines changed

7 files changed

+87
-28
lines changed

python/triton_kernels/triton_kernels/matmul.py

Lines changed: 5 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
from triton_kernels import target_info
1212
from triton_kernels.numerics import InFlexData, OutFlexData
1313
from triton_kernels.target_info import is_cuda
14+
from triton_kernels.tensor_details.layout_details.hopper_scale import HopperMXScaleLayout
1415
# details
1516
from .matmul_details._matmul import _matmul
1617
from .matmul_details._p_matmul import _p_matmul, get_per_device_per_stream_alloc_fn
@@ -104,6 +105,8 @@ class PrecisionConfig:
104105
def get_swap_xw(precision_config, opt_flags):
105106
if target_info.cuda_capability_geq(10, 0):
106107
return precision_config.b_mx_scale is not None and opt_flags.block_m <= 64 and opt_flags.is_persistent
108+
elif target_info.cuda_capability_geq(9, 0):
109+
return precision_config.b_mx_scale is not None and opt_flags.is_persistent
107110

108111
return False
109112

@@ -296,8 +299,6 @@ def matmul(a, b, bias,
296299
# Currently we don't support tma if y is column major; may revisit later if this becomes an issue.
297300
(c is None or c.stride(-1) == 1) and
298301
(c_acc_in is None or c_acc_is_c) and
299-
# for simulated MXFP, not supported
300-
(b_scale is None or target_info.has_native_mxfp()) and
301302
# if ragged dimension is K, w must be either padded or row major to ensure alignment
302303
(ragged_dimension != "K" or b.stride(-1) == 1 or b_ragged_metadata.slice_sizes_divisibility is not None)
303304
)
@@ -308,8 +309,6 @@ def matmul(a, b, bias,
308309
# which is too big.
309310
can_use_tma = False
310311
has_gather_tma = has_gather and target_info.has_tma_gather()
311-
# hopper w/ mxfp4 doesn't support TMA
312-
can_use_tma = can_use_tma and (torch.cuda.get_device_capability()[0] > 9 or bitwidth(b.dtype) != 4)
313312
can_use_split_k = scatter_indx is None and not a_has_mx and not b_has_mx and ragged_dimension != "K"
314313
block_k = None
315314
if ragged_dimension == "K":
@@ -338,8 +337,6 @@ def matmul(a, b, bias,
338337
assert K == K_W
339338
a_has_tma = opt_flags.is_persistent and (has_gather_tma or not has_gather)
340339
even_K = (K % opt_flags.block_k == 0)
341-
if b_scale is not None and opt_flags.is_persistent and not target_info.has_native_mxfp():
342-
raise NotImplementedError("Must use non-persistent kernel for simulated MXFP")
343340
if b_scale is not None and b_scale.storage.layout.name is not None and not opt_flags.is_persistent and target_info.has_native_mxfp():
344341
raise NotImplementedError("Must use persistent kernel and be TMA-compliant for native MXFP")
345342
# fused activation
@@ -425,8 +422,8 @@ def matmul(a, b, bias,
425422
if b_scale_has_tma:
426423
scale_block_k = opt_flags.block_k // int(MXFP_BLOCK_SIZE)
427424
b_scale_storage = b_scale.storage
428-
b_scale_tma_block_size = [opt_flags.block_n, scale_block_k] if b_transpose else [scale_block_k, opt_flags.block_n]
429-
if isinstance(b_scale.storage.layout, StridedLayout):
425+
b_scale_tma_block_size = [scale_block_k, opt_flags.block_n]
426+
if isinstance(b_scale_storage.layout, (StridedLayout, HopperMXScaleLayout)):
430427
b_scale_storage = _canonicalize_storage(b_scale.storage, 3, None)
431428
b_scale_tma_block_size = [1] + b_scale_tma_block_size
432429
b_scale_tensor_or_tma = b_scale_storage.make_tma(b_scale_tma_block_size, "dense", is_scale=True)

python/triton_kernels/triton_kernels/matmul_details/_matmul.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -116,7 +116,6 @@ def _matmul(
116116
tl.static_assert(BLOCK_K % MX_PACK_DIVISOR == 0, f"{BLOCK_K=} must be a multiple of {MX_PACK_DIVISOR=}")
117117
tl.static_assert(SWIZZLE_MX_VALUE == "HOPPER_VALUE" or SWIZZLE_MX_VALUE is None, "Only Hopper swizzling is supported for values")
118118

119-
# TODO: refactor if/else when triton front end improves
120119
if SWIZZLE_MX_VALUE == "HOPPER_VALUE":
121120
tl.static_assert(is_w_mxfp4, "Only mxfp4 is supported for HOPPER swizzling")
122121
tl.static_assert(not is_x_microscaled)

python/triton_kernels/triton_kernels/matmul_details/_p_matmul.py

Lines changed: 56 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,8 @@
1313
compute_scale,
1414
)
1515
from triton_kernels.numerics_details.mxfp_details._downcast_to_mxfp import MXFP_BLOCK_SIZE
16+
from triton_kernels.tensor_details.layout_details.hopper_scale import unswizzle_mxfp4_scale_hopper
17+
from triton_kernels.tensor_details.layout_details.hopper_value import mxfp4_to_bf16_triton
1618
from ._common import (
1719
compute_offsets,
1820
get_scaled_dot_format_string,
@@ -112,25 +114,48 @@ def _p_matmul(
112114
if Y_TMA_MODE is not None:
113115
Y = tl.make_tensor_descriptor(YPtr, Y.shape, Y.strides[:-1] + (1,), Y.block_shape)
114116

117+
w_type: tl.constexpr = get_dtype(W)
115118
is_w_microscaled: tl.constexpr = WMxScale is not None
119+
is_x_microscaled: tl.constexpr = XMxScale is not None
120+
is_w_mxfp4: tl.constexpr = w_type == tl.uint8 and is_w_microscaled
116121
tl.static_assert(not is_w_microscaled or W_TRANSPOSE, "NYI. Non-transposed mxfp4 weights")
117122
MX_PACK_DIVISOR: tl.constexpr = MXFP_BLOCK_SIZE
118123
if is_w_microscaled:
119-
w_type: tl.constexpr = get_dtype(W)
120124
tl.static_assert(w_type == tl.uint8 or (w_type == tl.float8e4nv or w_type == tl.float8e5),
121125
"mx_weight_ptr must be uint8 or fp8")
122126
tl.static_assert(get_dtype(WMxScale) == tl.uint8, "mx_scale_ptr must be uint8")
123127
tl.static_assert(BLOCK_K % MX_PACK_DIVISOR == 0, "BLOCK_K must be a multiple of MX_PACK_DIVISOR")
124-
tl.static_assert(SWIZZLE_MX_SCALE == "BLACKWELL_SCALE" or SWIZZLE_MX_SCALE is None, "Only Blackwell swizzling is supported for scales")
125128

126129
# We have pack 2 fp4 values in a byte
127-
W_PACK_DIVISOR: tl.constexpr = 2 if w_type == tl.uint8 else 1
128-
PACKED_BLOCK_K_W: tl.constexpr = BLOCK_K // W_PACK_DIVISOR
129130
MX_SCALE_BLOCK_K: tl.constexpr = BLOCK_K // MX_PACK_DIVISOR
131+
if SWIZZLE_MX_VALUE == "HOPPER_VALUE":
132+
tl.static_assert(is_w_mxfp4, "Only mxfp4 is supported for HOPPER swizzling")
133+
tl.static_assert(not is_x_microscaled)
134+
# We have pack 2 fp4 values in a byte but we divide the dimension by 2
135+
# when swizzling
136+
W_K_DIVISOR: tl.constexpr = 1
137+
W_K_MULTIPLIER: tl.constexpr = 2
138+
W_N_DIVISOR: tl.constexpr = 4
139+
else:
140+
# We have pack 2 fp4 values in a byte
141+
W_K_DIVISOR: tl.constexpr = 2 if is_w_mxfp4 else 1
142+
W_K_MULTIPLIER: tl.constexpr = 1
143+
W_N_DIVISOR: tl.constexpr = 1
144+
145+
if W_TRANSPOSE:
146+
# When weight is transposed, 2 fp4 values are packed per Byte along
147+
# the contiguous dimension, K.
148+
PACKED_BLOCK_K_W: tl.constexpr = (BLOCK_K // W_K_DIVISOR) * W_K_MULTIPLIER
149+
PACKED_BLOCK_N_W: tl.constexpr = BLOCK_N // W_N_DIVISOR
150+
else:
151+
# When weight is not transposed, fp4 values are *not* packed along
152+
# the contiguous dimension, N.
153+
PACKED_BLOCK_K_W: tl.constexpr = BLOCK_K
154+
PACKED_BLOCK_N_W: tl.constexpr = BLOCK_N // W_K_DIVISOR
130155
else:
131156
PACKED_BLOCK_K_W: tl.constexpr = BLOCK_K
157+
PACKED_BLOCK_N_W: tl.constexpr = BLOCK_N
132158
tl.static_assert(SWIZZLE_MX_SCALE is None)
133-
is_x_microscaled: tl.constexpr = XMxScale is not None
134159
if is_x_microscaled:
135160
x_type: tl.constexpr = get_dtype(X)
136161
tl.static_assert(x_type == tl.float8e4nv, "mx_act_ptr must be float8e4nv")
@@ -202,6 +227,7 @@ def _p_matmul(
202227
else:
203228
shape_m = M
204229
off_n = BLOCK_N * pid_n
230+
off_w_n = PACKED_BLOCK_N_W * pid_n
205231

206232
# ---- offset x ------
207233
if USE_GATHER_TMA:
@@ -283,7 +309,7 @@ def _p_matmul(
283309
x_format: tl.constexpr = get_scaled_dot_format_string(x.dtype)
284310
if is_x_microscaled:
285311
if XMxScalePtrs is not None: # not using TMA for x scale load
286-
off_k_mx = off_k_w // (MX_PACK_DIVISOR // W_PACK_DIVISOR)
312+
off_k_mx = off_k_w // (MX_PACK_DIVISOR // W_K_DIVISOR)
287313
if EVEN_K:
288314
mask_k_scale = tl.full([MX_SCALE_BLOCK_K], True, dtype=tl.int1)
289315
else:
@@ -306,30 +332,47 @@ def _p_matmul(
306332

307333
# --- load w ---
308334
if W_TRANSPOSE:
309-
w = tl.reshape(W.load([off_w_z, off_n, off_k_w]), W.block_shape[1:]).T
335+
w = tl.reshape(W.load([off_w_z, off_w_n, off_k_w]), W.block_shape[1:]).T
310336
else:
311-
w = tl.reshape(W.load([off_w_z, off_k_w, off_n]), W.block_shape[1:])
337+
w = tl.reshape(W.load([off_w_z, off_k_w, off_w_n]), W.block_shape[1:])
312338

313339
# --- load w_scale ---
314340
w_format: tl.constexpr = get_scaled_dot_format_string(w.dtype)
315341
if is_w_microscaled:
316-
off_k_mx = off_k_w // (MX_PACK_DIVISOR // W_PACK_DIVISOR)
317-
tl.static_assert(MX_PACK_DIVISOR % W_PACK_DIVISOR == 0)
342+
off_k_mx = off_k_w // (MX_PACK_DIVISOR // W_K_DIVISOR)
343+
tl.static_assert(MX_PACK_DIVISOR % W_K_DIVISOR == 0)
318344
if SWIZZLE_MX_SCALE == "BLACKWELL_SCALE":
319345
flattened_expt_n_idx = off_w_z * ((N + 127) // 128) + (off_n // 128)
320346
w_scales = WMxScale.load([0, flattened_expt_n_idx, off_k_mx // 4, 0, 0])
321347
w_scales = w_scales.reshape((w_scales.shape[1], w_scales.shape[2] * w_scales.shape[-2] * w_scales.shape[-1]))
322348
w_scales = unswizzle_mx_scale_bw(w_scales)
349+
elif SWIZZLE_MX_SCALE == "HOPPER_SCALE":
350+
# NYI: Hopper swizzling with non-transposed W
351+
tl.static_assert(W_TRANSPOSE)
352+
off_n_scale = pid_n * (BLOCK_N // 32)
353+
off_k_scale = (off_k_w // PACKED_BLOCK_K_W) * MX_SCALE_BLOCK_K * 32
354+
w_scales = WMxScale.load([off_w_z, off_n_scale, off_k_scale])
355+
w_scales = tl.reshape(w_scales, *w_scales.shape[1:])
356+
num_warps: tl.constexpr = tl.extra.cuda.num_warps()
357+
w_scales = unswizzle_mxfp4_scale_hopper(w_scales, mx_axis=1, num_warps=num_warps)
323358
else:
324359
w_scales = WMxScale.load([off_w_z, off_k_mx, off_n])
325360
w_scales = tl.reshape(w_scales, *w_scales.shape[1:]).T
326361

327362
# --- update accumulator ---
328363
if is_w_microscaled:
329-
if SWAP_XW:
330-
acc = tl.dot_scaled(w.T, w_scales, w_format, x.T, x_scales, x_format, acc=acc, fast_math=True)
364+
if SWIZZLE_MX_VALUE == "HOPPER_VALUE":
365+
tl.static_assert(x_format == "bf16")
366+
tl.static_assert(w_format == "e2m1")
367+
tl.static_assert(SWAP_XW)
368+
wT = mxfp4_to_bf16_triton(w.T, w_scales, mx_axis=1)
369+
tl.static_assert(wT.dtype == tl.bfloat16)
370+
acc = tl.dot(wT, x.T, acc, max_num_imprecise_acc=MAX_NUM_IMPRECISE_ACC, allow_tf32=ALLOW_TF32)
331371
else:
332-
acc = tl.dot_scaled(x, x_scales, x_format, w, w_scales, w_format, acc=acc, fast_math=True)
372+
if SWAP_XW:
373+
acc = tl.dot_scaled(w.T, w_scales, w_format, x.T, x_scales, x_format, acc=acc, fast_math=True)
374+
else:
375+
acc = tl.dot_scaled(x, x_scales, x_format, w, w_scales, w_format, acc=acc, fast_math=True)
333376
else:
334377
if SWAP_XW:
335378
acc = tl.dot(w.T, x.T, acc, max_num_imprecise_acc=MAX_NUM_IMPRECISE_ACC, allow_tf32=ALLOW_TF32)

python/triton_kernels/triton_kernels/matmul_details/opt_flags.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
from triton_kernels.target_info import get_cdna_version
88
from triton_kernels.tensor import FP4
99
import torch
10+
from triton_kernels.tensor_details.layout_details.hopper_scale import HopperMXScaleLayout
1011
from .opt_flags_details import opt_flags_amd, opt_flags_nvidia
1112
from triton_kernels.tensor import bitwidth, get_layout
1213

@@ -239,9 +240,15 @@ def make_default_opt_flags_nvidia(
239240
# TMA is slower for batched matmuls with small m/n/k.
240241
if m * n * k < 131072:
241242
is_persistent = False
243+
if (
244+
(b_scale_layout := get_layout(precision_config.b_mx_scale)) is not None and
245+
isinstance(b_scale_layout, HopperMXScaleLayout)
246+
):
247+
# TODO: persistent kernel is currently slower than non-persistent
248+
is_persistent = False
242249
# adjust block_n based on is_persistent signal
243250
block_n = block_n_tma if is_persistent else block_n
244-
# adjut block_m based on is_persistent signal
251+
# adjust block_m based on is_persistent signal
245252
if is_persistent and opt_flags_nvidia.is_x_scale_swizzled(precision_config):
246253
# a mx scale has been swizzled to BlackwellActMXScaleLayout, enforce block_m=128 to align with swizzling layout
247254
block_m = 128

python/triton_kernels/triton_kernels/tensor_details/layout_details/blackwell_scale.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -56,8 +56,9 @@ def unswizzle_data(self, data):
5656
return data[..., :self.K, :self.N]
5757

5858
def swizzle_block_shape(self, block_shape):
59-
assert block_shape[0] >= 128, f"{block_shape[0]=} must be >= 128"
60-
return [1, block_shape[0] // 128, block_shape[1] // 4, 2, 256]
59+
K, N = block_shape
60+
assert N >= 128, f"{block_shape[1]=} must be >= 128"
61+
return [1, N // 128, K // 4, 2, 256]
6162

6263

6364
@triton.jit

python/triton_kernels/triton_kernels/tensor_details/layout_details/hopper_scale.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -65,7 +65,9 @@ def unswizzle_data(self, data):
6565
return data[..., :self.M, :self.K]
6666

6767
def swizzle_block_shape(self, block_shape):
68-
return block_shape
68+
N, K = block_shape[-2:]
69+
assert N % 32 == 0
70+
return [*block_shape[:-2], N // 32, K * 32]
6971

7072

7173
@triton.jit

python/triton_kernels/triton_kernels/tensor_details/layout_details/hopper_value.py

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,8 @@
33
import triton.language as tl
44
from dataclasses import dataclass
55
from .base import Layout
6+
7+
from triton_kernels.numerics_details.mxfp_details._downcast_to_mxfp import MXFP_BLOCK_SIZE
68
from triton_kernels.target_info import cuda_capability_geq
79

810

@@ -211,7 +213,9 @@ def unswizzle_data(self, data):
211213
return data[..., :self.K, :self.N]
212214

213215
def swizzle_block_shape(self, block_shape):
214-
return block_shape
216+
N, K = block_shape[-2:]
217+
assert N % 4 == 0
218+
return [*block_shape[:-2], N // 4, K * 4]
215219

216220

217221
@triton.jit
@@ -329,9 +333,15 @@ def mxfp4_to_bf16_triton(x, scale, mx_axis: tl.constexpr):
329333
is_pure=True,
330334
pack=4,
331335
)
336+
# Sanity check shape
337+
for axis in tl.static_range(len(x.shape)):
338+
if axis == mx_axis:
339+
tl.static_assert(x.shape[axis] == MXFP_BLOCK_SIZE * scale.shape[axis])
340+
else:
341+
tl.static_assert(x.shape[axis] == scale.shape[axis])
332342
# Broadcast scale
333343
scale = scale.expand_dims(mx_axis + 1)
334-
scale = scale.broadcast_to(scale.shape[:mx_axis + 1] + [32] + scale.shape[mx_axis + 2:])
344+
scale = scale.broadcast_to(scale.shape[:mx_axis + 1] + [MXFP_BLOCK_SIZE] + scale.shape[mx_axis + 2:])
335345
scale = scale.reshape(x.shape)
336346

337347
# Combine scale and x

0 commit comments

Comments
 (0)