Skip to content

Commit 53fd113

Browse files
lezcanozwu-2025
authored andcommitted
[BACKEND] Support stmatrix.trans (triton-lang#6910)
With this we are able to lower pretty much anything that can be lowered to an stmatrix. We are just missing two niche cases: - Multi CTA - Lowering fp8 with stmatrix.trans (you need the first two bases of kReg to be `[[0, 1], [1, 0]]`). These can be supported in the future if necessary. Will use this to support `ldmatrix` in the next PR.
1 parent 5a14866 commit 53fd113

File tree

3 files changed

+156
-42
lines changed

3 files changed

+156
-42
lines changed

python/test/unit/language/test_core.py

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -5958,10 +5958,6 @@ def compute_scratch_buffer_shape(src_layout, dst_layout, shape):
59585958
def test_convert2d(M, N, src_layout, interm_layout, dst_layout, dtype, device, tmp_path: pathlib.Path):
59595959
if str(src_layout) == str(dst_layout):
59605960
pytest.skip()
5961-
if (isinstance(src_layout, DotOperandLayout)
5962-
and isinstance(interm_layout, SharedLayout)) or (isinstance(dst_layout, DotOperandLayout)
5963-
and isinstance(interm_layout, SharedLayout)):
5964-
pytest.skip("DotOperandLayout <-> SharedLayout conversion is not completely supported")
59655961
if is_hip():
59665962
try:
59675963
scratch_shape = compute_scratch_buffer_shape(src_layout, dst_layout, (M, N))

test/Conversion/tritongpu_to_llvm_hopper.mlir

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -363,6 +363,45 @@ module attributes {"ttg.target" = "cuda:90", "ttg.num-ctas" = 1 : i32, "ttg.num-
363363
}
364364
}
365365

366+
367+
// -----
368+
369+
#shared = #ttg.swizzled_shared<{vec = 1, perPhase = 1, maxPhase = 1, order = [0, 1]}>
370+
#linear = #ttg.linear<{register = [[0, 1], [8, 0], [0, 8], [0, 16]], lane = [[0, 2], [0, 4], [1, 0], [2, 0], [4, 0]], warp = [[16, 0], [32, 0]], block = []}>
371+
#smem = #ttg.shared_memory
372+
// CHECK-LABEL: linear_to_swizzled_st_matrix_trans_local_store
373+
module attributes {"ttg.target" = "cuda:90", "ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, "ttg.threads-per-warp" = 32 : i32} {
374+
tt.func @linear_to_swizzled_st_matrix_trans_local_store(%a: tensor<64x32xf16, #linear>) {
375+
// CHECK-COUNT-2: nvgpu.stmatrix %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}} {trans}
376+
// CHECK: llvm.return
377+
%b = ttg.local_alloc {allocation.offset = 0 : i32} : () -> !ttg.memdesc<64x32xf16, #shared, #smem, mutable>
378+
ttg.local_store %a, %b : tensor<64x32xf16, #linear> -> !ttg.memdesc<64x32xf16, #shared, #smem, mutable>
379+
tt.return
380+
}
381+
}
382+
383+
// -----
384+
385+
// Stretching a bit the lowering. Feel free to kill this test if we restrain
386+
// the lowering a bit later on.
387+
// These layouts will have plenty of bank conflicts, so it'd make sense not to
388+
// lower them via stmatrix.
389+
// It is of course possible to design a shared memory layout that makes the lowering
390+
// via stmatrix not have any bank conflicts, but yeah.
391+
#shared = #ttg.swizzled_shared<{vec = 1, perPhase = 1, maxPhase = 1, order = [0, 1]}>
392+
#linear = #ttg.linear<{register = [[0, 2], [0, 8], [0, 0], [0, 16], [0, 1]], lane = [[0, 0], [0, 4], [1, 0], [2, 0], [4, 0]], warp = [[0, 0], [8, 0]], block = []}>
393+
#smem = #ttg.shared_memory
394+
// CHECK-LABEL: linear_to_swizzled_st_matrix_trans_local_store
395+
module attributes {"ttg.target" = "cuda:90", "ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, "ttg.threads-per-warp" = 32 : i32} {
396+
tt.func @linear_to_swizzled_st_matrix_trans_local_store(%a: tensor<16x32xf16, #linear>) {
397+
// CHECK-COUNT-2: nvgpu.stmatrix %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}} {trans}
398+
// CHECK: llvm.return
399+
%b = ttg.local_alloc {allocation.offset = 0 : i32} : () -> !ttg.memdesc<16x32xf16, #shared, #smem, mutable>
400+
ttg.local_store %a, %b : tensor<16x32xf16, #linear> -> !ttg.memdesc<16x32xf16, #shared, #smem, mutable>
401+
tt.return
402+
}
403+
}
404+
366405
// -----
367406

