3
3
from tilelang import tvm as tvm
4
4
import tilelang .language as T
5
5
from tilelang .intrinsics import make_mfma_swizzle_layout as make_swizzle_layout
6
- from tilelang .intrinsics .mfma_macro_generator import (
7
- MatrixCoreIntrinEmitter ,)
6
+ from tilelang .intrinsics .mfma_macro_generator import MatrixCorePreshuffleIntrinEmitter
8
7
from tilelang .transform import simplify_prim_func
9
8
10
9
tilelang .testing .set_random_seed (0 )
@@ -22,16 +21,8 @@ def tl_matmul(
22
21
b_transposed = True ,
23
22
k_pack = 1 ,
24
23
b_preshuffle = False ,
24
+ b_g2l_load = False ,
25
25
):
26
- assert in_dtype in [
27
- "float16" ,
28
- "int8" ,
29
- ], "Currently only float16 and int8 are supported"
30
- assert out_dtype in [
31
- "float16" ,
32
- "float32" ,
33
- "int32" ,
34
- ], "Currently only float16, float32 and int32 are supported"
35
26
36
27
micro_size_x = micro_size_y = micro_size_k = 16
37
28
@@ -47,15 +38,14 @@ def tl_matmul(
47
38
if b_preshuffle :
48
39
block_row_warps = 1
49
40
block_col_warps = 4
50
- warp_row_tiles = 128
51
- warp_col_tiles = 32
41
+ warp_row_tiles = 64
42
+ warp_col_tiles = 16
52
43
53
- chunk = 32 * k_pack
44
+ chunk = 256 * k_pack
54
45
55
46
pack_size_k = micro_size_k * k_pack
56
47
57
48
shared_scope = "shared"
58
- cache_write_shared = False
59
49
60
50
block_M = block_row_warps * warp_row_tiles
61
51
block_N = block_col_warps * warp_col_tiles
@@ -68,6 +58,7 @@ def tl_matmul(
68
58
pack_size_k , micro_size_y )
69
59
else :
70
60
B_shape = (N , K ) if b_transposed else (K , N )
61
+
71
62
A_shared_shape = (block_K , block_M ) if a_transposed else (block_M , block_K )
72
63
if b_preshuffle :
73
64
B_shared_shape = (block_N // micro_size_y , block_K // pack_size_k , micro_size_y ,
@@ -76,12 +67,6 @@ def tl_matmul(
76
67
micro_size_y )
77
68
else :
78
69
B_shared_shape = (block_N , block_K ) if b_transposed else (block_K , block_N )
79
- C_shared_shape = (
80
- block_M // micro_size_x ,
81
- block_N // micro_size_y ,
82
- micro_size_x ,
83
- micro_size_y ,
84
- )
85
70
86
71
warp_size = 64
87
72
threads = warp_size * (block_row_warps * block_col_warps )
@@ -92,7 +77,7 @@ def tl_matmul(
92
77
warp_cols = warp_col_tiles // micro_size_y
93
78
94
79
# MMA Wrapper to Auto Generate Code for MMA
95
- mfma_emitter = MatrixCoreIntrinEmitter (
80
+ mfma_emitter = MatrixCorePreshuffleIntrinEmitter (
96
81
a_dtype = in_dtype ,
97
82
b_dtype = in_dtype ,
98
83
accum_dtype = accum_dtype ,
@@ -117,7 +102,6 @@ def main(
117
102
118
103
A_shared = T .alloc_shared (A_shared_shape , in_dtype , scope = shared_scope )
119
104
B_shared = T .alloc_shared (B_shared_shape , in_dtype , scope = shared_scope )
120
- C_shared = T .alloc_shared (C_shared_shape , out_dtype , scope = shared_scope )
121
105
A_local = T .alloc_local ((warp_rows * local_size_a ), in_dtype )
122
106
B_local = T .alloc_local ((warp_cols * local_size_b ), in_dtype )
123
107
C_local = T .alloc_local ((warp_rows * warp_cols * local_size_c ), accum_dtype )
@@ -126,12 +110,15 @@ def main(
126
110
A_shared : make_swizzle_layout (A_shared ),
127
111
})
128
112
113
+ num_ko = K // block_K
114
+ num_ki = block_K // (k_pack * micro_size_k )
115
+
129
116
# Improve L2 Cache
130
117
T .use_swizzle (panel_size = 10 )
131
118
132
119
T .clear (C_local )
133
120
134
- for ko in T .Pipelined (( K // block_K ) , num_stages = 0 ):
121
+ for ko in T .Pipelined (num_ko , num_stages = 0 ):
135
122
136
123
# Load A into shared memory
137
124
if a_transposed :
@@ -140,7 +127,7 @@ def main(
140
127
T .copy (A [by * block_M , ko * block_K ], A_shared )
141
128
142
129
# Load B into shared memory
143
- if b_preshuffle :
130
+ if b_g2l_load is False :
144
131
if b_transposed :
145
132
for j , k , jj , kk in T .Parallel (block_N // micro_size_y ,
146
133
block_K // pack_size_k , micro_size_y ,
@@ -153,53 +140,37 @@ def main(
153
140
micro_size_y ):
154
141
B_shared [k , j , kk , jj ] = B [ko * block_K // pack_size_k + k ,
155
142
bx * block_N // micro_size_y + j , kk , jj ]
156
- else :
157
- if b_transposed :
158
- T .copy (B [bx * block_N , ko * block_K ], B_shared )
159
- else :
160
- T .copy (B [ko * block_K , bx * block_N ], B_shared )
161
143
162
- for ki in T .serial (0 , ( block_K // ( k_pack * micro_size_k )) ):
144
+ for ki in T .serial (0 , num_ki ):
163
145
164
- # Load A into fragment
146
+ # Load A S2L
165
147
mfma_emitter .ldmatrix_a (
166
148
A_local ,
167
149
A_shared ,
168
150
ki ,
169
151
)
170
152
171
- # Load B into fragment
172
- mfma_emitter .ldmatrix_b (
173
- B_local ,
174
- B_shared ,
175
- ki ,
176
- )
153
+ if b_g2l_load :
154
+ # Load B G2L
155
+ mfma_emitter .ldmatrix_b (B_local , B , ki + ko * num_ki , pid_m = by , pid_n = bx )
156
+ else :
157
+ # Load B S2L
158
+ mfma_emitter .ldmatrix_b (
159
+ B_local ,
160
+ B_shared ,
161
+ ki ,
162
+ )
177
163
178
164
# Perform Matrix Multiplication
179
165
mfma_emitter .mfma (A_local , B_local , C_local )
180
166
181
167
# Perform STMatrix
182
- if cache_write_shared :
183
- mfma_emitter .stmatrix (
184
- C_local ,
185
- C_shared ,
186
- )
187
-
188
- # Store shared into global
189
- for i , j in T .Parallel (block_M , block_N ):
190
- C [by * block_M + i , bx * block_N + j ] = C_shared [
191
- i // micro_size_x ,
192
- j // micro_size_y ,
193
- i % micro_size_x ,
194
- j % micro_size_y ,
195
- ]
196
- else :
197
- mfma_emitter .stmatrix (
198
- C_local ,
199
- C ,
200
- pid_m = by ,
201
- pid_n = bx ,
202
- )
168
+ mfma_emitter .stmatrix (
169
+ C_local ,
170
+ C ,
171
+ pid_m = by ,
172
+ pid_n = bx ,
173
+ )
203
174
204
175
return main
205
176
@@ -232,9 +203,10 @@ def assert_tl_matmul_correctness(M,
232
203
a_transposed = False ,
233
204
b_transposed = True ,
234
205
k_pack = 1 ,
235
- b_preshuffle = False ):
206
+ b_preshuffle = False ,
207
+ b_g2l_load = False ):
236
208
matmul = tl_matmul (M , N , K , in_dtype , out_dtype , accum_dtype , a_transposed , b_transposed ,
237
- k_pack , b_preshuffle )
209
+ k_pack , b_preshuffle , b_g2l_load )
238
210
print (matmul )
239
211
kernel = tilelang .compile (matmul )
240
212
src_code = kernel .get_kernel_source ()
@@ -285,30 +257,25 @@ def assert_tl_matmul_correctness(M,
285
257
286
258
print (C )
287
259
print (ref_c )
260
+
288
261
torch .testing .assert_close (C , ref_c , rtol = 1e-2 , atol = 1e-2 )
289
262
290
263
291
264
@tilelang .testing .requires_rocm
292
265
def test_assert_tl_matmul ():
293
- assert_tl_matmul_correctness (128 , 128 , 128 , "int8" , "int32" , accum_dtype = "int32" )
294
- assert_tl_matmul_correctness (128 , 256 , 256 , "int8" , "int32" , accum_dtype = "int32" )
295
- assert_tl_matmul_correctness (
296
- 128 , 256 , 256 , "int8" , "int32" , b_transposed = False , accum_dtype = "int32" )
297
- assert_tl_matmul_correctness (128 , 256 , 256 , "int8" , "int32" , accum_dtype = "int32" , k_pack = 2 )
298
-
299
266
assert_tl_matmul_correctness (
300
- 128 , 128 , 128 , "int8" , "int32" , accum_dtype = "int32" , b_preshuffle = True )
267
+ 256 , 256 , 256 , "int8" , "int32" , accum_dtype = "int32" , b_preshuffle = True )
301
268
assert_tl_matmul_correctness (
302
- 128 , 256 , 256 , "int8" , "int32" , accum_dtype = "int32" , b_preshuffle = True )
269
+ 256 , 256 , 256 , "int8" , "int32" , accum_dtype = "int32" , b_preshuffle = True )
303
270
assert_tl_matmul_correctness (
304
- 128 , 256 , 256 , "int8" , "int32" , b_transposed = False , accum_dtype = "int32" , b_preshuffle = True )
271
+ 256 , 256 , 256 , "int8" , "int32" , b_transposed = False , accum_dtype = "int32" , b_preshuffle = True )
305
272
306
273
assert_tl_matmul_correctness (
307
- 128 , 256 , 256 , "int8" , "int32" , accum_dtype = "int32" , k_pack = 2 , b_preshuffle = True )
274
+ 256 , 256 , 512 , "int8" , "int32" , accum_dtype = "int32" , k_pack = 2 , b_preshuffle = True )
308
275
assert_tl_matmul_correctness (
309
- 128 ,
310
276
256 ,
311
277
256 ,
278
+ 512 ,
312
279
"int8" ,
313
280
"int32" ,
314
281
b_transposed = False ,
0 commit comments