Skip to content

Commit 7578e3e

Browse files
[mxfp] support EXPT_IS_INNER for MX (#8385)
<!--- The core Triton is a small number of people, and we receive many PRs (thank you!). To help us review your code more quickly, **if you are a new contributor (less than 3 PRs merged) we ask that you complete the following tasks and include the filled-out checklist in your PR description.** Complete the following tasks before sending your PR, and replace `[ ]` with `[x]` to indicate you have done them. --> # New contributor declaration - [x] I am not making a trivial change, such as fixing a typo in a comment. - [x] I have written a PR description following these [rules](https://cbea.ms/git-commit/#why-not-how). - [x] I have run `pre-commit run --from-ref origin/main --to-ref HEAD`. - Select one of the following. - [x] I have added tests. - `/test` for `lit` tests - `/unittest` for C++ tests - `/python/test` for end-to-end tests - [ ] This PR does not need a test because `FILL THIS IN`. - Select one of the following. - [x] I have not added any `lit` tests. - [ ] The `lit` tests I have added follow these [best practices](https://mlir.llvm.org/getting_started/TestingGuide/#filecheck-best-practices), including the "tests should be minimal" section. (Usually running Python code and using the instructions it generates is not minimal.)
1 parent 11af53c commit 7578e3e

File tree

9 files changed

+107
-46
lines changed

9 files changed

+107
-46
lines changed

python/triton_kernels/tests/test_matmul.py

Lines changed: 48 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -246,7 +246,8 @@ class Case:
246246
Case(300, 400, 400, "ragged", "bfloat16", "mxfloat8_e4m3fn", 8, 4, hbm_swizzling=True),
247247
Case(300, 400, 400, "batched", "bfloat16", "mxfloat8_e5m2", 32, 4),
248248
Case(1000, 700, 2, "batched", "bfloat16", "mxfloat4_e2m1", 8, 2),
249-
Case(1, 2880, 2880, "ragged", "bfloat16", "mxfloat4_e2m1", 128, 4),
249+
# Cover (N or K) % 128 == 64 (https://github.com/triton-lang/triton/pull/7203)
250+
Case(1, 1472, 1472, "ragged", "bfloat16", "mxfloat4_e2m1", 128, 4),
250251
Case(16, 256, 256, "ragged", "float8_e5m2", "mxfloat4_e2m1", 128, 4, hbm_swizzling=True),
251252
Case(1000, 704, 832, "batched", "float8_e5m2", "mxfloat4_e2m1", 3, 1, hbm_swizzling=True),
252253
Case(1000, 704, 832, "batched", "float8_e5m2", "mxfloat4_e2m1", 3, 1, hbm_swizzling=True),
@@ -318,6 +319,24 @@ def test_op(m, n, k, split_k, do_gather, do_scatter, fused_scatter, inner_expt_o
318319
n_expts_act, mode, act_dtype_str, weight_dtype_str, block_m, hbm_swizzling, colmajor_mxfp_weight, epilogue_subtile,
319320
x_transpose, w_transpose, y_transpose,
320321
device, opt_flags_scope):
322+
# We catch and re-invoke pytest.skip(), because otherwise pytest may hold a reference to
323+
# the frame that called pytest.skip, including all the tensors, leading to OOM.
324+
skip_message = None
325+
try:
326+
_test_op(m, n, k, split_k, do_gather, do_scatter, fused_scatter, inner_expt_opt, has_y_gammas, is_persistent, n_expts_tot,
327+
n_expts_act, mode, act_dtype_str, weight_dtype_str, block_m, hbm_swizzling, colmajor_mxfp_weight, epilogue_subtile,
328+
x_transpose, w_transpose, y_transpose,
329+
device, opt_flags_scope)
330+
except pytest.skip.Exception as e:
331+
skip_message = str(e)
332+
333+
if skip_message is not None:
334+
pytest.skip(skip_message)
335+
336+
def _test_op(m, n, k, split_k, do_gather, do_scatter, fused_scatter, inner_expt_opt, has_y_gammas, is_persistent, n_expts_tot,
337+
n_expts_act, mode, act_dtype_str, weight_dtype_str, block_m, hbm_swizzling, colmajor_mxfp_weight, epilogue_subtile,
338+
x_transpose, w_transpose, y_transpose,
339+
device, opt_flags_scope):
321340
# TODO: remove when Triton FP8 supports proper RTNE
322341
if is_cuda():
323342
if "float8" in weight_dtype_str and torch.cuda.get_device_capability()[0] < 9:
@@ -327,8 +346,6 @@ def test_op(m, n, k, split_k, do_gather, do_scatter, fused_scatter, inner_expt_o
327346
if weight_dtype_str.startswith("mx"):
328347
if "float8" in act_dtype_str and torch.cuda.get_device_capability()[0] < 10:
329348
pytest.skip("float8 x mx not supported with cuda capability < 10")
330-
if n == 2880 and k == 2880 and torch.cuda.get_device_capability()[0] < 9:
331-
pytest.skip("Not enough memory on A100")
332349

333350
elif is_hip():
334351
if "float8" in act_dtype_str and "mx" in weight_dtype_str and not is_hip_cdna4():
@@ -366,8 +383,21 @@ def test_op(m, n, k, split_k, do_gather, do_scatter, fused_scatter, inner_expt_o
366383
pytest.skip("Hopper swizzling acts on a 64x64 tile (4x1 mma tiles).")
367384

368385
expt_is_inner = (inner_expt_opt is not None)
369-
if expt_is_inner and (mode != "ragged" or "mx" in act_dtype_str or "mx" in weight_dtype_str):
370-
pytest.skip("Not supported yet")
386+
if expt_is_inner:
387+
if mode != "ragged":
388+
pytest.skip("inner_expt_opt only meaningful with ragged")
389+
if "mx" in act_dtype_str and inner_expt_opt != "pad_x":
390+
pytest.skip("inner_expt_opt and act mx only supported with pad_x")
391+
if "mx" in weight_dtype_str:
392+
if inner_expt_opt != "pad_w":
393+
pytest.skip("inner_expt_opt and weight mx only supported with pad_w")
394+
if is_persistent and not hbm_swizzling:
395+
pytest.skip("FIXME: Fatal Python error: Aborted")
396+
if is_hip():
397+
if act_dtype_str == "bfloat16":
398+
pytest.skip("FIXME: failed to translate module to LLVM IR")
399+
if hbm_swizzling:
400+
pytest.skip("NYI: nner_expt_opt and HBM swizzling")
371401

372402
# launch metadata for batched / mx types may not work yet.
373403
torch.manual_seed(0)
@@ -399,6 +429,7 @@ def test_op(m, n, k, split_k, do_gather, do_scatter, fused_scatter, inner_expt_o
399429
opt_flags.update_opt_flags_constraints(constraints)
400430

401431
weight_mxfp = weight_dtype_str.startswith("mx")
432+
weight_mxfp4 = weight_mxfp and "float4" in weight_dtype_str
402433
if weight_mxfp:
403434
weight_dtype_str = weight_dtype_str[2:]
404435
act_mxfp8 = act_dtype_str.startswith("mx")
@@ -422,6 +453,13 @@ def test_op(m, n, k, split_k, do_gather, do_scatter, fused_scatter, inner_expt_o
422453
rdata = gindx = sindx = None
423454

424455
padding_block_k = 32
456+
if hbm_swizzling:
457+
if torch.cuda.get_device_capability()[0] >= 10:
458+
# Blackwell scale swizzling constraint
459+
# https://github.com/triton-lang/triton/blob/814b862166c756d9f33238844f4ac047e0243388/python/triton_kernels/triton_kernels/tensor_details/layout_details/blackwell_scale.py#L45
460+
padding_block_k = 128
461+
elif not is_persistent:
462+
padding_block_k = 64
425463
x_tri, w_tri, bias_tri, gs0_tri, gs1_tri = init_compute_data(m, n, k, rdata, gindx, sindx, n_expts_tot, n_expts_act,
426464
mode, torch.bfloat16 if act_mxfp8 else act_dtype, #
427465
torch.bfloat16 if weight_mxfp else weight_dtype,
@@ -457,7 +495,7 @@ def test_op(m, n, k, split_k, do_gather, do_scatter, fused_scatter, inner_expt_o
457495
# compute layouts
458496
w_layout, w_layout_opts = layout.StridedLayout, dict()
459497
w_scale_layout, w_scale_layout_opts = layout.StridedLayout, dict()
460-
if hbm_swizzling and "float4" in weight_dtype_str:
498+
if hbm_swizzling and weight_mxfp4:
461499
w_layout, w_layout_opts = layout.make_default_matmul_mxfp4_w_layout(mx_axis=mx_axis)
462500
w_scale_layout, w_scale_layout_opts = layout.make_default_matmul_mxfp4_w_scale_layout(
463501
mx_axis=mx_axis, num_warps=8)
@@ -466,7 +504,7 @@ def test_op(m, n, k, split_k, do_gather, do_scatter, fused_scatter, inner_expt_o
466504
if colmajor_mxfp_weight:
467505
w_tri, w_scale_tri = downcast_to_mxfp(w_tri, weight_dtype, axis=mx_axis)
468506
w_ref = upcast_from_mxfp(w_tri, w_scale_tri, torch.bfloat16, axis=mx_axis)
469-
w_tri_dtype = FP4 if "float4" in weight_dtype_str else weight_dtype
507+
w_tri_dtype = FP4 if weight_mxfp4 else weight_dtype
470508
w_tri = wrap_torch_tensor(w_tri, w_tri_dtype)
471509
w_scale_tri = wrap_torch_tensor(w_scale_tri)
472510
# convert layouts
@@ -568,8 +606,8 @@ def _pad_and_block(x: torch.Tensor) -> torch.Tensor:
568606
tri_y = matmul_ogs(x_tri, w_tri, bias_tri, rdata, gindx, sindx, precision_opt,
569607
gammas=gs1_ref, epilogue=epilogue, y=y_tri_in,
570608
inner_routing_data=inner_routing_data)
571-
except (opt_flags.InapplicableConstraint, NotImplementedError):
572-
pytest.skip("inapplicable opt_flags constraint")
609+
except (opt_flags.InapplicableConstraint, NotImplementedError) as e:
610+
pytest.skip(f"inapplicable opt_flags constraint {e}")
573611
if y_tri_in is not None:
574612
assert tri_y.data_ptr() == y_tri_in.data_ptr()
575613
assert tri_y.shape == y_tri_in.shape
@@ -602,7 +640,7 @@ def scale(val, scal):
602640
ref_y = upcast_from_mxfp_torch(ref_y_quant, ref_y_scale, target_dtype=ref_y.dtype, axis=-1)
603641
maxtol = 4e-1
604642
rmstol = 4e-2
605-
elif weight_mxfp and "float4_e2m1" in weight_dtype_str:
643+
elif weight_mxfp4:
606644
if act_is_float8:
607645
maxtol = 8e-2
608646
else:

python/triton_kernels/triton_kernels/matmul_ogs.py

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -646,9 +646,17 @@ def matmul_ogs(x, w, bias,
646646
w_has_tma = opt_flags.is_persistent
647647
w_tensor_or_tma = w_storage.make_tma([1, opt_flags.block_k, opt_flags.block_n], "dense") if w_has_tma else w_storage.data
648648
# create tma descriptor for w_scale
649-
w_scale_tensor_or_tma = w_scale
650649
w_scale_has_tma = opt_flags.is_persistent and w_scale is not None
651-
w_scale_tensor_or_tma = w_scale.storage.make_tma([opt_flags.block_n, opt_flags.block_k], "dense") if w_scale_has_tma else w_scale
650+
w_transpose = w_storage.data.stride()[-2] == 1
651+
if w_scale_has_tma:
652+
w_scale_storage = w_scale.storage
653+
w_scale_tma_block_size = [opt_flags.block_n, opt_flags.block_k] if w_transpose else [opt_flags.block_k, opt_flags.block_n]
654+
if isinstance(w_scale.storage.layout, StridedLayout):
655+
w_scale_storage = _canonicalize_storage(w_scale.storage, 3, None)
656+
w_scale_tma_block_size = [1] + w_scale_tma_block_size
657+
w_scale_tensor_or_tma = w_scale_storage.make_tma(w_scale_tma_block_size, "dense")
658+
else:
659+
w_scale_tensor_or_tma = w_scale
652660
# canonicalize strides
653661
x_strides = [0]*(3 - x_storage.data.ndim) + list(x_storage.data.stride())
654662
x_scale_strides = x_scale.stride() if x_has_mx else (None, None, None)
@@ -663,7 +671,6 @@ def matmul_ogs(x, w, bias,
663671
# (i.e. col-wise). Since this matters when w_has_mx is True and w_transpose
664672
# is True the fast code path, stride(-2) == 1 takes precedence, e.g., vs.
665673
# w_transpose = w_storage.data.stride()[-1] != 1
666-
w_transpose = w_storage.data.stride()[-2] == 1
667674
fused_comm_kwargs = {
668675
"pYPtrs": fused_comm.out_handles,
669676
"ScatterShardIndx": fused_comm.scatter_shard_indx,

python/triton_kernels/triton_kernels/matmul_ogs_details/_common.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -98,13 +98,14 @@ def _load_tile_attrs(
9898
tl.static_assert(M is not None)
9999
expt_id, pid_z, pid_z_out, start_m, block_id, eM = 0, 0, pid_e, 0, pid_m, M
100100
k_tiles = tl.cdiv(tl.load(ExptHist + pid_e), BLOCK_K)
101-
padded_start_off = tl.load(ExptTileOffs + pid_e) * BLOCK_K
101+
padded_start_off_raw = tl.load(ExptTileOffs + pid_e)
102+
padded_start_off = padded_start_off_raw * BLOCK_K
102103
unpadded_start_off = tl.load(ExptOffs + pid_e)
103104
off_k_x = padded_start_off if X_IS_PADDED else unpadded_start_off
104105
# K_W is only used for non-TMA kernel (W bound is handled by TMA on TMA kernel).
105106
if W_IS_PADDED:
106-
off_k_w = padded_start_off
107-
K_W = tl.load(ExptTileOffs + pid_e + 1) * BLOCK_K
107+
off_k_w = padded_start_off_raw * PACKED_BLOCK_K_W
108+
K_W = tl.load(ExptTileOffs + pid_e + 1) * PACKED_BLOCK_K_W
108109
else:
109110
off_k_w = unpadded_start_off
110111
K_W = tl.load(ExptOffs + pid_e + 1)

python/triton_kernels/triton_kernels/matmul_ogs_details/_matmul_ogs.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -131,7 +131,7 @@ def _matmul_ogs(
131131
tl.static_assert(w_type == tl.uint8 or (w_type == tl.float8e4nv or w_type == tl.float8e5),
132132
"mx_weight_ptr must be uint8 or fp8")
133133
tl.static_assert(WMxScale.dtype.element_ty == tl.uint8, "mx_scale_ptr must be uint8")
134-
tl.static_assert(BLOCK_K % MX_PACK_DIVISOR == 0, "BLOCK_K must be a multiple of MX_PACK_DIVISOR")
134+
tl.static_assert(BLOCK_K % MX_PACK_DIVISOR == 0, f"{BLOCK_K=} must be a multiple of {MX_PACK_DIVISOR=}")
135135
tl.static_assert(SWIZZLE_MX_VALUE == "HOPPER_VALUE" or SWIZZLE_MX_VALUE is None, "Only Hopper swizzling is supported for values")
136136

137137
# TODO: refactor if/else when triton front end improves
@@ -247,7 +247,6 @@ def _matmul_ogs(
247247

248248
# TODO: refactor if/else when triton front end improves
249249
if is_w_microscaled:
250-
tl.static_assert(not EXPT_IS_INNER, "Not supported yet")
251250
WMxScale += expt_id * stride_w_mx_e
252251

253252
if SWIZZLE_MX_SCALE == "BLACKWELL_SCALE":
@@ -281,7 +280,8 @@ def _matmul_ogs(
281280
offs_n_scale = (pid_n * SCALE_BLOCK_N + tl.arange(0, SCALE_BLOCK_N)) % N
282281
offs_n_scale = tl.max_contiguous(tl.multiple_of(offs_n_scale, SCALE_BLOCK_N), SCALE_BLOCK_N)
283282
# K dimension must be the last dimension for the scales
284-
offs_k_scale = PACKED_MX_BLOCK * pid_k + tl.arange(0, PACKED_MX_BLOCK)
283+
tl.static_assert(not EXPT_IS_INNER or W_IS_PADDED)
284+
offs_k_scale = off_k_w // PACKED_BLOCK_K_W * PACKED_MX_BLOCK + tl.arange(0, PACKED_MX_BLOCK)
285285
WMxScalePtrs = WMxScale + offs_k_scale.to(index_type)[None, :] * stride_scale_k + offs_n_scale.to(index_type)[:, None] * stride_w_mx_n
286286
else:
287287
WMxScalePtrs = None
@@ -295,7 +295,7 @@ def _matmul_ogs(
295295
XMxScale += start_z.to(index_type) * stride_x_mx_z
296296
if GatherIndx is None:
297297
XMxScale += start_m * stride_x_mx_m
298-
offs_x_k_scale = MX_SCALE_BLOCK_K * pid_k + tl.arange(0, MX_SCALE_BLOCK_K)
298+
offs_x_k_scale = off_k_x // MXFP_BLOCK_SIZE + tl.arange(0, MX_SCALE_BLOCK_K)
299299
XMxScalePtrs = XMxScale + offs_x_m.to(index_type)[:, None] * stride_x_mx_m + offs_x_k_scale.to(index_type)[None, :] * stride_x_mx_k
300300
else:
301301
XMxScalePtrs = None

python/triton_kernels/triton_kernels/matmul_ogs_details/_p_matmul_ogs.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -125,7 +125,6 @@ def _p_matmul_ogs(
125125
tl.static_assert(get_dtype(WMxScale) == tl.uint8, "mx_scale_ptr must be uint8")
126126
tl.static_assert(BLOCK_K % MX_PACK_DIVISOR == 0, "BLOCK_K must be a multiple of MX_PACK_DIVISOR")
127127
tl.static_assert(SWIZZLE_MX_SCALE == "BLACKWELL_SCALE" or SWIZZLE_MX_SCALE is None, "Only Blackwell swizzling is supported for scales")
128-
tl.static_assert(not EXPT_IS_INNER, "Not supported yet")
129128

130129
# We have pack 2 fp4 values in a byte
131130
W_PACK_DIVISOR: tl.constexpr = 2 if w_type == tl.uint8 else 1
@@ -249,7 +248,7 @@ def _p_matmul_ogs(
249248
XMxScalePtrs = XMxScale + start_z.to(index_type) * stride_x_mx_z
250249
if GatherIndx is None:
251250
XMxScalePtrs += start_m * stride_x_mx_m
252-
offs_k_scale = MX_SCALE_BLOCK_K * pid_k + tl.arange(0, MX_SCALE_BLOCK_K)
251+
offs_k_scale = off_k_x0 // MXFP_BLOCK_SIZE + tl.arange(0, MX_SCALE_BLOCK_K)
253252
XMxScalePtrs += (offs_x_m if USE_GATHER_TMA else offs_m).to(index_type)[:, None] * stride_x_mx_m
254253
XMxScalePtrs += offs_k_scale.to(index_type)[None, :] * stride_x_mx_k
255254
else:

python/triton_kernels/triton_kernels/matmul_ogs_details/opt_flags.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -229,14 +229,13 @@ def make_default_opt_flags_nvidia(
229229
is_persistent = False
230230
block_n = block_n_tma if is_persistent else block_n
231231
# block k
232-
if constraints.get("block_k", None) is not None:
233-
block_k = constraints["block_k"]
234-
else:
235-
block_k = opt_flags_nvidia.compute_block_k(m, k, is_persistent, lhs_dtype, rhs_dtype, precision_config, has_y_acc_in)
232+
block_k = opt_flags_nvidia.compute_block_k(m, k, is_persistent, lhs_dtype, rhs_dtype, precision_config, has_y_acc_in)
236233
if block_n == 256 and block_k == 128 and block_m <= 64 and is_persistent and rhs_dtype == FP4 and k >= 4096 and tokens_per_expt > 1:
237234
# Swap block_n and block_k for mxfp4 weights so that block_k is a full cacheline, so long as K is sufficiently large.
238235
# TODO: swizzle the HBM layout of the weights instead
239236
block_n, block_k = block_k, block_n
237+
if constraints.get("block_k", None) is not None:
238+
block_k = constraints["block_k"]
240239
# split_k
241240
if constraints.get("max_allowable_mn", 0) > 0 and constraints.get("split_k") is not None:
242241
split_k = max_allowable_mn(constraints["max_allowable_mn"], m, n, constraints.get("split_k"))

python/triton_kernels/triton_kernels/matmul_ogs_details/opt_flags_details/opt_flags_nvidia.py

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,9 @@
11
import torch
22
import triton
33
from triton_kernels import target_info
4-
from triton_kernels.tensor import get_layout, bitwidth, FP4
5-
from triton_kernels.tensor_details.layout import HopperAmpereMXScaleLayout
64
from triton_kernels.numerics_details.mxfp_details._downcast_to_mxfp import MXFP_BLOCK_SIZE
5+
from triton_kernels.tensor import FP4, bitwidth, get_layout
6+
from triton_kernels.tensor_details.layout import HopperAmpereMXScaleLayout
77

88

99
def compute_grid_size(routing_data, batch_size, m, n, block_m, block_n):
@@ -18,8 +18,11 @@ def compute_grid_size(routing_data, batch_size, m, n, block_m, block_n):
1818
def compute_block_n(n: int, arch, precision_config):
1919
# block_n:
2020
layout = get_layout(precision_config.weight_scale)
21-
if isinstance(layout, HopperAmpereMXScaleLayout) and layout.num_warps == 4:
22-
return 128, 128
21+
if isinstance(layout, HopperAmpereMXScaleLayout):
22+
if layout.num_warps in [4, 8]:
23+
# https://github.com/triton-lang/triton/blob/814b862166c756d9f33238844f4ac047e0243388/python/triton_kernels/triton_kernels/matmul_ogs_details/_matmul_ogs.py#L265
24+
block_n = 2 * layout.num_warps * 2 * 8
25+
return block_n, block_n
2326
elif precision_config.max_num_imprecise_acc is None and n > 128:
2427
return 256, 256
2528
else:

python/triton_kernels/triton_kernels/tensor.py

Lines changed: 15 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -2,12 +2,13 @@
22
from typing import Type
33

44
import torch
5-
from triton.tools.tensor_descriptor import TensorDescriptor
65
from triton.tools.ragged_tma import create_ragged_descriptor
6+
from triton.tools.tensor_descriptor import TensorDescriptor
7+
78
from .target_info import cuda_capability_geq
8-
from .tensor_details.layout import Layout, StridedLayout
9-
from .tensor_details import ragged_tensor as ragged_tensor_details
109
from .tensor_details import bitmatrix as bitmatrix_details
10+
from .tensor_details import ragged_tensor as ragged_tensor_details
11+
from .tensor_details.layout import BlackwellMXValueLayout, Layout, StridedLayout
1112
from .tensor_details.ragged_tensor import RaggedTensorMetadata
1213

1314

@@ -46,26 +47,28 @@ def is_tma_compliant(self):
4647
compliant = [strides[i] * bitwidth % 128 == 0 for i in range(ndim) if i != major_dim]
4748
return all(compliant)
4849

49-
def make_dense_tma(self, block_shape, transpose=False):
50+
def make_dense_tma(self, block_shape):
5051
strides = list(self.data.stride())
5152
shape = list(self.data.shape)
52-
transpose = self.data.stride()[-1] != 1
53+
transpose = strides[-1] != 1
5354
if transpose:
5455
block_shape = block_shape[:-2] + [block_shape[-1], block_shape[-2]]
5556
shape = shape[:-2] + [shape[-1], shape[-2]]
5657
strides = strides[:-2] + [strides[-1], strides[-2]]
57-
if self.data.dtype == torch.uint8 and self.layout.name == "BLACKWELL_VALUE":
58+
if self.data.dtype == torch.uint8 and (self.layout.name is None or "_SCALE" not in self.layout.name):
5859
indx = strides.index(1)
5960
block_shape[indx] = block_shape[indx] // 2
60-
if shape[-1] % 128 != 0:
61-
raise ValueError("inner shape need to be multiple of 128 for "
62-
"mxfp4 (CU_TENSOR_MAP_DATA_TYPE_16U4_ALIGN16B) TMAs.")
61+
if isinstance(self.layout, BlackwellMXValueLayout):
62+
if shape[-1] % 128 != 0:
63+
raise ValueError(
64+
"inner shape need to be multiple of 128 for mxfp4 (CU_TENSOR_MAP_DATA_TYPE_16U4_ALIGN16B) TMAs."
65+
)
6366
block_shape = self.layout.swizzle_block_shape(block_shape)
6467
return TensorDescriptor(self.data, shape, strides, block_shape)
6568

66-
def make_tma(self, block_shape, mode, transpose=False):
69+
def make_tma(self, block_shape, mode):
6770
if mode in ["dense", "gather", "scatter"]:
68-
return self.make_dense_tma(block_shape, transpose)
71+
return self.make_dense_tma(block_shape)
6972
assert mode == "ragged"
7073
ragged_dim = len(self.data.shape) - 2
7174
return create_ragged_descriptor(self.data, block_shape, ragged_dim=ragged_dim)
@@ -195,6 +198,7 @@ class RaggedTensor:
195198
A ragged `tensor` is a collection of 2D tensors that share the same number of columns.
196199
Each tensor in this collection is called a `slice`.
197200
"""
201+
198202
# slice_sizes[i] is the number of rows in slice `i`
199203
slice_sizes: torch.Tensor
200204
# ragged tensors are stored in memory as (potentially padded) 2D tensors of shape

0 commit comments

Comments
 (0)