368407
#blocked = #ttg.blocked<{sizePerThread = [8], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}>

third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/MemoryOpToLLVM.cpp

Lines changed: 117 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -139,7 +139,7 @@ struct LocalLoadOpConversion
139139

140140
LogicalResult lowerDistributedToSharedStmatrix(
141141
Location loc, RankedTensorType tensorTy, MemDescType memDescType,
142-
Value adaptorSrc, Value smemBase, Type llvmElemTy,
142+
bool transpose, Value adaptorSrc, Value smemBase, Type llvmElemTy,
143143
ConversionPatternRewriter &rewriter, const TargetInfo &targetInfo,
144144
std::pair<size_t, Type> *const llvmOpCount = nullptr) {
145145
if (!targetInfo.supportLdStMatrix())
@@ -160,7 +160,11 @@ LogicalResult lowerDistributedToSharedStmatrix(
160160
auto kOffset = S("offset");
161161
auto smemPtrTy = ptr_ty(ctx, 3);
162162
auto bitwidth = tensorTy.getElementTypeBitWidth();
163-
if (bitwidth > 32)
163+
// In the transpose case, consecutive elements are not stored contiguously
164+
// so we cannot split an fp32
165+
// We could support bitwidth == 8, but it'd be a rather weird layout
166+
// so we don't do that for now
167+
if ((!transpose && bitwidth > 32) || (transpose && bitwidth != 16))
164168
return failure();
165169
// Inter block stmatrix is not supported
166170
if (cvt.hasInDim(kBlock))
@@ -173,31 +177,75 @@ LogicalResult lowerDistributedToSharedStmatrix(
173177
cvt = removeBroadcast.apply(cvt);
174178
srcVals = removeBroadcast.apply(srcVals);
175179

176-
auto tile = LinearLayout::identity1D(32 / bitwidth, kReg, kOffset) *
177-
LinearLayout::identity1D(4, kLane, kOffset);
178-
// Find if there is a register permutation that allows us to divideLeft
179-
auto maybeAction = regPermForDivideLeft(cvt, tile);
180-
if (!maybeAction.has_value()) {
181-
return failure();
180+
LinearLayout reps;
181+
if (!transpose) {
182+
auto tile = LinearLayout::identity1D(32 / bitwidth, kReg, kOffset) *
183+
LinearLayout::identity1D(4, kLane, kOffset);
184+
185+
// Find if there is a register permutation that allows us to divideLeft
186+
// We need to pass the map from regs to offsets, as is cvt
187+
auto maybeAction = regPermForDivideLeft(cvt, tile);
188+
if (!maybeAction.has_value()) {
189+
return failure();
190+
}
191+
auto action = maybeAction.value();
192+
// Check if the action indeed allows us to divideLeft
193+
cvt = action.apply(cvt);
194+
srcVals = action.apply(srcVals);
195+
196+
auto maybeQuot = divideLeft(cvt, tile);
197+
if (!maybeQuot.has_value()) {
198+
return failure();
199+
}
200+
reps = zerosLike(tile) * maybeQuot.value();
201+
} else {
202+
// Division does not quite work here. To define this properly, we would need
203+
// to define a different multiplication that does:
204+
// A *' B = [[0, A], [B, 0]] and define leftDivision for it
205+
// We do it ad-hoc for now, as I beleive there's not much demand for this op
206+
// outside of this lowering
207+
208+
// Divisibility in the sense above is the same as regular divisibility
209+
// You need to see that the tile A is a sublayout of the matrix, and that
210+
// it has zeros above it and to its right.
211+
212+
// In particular, offsets lanes 4, 8, 16 map to offsets 1, 2, 4...
213+
const auto &laneBases = cvt.getBases().find(kLane)->second;
214+
for (int i = 0; i < 3; ++i) {
215+
if (laneBases[i + 2][0] != (1 << i))
216+
return failure();
217+
}
218+
// ... and no other basis should depend on 1, 2, 4
219+
// Note that this gives us the usual alignment condition, but we have
220+
// translated it to checking that the matrix to the left of A is all zeros
221+
for (auto dim : cvt.getInDimNames()) {
222+
const auto &bases = cvt.getBases().find(dim)->second;
223+
for (auto [i, basis] : llvm::enumerate(bases)) {
224+
if (dim == kLane && i >= 2)
225+
continue;
226+
if (basis[0] & 0b111)
227+
return failure();
228+
}
229+
}
230+
231+
// Hack: We are not going to use in the rest of the function reps[kLane][2:]
232+
// so we don't need to zero them out
233+
reps = cvt;
182234
}
183-
auto action = maybeAction.value();
184-
// Check if the action indeed allows us to divideLeft
185-
cvt = action.apply(cvt);
186-
auto maybeQuot = divideLeft(cvt, tile);
187-
if (!maybeQuot.has_value()) {
235+
236+
// We must have at least 2 register elements to use stmatrix.trans
237+
if (transpose && reps.getInDimSizeLog2(kReg) < llvm::Log2_32(32 / bitwidth)) {
188238
return failure();
189239
}
190-
auto quot = maybeQuot.value();
191-
srcVals = action.apply(srcVals);
192-
// Map from kReg, kLane, kWarp to beginning of each tile
193-
auto reps = zerosLike(tile) * quot;
194-
assert(reps.getOutDimSize(kOffset) == cvt.getOutDimSize(kOffset));
195240

196-
// Choose up to 4 packs of 32-bit elements indexed by the next to bases
197-
// as the vectorisation factor
198-
auto vec = std::min(2, quot.getInDimSizeLog2(kReg));
241+
// Choose up to 4 packs of 32-bit elements indexed by the next (at most) two
242+
// bases as the vectorisation factor. We don't consider the basis of the tile
243+
// for vectorisation so we substract them
244+
auto vec = std::min<int32_t>(2, reps.getInDimSizeLog2(kReg) -
245+
llvm::Log2_32(32 / bitwidth));
199246

200-
// FIXME(Lezcano): Should we bail if any of the other 3 lane bases is zero?
247+
// Map from kReg, kLane, kWarp to beginning of each tile
248+
assert(reps.getOutDimSize(kOffset) == cvt.getOutDimSize(kOffset));
201249

202250
auto [laneId, warpId] = getLaneAndWarpId(rewriter, loc);
203251
// Compute the addresses for the 0th tile
@@ -212,12 +260,24 @@ LogicalResult lowerDistributedToSharedStmatrix(
212260
// given
213261
// by the first `vec` reg bases that are not part of the tile
214262
std::vector<std::vector<int32_t>> laneBases;
215-
assert(tile.getInDimSizeLog2(kLane) == 2);
216-
for (int i = 0; i < 3; ++i) {
217-
laneBases.push_back(reps.getBasis(kLane, tile.getInDimSizeLog2(kLane) + i));
218-
}
219-
for (int i = 0; i < vec; ++i) {
220-
laneBases.push_back(reps.getBasis(kReg, tile.getInDimSizeLog2(kReg) + i));
263+
if (!transpose) {
264+
auto tileDimSizeReg = llvm::Log2_32(32 / bitwidth);
265+
auto tileDimSizeLane = 2;
266+
for (int i = 0; i < 3; ++i) {
267+
laneBases.push_back(reps.getBasis(kLane, tileDimSizeLane + i));
268+
}
269+
for (int i = 0; i < vec; ++i) {
270+
laneBases.push_back(reps.getBasis(kReg, tileDimSizeReg + i));
271+
}
272+
} else {
273+
// We choose the first basis of the register. In the future we could choose
274+
// a basis that minimises the bank conflicts
275+
laneBases.push_back(reps.getBasis(kReg, 0));
276+
laneBases.push_back(reps.getBasis(kLane, 0));
277+
laneBases.push_back(reps.getBasis(kLane, 1));
278+
for (int i = 0; i < vec; ++i) {
279+
laneBases.push_back(reps.getBasis(kReg, i + 1));
280+
}
221281
}
222282

223283
LinearLayout addrLayout =
@@ -247,7 +307,8 @@ LogicalResult lowerDistributedToSharedStmatrix(
247307
}
248308
inputs.push_back(b.bitcast(input, i32_ty));
249309
}
250-
rewriter.create<triton::nvgpu::StoreMatrixOp>(loc, vecAddr, inputs);
310+
rewriter.create<triton::nvgpu::StoreMatrixOp>(loc, vecAddr, inputs,
311+
/*needTrans=*/transpose);
251312
}
252313
return success();
253314
}
@@ -271,10 +332,19 @@ struct LocalAllocOpConversion
271332
Value smemBase =
272333
LLVM::getSharedMemoryBase(op.getLoc(), rewriter, targetInfo, op);
273334

274-
if (lowerDistributedToSharedStmatrix(op.getLoc(), srcTy, memDescType,
275-
adaptor.getSrc(), smemBase, llvmElemTy,
276-
rewriter, targetInfo)
277-
.failed()) {
335+
// Try to lower transposed or not
336+
bool lowered = false;
337+
for (bool transpose : {false, true}) {
338+
lowered =
339+
lowerDistributedToSharedStmatrix(
340+
op.getLoc(), srcTy, memDescType, transpose, adaptor.getSrc(),
341+
smemBase, llvmElemTy, rewriter, targetInfo)
342+
.succeeded();
343+
if (lowered) {
344+
break;
345+
}
346+
}
347+
if (!lowered) {
278348
return failure();
279349
}
280350

@@ -306,11 +376,20 @@ struct LocalStoreOpConversion
306376
getTypeConverter()->convertType(op.getDst().getType().getElementType());
307377
SharedMemoryObject smemObj = LLVM::getSharedMemoryObjectFromStruct(
308378
op.getLoc(), adaptor.getDst(), llvmElemTy, rewriter);
309-
if (lowerDistributedToSharedStmatrix(op.getLoc(), op.getSrc().getType(),
310-
op.getDst().getType(),
311-
adaptor.getSrc(), smemObj.getBase(),
312-
llvmElemTy, rewriter, targetInfo)
313-
.failed()) {
379+
380+
// Try to lower transposed or not
381+
bool lowered = false;
382+
for (bool transpose : {false, true}) {
383+
lowered = lowerDistributedToSharedStmatrix(
384+
op.getLoc(), op.getSrc().getType(), op.getDst().getType(),
385+
transpose, adaptor.getSrc(), smemObj.getBase(), llvmElemTy,
386+
rewriter, targetInfo)
387+
.succeeded();
388+
if (lowered) {
389+
break;
390+
}
391+
}
392+
if (!lowered) {
314393
return failure();
315394
}
316395
rewriter.eraseOp(op);

0 commit comments

Comments
 (0)