Skip to content

Commit 618277b

Browse files
ThomasRaouxzwu-2025
authored andcommitted
[BACKEND] Add support for tmem load/store 16x256b (triton-lang#6897)
The logic to pick this new layout is currently not implemented but this will allow us to choose between different layouts in the future.
1 parent ca24b10 commit 618277b

File tree

8 files changed

+212
-21
lines changed

8 files changed

+212
-21
lines changed

include/triton/Dialect/TritonGPU/IR/LinearLayoutConversions.h

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -271,6 +271,10 @@ LinearLayout chooseDsReadB64TrLayout(Attribute enc, ArrayRef<int64_t> shape,
271271
LinearLayout getScaleTMEMStoreLinearLayout(RankedTensorType scaleType,
272272
int numWarps);
273273

274+
std::optional<LinearLayout>
275+
getTmemLoadStoreLayout16x256(int M, int N, RankedTensorType oldType,
276+
int numWarps);
277+
274278
// Return a layout valid for TMemLoad op for a tmem layout of block MxN that
275279
// distribute the data long M for the warp groups. This doesn't affect the TMem
276280
// layout it just returns a distributed layout compatible for tmem_load.

include/triton/Tools/Sys/GetEnv.hpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,8 @@ inline const std::set<std::string> CACHE_INVALIDATING_ENV_VARS = {
4343
"NVPTX_ENABLE_DUMP",
4444
"STORE_TMEM_TO_GLOBAL_BYPASS_SMEM",
4545
"ALLOW_LHS_TMEM_LAYOUT_CONVERSION",
46-
"TRITON_F32_DEFAULT"
46+
"TRITON_F32_DEFAULT",
47+
"TRITON_PREFER_TMEM_16x256_LAYOUT",
4748
// clang-format on
4849
};
4950

lib/Dialect/TritonGPU/IR/LinearLayoutConversions.cpp

Lines changed: 86 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1682,6 +1682,92 @@ LinearLayout getScaleTMEMStoreLinearLayout(RankedTensorType scaleType,
16821682
return combineCtaCgaWithShape(regLanes, CTALayout, scaleType.getShape());
16831683
}
16841684

1685+
std::optional<LinearLayout>
1686+
getTmemLoadStoreLayout16x256(int M, int N, RankedTensorType oldType,
1687+
int numWarps) {
1688+
// Too small to distribute on two warp groups while using 16x256 message.
1689+
if (numWarps == 8 && M == 64 && N <= 16 &&
1690+
oldType.getElementTypeBitWidth() < 32) {
1691+
return {};
1692+
}
1693+
assert(numWarps == 4 || numWarps == 8);
1694+
auto ctaLayout = getCTALayout(oldType.getEncoding());
1695+
SmallVector<int64_t> shape = getShapePerCTA(oldType);
1696+
MLIRContext *ctx = ctaLayout.getContext();
1697+
1698+
using basisT = std::vector<std::vector<int32_t>>;
1699+
StringAttr kRegister = StringAttr::get(ctx, "register");
1700+
StringAttr kLane = StringAttr::get(ctx, "lane");
1701+
StringAttr kWarp = StringAttr::get(ctx, "warp");
1702+
SmallVector<StringAttr> outDimNames = standardOutDimNames(ctx, 2);
1703+
1704+
unsigned numElementsPerThread = 256 / oldType.getElementTypeBitWidth();
1705+
int kWidth = 64 / oldType.getElementTypeBitWidth();
1706+
// Follow the layout given by a tmem load using this layout for the inner
1707+
// shape:
1708+
// https://docs.nvidia.com/cuda/parallel-thread-execution/#tcgen05-matrix-fragments-shape-16256b
1709+
LinearLayout innerTile =
1710+
nvidiaMmaTile(ctx, {8, numElementsPerThread}, kWidth, {1, 0}, {0, 1});
1711+
innerTile =
1712+
innerTile * LinearLayout::identity1D(2, kRegister, outDimNames[0]);
1713+
// Then distribute the rest along warpgroups and registers.
1714+
// Then the last warp distribute along M or N following the same order as
1715+
// in getTmemLoadStoreLayout32x32b. This allows us to use the same lowering to
1716+
// tmem for load and store. This part could be generalized by making the
1717+
// lowering of tmem load and store rely more on linear layout.
1718+
bool distributeMAlongWarps = false;
1719+
bool distributeNAlongWarps = false;
1720+
// Figure out how to distribute acorss warpgroups.
1721+
if (numWarps == 8) {
1722+
if (shape[0] > 128) {
1723+
distributeMAlongWarps = true;
1724+
} else {
1725+
distributeNAlongWarps = true;
1726+
}
1727+
}
1728+
int nBase = numElementsPerThread;
1729+
int maxRegN =
1730+
std::min(N, distributeNAlongWarps ? (int)shape[1] / 2 : (int)shape[1]);
1731+
if (maxRegN / nBase > 1) {
1732+
innerTile = innerTile * LinearLayout::identity1D(maxRegN / nBase, kRegister,
1733+
outDimNames[1]);
1734+
}
1735+
if (M != 64) {
1736+
innerTile =
1737+
innerTile * LinearLayout::identity1D(2, kRegister, outDimNames[0]);
1738+
}
1739+
// Distribute M along 4 warps to satisfy TMEM requirements.
1740+
innerTile = innerTile * LinearLayout::identity1D(4, kWarp, outDimNames[0]);
1741+
1742+
// Fill out the rest of the shape with M first then N.
1743+
int numMRegDim = std::min(128, (int)shape[0]) / M;
1744+
if (numMRegDim > 1) {
1745+
innerTile = innerTile *
1746+
LinearLayout::identity1D(numMRegDim, kRegister, outDimNames[0]);
1747+
}
1748+
// Dim M=128 should be distributed on the second warp group.
1749+
int nextDim = 128;
1750+
if (distributeMAlongWarps) {
1751+
innerTile = innerTile * LinearLayout::identity1D(2, kWarp, outDimNames[0]);
1752+
nextDim <<= 1;
1753+
}
1754+
numMRegDim = shape[0] / nextDim;
1755+
if (numMRegDim > 1) {
1756+
innerTile = innerTile *
1757+
LinearLayout::identity1D(numMRegDim, kRegister, outDimNames[0]);
1758+
}
1759+
int maxN = distributeNAlongWarps ? shape[1] / 2 : shape[1];
1760+
int numNRegDim = maxN / maxRegN;
1761+
if (numNRegDim > 1) {
1762+
innerTile = innerTile *
1763+
LinearLayout::identity1D(numNRegDim, kRegister, outDimNames[1]);
1764+
}
1765+
if (distributeNAlongWarps) {
1766+
innerTile = innerTile * LinearLayout::identity1D(2, kWarp, outDimNames[1]);
1767+
}
1768+
return combineCtaCgaWithShape(innerTile, ctaLayout, oldType.getShape());
1769+
}
1770+
16851771
LinearLayout getTmemLoadLayoutSplitLongM(int M, int N, RankedTensorType oldType,
16861772
int numWarps) {
16871773
assert(numWarps == 8);

lib/Dialect/TritonNvidiaGPU/IR/Dialect.cpp

Lines changed: 32 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -23,13 +23,15 @@
2323

2424
#include "triton/Dialect/Triton/IR/Dialect.h"
2525
#include "triton/Dialect/Triton/IR/Utility.h"
26+
#include "triton/Tools/Sys/GetEnv.hpp"
2627

2728
#include <numeric>
2829

2930
#include "mlir/IR/DialectImplementation.h"
3031
#include "mlir/IR/OpImplementation.h"
3132
#include "triton/Analysis/Utility.h"
3233
#include "triton/Dialect/TritonGPU/IR/Dialect.h"
34+
#include "triton/Dialect/TritonGPU/IR/LinearLayoutConversions.h"
3335
#include "triton/Dialect/TritonNvidiaGPU/IR/Dialect.h"
3436
#include "llvm/ADT/TypeSwitch.h"
3537
#include "llvm/Support/Debug.h"
@@ -96,8 +98,9 @@ TMemAllocation getTmemAllocSizes(MemDescType memDescType) {
9698
return TMemAllocation(numColumn, numRows);
9799
}
98100

99-
Attribute getTmemCompatibleLayout(unsigned M, unsigned N,
100-
RankedTensorType oldType, unsigned numWarps) {
101+
Attribute getTmemLoadStoreLayout32x32b(unsigned M, unsigned N,
102+
RankedTensorType oldType,
103+
unsigned numWarps) {
101104
assert(numWarps == 4 || numWarps == 8);
102105
auto shape = getShapePerCTA(oldType);
103106
assert(shape.size() == 2);
@@ -146,6 +149,20 @@ Attribute getTmemCompatibleLayout(unsigned M, unsigned N,
146149
warpsPerCTA, order, ctaLayout);
147150
}
148151

152+
Attribute getTmemCompatibleLayout(unsigned M, unsigned N,
153+
RankedTensorType oldType, unsigned numWarps) {
154+
bool prefer16x256 =
155+
triton::tools::getBoolEnv("TRITON_PREFER_TMEM_16x256_LAYOUT");
156+
if (prefer16x256) {
157+
std::optional<LinearLayout> ll =
158+
getTmemLoadStoreLayout16x256(M, N, oldType, numWarps);
159+
if (ll) {
160+
return LinearEncodingAttr::get(oldType.getContext(), *ll);
161+
}
162+
}
163+
return getTmemLoadStoreLayout32x32b(M, N, oldType, numWarps);
164+
}
165+
149166
bool isDistributedLayoutSplitMTmemLoadStore(RankedTensorType tensorType,
150167
MemDescType memType, int numWarps) {
151168
auto tmemEnc = dyn_cast<triton::nvidia_gpu::TensorMemoryEncodingAttr>(
@@ -159,6 +176,8 @@ bool isDistributedLayoutSplitMTmemLoadStore(RankedTensorType tensorType,
159176
return false;
160177
auto CTALayout = getCTALayout(tensorType.getEncoding());
161178
auto shapePerCTA = mlir::triton::gpu::getShapePerCTA(tensorType);
179+
if (numWarps != 8)
180+
return false;
162181
LinearLayout llLayout =
163182
getTmemLoadLayoutSplitLongM(M, N, tensorType, numWarps);
164183
return llEncoding.getLinearLayout() == llLayout;
@@ -170,7 +189,6 @@ bool isDistributedLayoutTMemCompatible(Operation *op,
170189
MemDescType memType) {
171190
int numWarps = lookupNumWarps(op);
172191
assert(numWarps % 4 == 0);
173-
int numWarpGroups = numWarps / 4;
174192
if (isa<triton::nvidia_gpu::TensorMemoryScalesEncodingAttr>(
175193
memType.getEncoding())) {
176194
return tensorType.getEncoding() ==
@@ -184,8 +202,17 @@ bool isDistributedLayoutTMemCompatible(Operation *op,
184202
int blockN = attr.getBlockN();
185203
if (isDistributedLayoutSplitMTmemLoadStore(tensorType, memType, numWarps))
186204
return true;
187-
Attribute layout =
188-
nvidia_gpu::getTmemCompatibleLayout(blockM, blockN, tensorType, numWarps);
205+
206+
auto ll16x256 =
207+
getTmemLoadStoreLayout16x256(blockM, blockN, tensorType, numWarps);
208+
if (ll16x256.has_value() &&
209+
areLayoutsEquivalent(
210+
tensorType.getShape(),
211+
LinearEncodingAttr::get(tensorType.getContext(), ll16x256.value()),
212+
tensorType.getEncoding()))
213+
return true;
214+
Attribute layout = nvidia_gpu::getTmemLoadStoreLayout32x32b(
215+
blockM, blockN, tensorType, numWarps);
189216
// TODO: Add support for more layout compatible with tmem load/store. There
190217
// will only be a discret set of layout possible due to the limiations of
191218
// tmem_load/store.

python/test/unit/language/test_matmul.py

Lines changed: 17 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,7 @@ def matmul_kernel( #
3434
stride_cm, stride_cn, #
3535
BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr, #
3636
NUM_STAGES: tl.constexpr, SCALE_A: tl.constexpr = None, PRECISION: tl.constexpr = "ieee",
37-
A_TRANS: tl.constexpr = False, EPILOGUE_SUBTILE: tl.constexpr = False):
37+
A_TRANS: tl.constexpr = False, EPILOGUE_SUBTILE: tl.constexpr = False, dummy: tl.constexpr = 0):
3838
pid = tl.program_id(axis=0)
3939
num_pid_m = tl.cdiv(M, BLOCK_M)
4040
pid_m = pid % num_pid_m
@@ -93,8 +93,9 @@ def get_src_element_ty_size(dtype_str):
9393
@pytest.mark.parametrize("NUM_CTAS", [1, 2])
9494
@pytest.mark.parametrize("NUM_WARPS", [4, 8])
9595
@pytest.mark.parametrize("EPILOGUE_SUBTILE", [True, False])
96+
@pytest.mark.parametrize("LAYOUT_16x256", [True, False])
9697
def test_simple_matmul(dtype_src_str, dtype_dst_str, BLOCK_M, BLOCK_N, BLOCK_K, NUM_STAGES, NUM_WARPS, NUM_CTAS, device,
97-
EPILOGUE_SUBTILE):
98+
EPILOGUE_SUBTILE, LAYOUT_16x256, monkeypatch):
9899
if NUM_CTAS > 1 and (not is_cuda() or torch.cuda.get_device_capability()[0] < 9):
99100
pytest.skip("Clusters requires nvidia compute capability >= 9")
100101
if is_hip() and ((BLOCK_K * BLOCK_M + BLOCK_K * BLOCK_N) * NUM_STAGES * get_src_element_ty_size(dtype_src_str)
@@ -114,6 +115,8 @@ def test_simple_matmul(dtype_src_str, dtype_dst_str, BLOCK_M, BLOCK_N, BLOCK_K,
114115
pytest.skip("multi-CTAs is broken for mmav2")
115116
if EPILOGUE_SUBTILE and (is_hip() or NUM_CTAS > 1 or BLOCK_N >= 512):
116117
pytest.skip("creates convert layout too big to fit in smem")
118+
if LAYOUT_16x256 and (not is_cuda() or torch.cuda.get_device_capability()[0] < 10):
119+
pytest.skip("skip forcing tmem layout on non blackwell targets.")
117120
M, N, K = 1024, 512, 256
118121
torch.manual_seed(42)
119122
precision = "tf32" if dtype_src_str == "tensorfloat32" else "ieee"
@@ -129,12 +132,16 @@ def test_simple_matmul(dtype_src_str, dtype_dst_str, BLOCK_M, BLOCK_N, BLOCK_K,
129132
b = torch.randn(K, N, dtype=dtype_src, device=device)
130133
A = a
131134
B = b
135+
# pass a dummy constexpr argument to force recompilation.
136+
if LAYOUT_16x256:
137+
monkeypatch.setenv("TRITON_PREFER_TMEM_16x256_LAYOUT", "1")
132138
dtype_dst = getattr(torch, dtype_dst_str)
133139
output = torch.empty((M, N), dtype=dtype_dst, device=device)
134140
grid = (triton.cdiv(M, BLOCK_M) * triton.cdiv(N, BLOCK_N), 1)
135141
k = matmul_kernel[grid](a, b, output, M, N, K, a.stride(0), a.stride(1), b.stride(0), b.stride(1), output.stride(0),
136142
output.stride(1), BLOCK_M, BLOCK_N, BLOCK_K, NUM_STAGES=NUM_STAGES, PRECISION=precision,
137-
num_warps=NUM_WARPS, num_ctas=NUM_CTAS, EPILOGUE_SUBTILE=EPILOGUE_SUBTILE)
143+
num_warps=NUM_WARPS, num_ctas=NUM_CTAS, EPILOGUE_SUBTILE=EPILOGUE_SUBTILE,
144+
dummy=LAYOUT_16x256)
138145
ref_out = torch.matmul(A, B).to(torch.float32)
139146
output = output.to(torch.float32)
140147
if dtype_src_str == "float32":
@@ -157,6 +164,13 @@ def test_simple_matmul(dtype_src_str, dtype_dst_str, BLOCK_M, BLOCK_N, BLOCK_K,
157164
ttgir = k.asm["ttgir"]
158165
count = ttgir.count("ttng.tc_gen5_mma")
159166
assert count == 2, "The TTGIR does not match the expected pattern."
167+
ptx = k.asm["ptx"]
168+
if LAYOUT_16x256:
169+
assert "16x256b" in ptx, "PTX does not contain 16x256b"
170+
else:
171+
if "32x32b" not in ptx and "16x32b" not in ptx:
172+
print(ptx)
173+
assert ("32x32b" in ptx) or ("16x32b" in ptx), "PTX does not contain 32x32b or 16x32b"
160174

161175

162176
# persistent matmul with fused loops

test/Conversion/tritongpu_to_llvm_blackwell.mlir

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -120,6 +120,24 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.shar
120120

121121
// -----
122122

123+
#linear = #ttg.linear<{register = [[0, 1], [8, 0], [0, 8], [0, 16], [0, 32], [0, 64], [16, 0]], lane = [[0, 2], [0, 4], [1, 0], [2, 0], [4, 0]], warp = [[32, 0], [64, 0]], block = []}>
124+
#tmem = #ttng.tensor_memory_encoding<blockM = 128, blockN = 128, unpacked = true>
125+
module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.shared = 65544 : i32, ttg.target = "cuda:100", ttg.tensor_memory_size = 128 : i32, "ttg.threads-per-warp" = 32 : i32} {
126+
// CHECK-LABEL: @tensor_memory_ld_16x256
127+
// CHECK: tcgen05.st.sync.aligned.16x256b.x16.b32
128+
// CHECK: tcgen05.st.sync.aligned.16x256b.x16.b32
129+
// CHECK: tcgen05.ld.sync.aligned.16x256b.x16.b32
130+
// CHECK: tcgen05.ld.sync.aligned.16x256b.x16.b32
131+
tt.func public @tensor_memory_ld_16x256(%arg0: !tt.ptr<f16>, %arg1: !tt.ptr<f16>, %arg2: !tt.ptr<f16>) {
132+
%cst_0 = arith.constant dense<0.000000e+00> : tensor<128x128xf32, #linear>
133+
%0 = ttng.tmem_alloc %cst_0 {tensor_memory_col_offset = 0 : i32, tensor_memory_row_offset = 0 : i32} : (tensor<128x128xf32, #linear>) -> !ttg.memdesc<128x128xf32, #tmem, #ttng.tensor_memory, mutable>
134+
%20 = ttng.tmem_load %0 : !ttg.memdesc<128x128xf32, #tmem, #ttng.tensor_memory, mutable> -> tensor<128x128xf32, #linear>
135+
tt.return
136+
}
137+
}
138+
139+
// -----
140+
123141
#tmem = #ttng.tensor_memory_encoding<blockM = 64, blockN = 128, unpacked = true>
124142
module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.shared = 65544 : i32, ttg.target = "cuda:100", ttg.tensor_memory_size = 128 : i32, "ttg.threads-per-warp" = 32 : i32} {
125143
// CHECK-LABEL: @tensor_memory_allocation

test/TritonGPU/accelerate-matmul.mlir

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
// RUN: triton-opt %s -split-input-file --tritongpu-accelerate-matmul -verify-diagnostics=only-expected | FileCheck %s
2+
// RUN: TRITON_PREFER_TMEM_16x256_LAYOUT=1 triton-opt %s -split-input-file --tritongpu-accelerate-matmul | FileCheck %s --check-prefix=LAYOUT_16x256
23

34
// CHECK: #[[MMA:.+]] = #ttg.nvidia_mma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [4, 1], instrShape = [16, 16, 16]}>
45
// CHECK: #[[MMA1:.+]] = #ttg.nvidia_mma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [4, 1], instrShape = [16, 64, 16]}>
@@ -195,6 +196,7 @@ module attributes {"ttg.target" = "cuda:90", "ttg.num-ctas" = 1 : i32, "ttg.num-
195196
#blocked1 = #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [1, 32], warpsPerCTA = [1, 4], order = [1, 0]}>
196197
#blocked2 = #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [1, 32], warpsPerCTA = [1, 4], order = [1, 0]}>
197198
module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.target = "cuda:100", "ttg.threads-per-warp" = 32 : i32} {
199+
// LAYOUT_16x256{LITERAL}: #ttg.linear<{register = [[0, 1], [8, 0], [0, 8], [0, 16], [0, 32], [0, 64], [0, 128], [16, 0]], lane = [[0, 2], [0, 4], [1, 0], [2, 0], [4, 0]], warp = [[32, 0], [64, 0]], block = []}>
198200
// CHECK-DAG: #[[$TMEM:.+]] = #ttng.tensor_memory_encoding<blockM = 128, blockN = 256, unpacked = true>
199201
// CHECK-DAG: #[[$B:.+]] = #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [1, 32], warpsPerCTA = [1, 4], order = [1, 0]}>
200202
// CHECK-DAG: #[[$T:.+]] = #ttg.blocked<{sizePerThread = [1, 256], threadsPerWarp = [32, 1], warpsPerCTA = [4, 1], order = [0, 1]}>
@@ -257,6 +259,7 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.targ
257259
#blocked1 = #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [1, 32], warpsPerCTA = [1, 4], order = [1, 0], CTAsPerCGA = [2, 1], CTASplitNum = [2, 1], CTAOrder = [1, 0]}>
258260
#blocked2 = #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [1, 32], warpsPerCTA = [1, 4], order = [1, 0], CTAsPerCGA = [2, 1], CTASplitNum = [2, 1], CTAOrder = [1, 0]}>
259261
module attributes {"ttg.num-ctas" = 2 : i32, "ttg.num-warps" = 4 : i32, ttg.target = "cuda:100", "ttg.threads-per-warp" = 32 : i32} {
262+
// LAYOUT_16x256{LITERAL}: #ttg.linear<{register = [[0, 1], [8, 0], [0, 8], [0, 16], [0, 32], [0, 64], [0, 128]], lane = [[0, 2], [0, 4], [1, 0], [2, 0], [4, 0]], warp = [[16, 0], [32, 0]], block = [[64, 0]]}>
260263
// CHECK-DAG: #[[$TMEM:.+]] = #ttng.tensor_memory_encoding<blockM = 64, blockN = 256, unpacked = true, CTASplitM = 2>
261264
// CHECK-DAG: #[[$B:.+]] = #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [1, 32], warpsPerCTA = [1, 4], order = [1, 0], CTAsPerCGA = [2, 1], CTASplitNum = [2, 1], CTAOrder = [1, 0]}>
262265
// CHECK-DAG: #[[$T:.+]] = #ttg.blocked<{sizePerThread = [1, 128], threadsPerWarp = [16, 2], warpsPerCTA = [4, 1], order = [0, 1], CTAsPerCGA = [2, 1], CTASplitNum = [2, 1], CTAOrder = [1, 0]}>
@@ -496,8 +499,9 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.targ
496499
#blocked1 = #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [16, 2], warpsPerCTA = [8, 1], order = [1, 0]}>
497500
#blocked2 = #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [2, 16], warpsPerCTA = [8, 1], order = [1, 0]}>
498501
module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 8 : i32, ttg.target = "cuda:100", "ttg.threads-per-warp" = 32 : i32} {
502+
// LAYOUT_16x256{LITERAL}: #ttg.linear<{register = [[0, 1], [8, 0], [0, 8], [0, 16], [0, 32], [16, 0]], lane = [[0, 2], [0, 4], [1, 0], [2, 0], [4, 0]], warp = [[32, 0], [64, 0], [0, 64]], block = []}>
499503
// CHECK-DAG: #[[$TMEM1:.+]] = #ttng.tensor_memory_scales_encoding
500-
// CHECK{LITERALE}-DAG: #linear = #ttg.linear<{register = [[0, 1], [0, 2], [0, 4], [32, 0], [64, 0]], lane = [[1, 0], [2, 0], [4, 0], [8, 0], [16, 0]], warp = [[0, 0], [0, 0]], block = []}>
504+
// CHECK{LITERAL}-DAG: #linear = #ttg.linear<{register = [[0, 1], [0, 2], [0, 4], [32, 0], [64, 0]], lane = [[1, 0], [2, 0], [4, 0], [8, 0], [16, 0]], warp = [[0, 0], [0, 0]], block = []}>
501505
// CHECK-LABEL: mmav5_block_scaled_8_warps
502506
// CHECK: ttng.tmem_alloc %{{.*}} : (tensor<128x8xi8, #linear>) -> !ttg.memdesc<128x8xi8, #[[$TMEM1]], #ttng.tensor_memory>
503507
// CHECK: ttng.tmem_alloc %{{.*}} : (tensor<128x8xi8, #linear>) -> !ttg.memdesc<128x8xi8, #[[$TMEM1]], #ttng.tensor_memory>
@@ -511,6 +515,7 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 8 : i32, ttg.targ
511515

512516
// -----
513517

518+
// LAYOUT_16x256{LITERAL}: #ttg.linear<{register = [[0, 1], [8, 0], [0, 8], [0, 16], [0, 32], [0, 64], [0, 128], [16, 0]], lane = [[0, 2], [0, 4], [1, 0], [2, 0], [4, 0]], warp = [[32, 0], [64, 0]], block = []}>
514519
// CHECK-DAG: #[[$SHARED_A:.+]] = #ttg.nvmma_shared<{swizzlingByteWidth = 128, transposed = false, elementBitWidth = 8}>
515520
// CHECK-DAG: #[[$SHARED_B:.+]] = #ttg.nvmma_shared<{swizzlingByteWidth = 128, transposed = false, elementBitWidth = 8, fp4Padded = true}>
516521
#blocked = #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [1, 32], warpsPerCTA = [1, 4], order = [1, 0]}>

0 commit comments

Comments
 (0)