Skip to content

Commit 5dc3032

Browse files
authored
[AMD] Emit llvm.amdgcn.wave.id for GFX12 (#8817)
1 parent 7b29378 commit 5dc3032

File tree

3 files changed

+37
-37
lines changed

3 files changed

+37
-37
lines changed

test/Conversion/amd/buffer_load_to_local_to_llvm.mlir

Lines changed: 0 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -331,42 +331,6 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 32 : i32, ttg.sha
331331

332332
// -----
333333

334-
#blocked = #ttg.blocked<{sizePerThread = [1], threadsPerWarp = [64], warpsPerCTA = [1], order = [0]}>
335-
#shared = #ttg.swizzled_shared<{vec = 1, perPhase = 1, maxPhase = 1, order = [0]}>
336-
#smem = #ttg.shared_memory
337-
module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 1 : i32, "ttg.threads-per-warp" = 64 : i32} {
338-
// COMMON-LABEL: buffer_load_to_local_wave_id
339-
tt.func public @buffer_load_to_local_wave_id(%arg0: !tt.ptr<f32> {tt.divisibility = 16 : i32, tt.pointer_range = 32 : i32},
340-
%arg2: !ttg.memdesc<64xf32, #shared, #smem, mutable>, %arg3: i32) {
341-
// COMMON: %[[C64:.+]] = llvm.mlir.constant(64 : i32) : i32
342-
// COMMON-NEXT: %[[IDX:.+]] = rocdl.workitem.id.x : i32
343-
// COMMON-NEXT: %[[C63:.+]] = llvm.mlir.constant(63 : i32) : i32
344-
// COMMON-NEXT: %[[AND:.+]] = llvm.and %[[IDX]], %[[C63]] : i32
345-
// COMMON-NEXT: %[[DIV:.+]] = llvm.udiv %[[AND]], %[[C64]] : i32
346-
// COMMON-NEXT: %{{.+}} = rocdl.readfirstlane %[[DIV]] : i32
347-
348-
// COMMON: %[[C64:.+]] = llvm.mlir.constant(64 : i32) : i32
349-
// COMMON-NEXT: %[[IDX:.+]] = rocdl.workitem.id.x : i32
350-
// COMMON-NEXT: %[[C63:.+]] = llvm.mlir.constant(63 : i32) : i32
351-
// COMMON-NEXT: %[[AND:.+]] = llvm.and %[[IDX]], %[[C63]] : i32
352-
// COMMON-NEXT: %[[DIV:.+]] = llvm.udiv %[[AND]], %[[C64]] : i32
353-
// COMMON-NEXT: %{{.+}} = rocdl.readfirstlane %[[DIV]] : i32
354-
355-
%0 = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32, #blocked>
356-
%1 = amdg.buffer_load_to_local %arg0[%0] into %arg2: <f32>[tensor<64xi32, #blocked>] -> <64xf32, #shared, #smem, mutable>
357-
%c0_i32 = arith.constant 0 : i32
358-
%cond = llvm.icmp "eq" %arg3, %c0_i32 : i32
359-
cf.cond_br %cond, ^bb1, ^bb2
360-
^bb1:
361-
amdg.buffer_load_to_local %arg0[%0] into %arg2: <f32>[tensor<64xi32, #blocked>] -> <64xf32, #shared, #smem, mutable>
362-
cf.br ^bb1
363-
^bb2:
364-
tt.return
365-
}
366-
}
367-
368-
// -----
369-
370334
#blocked = #ttg.blocked<{sizePerThread = [4], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}>
371335
#shared1D = #ttg.swizzled_shared<{vec = 2, perPhase = 1, maxPhase = 8, order = [0]}>
372336
#smem = #ttg.shared_memory
Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,29 @@
1+
// RUN: triton-opt %s -split-input-file --convert-triton-amdgpu-to-llvm=arch=gfx942 | FileCheck %s --check-prefixes=CHECK,GFX9
2+
// RUN: triton-opt %s -split-input-file --convert-triton-amdgpu-to-llvm=arch=gfx950 | FileCheck %s --check-prefixes=CHECK,GFX9
3+
// RUN: triton-opt %s -split-input-file --convert-triton-amdgpu-to-llvm=arch=gfx1200 | FileCheck %s --check-prefixes=CHECK,GFX12
4+
// RUN: triton-opt %s -split-input-file --convert-triton-amdgpu-to-llvm=arch=gfx1250 | FileCheck %s --check-prefixes=CHECK,GFX12
5+
6+
module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 1 : i32, ttg.shared = 0 : i32, "ttg.threads-per-warp" = 64 : i32} {
7+
8+
// CHECK-LABEL: @wave_id
9+
tt.func public @wave_id() {
10+
// GFX9: %[[C64:.+]] = llvm.mlir.constant(64 : i32) : i32
11+
// GFX9-NEXT: %[[IDX:.+]] = rocdl.workitem.id.x : i32
12+
// GFX9-NEXT: %[[C63:.+]] = llvm.mlir.constant(63 : i32) : i32
13+
// GFX9-NEXT: %[[AND:.+]] = llvm.and %[[IDX]], %[[C63]] : i32
14+
// GFX9-NEXT: %[[DIV:.+]] = llvm.udiv %[[AND]], %[[C64]] : i32
15+
// GFX9-NEXT: %{{.+}} = rocdl.readfirstlane %[[DIV]] : i32
16+
17+
// GFX12-NEXT: llvm.call_intrinsic "llvm.amdgcn.wave.id"
18+
// CHECK: scf.for
19+
20+
%c0 = arith.constant 0 : index
21+
%c1 = arith.constant 1 : index
22+
scf.for %i = %c0 to %c1 step %c1 {
23+
%1 = "ttg.warp_id"() : () -> i32
24+
scf.yield
25+
}
26+
tt.return
27+
}
28+
29+
}

third_party/amd/lib/TritonAMDGPUToLLVM/WarpIdOpToLLVM.cpp

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -33,8 +33,15 @@ class WarpIdOpPattern : public ConvertOpToLLVMPattern<WarpIdOp> {
3333
&funcOp.getFunctionBody().getBlocks().front().front());
3434

3535
auto loc = op.getLoc();
36-
auto b = TritonLLVMOpBuilder(loc, rewriter);
3736
auto isaFamily = targetInfo.getISAFamily();
37+
if (ISAFamily::RDNA4 == isaFamily || ISAFamily::GFX1250 == isaFamily) {
38+
auto warpIdOp = LLVM::createLLVMIntrinsicCallOp(
39+
rewriter, loc, "llvm.amdgcn.wave.id", {i32_ty}, ValueRange{});
40+
rewriter.replaceOp(op, warpIdOp.getResult(0));
41+
return success();
42+
}
43+
44+
auto b = TritonLLVMOpBuilder(loc, rewriter);
3845
int threadsPerWarp = triton::gpu::lookupThreadsPerWarp(rewriter);
3946
Value warpSizeVal = b.i32_val(threadsPerWarp);
4047
Value tid = getThreadId(rewriter, loc);

0 commit comments

Comments
 (0)