Skip to content

Commit 2bc85dc

Browse files
plognjenoplavsic
andauthored
[AMD] Add gfx950 ds_read_b64_tr_b16 support (#5750)
This PR introduces initial support for using LDS transposed load intrinsics in Triton. It implements selection of ds_read_b64_tr_b16 instruction for gfx950 architecture. These intrinsics are designed to handle tensors that are non-k contiguous in LDS, enabling efficient loading in a transposed manner. LDS transposed loading only works if: 1. The hardware supports transposed LDS load instructions. 2. The tensor we’re loading is contiguous along the non-K dimension in LDS. How the compiler handles LDS transpose loads (like the intrinsics and LL we use) depends on two things: 1. The type of data we’re working with. 2. The type of MFMA instruction. Currently, we only support transpose loads for 16-bit data on gfx950. TODO: Add support for other data types (e.g., 8-bit types). TODO: Support transpose loading on swizzled data to avoid bank conflicts. --------- Co-authored-by: Ognjen Plavsic <[email protected]>
1 parent df66eb5 commit 2bc85dc

File tree

9 files changed

+601
-2
lines changed

9 files changed

+601
-2
lines changed

include/triton/Dialect/TritonGPU/IR/LinearLayoutConversions.h

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -256,6 +256,11 @@ LinearLayout chooseStMatrixLayout(MLIRContext *ctx, RankedTensorType tensorTy,
256256
// tensor into shared memory using the `ldmatrix` instruction.
257257
LinearLayout chooseLdMatrixLayout(Attribute enc, ArrayRef<int64_t> shape,
258258
bool needTrans, int32_t elemBitWidth);
259+
260+
// The primary goal of this function is to efficiently load 2D tiles of a
261+
// tensor from shared memory using the `ds_read_tr` instruction for AMD GPUs.
262+
LinearLayout chooseDsReadB64Tr16Layout(Attribute enc, ArrayRef<int64_t> shape,
263+
int32_t elemBitWidth);
259264
} // namespace mlir::triton::gpu
260265

261266
#endif // TRITON_DIALECT_TRITONGPU_IR_LINEARLAYOUTCONVERSIONS_H

lib/Dialect/TritonGPU/IR/LinearLayoutConversions.cpp

Lines changed: 135 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -390,6 +390,135 @@ AMDMfmaEncodingAttr::toLinearLayout(ArrayRef<int64_t> shape) const {
390390
return combineCtaCgaWithShape(ctaLayout, getCTALayout(), shape);
391391
}
392392

393+
LinearLayout chooseDotDsReadB64Tr16Layout(DotOperandEncodingAttr dotMfmaLayout,
394+
ArrayRef<int64_t> shape,
395+
int32_t elemBitWidth) {
396+
auto mfmaLayout = llvm::cast<AMDMfmaEncodingAttr>(dotMfmaLayout.getParent());
397+
assert(mfmaLayout.getMDim() == 16 || mfmaLayout.getNDim() == 32);
398+
assert(elemBitWidth == 16);
399+
400+
auto rank = shape.size();
401+
bool hasBatchDim = rank == 3;
402+
int32_t kWidthDot = dotMfmaLayout.getKWidth();
403+
// Number of bits loaded by an LDS read. ds_read_tr primarily supports 64-bit
404+
// loads for most element sizes (16b, 8b, 4b).
405+
const int32_t ldsReadWidth = 64;
406+
int32_t kWidthTransRead = ldsReadWidth / elemBitWidth;
407+
auto kDim = dotMfmaLayout.getOpIdx() == 0 ? rank - 1 : rank - 2;
408+
409+
int32_t kSize = shape[kDim];
410+
auto warpsPerCTA = mfmaLayout.getWarpsPerCTA();
411+
412+
MLIRContext *ctx = dotMfmaLayout.getContext();
413+
SmallVector<StringAttr> outDimNames = standardOutDimNames(ctx, rank);
414+
415+
StringAttr kRegister = S("register");
416+
StringAttr kLane = S("lane");
417+
StringAttr kWarp = S("warp");
418+
419+
// register order
420+
// operand A: [1, 0] / [2, 1, 0]
421+
// operand B: [0, 1] / [1, 2, 0]
422+
// Regular dot mfma order for both cases is [k, nonk]/[k, nonk, batch]
423+
// For LDS transpose layout swap order to [nonk, k]/[nonk, k, batch]
424+
SmallVector<unsigned> order = triton::gpu::getOrder(dotMfmaLayout);
425+
std::swap(order[0], order[1]);
426+
427+
// In the LDS transpose logic, each thread accesses 64 bits (8 bytes) of data.
428+
// The smallest unit for transposing is a 4x4 sub-tile of threads, where each
429+
// thread reads 4 16-bit elements along the non-K dimension, resulting in a
430+
// [non-K, K] = {16, 4} sub-tile of elements. Because of transposing
431+
// mechanism, thread ends up with 4 16-bit elements along K dim.
432+
//
433+
// The MFMA selection logic prioritizes double-rate MFMA instructions whenever
434+
// possible. Specifically:
435+
// - For MFMA operations that are non-K = 16, when blockK > 16, mfma16x16x32
436+
// is selected; otherwise (blockK ≤ 16), mfma16x16x16 remains the choice.
437+
// - For MFMA operations that are non-K = 32, when blockK > 8, mfma32x32x16 is
438+
// selected; otherwise (blockK ≤ 8), mfma32x32x8 is used.
439+
//
440+
// In double-rate MFMA instructions, each thread holds 8 elements along the K
441+
// dimension.
442+
// - The first 4 elements belong to the first sub-tile.
443+
// - The next 4 elements belong to the second sub-tile.
444+
//
445+
// We then group these into larger tiles, each consisting of 8 of these 16x4
446+
// sub-tiles. These tiles correspond to data for one mfma instruction. The
447+
// shapes of these tiles depend on the MFMA instruction used:
448+
// 1. For mfma32x32x16, the tile shape is [non-K, K] = {32, 16}.
449+
// 2. For mfma16x16x32, the tile shape is [non-K, K] = {16, 32}.
450+
//
451+
// For single-rate mfma instructions, each thread holds 4 elements along K
452+
// dimension. This means larger tile (that corresponds to one mfma
453+
// instruction) consists of 4 16x4 sub-tiles.
454+
std::vector<std::vector<int32_t>> registerBase = {{1, 0},
455+
{2, 0}}; // first sub-tile
456+
std::vector<std::vector<int32_t>> laneBase = {{kWidthTransRead, 0},
457+
{2 * kWidthTransRead, 0},
458+
{0, 1},
459+
{0, 2}}; // first sub-tile
460+
461+
// Extend register base for multiple tiles in K dimension (corresponding to
462+
// multiple mfma instructions accross k dim).
463+
auto populateRegisterBase = [&](int kTileSize) {
464+
const int regsPerTile = 8;
465+
int numRegs = (kSize / kTileSize) * regsPerTile;
466+
for (int reg = regsPerTile; reg < numRegs; reg *= 2) {
467+
registerBase.push_back({0, (reg / regsPerTile) * kTileSize});
468+
}
469+
};
470+
471+
const bool isMfma32 = (mfmaLayout.getMDim() == 32);
472+
const bool isMfma16 = (mfmaLayout.getMDim() == 16);
473+
const int kTileSize = isMfma32 ? 16 : 32;
474+
475+
if (kSize >= kTileSize) {
476+
// Handles mfma32x32x16 and mfma16x16x32 cases
477+
assert(kWidthDot == 8);
478+
registerBase.push_back({0, 4}); // second sub-tile
479+
populateRegisterBase(kTileSize);
480+
auto laneBaseExt = isMfma32
481+
? std::vector<std::vector<int32_t>>{{16, 0}, {0, 8}}
482+
: std::vector<std::vector<int32_t>>{{0, 8}, {0, 16}};
483+
laneBase.insert(laneBase.end(), laneBaseExt.begin(), laneBaseExt.end());
484+
} else {
485+
// Handles mfma32x32x8 and mfma16x16x16 cases
486+
assert(kWidthDot == 4);
487+
auto laneBaseExt = isMfma32
488+
? std::vector<std::vector<int32_t>>{{16, 0}, {0, 4}}
489+
: std::vector<std::vector<int32_t>>{{0, 4}, {0, 8}};
490+
laneBase.insert(laneBase.end(), laneBaseExt.begin(), laneBaseExt.end());
491+
}
492+
493+
// Base vectors above are defined in a fixed order [non-k-dim, k-dim].
494+
// To assign them to actual matrix dimensions `order` array is used.
495+
// For operand A: non-k-dim -> dim0, k-dim -> dim1
496+
// For operand B: non-k-dim -> dim1, k-dim -> dim0
497+
LinearLayout tileLayout({{kRegister, registerBase}, {kLane, laneBase}},
498+
{outDimNames[order[0]], outDimNames[order[1]]});
499+
500+
if (hasBatchDim) {
501+
assert(order[2] == 0);
502+
// Extend the base vector with one value to accommodate for the batch
503+
// dimension, which appears at the last.
504+
tileLayout *= LinearLayout::identity1D(1, kRegister, outDimNames[order[2]]);
505+
tileLayout *= LinearLayout::identity1D(1, kLane, outDimNames[order[2]]);
506+
}
507+
508+
// warp order
509+
// common for both operand A and B: [0, 1] / [0, 1, 2]
510+
// in both cases it is [M dim, N dim]/[batch, M dim, N dim]
511+
SmallVector<unsigned> warpOrder = triton::gpu::getWarpOrder(dotMfmaLayout);
512+
LinearLayout warpLayout = identityStandardND(kWarp, warpsPerCTA, warpOrder);
513+
514+
LinearLayout ctaLayout = tileLayout.transposeOuts(outDimNames) *
515+
warpLayout.transposeOuts(outDimNames);
516+
auto finalLayout =
517+
combineCtaCgaWithShape(ctaLayout, mfmaLayout.getCTALayout(), shape);
518+
519+
return finalLayout;
520+
}
521+
393522
LinearLayout mfmaDotToLinearLayout(DotOperandEncodingAttr dotMfmaLayout,
394523
ArrayRef<int64_t> shape) {
395524

@@ -1200,4 +1329,10 @@ LinearLayout chooseLdMatrixLayout(Attribute enc, ArrayRef<int64_t> shape,
12001329
return chooseDotLdMatrixLayout(dot, shape, needTrans, elemBitWidth);
12011330
}
12021331

1332+
LinearLayout chooseDsReadB64Tr16Layout(Attribute enc, ArrayRef<int64_t> shape,
1333+
int32_t elemBitWidth) {
1334+
auto dot = cast<DotOperandEncodingAttr>(enc);
1335+
return chooseDotDsReadB64Tr16Layout(dot, shape, elemBitWidth);
1336+
}
1337+
12031338
} // namespace mlir::triton::gpu
Lines changed: 77 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,77 @@
1+
// RUN: triton-opt %s --split-input-file --convert-triton-amdgpu-to-llvm=arch=gfx950 | FileCheck %s
2+
3+
#mma16 = #ttg.amd_mfma<{versionMajor = 4, versionMinor = 0, warpsPerCTA = [2, 2], instrShape = [16, 16], isTransposed = true}>
4+
#mma32 = #ttg.amd_mfma<{versionMajor = 4, versionMinor = 0, warpsPerCTA = [2, 2], instrShape = [32, 32], isTransposed = true}>
5+
#shared = #ttg.swizzled_shared<{vec = 1, perPhase = 1, maxPhase = 1, order = [0, 1]}>
6+
#shared1 = #ttg.swizzled_shared<{vec = 1, perPhase = 1, maxPhase = 1, order = [1, 0]}>
7+
#smem = #ttg.shared_memory
8+
9+
module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, "ttg.threads-per-warp" = 64 : i32} {
10+
// CHECK-LABEL: ds_transpose_n_t_fp16_mfma_16
11+
tt.func @ds_transpose_n_t_fp16_mfma_16(%arg0: !ttg.memdesc<128x64xf16, #shared, #smem, mutable>, %arg1: !ttg.memdesc<64x128xf16, #shared1, #smem, mutable>) {
12+
// CHECK-COUNT-32: rocdl.ds.read.tr16.b64 %{{.*}} : <3> -> vector<4xf16>
13+
%1 = ttg.local_load %arg0 : !ttg.memdesc<128x64xf16, #shared, #smem, mutable> -> tensor<128x64xf16, #ttg.dot_op<{opIdx = 0, parent = #mma16, kWidth = 8}>>
14+
%2 = ttg.local_load %arg1 : !ttg.memdesc<64x128xf16, #shared1, #smem, mutable> -> tensor<64x128xf16, #ttg.dot_op<{opIdx = 1, parent = #mma16, kWidth = 8}>>
15+
tt.return
16+
}
17+
18+
// CHECK-LABEL: ds_transpose_t_t_fp16_mfma_16
19+
tt.func @ds_transpose_t_t_fp16_mfma_16(%arg0: !ttg.memdesc<128x64xf16, #shared1, #smem, mutable>, %arg1: !ttg.memdesc<64x128xf16, #shared1, #smem, mutable>) {
20+
// CHECK-COUNT-8: llvm.load %{{.*}} : !llvm.ptr<3> -> vector<8xf16>
21+
%1 = ttg.local_load %arg0 : !ttg.memdesc<128x64xf16, #shared1, #smem, mutable> -> tensor<128x64xf16, #ttg.dot_op<{opIdx = 0, parent = #mma16, kWidth = 8}>>
22+
// CHECK-COUNT-16: rocdl.ds.read.tr16.b64 %{{.*}} : <3> -> vector<4xf16>
23+
%2 = ttg.local_load %arg1 : !ttg.memdesc<64x128xf16, #shared1, #smem, mutable> -> tensor<64x128xf16, #ttg.dot_op<{opIdx = 1, parent = #mma16, kWidth = 8}>>
24+
tt.return
25+
}
26+
27+
// CHECK-LABEL: ds_transpose_n_n_fp16_mfma_16
28+
tt.func @ds_transpose_n_n_fp16_mfma_16(%arg0: !ttg.memdesc<128x64xf16, #shared, #smem, mutable>, %arg1: !ttg.memdesc<64x128xf16, #shared, #smem, mutable>) {
29+
// CHECK-COUNT-16: rocdl.ds.read.tr16.b64 %{{.*}} : <3> -> vector<4xf16>
30+
%1 = ttg.local_load %arg0 : !ttg.memdesc<128x64xf16, #shared, #smem, mutable> -> tensor<128x64xf16, #ttg.dot_op<{opIdx = 0, parent = #mma16, kWidth = 8}>>
31+
// CHECK-COUNT-8: llvm.load %{{.*}} : !llvm.ptr<3> -> vector<8xf16>
32+
%2 = ttg.local_load %arg1 : !ttg.memdesc<64x128xf16, #shared, #smem, mutable> -> tensor<64x128xf16, #ttg.dot_op<{opIdx = 1, parent = #mma16, kWidth = 8}>>
33+
tt.return
34+
}
35+
36+
// CHECK-LABEL: ds_transpose_t_n_fp16_mfma_16
37+
tt.func @ds_transpose_t_n_fp16_mfma_16(%arg0: !ttg.memdesc<128x64xf16, #shared1, #smem, mutable>, %arg1: !ttg.memdesc<64x128xf16, #shared, #smem, mutable>) {
38+
// CHECK-NOT: rocdl.ds.read.tr16.b64 %{{.*}} : <3> -> vector<4xf16>
39+
%1 = ttg.local_load %arg0 : !ttg.memdesc<128x64xf16, #shared1, #smem, mutable> -> tensor<128x64xf16, #ttg.dot_op<{opIdx = 0, parent = #mma16, kWidth = 8}>>
40+
%2 = ttg.local_load %arg1 : !ttg.memdesc<64x128xf16, #shared, #smem, mutable> -> tensor<64x128xf16, #ttg.dot_op<{opIdx = 1, parent = #mma16, kWidth = 8}>>
41+
tt.return
42+
}
43+
44+
// CHECK-LABEL: ds_transpose_n_t_fp16_mfma32
45+
tt.func @ds_transpose_n_t_fp16_mfma32(%arg0: !ttg.memdesc<128x64xf16, #shared, #smem, mutable>, %arg1: !ttg.memdesc<64x128xf16, #shared1, #smem, mutable>) {
46+
// CHECK-COUNT-32: rocdl.ds.read.tr16.b64 %{{.*}} : <3> -> vector<4xf16>
47+
%1 = ttg.local_load %arg0 : !ttg.memdesc<128x64xf16, #shared, #smem, mutable> -> tensor<128x64xf16, #ttg.dot_op<{opIdx = 0, parent = #mma32, kWidth = 8}>>
48+
%2 = ttg.local_load %arg1 : !ttg.memdesc<64x128xf16, #shared1, #smem, mutable> -> tensor<64x128xf16, #ttg.dot_op<{opIdx = 1, parent = #mma32, kWidth = 8}>>
49+
tt.return
50+
}
51+
52+
// CHECK-LABEL: ds_transpose_t_t_fp16_mfma32
53+
tt.func @ds_transpose_t_t_fp16_mfma32(%arg0: !ttg.memdesc<128x64xf16, #shared1, #smem, mutable>, %arg1: !ttg.memdesc<64x128xf16, #shared1, #smem, mutable>) {
54+
// CHECK-COUNT-8: llvm.load %{{.*}} : !llvm.ptr<3> -> vector<8xf16>
55+
%1 = ttg.local_load %arg0 : !ttg.memdesc<128x64xf16, #shared1, #smem, mutable> -> tensor<128x64xf16, #ttg.dot_op<{opIdx = 0, parent = #mma32, kWidth = 8}>>
56+
// CHECK-COUNT-16: rocdl.ds.read.tr16.b64 %{{.*}} : <3> -> vector<4xf16>
57+
%2 = ttg.local_load %arg1 : !ttg.memdesc<64x128xf16, #shared1, #smem, mutable> -> tensor<64x128xf16, #ttg.dot_op<{opIdx = 1, parent = #mma32, kWidth = 8}>>
58+
tt.return
59+
}
60+
61+
// CHECK-LABEL: ds_transpose_n_n_fp16_mfma32
62+
tt.func @ds_transpose_n_n_fp16_mfma32(%arg0: !ttg.memdesc<128x64xf16, #shared, #smem, mutable>, %arg1: !ttg.memdesc<64x128xf16, #shared, #smem, mutable>) {
63+
// CHECK-COUNT-16: rocdl.ds.read.tr16.b64 %{{.*}} : <3> -> vector<4xf16>
64+
%1 = ttg.local_load %arg0 : !ttg.memdesc<128x64xf16, #shared, #smem, mutable> -> tensor<128x64xf16, #ttg.dot_op<{opIdx = 0, parent = #mma32, kWidth = 8}>>
65+
// CHECK-COUNT-8: llvm.load %{{.*}} : !llvm.ptr<3> -> vector<8xf16>
66+
%2 = ttg.local_load %arg1 : !ttg.memdesc<64x128xf16, #shared, #smem, mutable> -> tensor<64x128xf16, #ttg.dot_op<{opIdx = 1, parent = #mma32, kWidth = 8}>>
67+
tt.return
68+
}
69+
70+
// CHECK-LABEL: ds_transpose_t_n_fp16_mfma32
71+
tt.func @ds_transpose_t_n_fp16_mfma32(%arg0: !ttg.memdesc<128x64xf16, #shared1, #smem, mutable>, %arg1: !ttg.memdesc<64x128xf16, #shared, #smem, mutable>) {
72+
// CHECK-NOT: rocdl.ds.read.tr16.b64 %{{.*}} : <3> -> vector<4xf16>
73+
%1 = ttg.local_load %arg0 : !ttg.memdesc<128x64xf16, #shared1, #smem, mutable> -> tensor<128x64xf16, #ttg.dot_op<{opIdx = 0, parent = #mma32, kWidth = 8}>>
74+
%2 = ttg.local_load %arg1 : !ttg.memdesc<64x128xf16, #shared, #smem, mutable> -> tensor<64x128xf16, #ttg.dot_op<{opIdx = 1, parent = #mma32, kWidth = 8}>>
75+
tt.return
76+
}
77+
}

0 commit comments

Comments
 (0)