Skip to content

Commit 06c0b20

Browse files
yiqian1zhanglx13
andauthored
[AMD] Enable v_permlane16_swap for convert_layout and reduceOp on GFX1250 (#8724)
--------- Co-authored-by: Lixun Zhang <[email protected]>
1 parent ca003a0 commit 06c0b20

File tree

3 files changed

+57
-1
lines changed

3 files changed

+57
-1
lines changed
Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,29 @@
1+
// RUN: triton-opt %s -split-input-file --allocate-shared-memory --convert-triton-amdgpu-to-llvm=arch="gfx1250" | FileCheck %s --check-prefix=GFX1250
2+
#linear = #ttg.linear<{register = [[0, 1], [0, 2], [0, 8], [0, 16]], lane = [[1, 0], [2, 0], [4, 0], [8, 0], [0, 4]], warp = [[16, 0]], block = []}>
3+
#mma = #ttg.amd_wmma<{version = 3, warpsPerCTA = [2, 1], isTranspose = true, instrShape = [16, 16, 32]}>
4+
module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 2 : i32, "ttg.threads-per-warp" = 32 : i32} {
5+
// GFX1250-LABEL: wmma_permlane16_swap
6+
tt.func @wmma_permlane16_swap(%arg0: tensor<32x32xf16, #mma>) {
7+
// GFX1250-NOT: store
8+
// GFX1250-NOT: load
9+
// GFX1250-COUNT-4: llvm.call_intrinsic "llvm.amdgcn.permlane16.swap"
10+
// GFX1250-NOT: llvm.call_intrinsic "llvm.amdgcn.permlane16.swap"
11+
%0 = ttg.convert_layout %arg0 : tensor<32x32xf16, #mma> -> tensor<32x32xf16, #linear>
12+
tt.return
13+
}
14+
}
15+
16+
// -----
17+
18+
#mma = #ttg.amd_wmma<{version = 3, warpsPerCTA = [4, 1], isTranspose = true, instrShape = [16, 16, 32]}>
19+
module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, "ttg.threads-per-warp" = 32 : i32} {
20+
// GFX1250-LABEL: reduce_16x16
21+
tt.func @reduce_16x16(%input: tensor<128x128xf32, #mma>) {
22+
// GFX1250-COUNT-2: llvm.call_intrinsic "llvm.amdgcn.permlane16.swap"
23+
%0 = "tt.reduce"(%input) <{axis = 1 : i32}> ({
24+
^bb0(%arg1: f32 , %arg2: f32):
25+
%2 = "arith.maxnumf"(%arg1, %arg2) : (f32, f32) -> f32
26+
tt.reduce.return %2 : f32 }) : (tensor<128x128xf32, #mma>) -> tensor<128xf32, #ttg.slice<{dim = 1, parent = #mma}>>
27+
tt.return
28+
}
29+
}

third_party/amd/lib/TritonAMDGPUToLLVM/ConvertLayoutOpToLLVM.cpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,8 @@ class ConvertLayoutOpPermlaneSwap
2424
ConversionPatternRewriter &rewriter) const override {
2525
auto &amdTargInfo =
2626
static_cast<const mlir::triton::AMD::TargetInfo &>(targetInfo);
27-
if (amdTargInfo.getISAFamily() != AMD::ISAFamily::CDNA4)
27+
if (!(amdTargInfo.getISAFamily() == AMD::ISAFamily::CDNA4 ||
28+
amdTargInfo.getISAFamily() == AMD::ISAFamily::GFX1250))
2829
return failure();
2930

3031
auto srcTy = cast<RankedTensorType>(op.getSrc().getType());

third_party/amd/lib/TritonAMDGPUToLLVM/TargetInfo.cpp

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -324,6 +324,29 @@ static bool warpReduceSwap16or32(RewriterBase &rewriter, Location loc,
324324
return true;
325325
}
326326

327+
static bool warpReduceSwap16(RewriterBase &rewriter, Location loc,
328+
SmallVector<Value> &acc, triton::ReduceOp op,
329+
unsigned numLaneToReduce, unsigned interleave) {
330+
Operation *reduxOp = op.getSingleCombiner();
331+
if (!reduxOp)
332+
return false;
333+
334+
bool mfma16Case = numLaneToReduce == 2 && interleave == 16;
335+
if (!mfma16Case)
336+
return false;
337+
338+
Value val = acc[0];
339+
unsigned bits = val.getType().getIntOrFloatBitWidth();
340+
if (bits > 32)
341+
return false;
342+
343+
StringRef intrinsic = "llvm.amdgcn.permlane16.swap";
344+
for (auto i = 0; i < acc.size(); i++) {
345+
acc[i] = permuteAndReduce(rewriter, loc, intrinsic, acc[i], reduxOp);
346+
}
347+
return true;
348+
}
349+
327350
bool TargetInfo::warpReduce(RewriterBase &rewriter, Location loc,
328351
SmallVector<Value> &acc, triton::ReduceOp op,
329352
unsigned numLaneToReduce,
@@ -333,6 +356,9 @@ bool TargetInfo::warpReduce(RewriterBase &rewriter, Location loc,
333356
if (getISAFamily() == ISAFamily::CDNA4 &&
334357
warpReduceSwap16or32(rewriter, loc, acc, op, numLaneToReduce, interleave))
335358
return true;
359+
if ((getISAFamily() == ISAFamily::GFX1250) &&
360+
warpReduceSwap16(rewriter, loc, acc, op, numLaneToReduce, interleave))
361+
return true;
336362
if (numLaneToReduce != getWarpSize())
337363
return false;
338364
if (isCDNA(getISAFamily()) && getISAFamily() == ISAFamily::CDNA1)

0 commit comments

Comments
 (0)