Skip to content

Commit 48c9a35

Browse files
authored
[AMD] refactor MatrixCoreIntrinEmitter (#860)
1 parent b12a63c commit 48c9a35

File tree

3 files changed

+269
-116
lines changed

3 files changed

+269
-116
lines changed

testing/python/amd/test_tilelang_gemm_mfma_intrinsic.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -234,6 +234,10 @@ def test_assert_tl_matmul():
234234
assert_tl_matmul_correctness(128, 128, 128, "int8", "int32", accum_dtype="int32")
235235
assert_tl_matmul_correctness(128, 256, 256, "int8", "int32", accum_dtype="int32")
236236
assert_tl_matmul_correctness(128, 256, 256, "int8", "int32", accum_dtype="int32", k_pack=2)
237+
assert_tl_matmul_correctness(
238+
128, 256, 256, "int8", "int32", b_transposed=False, accum_dtype="int32")
239+
assert_tl_matmul_correctness(
240+
128, 256, 256, "int8", "int32", b_transposed=False, accum_dtype="int32", k_pack=2)
237241

238242

239243
if __name__ == "__main__":

testing/python/amd/test_tilelang_gemm_mfma_preshuffle.py

Lines changed: 39 additions & 72 deletions
Original file line numberDiff line numberDiff line change
@@ -3,8 +3,7 @@
33
from tilelang import tvm as tvm
44
import tilelang.language as T
55
from tilelang.intrinsics import make_mfma_swizzle_layout as make_swizzle_layout
6-
from tilelang.intrinsics.mfma_macro_generator import (
7-
MatrixCoreIntrinEmitter,)
6+
from tilelang.intrinsics.mfma_macro_generator import MatrixCorePreshuffleIntrinEmitter
87
from tilelang.transform import simplify_prim_func
98

109
tilelang.testing.set_random_seed(0)
@@ -22,16 +21,8 @@ def tl_matmul(
2221
b_transposed=True,
2322
k_pack=1,
2423
b_preshuffle=False,
24+
b_g2l_load=False,
2525
):
26-
assert in_dtype in [
27-
"float16",
28-
"int8",
29-
], "Currently only float16 and int8 are supported"
30-
assert out_dtype in [
31-
"float16",
32-
"float32",
33-
"int32",
34-
], "Currently only float16, float32 and int32 are supported"
3526

3627
micro_size_x = micro_size_y = micro_size_k = 16
3728

@@ -47,15 +38,14 @@ def tl_matmul(
4738
if b_preshuffle:
4839
block_row_warps = 1
4940
block_col_warps = 4
50-
warp_row_tiles = 128
51-
warp_col_tiles = 32
41+
warp_row_tiles = 64
42+
warp_col_tiles = 16
5243

53-
chunk = 32 * k_pack
44+
chunk = 256 * k_pack
5445

5546
pack_size_k = micro_size_k * k_pack
5647

5748
shared_scope = "shared"
58-
cache_write_shared = False
5949

6050
block_M = block_row_warps * warp_row_tiles
6151
block_N = block_col_warps * warp_col_tiles
@@ -68,6 +58,7 @@ def tl_matmul(
6858
pack_size_k, micro_size_y)
6959
else:
7060
B_shape = (N, K) if b_transposed else (K, N)
61+
7162
A_shared_shape = (block_K, block_M) if a_transposed else (block_M, block_K)
7263
if b_preshuffle:
7364
B_shared_shape = (block_N // micro_size_y, block_K // pack_size_k, micro_size_y,
@@ -76,12 +67,6 @@ def tl_matmul(
7667
micro_size_y)
7768
else:
7869
B_shared_shape = (block_N, block_K) if b_transposed else (block_K, block_N)
79-
C_shared_shape = (
80-
block_M // micro_size_x,
81-
block_N // micro_size_y,
82-
micro_size_x,
83-
micro_size_y,
84-
)
8570

8671
warp_size = 64
8772
threads = warp_size * (block_row_warps * block_col_warps)
@@ -92,7 +77,7 @@ def tl_matmul(
9277
warp_cols = warp_col_tiles // micro_size_y
9378

9479
# MMA Wrapper to Auto Generate Code for MMA
95-
mfma_emitter = MatrixCoreIntrinEmitter(
80+
mfma_emitter = MatrixCorePreshuffleIntrinEmitter(
9681
a_dtype=in_dtype,
9782
b_dtype=in_dtype,
9883
accum_dtype=accum_dtype,
@@ -117,7 +102,6 @@ def main(
117102

118103
A_shared = T.alloc_shared(A_shared_shape, in_dtype, scope=shared_scope)
119104
B_shared = T.alloc_shared(B_shared_shape, in_dtype, scope=shared_scope)
120-
C_shared = T.alloc_shared(C_shared_shape, out_dtype, scope=shared_scope)
121105
A_local = T.alloc_local((warp_rows * local_size_a), in_dtype)
122106
B_local = T.alloc_local((warp_cols * local_size_b), in_dtype)
123107
C_local = T.alloc_local((warp_rows * warp_cols * local_size_c), accum_dtype)
@@ -126,12 +110,15 @@ def main(
126110
A_shared: make_swizzle_layout(A_shared),
127111
})
128112

113+
num_ko = K // block_K
114+
num_ki = block_K // (k_pack * micro_size_k)
115+
129116
# Improve L2 Cache
130117
T.use_swizzle(panel_size=10)
131118

132119
T.clear(C_local)
133120

134-
for ko in T.Pipelined((K // block_K), num_stages=0):
121+
for ko in T.Pipelined(num_ko, num_stages=0):
135122

136123
# Load A into shared memory
137124
if a_transposed:
@@ -140,7 +127,7 @@ def main(
140127
T.copy(A[by * block_M, ko * block_K], A_shared)
141128

142129
# Load B into shared memory
143-
if b_preshuffle:
130+
if b_g2l_load is False:
144131
if b_transposed:
145132
for j, k, jj, kk in T.Parallel(block_N // micro_size_y,
146133
block_K // pack_size_k, micro_size_y,
@@ -153,53 +140,37 @@ def main(
153140
micro_size_y):
154141
B_shared[k, j, kk, jj] = B[ko * block_K // pack_size_k + k,
155142
bx * block_N // micro_size_y + j, kk, jj]
156-
else:
157-
if b_transposed:
158-
T.copy(B[bx * block_N, ko * block_K], B_shared)
159-
else:
160-
T.copy(B[ko * block_K, bx * block_N], B_shared)
161143

162-
for ki in T.serial(0, (block_K // (k_pack * micro_size_k))):
144+
for ki in T.serial(0, num_ki):
163145

164-
# Load A into fragment
146+
# Load A S2L
165147
mfma_emitter.ldmatrix_a(
166148
A_local,
167149
A_shared,
168150
ki,
169151
)
170152

171-
# Load B into fragment
172-
mfma_emitter.ldmatrix_b(
173-
B_local,
174-
B_shared,
175-
ki,
176-
)
153+
if b_g2l_load:
154+
# Load B G2L
155+
mfma_emitter.ldmatrix_b(B_local, B, ki + ko * num_ki, pid_m=by, pid_n=bx)
156+
else:
157+
# Load B S2L
158+
mfma_emitter.ldmatrix_b(
159+
B_local,
160+
B_shared,
161+
ki,
162+
)
177163

178164
# Perform Matrix Multiplication
179165
mfma_emitter.mfma(A_local, B_local, C_local)
180166

181167
# Perform STMatrix
182-
if cache_write_shared:
183-
mfma_emitter.stmatrix(
184-
C_local,
185-
C_shared,
186-
)
187-
188-
# Store shared into global
189-
for i, j in T.Parallel(block_M, block_N):
190-
C[by * block_M + i, bx * block_N + j] = C_shared[
191-
i // micro_size_x,
192-
j // micro_size_y,
193-
i % micro_size_x,
194-
j % micro_size_y,
195-
]
196-
else:
197-
mfma_emitter.stmatrix(
198-
C_local,
199-
C,
200-
pid_m=by,
201-
pid_n=bx,
202-
)
168+
mfma_emitter.stmatrix(
169+
C_local,
170+
C,
171+
pid_m=by,
172+
pid_n=bx,
173+
)
203174

204175
return main
205176

@@ -232,9 +203,10 @@ def assert_tl_matmul_correctness(M,
232203
a_transposed=False,
233204
b_transposed=True,
234205
k_pack=1,
235-
b_preshuffle=False):
206+
b_preshuffle=False,
207+
b_g2l_load=False):
236208
matmul = tl_matmul(M, N, K, in_dtype, out_dtype, accum_dtype, a_transposed, b_transposed,
237-
k_pack, b_preshuffle)
209+
k_pack, b_preshuffle, b_g2l_load)
238210
print(matmul)
239211
kernel = tilelang.compile(matmul)
240212
src_code = kernel.get_kernel_source()
@@ -285,30 +257,25 @@ def assert_tl_matmul_correctness(M,
285257

286258
print(C)
287259
print(ref_c)
260+
288261
torch.testing.assert_close(C, ref_c, rtol=1e-2, atol=1e-2)
289262

290263

291264
@tilelang.testing.requires_rocm
292265
def test_assert_tl_matmul():
293-
assert_tl_matmul_correctness(128, 128, 128, "int8", "int32", accum_dtype="int32")
294-
assert_tl_matmul_correctness(128, 256, 256, "int8", "int32", accum_dtype="int32")
295-
assert_tl_matmul_correctness(
296-
128, 256, 256, "int8", "int32", b_transposed=False, accum_dtype="int32")
297-
assert_tl_matmul_correctness(128, 256, 256, "int8", "int32", accum_dtype="int32", k_pack=2)
298-
299266
assert_tl_matmul_correctness(
300-
128, 128, 128, "int8", "int32", accum_dtype="int32", b_preshuffle=True)
267+
256, 256, 256, "int8", "int32", accum_dtype="int32", b_preshuffle=True)
301268
assert_tl_matmul_correctness(
302-
128, 256, 256, "int8", "int32", accum_dtype="int32", b_preshuffle=True)
269+
256, 256, 256, "int8", "int32", accum_dtype="int32", b_preshuffle=True)
303270
assert_tl_matmul_correctness(
304-
128, 256, 256, "int8", "int32", b_transposed=False, accum_dtype="int32", b_preshuffle=True)
271+
256, 256, 256, "int8", "int32", b_transposed=False, accum_dtype="int32", b_preshuffle=True)
305272

306273
assert_tl_matmul_correctness(
307-
128, 256, 256, "int8", "int32", accum_dtype="int32", k_pack=2, b_preshuffle=True)
274+
256, 256, 512, "int8", "int32", accum_dtype="int32", k_pack=2, b_preshuffle=True)
308275
assert_tl_matmul_correctness(
309-
128,
310276
256,
311277
256,
278+
512,
312279
"int8",
313280
"int32",
314281
b_transposed=False,

0 commit comments

Comments
 (0)