Skip to content

Commit cd9408c

Browse files
authored
[AMD] Support lowering GPU async copy/commit/wait ops (#5729)
Support lowering of `ttg.async_copy_global_to_local` for `gfx9` GPUs. The lowering does check if the resulting writes are coalesced which is a requirement by the hardware. Also associated `ttg.async_commit_group` and `ttg.async_wait`. Note that we are currently not emitting `AsyncCopyGlobalToLocal` for AMD targets, this will come with a follow up PR.
1 parent ac79534 commit cd9408c

File tree

4 files changed

+479
-9
lines changed

4 files changed

+479
-9
lines changed
Lines changed: 217 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,217 @@
1+
// RUN: triton-opt %s -split-input-file --allocate-shared-memory --convert-triton-amdgpu-to-llvm=arch=gfx950 | FileCheck %s --check-prefix=GFX950
2+
// RUN: triton-opt %s -split-input-file --allocate-shared-memory --convert-triton-amdgpu-to-llvm=arch=gfx942 --verify-diagnostics | FileCheck %s --check-prefix=GFX942
3+
4+
#blocked = #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [1, 64], warpsPerCTA = [4, 1], order = [1, 0]}>
5+
#shared = #ttg.swizzled_shared<{vec = 1, perPhase = 1, maxPhase = 1, order = [1, 0]}>
6+
#smem = #ttg.shared_memory
7+
module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.shared = 8192 : i32, ttg.target = "hip:gfx942", "ttg.threads-per-warp" = 64 : i32} {
8+
// CHECK-LABEL: async_copy
9+
tt.func public @async_copy(%arg0: !tt.ptr<f16> {tt.divisibility = 16 : i32, tt.pointer_range = 32 : i32},
10+
%arg1: i32 {tt.divisibility = 16 : i32},
11+
%arg2: !ttg.memdesc<32x64xf16, #shared, #smem, mutable>) {
12+
// We need the splat to allow the AxisAnalysis to work during lowering
13+
%1 = tt.splat %arg0 : !tt.ptr<f16> -> tensor<32x64x!tt.ptr<f16>, #blocked>
14+
// Each thread needs to load 8 elements and we load 1 (sizePerThread) per global.load.lds
15+
// CHECK-COUNT-8: rocdl.global.load.lds
16+
// CHECK-NOT: rocdl.global.load.lds
17+
%2 = ttg.async_copy_global_to_local %1, %arg2 : tensor<32x64x!tt.ptr<f16>, #blocked> -> <32x64xf16, #shared, #smem, mutable>
18+
tt.return
19+
}
20+
}
21+
22+
// -----
23+
24+
#blocked = #ttg.blocked<{sizePerThread = [1, 2], threadsPerWarp = [2, 32], warpsPerCTA = [4, 1], order = [1, 0]}>
25+
#shared = #ttg.swizzled_shared<{vec = 1, perPhase = 1, maxPhase = 1, order = [1, 0]}>
26+
#smem = #ttg.shared_memory
27+
module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.shared = 8192 : i32, ttg.target = "hip:gfx942", "ttg.threads-per-warp" = 64 : i32} {
28+
// CHECK-LABEL: async_copy_vectorized_2xf16
29+
tt.func public @async_copy_vectorized_2xf16(%arg0: !tt.ptr<f16> {tt.divisibility = 16 : i32, tt.pointer_range = 32 : i32},
30+
%arg1: i32 {tt.divisibility = 16 : i32},
31+
%arg2: !ttg.memdesc<32x64xf16, #shared, #smem, mutable>) {
32+
// We need the index calculation so AxisAnalysis sees that we can vectorize the load
33+
%1 = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32, #ttg.slice<{dim = 0, parent = #blocked}>>
34+
%2 = tt.expand_dims %1 {axis = 0 : i32} : tensor<64xi32, #ttg.slice<{dim = 0, parent = #blocked}>> -> tensor<1x64xi32, #blocked>
35+
%3 = tt.broadcast %2 : tensor<1x64xi32, #blocked> -> tensor<32x64xi32, #blocked>
36+
%4 = tt.splat %arg0 : !tt.ptr<f16> -> tensor<32x64x!tt.ptr<f16>, #blocked>
37+
%5 = tt.addptr %4, %3 : tensor<32x64x!tt.ptr<f16>, #blocked>, tensor<32x64xi32, #blocked>
38+
39+
// Each thread needs to load 8 elements and we load 2 (sizePerThread) per global.load.lds
40+
// CHECK-COUNT-4: rocdl.global.load.lds
41+
// CHECK-NOT: rocdl.global.load.lds
42+
%6 = ttg.async_copy_global_to_local %5, %arg2 : tensor<32x64x!tt.ptr<f16>, #blocked> -> <32x64xf16, #shared, #smem, mutable>
43+
tt.return
44+
}
45+
}
46+
47+
// -----
48+
49+
#blocked = #ttg.blocked<{sizePerThread = [1, 8], threadsPerWarp = [8, 8], warpsPerCTA = [4, 1], order = [1, 0]}>
50+
#shared = #ttg.swizzled_shared<{vec = 1, perPhase = 1, maxPhase = 1, order = [1, 0]}>
51+
#smem = #ttg.shared_memory
52+
module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.shared = 8192 : i32, ttg.target = "hip:gfx950", "ttg.threads-per-warp" = 64 : i32} {
53+
// GFX950-LABEL: async_copy_vectorized_8xf16
54+
tt.func public @async_copy_vectorized_8xf16(%arg0: !tt.ptr<f16> {tt.divisibility = 16 : i32, tt.pointer_range = 32 : i32},
55+
%arg1: i32 {tt.divisibility = 16 : i32},
56+
%arg2: !ttg.memdesc<32x64xf16, #shared, #smem, mutable>) {
57+
// We need the index calculation so AxisAnalysis sees that we can vectorize the load
58+
%1 = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32, #ttg.slice<{dim = 0, parent = #blocked}>>
59+
%2 = tt.expand_dims %1 {axis = 0 : i32} : tensor<64xi32, #ttg.slice<{dim = 0, parent = #blocked}>> -> tensor<1x64xi32, #blocked>
60+
%3 = tt.broadcast %2 : tensor<1x64xi32, #blocked> -> tensor<32x64xi32, #blocked>
61+
%4 = tt.splat %arg0 : !tt.ptr<f16> -> tensor<32x64x!tt.ptr<f16>, #blocked>
62+
%5 = tt.addptr %4, %3 : tensor<32x64x!tt.ptr<f16>, #blocked>, tensor<32x64xi32, #blocked>
63+
64+
// Each thread needs to load 8 elements and we load 8 (sizePerThread) per global.load.lds
65+
// GFX950: rocdl.global.load.lds
66+
// GFX950-next: llvm.return
67+
68+
// GFX942 does not support vectorization > 4bytes
69+
// expected-error@+1 {{failed to legalize operation 'ttg.async_copy_global_to_local' that was explicitly marked illegal}}
70+
%6 = ttg.async_copy_global_to_local %5, %arg2 : tensor<32x64x!tt.ptr<f16>, #blocked> -> <32x64xf16, #shared, #smem, mutable>
71+
tt.return
72+
}
73+
}
74+
75+
// -----
76+
77+
#blocked = #ttg.blocked<{sizePerThread = [1, 8], threadsPerWarp = [8, 8], warpsPerCTA = [4, 1], order = [1, 0]}>
78+
#shared = #ttg.swizzled_shared<{vec = 1, perPhase = 1, maxPhase = 1, order = [1, 0]}>
79+
#smem = #ttg.shared_memory
80+
module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.shared = 8192 : i32, ttg.target = "hip:gfx950", "ttg.threads-per-warp" = 64 : i32} {
81+
// CHECK-LABEL: async_wait
82+
tt.func public @async_wait(%arg0: !tt.ptr<f16> {tt.divisibility = 16 : i32, tt.pointer_range = 32 : i32},
83+
%arg1: i32 {tt.divisibility = 16 : i32},
84+
%arg2: !ttg.memdesc<32x64xf16, #shared, #smem, mutable>) {
85+
// The waitcnt stores all counters in one i32 bits 15:14 and 3:0 store the vmcnt we have to wait on
86+
// CHECK: rocdl.waitcnt -49168
87+
// CHECK: rocdl.barrier
88+
ttg.async_wait {num = 0 : i32}
89+
// CHECK: rocdl.waitcnt -49167
90+
// CHECK: rocdl.barrier
91+
ttg.async_wait {num = 1 : i32}
92+
// CHECK: rocdl.waitcnt -2
93+
// CHECK: rocdl.barrier
94+
ttg.async_wait {num = 62 : i32}
95+
// CHECK: rocdl.waitcnt -1
96+
// CHECK: rocdl.barrier
97+
ttg.async_wait {num = 63 : i32}
98+
tt.return
99+
}
100+
}
101+
102+
// -----
103+
104+
#blocked = #ttg.blocked<{sizePerThread = [1, 8], threadsPerWarp = [8, 8], warpsPerCTA = [4, 1], order = [1, 0]}>
105+
#shared = #ttg.swizzled_shared<{vec = 1, perPhase = 1, maxPhase = 1, order = [1, 0]}>
106+
#smem = #ttg.shared_memory
107+
module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.shared = 8192 : i32, ttg.target = "hip:gfx950", "ttg.threads-per-warp" = 64 : i32} {
108+
// CHECK-LABEL: async_commit_group
109+
tt.func public @async_commit_group(%arg0: !tt.ptr<f16> {tt.divisibility = 16 : i32, tt.pointer_range = 32 : i32},
110+
%arg1: i32 {tt.divisibility = 16 : i32},
111+
%arg2: !ttg.memdesc<32x64xf16, #shared, #smem, mutable>) {
112+
// CHECK-NEXT: llvm.mlir.constant(0 : i32) : i32
113+
// CHECK-NEXT: llvm.return
114+
ttg.async_commit_group
115+
tt.return
116+
}
117+
}
118+
119+
// -----
120+
121+
#blocked = #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [2, 32], warpsPerCTA = [4, 1], order = [1, 0]}>
122+
#shared = #ttg.swizzled_shared<{vec = 1, perPhase = 1, maxPhase = 1, order = [1, 0]}>
123+
#smem = #ttg.shared_memory
124+
module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.shared = 8192 : i32, ttg.target = "hip:gfx942", "ttg.threads-per-warp" = 64 : i32} {
125+
// CHECK-LABEL: async_copy_mask_other
126+
tt.func public @async_copy_mask_other(%arg0: !tt.ptr<f16> {tt.divisibility = 16 : i32, tt.pointer_range = 32 : i32},
127+
%arg1: i32 {tt.divisibility = 16 : i32},
128+
%arg2: !ttg.memdesc<32x32xf16, #shared, #smem, mutable>,
129+
%arg3: i32 {tt.divisibility = 16 : i32}) {
130+
// We need the splat to allow the AxisAnalysis to work during lowering
131+
%cst_0 = arith.constant dense<0.000000e+00> : tensor<32x32xf16, #blocked>
132+
%c0_i32 = arith.constant 0 : i32
133+
%c32_i32 = arith.constant 32 : i32
134+
%c31_i32 = arith.constant 31 : i32
135+
%1 = tt.splat %arg0 : !tt.ptr<f16> -> tensor<32x32x!tt.ptr<f16>, #blocked>
136+
%29 = arith.addi %arg3, %c31_i32 : i32
137+
%30 = arith.divsi %29, %c32_i32 : i32
138+
%31 = arith.cmpi sgt, %30, %c0_i32 : i32
139+
140+
%51 = tt.make_range {end = 32 : i32, start = 0 : i32} : tensor<32xi32, #ttg.slice<{dim = 1, parent = #blocked}>>
141+
%52 = tt.expand_dims %51 {axis = 1 : i32} : tensor<32xi32, #ttg.slice<{dim = 1, parent = #blocked}>> -> tensor<32x1xi32, #blocked>
142+
%65 = tt.splat %arg3 : i32 -> tensor<32x1xi32, #blocked>
143+
%66 = arith.cmpi slt, %52, %65 : tensor<32x1xi32, #blocked>
144+
%67 = tt.broadcast %66 : tensor<32x1xi1, #blocked> -> tensor<32x32xi1, #blocked>
145+
146+
%70 = tt.splat %31 : i1 -> tensor<32x32xi1, #blocked>
147+
%71 = arith.andi %70, %67 : tensor<32x32xi1, #blocked>
148+
149+
// Each thread needs to load 4 elements and we load 1 (sizePerThread) per global.load.lds
150+
// Note that mask/other alignment is 1 so we need 4 conditionals
151+
152+
// CHECK: llvm.cond_br
153+
// CHECK: rocdl.global.load.lds
154+
// CHECK-NEXT: llvm.br
155+
// CHECK: _predicated_store
156+
157+
// CHECK: llvm.cond_br
158+
// CHECK: rocdl.global.load.lds
159+
// CHECK-NEXT: llvm.br
160+
// CHECK: _predicated_store
161+
162+
// CHECK: llvm.cond_br
163+
// CHECK: rocdl.global.load.lds
164+
// CHECK-NEXT: llvm.br
165+
// CHECK: _predicated_store
166+
167+
// CHECK: llvm.cond_br
168+
// CHECK: rocdl.global.load.lds
169+
// CHECK-NEXT: llvm.br
170+
// CHECK: _predicated_store
171+
172+
%2 = ttg.async_copy_global_to_local %1, %arg2 mask %67 other %cst_0 : tensor<32x32x!tt.ptr<f16>, #blocked> -> <32x32xf16, #shared, #smem, mutable>
173+
tt.return
174+
}
175+
}
176+
177+
// -----
178+
179+
#blocked = #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [2, 32], warpsPerCTA = [16, 1], order = [1, 0]}>
180+
#shared = #ttg.swizzled_shared<{vec = 1, perPhase = 1, maxPhase = 1, order = [1, 0]}>
181+
#smem = #ttg.shared_memory
182+
module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 16 : i32, ttg.shared = 8192 : i32, ttg.target = "hip:gfx942", "ttg.threads-per-warp" = 64 : i32} {
183+
// GFX942-LABEL: async_copy_cache_mods
184+
tt.func public @async_copy_cache_mods(%arg0: !tt.ptr<f16> {tt.divisibility = 16 : i32, tt.pointer_range = 32 : i32},
185+
%arg1: i32 {tt.divisibility = 16 : i32},
186+
%arg2: !ttg.memdesc<32x32xf16, #shared, #smem, mutable>) {
187+
// We need the splat to allow the AxisAnalysis to work during lowering
188+
%1 = tt.splat %arg0 : !tt.ptr<f16> -> tensor<32x32x!tt.ptr<f16>, #blocked>
189+
// Each thread needs to load 1 element and we load 1 (sizePerThread) per global.load.lds
190+
191+
// GFX942: llvm.getelementptr
192+
// GFX942: %[[aux_ca:.*]] = llvm.mlir.constant(0 : i32) : i32
193+
// GFX942: rocdl.global.load.lds {{.*}}, {{.*}}, {{.*}}, {{.*}}, %[[aux_ca]]
194+
%2 = ttg.async_copy_global_to_local %1, %arg2 cacheModifier = ca: tensor<32x32x!tt.ptr<f16>, #blocked> -> <32x32xf16, #shared, #smem, mutable>
195+
// GFX942: llvm.getelementptr
196+
// GFX942: %[[aux_cg:.*]] = llvm.mlir.constant(0 : i32) : i32
197+
// GFX942: rocdl.global.load.lds {{.*}}, {{.*}}, {{.*}}, {{.*}}, %[[aux_cg]]
198+
%3 = ttg.async_copy_global_to_local %1, %arg2 cacheModifier = cg: tensor<32x32x!tt.ptr<f16>, #blocked> -> <32x32xf16, #shared, #smem, mutable>
199+
// GFX942: llvm.getelementptr
200+
// GFX942: %[[aux_cs:.*]] = llvm.mlir.constant(3 : i32) : i32
201+
// GFX942: rocdl.global.load.lds {{.*}}, {{.*}}, {{.*}}, {{.*}}, %[[aux_cs]]
202+
%5 = ttg.async_copy_global_to_local %1, %arg2 cacheModifier = cs: tensor<32x32x!tt.ptr<f16>, #blocked> -> <32x32xf16, #shared, #smem, mutable>
203+
// GFX942: llvm.getelementptr
204+
// GFX942: %[[aux_cv:.*]] = llvm.mlir.constant(9 : i32) : i32
205+
// GFX942: rocdl.global.load.lds {{.*}}, {{.*}}, {{.*}}, {{.*}}, %[[aux_cv]]
206+
%6 = ttg.async_copy_global_to_local %1, %arg2 cacheModifier = cv: tensor<32x32x!tt.ptr<f16>, #blocked> -> <32x32xf16, #shared, #smem, mutable>
207+
// GFX942: llvm.getelementptr
208+
// GFX942: %[[aux_wb:.*]] = llvm.mlir.constant(0 : i32) : i32
209+
// GFX942: rocdl.global.load.lds {{.*}}, {{.*}}, {{.*}}, {{.*}}, %[[aux_wb]]
210+
%7 = ttg.async_copy_global_to_local %1, %arg2 cacheModifier = wb: tensor<32x32x!tt.ptr<f16>, #blocked> -> <32x32xf16, #shared, #smem, mutable>
211+
// GFX942: llvm.getelementptr
212+
// GFX942: %[[aux_wt:.*]] = llvm.mlir.constant(8 : i32) : i32
213+
// GFX942: rocdl.global.load.lds {{.*}}, {{.*}}, {{.*}}, {{.*}}, %[[aux_wt]]
214+
%8 = ttg.async_copy_global_to_local %1, %arg2 cacheModifier = wt: tensor<32x32x!tt.ptr<f16>, #blocked> -> <32x32xf16, #shared, #smem, mutable>
215+
tt.return
216+
}
217+
}

0 commit comments

Comments
 (0)