@@ -22,42 +22,49 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.shar
2222
2323// -----
2424
25- #blocked = #ttg.blocked <{sizePerThread = [1 , 2 ], threadsPerWarp = [2 , 32 ], warpsPerCTA = [4 , 1 ], order = [1 , 0 ]}>
26- #shared = #ttg.swizzled_shared <{vec = 1 , perPhase = 1 , maxPhase = 1 , order = [1 , 0 ]}>
25+ #blocked = #ttg.blocked <{sizePerThread = [2 , 1 ], threadsPerWarp = [32 , 2 ], warpsPerCTA = [1 , 32 ], order = [0 , 1 ]}>
26+ #shared = #ttg.swizzled_shared <{vec = 2 , perPhase = 1 , maxPhase = 1 , order = [0 , 1 ]}>
2727#smem = #ttg.shared_memory
28- module attributes {" ttg.num-ctas" = 1 : i32 , " ttg.num-warps" = 4 : i32 , ttg.shared = 8192 : i32 , ttg.target = " hip:gfx942" , " ttg.threads-per-warp" = 64 : i32 } {
28+ module attributes {" ttg.num-ctas" = 1 : i32 , " ttg.num-warps" = 32 : i32 , ttg.shared = 0 : i32 , ttg.target = " hip:gfx942" , " ttg.threads-per-warp" = 64 : i32 } {
2929 // COMMON-LABEL: buffer_load_to_local_vectorized_2xf16
30- tt.func public @buffer_load_to_local_vectorized_2xf16 (
31- %arg1: !tt.ptr <f16 > {tt.divisibility = 16 : i32 , tt.pointer_range = 32 : i32 },
32- %arg2: !ttg.memdesc <32 x64 xf16 , #shared , #smem , mutable >,
33- %arg3: i32 ) {
34- %1 = tt.splat %arg3: i32 -> tensor <64 xi32 , #ttg.slice <{dim = 0 , parent = #blocked }>>
35- %2 = tt.expand_dims %1 {axis = 0 : i32 } : tensor <64 xi32 , #ttg.slice <{dim = 0 , parent = #blocked }>> -> tensor <1 x64 xi32 , #blocked >
36- %3 = tt.broadcast %2 : tensor <1 x64 xi32 , #blocked > -> tensor <32 x64 xi32 , #blocked >
37- // Each thread needs to load 8 elements and we load 2 (sizePerThread) per buffer load instruction
30+ tt.func public @buffer_load_to_local_vectorized_2xf16 (%arg1: !tt.ptr <f16 > {tt.divisibility = 16 : i32 , tt.pointer_range = 32 : i32 }, %arg2: !ttg.memdesc <64 x64 xf16 , #shared , #smem , mutable >) {
31+ %cst = arith.constant dense <64 > : tensor <1 x64 xi32 , #blocked >
32+ %0 = tt.make_range {end = 64 : i32 , start = 0 : i32 } : tensor <64 xi32 , #ttg.slice <{dim = 1 , parent = #blocked }>>
33+ %1 = tt.make_range {end = 64 : i32 , start = 0 : i32 } : tensor <64 xi32 , #ttg.slice <{dim = 0 , parent = #blocked }>>
34+ %2 = tt.expand_dims %0 {axis = 1 : i32 } : tensor <64 xi32 , #ttg.slice <{dim = 1 , parent = #blocked }>> -> tensor <64 x1 xi32 , #blocked >
35+ %3 = tt.broadcast %2 : tensor <64 x1 xi32 , #blocked > -> tensor <64 x64 xi32 , #blocked >
36+ %4 = tt.expand_dims %1 {axis = 0 : i32 } : tensor <64 xi32 , #ttg.slice <{dim = 0 , parent = #blocked }>> -> tensor <1 x64 xi32 , #blocked >
37+ %5 = arith.muli %4 , %cst : tensor <1 x64 xi32 , #blocked >
38+ %6 = tt.broadcast %5 : tensor <1 x64 xi32 , #blocked > -> tensor <64 x64 xi32 , #blocked >
39+ %7 = arith.addi %3 , %6 : tensor <64 x64 xi32 , #blocked >
40+
41+ // Each thread needs to load 2 elements and we load 2 (sizePerThread) per buffer load instruction
3842 // COMMON: rocdl.make.buffer.rsrc
3943 // COMMON-NOT: rocdl.make.buffer.rsrc
40- // COMMON-COUNT-4 : rocdl.raw.ptr.buffer.load.lds
44+ // COMMON: rocdl.raw.ptr.buffer.load.lds
4145 // COMMON-NOT: rocdl.raw.ptr.buffer.load.lds
42- %65 = amdgpu.buffer_load_to_local %arg1 [%3 ] into %arg2 : <f16 >[tensor <32 x 64 x i32 , #blocked >] -> <32 x 64 x f16 , #shared , #smem , mutable >
46+ %8 = amdgpu.buffer_load_to_local %arg1 [%7 ] into %arg2 : <f16 >[tensor <64 x 64 x i32 , #blocked >] -> <64 x 64 x f16 , #shared , #smem , mutable >
4347 tt.return
4448 }
4549}
4650
4751// -----
4852
49- #blocked = #ttg.blocked <{sizePerThread = [1 , 8 ], threadsPerWarp = [8 , 8 ], warpsPerCTA = [4 , 1 ], order = [1 , 0 ]}>
50- #shared = #ttg.swizzled_shared <{vec = 1 , perPhase = 1 , maxPhase = 1 , order = [1 , 0 ]}>
53+ #blocked = #ttg.blocked <{sizePerThread = [8 , 1 ], threadsPerWarp = [8 , 8 ], warpsPerCTA = [1 , 32 ], order = [0 , 1 ]}>
54+ #shared = #ttg.swizzled_shared <{vec = 2 , perPhase = 1 , maxPhase = 1 , order = [0 , 1 ]}>
5155#smem = #ttg.shared_memory
52- module attributes {" ttg.num-ctas" = 1 : i32 , " ttg.num-warps" = 4 : i32 , ttg.shared = 8192 : i32 , ttg.target = " hip:gfx942" , " ttg.threads-per-warp" = 64 : i32 } {
56+ module attributes {" ttg.num-ctas" = 1 : i32 , " ttg.num-warps" = 32 : i32 , ttg.shared = 0 : i32 , ttg.target = " hip:gfx942" , " ttg.threads-per-warp" = 64 : i32 } {
5357 // COMMON-LABEL: buffer_load_to_local_vectorized_8xf16
54- tt.func public @buffer_load_to_local_vectorized_8xf16 (
55- %arg1: !tt.ptr <f16 > {tt.divisibility = 16 : i32 , tt.pointer_range = 32 : i32 },
56- %arg2: !ttg.memdesc <32 x64 xf16 , #shared , #smem , mutable >,
57- %arg3: i32 ) {
58- %1 = tt.splat %arg3: i32 -> tensor <64 xi32 , #ttg.slice <{dim = 0 , parent = #blocked }>>
59- %2 = tt.expand_dims %1 {axis = 0 : i32 } : tensor <64 xi32 , #ttg.slice <{dim = 0 , parent = #blocked }>> -> tensor <1 x64 xi32 , #blocked >
60- %3 = tt.broadcast %2 : tensor <1 x64 xi32 , #blocked > -> tensor <32 x64 xi32 , #blocked >
58+ tt.func public @buffer_load_to_local_vectorized_8xf16 (%arg1: !tt.ptr <f16 > {tt.divisibility = 16 : i32 , tt.pointer_range = 32 : i32 }, %arg2: !ttg.memdesc <64 x64 xf16 , #shared , #smem , mutable >) {
59+ %cst = arith.constant dense <64 > : tensor <1 x64 xi32 , #blocked >
60+ %0 = tt.make_range {end = 64 : i32 , start = 0 : i32 } : tensor <64 xi32 , #ttg.slice <{dim = 1 , parent = #blocked }>>
61+ %1 = tt.make_range {end = 64 : i32 , start = 0 : i32 } : tensor <64 xi32 , #ttg.slice <{dim = 0 , parent = #blocked }>>
62+ %2 = tt.expand_dims %0 {axis = 1 : i32 } : tensor <64 xi32 , #ttg.slice <{dim = 1 , parent = #blocked }>> -> tensor <64 x1 xi32 , #blocked >
63+ %3 = tt.broadcast %2 : tensor <64 x1 xi32 , #blocked > -> tensor <64 x64 xi32 , #blocked >
64+ %4 = tt.expand_dims %1 {axis = 0 : i32 } : tensor <64 xi32 , #ttg.slice <{dim = 0 , parent = #blocked }>> -> tensor <1 x64 xi32 , #blocked >
65+ %5 = arith.muli %4 , %cst : tensor <1 x64 xi32 , #blocked >
66+ %6 = tt.broadcast %5 : tensor <1 x64 xi32 , #blocked > -> tensor <64 x64 xi32 , #blocked >
67+ %7 = arith.addi %3 , %6 : tensor <64 x64 xi32 , #blocked >
6168
6269 // Each thread needs to load 8 elements and we load 8 (sizePerThread) per buffer load instruction
6370 // GFX950: rocdl.make.buffer.rsrc
@@ -68,7 +75,7 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.shar
6875 // GFX942 does not support vectorization > 4bytes so we cannot lower it
6976 // GFX942-NOT: rocdl.raw.ptr.buffer.load.lds
7077 // GFX942: amdgpu.buffer_load_to_local
71- %65 = amdgpu.buffer_load_to_local %arg1 [%3 ] into %arg2 : <f16 >[tensor <32 x 64 x i32 , #blocked >] -> <32 x 64 x f16 , #shared , #smem , mutable >
78+ %8 = amdgpu.buffer_load_to_local %arg1 [%7 ] into %arg2 : <f16 >[tensor <64 x 64 x i32 , #blocked >] -> <64 x 64 x f16 , #shared , #smem , mutable >
7279 tt.return
7380 }
7481}
@@ -129,30 +136,28 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.shar
129136
130137// -----
131138
132-
133- #blocked = #ttg.blocked <{sizePerThread = [1 , 1 ], threadsPerWarp = [2 , 32 ], warpsPerCTA = [16 , 1 ], order = [1 , 0 ]}>
134- #shared = #ttg.swizzled_shared <{vec = 1 , perPhase = 1 , maxPhase = 1 , order = [1 , 0 ]}>
139+ #blocked = #ttg.blocked <{sizePerThread = [1 ], threadsPerWarp = [64 ], warpsPerCTA = [1 ], order = [0 ]}>
140+ #shared = #ttg.swizzled_shared <{vec = 1 , perPhase = 1 , maxPhase = 1 , order = [0 ]}>
135141#smem = #ttg.shared_memory
136- module attributes {" ttg.num-ctas" = 1 : i32 , " ttg.num-warps" = 16 : i32 , ttg.shared = 8192 : i32 , ttg.target = " hip:gfx942 " , " ttg.threads-per-warp" = 64 : i32 } {
142+ module attributes {" ttg.num-ctas" = 1 : i32 , " ttg.num-warps" = 1 : i32 , " ttg.threads-per-warp" = 64 : i32 } {
137143 // COMMON-LABEL: buffer_load_to_local_cache_mods
138- tt.func public @buffer_load_to_local_cache_mods (%arg0: !tt.ptr <f16 > {tt.divisibility = 16 : i32 , tt.pointer_range = 32 : i32 },
139- %arg1: !tt.ptr <f16 >,
140- %arg2: tensor <32 x32 xi32 , #blocked >,
141- %arg3: !ttg.memdesc <32 x32 xf16 , #shared , #smem , mutable >) {
144+ tt.func public @buffer_load_to_local_cache_mods (%arg0: !tt.ptr <f32 > {tt.divisibility = 16 : i32 , tt.pointer_range = 32 : i32 },
145+ %arg2: !ttg.memdesc <64 xf32 , #shared , #smem , mutable >) {
146+ %0 = tt.make_range {end = 64 : i32 , start = 0 : i32 } : tensor <64 xi32 , #blocked >
142147 // The first constant 0 skips the LDS offset which is also 0
143148 // COMMON: llvm.getelementptr
144149 // COMMON: llvm.mlir.constant(0 : i32) : i32
145150 // COMMON: %[[aux_ca:.*]] = llvm.mlir.constant(0 : i32) : i32
146151 // COMMON: rocdl.raw.ptr.buffer.load.lds {{.*}}, {{.*}}, {{.*}}, {{.*}}, {{.*}}, {{.*}}, %[[aux_ca]]
147- %1 = amdgpu.buffer_load_to_local %arg1 [ %arg2 ] cacheModifier = ca into %arg3 : <f16 >[tensor <32 x 32 x i32 , #blocked >] -> <32 x 32 x f16 , #shared , #smem , mutable >
152+ %1 = amdgpu.buffer_load_to_local %arg0 [% 0 ] cacheModifier = ca into %arg2 : <f32 >[tensor <64 x i32 , #blocked >] -> <64 x f32 , #shared , #smem , mutable >
148153 // COMMON: llvm.getelementptr
149154 // COMMON: %[[aux_cg:.*]] = llvm.mlir.constant(3 : i32) : i32
150155 // COMMON: rocdl.raw.ptr.buffer.load.lds {{.*}}, {{.*}}, {{.*}}, {{.*}}, {{.*}}, {{.*}}, %[[aux_cg]]
151- %2 = amdgpu.buffer_load_to_local %arg1 [ %arg2 ] cacheModifier = cg into %arg3 : <f16 >[tensor <32 x 32 x i32 , #blocked >] -> <32 x 32 x f16 , #shared , #smem , mutable >
156+ %2 = amdgpu.buffer_load_to_local %arg0 [% 0 ] cacheModifier = cg into %arg2 : <f32 >[tensor <64 x i32 , #blocked >] -> <64 x f32 , #shared , #smem , mutable >
152157 // COMMON: llvm.getelementptr
153158 // COMMON: %[[aux_cv:.*]] = llvm.mlir.constant(17 : i32) : i32
154159 // COMMON: rocdl.raw.ptr.buffer.load.lds {{.*}}, {{.*}}, {{.*}}, {{.*}}, {{.*}}, {{.*}}, %[[aux_cv]]
155- %3 = amdgpu.buffer_load_to_local %arg1 [ %arg2 ] cacheModifier = cv into %arg3 : <f16 >[tensor <32 x 32 x i32 , #blocked >] -> <32 x 32 x f16 , #shared , #smem , mutable >
160+ %3 = amdgpu.buffer_load_to_local %arg0 [% 0 ] cacheModifier = cv into %arg2 : <f32 >[tensor <64 x i32 , #blocked >] -> <64 x f32 , #shared , #smem , mutable >
156161
157162 tt.return
158163 }
0 commit comments