4
4
#include < cute/tensor.hpp>
5
5
#include < float.h>
6
6
7
- // TODO: thread block swizzle, cute hgemm nn
7
+ // BlockSwizzle: means apply thread block swizzle across N dim
8
8
template <
9
9
typename T,
10
10
int BM,
@@ -21,8 +21,9 @@ template <
21
21
typename S2RCopyAtomB,
22
22
typename R2SCopyAtomC,
23
23
typename S2GCopyAtomC,
24
- typename S2GCopyC>
25
- __global__ void hgemm_mma_stages_tn_cute_kernel (
24
+ typename S2GCopyC,
25
+ const bool BlockSwizzle>
26
+ __global__ void hgemm_mma_stages_block_swizzle_tn_cute_kernel (
26
27
const T *Aptr, const T *Bptr, T *Dptr, int m, int n, int k) {
27
28
using namespace cute ;
28
29
// Initilize shared memory
@@ -33,9 +34,12 @@ __global__ void hgemm_mma_stages_tn_cute_kernel(
33
34
34
35
// Initilize thread block
35
36
int idx = threadIdx .x ;
36
- int ix = blockIdx .x ;
37
+ // BlockSwizzle 0/1 control use block swizzle or not.
38
+ int ix = ((int ) BlockSwizzle) * blockIdx .z * gridDim .x + blockIdx .x ;
37
39
int iy = blockIdx .y ;
38
40
41
+ if (iy * BM >= m || ix * BN >= n) return ;
42
+
39
43
// use Tensor notation to represent device pointer + dimension
40
44
Tensor A = make_tensor (make_gmem_ptr (Aptr), make_shape (m, k), make_stride (k, Int<1 >{}));
41
45
Tensor B = make_tensor (make_gmem_ptr (Bptr), make_shape (n, k), make_stride (k, Int<1 >{}));
@@ -131,18 +135,19 @@ __global__ void hgemm_mma_stages_tn_cute_kernel(
131
135
}
132
136
133
137
// shm -> reg s[itile][ik + 1] -> r[ik + 1]
134
- cute::copy (s2r_tiled_copy_a, tAsA (_, _, ik_next, ismem_read), // tAsA: (CPY, CPY_M, CPY_K, kStage)
135
- tCrA_view (_, _, ik_next)); // tCrA_view: (CPY, CPY_M, CPY_K)
136
- cute::copy (s2r_tiled_copy_b, tBsB (_, _, ik_next, ismem_read), // tBsB: (CPY, CPY_M, CPY_K, kStage)
137
- tCrB_view (_, _, ik_next)); // tCrB_view: (CPY, CPY_M, CPY_K)
138
+ // tAsA: (CPY, CPY_M, CPY_K, kStage), tCrA_view: (CPY, CPY_M, CPY_K)
139
+ cute::copy (s2r_tiled_copy_a, tAsA (_, _, ik_next, ismem_read),
140
+ tCrA_view (_, _, ik_next));
141
+ // tBsB: (CPY, CPY_M, CPY_K, kStage), tCrB_view: (CPY, CPY_M, CPY_K)
142
+ cute::copy (s2r_tiled_copy_b, tBsB (_, _, ik_next, ismem_read),
143
+ tCrB_view (_, _, ik_next));
138
144
139
145
if (ik == 0 ) {
140
146
if (itile_to_read < ntile) {
141
147
cute::copy (g2s_tiled_copy_a, tAgA_copy (_, _, _, itile_to_read),
142
- tAsA_copy (_, _, _, ismem_write));
148
+ tAsA_copy (_, _, _, ismem_write));
143
149
cute::copy (g2s_tiled_copy_b, tBgB_copy (_, _, _, itile_to_read),
144
- tBsB_copy (_, _, _, ismem_write));
145
-
150
+ tBsB_copy (_, _, _, ismem_write));
146
151
++itile_to_read;
147
152
ismem_write = (ismem_write + 1 ) % kStage ;
148
153
}
@@ -195,28 +200,39 @@ __global__ void hgemm_mma_stages_tn_cute_kernel(
195
200
} // end for
196
201
}
197
202
198
- template <typename T, const int K_STAGE = 2 >
199
- void launch_hgemm_mma_stages_tn_cute (const T *a, const T *b, T *c, int M, int N, int K) {
203
+ // For torch binding, need dynamic block swizzle stride
204
+ template <typename T, const int Stages = 2 , const bool BlockSwizzle = false >
205
+ void launch_hgemm_mma_stages_block_swizzle_tn_cute (const T *a,
206
+ const T *b,
207
+ T *c,
208
+ int M,
209
+ int N,
210
+ int K,
211
+ int swizzle_stride) {
212
+ // block swizzle_stride: 1024/2048/..., etc.
200
213
using namespace cute ;
201
214
202
215
auto BM = Int<128 >{};
203
216
auto BN = Int<256 >{};
204
217
auto BK = Int<32 >{};
205
- auto KStage = Int<K_STAGE >{}; // default 2
218
+ auto KStage = Int<Stages >{}; // default 2
206
219
auto kSmemLayoutCBatch = Int<4 >{};
207
220
208
221
// Define the smem layouts
209
222
using SmemLayoutAtom = decltype (
210
223
composition (
211
224
Swizzle<3 , 3 , 3 >{},
212
- make_layout (make_shape (Int<8 >{}, Int<BK>{}), make_stride (Int<BK>{}, Int<1 >{}))
225
+ make_layout (make_shape (Int<8 >{}, Int<BK>{}),
226
+ make_stride (Int<BK>{}, Int<1 >{}))
213
227
)
214
228
);
215
229
using SmemLayoutA = decltype (
216
- tile_to_shape (SmemLayoutAtom{}, make_shape (Int<BM>{}, Int<BK>{}, Int<KStage>{}))
230
+ tile_to_shape (SmemLayoutAtom{},
231
+ make_shape (Int<BM>{}, Int<BK>{}, Int<KStage>{}))
217
232
);
218
233
using SmemLayoutB = decltype (
219
- tile_to_shape (SmemLayoutAtom{}, make_shape (Int<BN>{}, Int<BK>{}, Int<KStage>{}))
234
+ tile_to_shape (SmemLayoutAtom{},
235
+ make_shape (Int<BN>{}, Int<BK>{}, Int<KStage>{}))
220
236
); // (m,n) -> smem_idx
221
237
222
238
// mma
@@ -259,7 +275,8 @@ void launch_hgemm_mma_stages_tn_cute(const T *a, const T *b, T *c, int M, int N,
259
275
using SmemLayoutAtomC = decltype (
260
276
composition (
261
277
Swizzle<3 , 3 , 3 >{},
262
- make_layout (make_shape (Int<kMmaPM >{}, Int<kMmaPN >{}), make_stride (Int<kMmaPN >{}, Int<1 >{})))
278
+ make_layout (make_shape (Int<kMmaPM >{}, Int<kMmaPN >{}),
279
+ make_stride (Int<kMmaPN >{}, Int<1 >{})))
263
280
);
264
281
using SmemLayoutC = decltype (
265
282
tile_to_shape (
@@ -279,17 +296,20 @@ void launch_hgemm_mma_stages_tn_cute(const T *a, const T *b, T *c, int M, int N,
279
296
using S2GCopyC = decltype (
280
297
make_tiled_copy (
281
298
S2GCopyAtomC{},
282
- make_layout (make_shape (Int<32 >{}, Int<4 >{}), make_stride (Int<4 >{}, Int<1 >{})),
299
+ make_layout (make_shape (Int<32 >{}, Int<4 >{}),
300
+ make_stride (Int<4 >{}, Int<1 >{})),
283
301
make_layout (make_shape (Int<1 >{}, Int<8 >{}))
284
302
)
285
303
);
286
304
287
305
int BX = (N + BN - 1 ) / BN;
288
306
int BY = (M + BM - 1 ) / BM;
307
+ // NOTE: Apply thread block swizzle across N dim.
308
+ int BZ = BlockSwizzle ? (N + (swizzle_stride) - 1 ) / (swizzle_stride) : 1 ;
309
+ BX = BlockSwizzle ? (BX + BZ - 1 ) / BZ : BX;
289
310
290
- // TODO: thread block swizzle
291
311
dim3 block (size (MMA{}));
292
- dim3 grid (BX, BY);
312
+ dim3 grid (BX, BY, BZ );
293
313
294
314
// C_shm is shared with A_shm and B_shm
295
315
static constexpr int shm_size_AB =
@@ -301,7 +321,7 @@ void launch_hgemm_mma_stages_tn_cute(const T *a, const T *b, T *c, int M, int N,
301
321
int shm_size = kShmSize ;
302
322
303
323
cudaFuncSetAttribute (
304
- hgemm_mma_stages_tn_cute_kernel <
324
+ hgemm_mma_stages_block_swizzle_tn_cute_kernel <
305
325
T,
306
326
BM, BN, BK,
307
327
KStage,
@@ -315,13 +335,14 @@ void launch_hgemm_mma_stages_tn_cute(const T *a, const T *b, T *c, int M, int N,
315
335
S2RCopyAtomB,
316
336
R2SCopyAtomC,
317
337
S2GCopyAtomC,
318
- S2GCopyC
338
+ S2GCopyC,
339
+ BlockSwizzle
319
340
>,
320
341
cudaFuncAttributeMaxDynamicSharedMemorySize,
321
342
shm_size
322
343
);
323
344
324
- hgemm_mma_stages_tn_cute_kernel <
345
+ hgemm_mma_stages_block_swizzle_tn_cute_kernel <
325
346
T,
326
347
BM, BN, BK,
327
348
KStage,
@@ -335,7 +356,8 @@ void launch_hgemm_mma_stages_tn_cute(const T *a, const T *b, T *c, int M, int N,
335
356
S2RCopyAtomB,
336
357
R2SCopyAtomC,
337
358
S2GCopyAtomC,
338
- S2GCopyC
359
+ S2GCopyC,
360
+ BlockSwizzle
339
361
><<<grid, block, shm_size>>> (a, b, c, M, N, K);
340
362
}
341
363
@@ -360,12 +382,14 @@ int main() {
360
382
}
361
383
362
384
const int outer_repeat = 10 , inner_repeat = 1 ;
385
+ const int thread_block_swizzle_stride = 2048 ; // thread block swizzle stride
363
386
364
- printf (" ALGO = CuTe HGEMM TN STAGES=2\n " );
387
+ printf (" ALGO = CuTe HGEMM, TN, STAGES=2, SMEM SWIZZLE=<3, 3, 3>, BLOCK SWIZZLE=2048 \n " );
365
388
for (int j = 0 ; j < 5 ; j++) {
366
389
int M = M_list[j], N = N_list[j], K = K_list[j];
367
- float max_error = gemm_error_check_tn<T>(
368
- launch_hgemm_mma_stages_tn_cute, M, N, K);
390
+ float max_error = gemm_error_check_tn_swizzle<T>(
391
+ launch_hgemm_mma_stages_block_swizzle_tn_cute<T, 2 , true >,
392
+ M, N, K, thread_block_swizzle_stride);
369
393
printf (" M N K = %6d %6d %6d, " , M, N, K);
370
394
printf (" Max Error = %f\n " , max_error);
371
395
}
@@ -378,8 +402,9 @@ int main() {
378
402
double total_sec = 0.0 ;
379
403
380
404
for (int k = 0 ; k < outer_repeat; k++) {
381
- double this_sec = perf_gemm<T>(
382
- launch_hgemm_mma_stages_tn_cute, M, N, K, inner_repeat);
405
+ double this_sec = perf_gemm_swizzle<T>(
406
+ launch_hgemm_mma_stages_block_swizzle_tn_cute<T, 2 , true >,
407
+ M, N, K, thread_block_swizzle_stride, inner_repeat);
383
408
max_sec = max (max_sec, this_sec);
384
409
min_sec = min (min_sec, this_sec);
385
410
total_sec += this_sec;
@@ -395,8 +420,10 @@ int main() {
395
420
396
421
return 0 ;
397
422
}
398
- // build torch python binding
423
+
399
424
#else
425
+ // build torch python binding
426
+
400
427
#include < torch/types.h>
401
428
#include < torch/extension.h>
402
429
// --------------------- PyTorch bindings for custom kernel -----------------------
@@ -415,20 +442,29 @@ if (((T).size(0) != (S0)) || ((T).size(1) != (S1))) { \
415
442
throw std::runtime_error (" Tensor size mismatch!" ); \
416
443
}
417
444
418
- #define LAUNCH_HGEMM_MMA_STAGES_CUTE_TN (stages ) \
419
- launch_hgemm_mma_stages_tn_cute<half, (stages)>( \
420
- reinterpret_cast <half*>(a.data_ptr()), \
421
- reinterpret_cast <half*>(b.data_ptr()), \
422
- reinterpret_cast <half*>(c.data_ptr()), \
423
- M, N, K \
445
+ #define LAUNCH_HGEMM_MMA_STAGES_CUTE_NO_SWIZZLE_TN (stages ) \
446
+ launch_hgemm_mma_stages_block_swizzle_tn_cute< \
447
+ half, (stages), false >( \
448
+ reinterpret_cast <half*>(a.data_ptr()), \
449
+ reinterpret_cast <half*>(b.data_ptr()), \
450
+ reinterpret_cast <half*>(c.data_ptr()), \
451
+ M, N, K, 2048 \
452
+ );
453
+
454
+ #define LAUNCH_HGEMM_MMA_STAGES_CUTE_SWIZZLE_TN (stages, stride ) \
455
+ launch_hgemm_mma_stages_block_swizzle_tn_cute< \
456
+ half, (stages), true >( \
457
+ reinterpret_cast <half*>(a.data_ptr()), \
458
+ reinterpret_cast <half*>(b.data_ptr()), \
459
+ reinterpret_cast <half*>(c.data_ptr()), \
460
+ M, N, K, (stride) \
424
461
);
425
462
426
463
427
- // TODO: support thread block swizzle
464
+ // Multi stages CuTe HGEMM with smem and block swizzle.
428
465
void hgemm_mma_stages_tn_cute (
429
466
torch::Tensor a, torch::Tensor b, torch::Tensor c,
430
467
int stages, bool swizzle, int swizzle_stride) {
431
- // swizzle, swizzle_stride unused now
432
468
CHECK_TORCH_TENSOR_DTYPE (a, torch::kHalf )
433
469
CHECK_TORCH_TENSOR_DTYPE (b, torch::kHalf )
434
470
CHECK_TORCH_TENSOR_DTYPE (c, torch::kHalf )
@@ -439,20 +475,37 @@ void hgemm_mma_stages_tn_cute(
439
475
CHECK_TORCH_TENSOR_SHAPE (b, K, N)
440
476
CHECK_TORCH_TENSOR_SHAPE (c, M, N)
441
477
442
- switch (stages) {
443
- case 2 :
444
- LAUNCH_HGEMM_MMA_STAGES_CUTE_TN (2 )
445
- break ;
446
- case 3 :
447
- LAUNCH_HGEMM_MMA_STAGES_CUTE_TN (3 )
448
- break ;
449
- case 4 :
450
- LAUNCH_HGEMM_MMA_STAGES_CUTE_TN (4 )
451
- break ;
452
- default :
453
- LAUNCH_HGEMM_MMA_STAGES_CUTE_TN (2 )
454
- break ;
478
+ if (swizzle) {
479
+ switch (stages)
480
+ {
481
+ case 2 :
482
+ LAUNCH_HGEMM_MMA_STAGES_CUTE_SWIZZLE_TN (2 , swizzle_stride);
483
+ break ;
484
+ case 3 :
485
+ LAUNCH_HGEMM_MMA_STAGES_CUTE_SWIZZLE_TN (3 , swizzle_stride);
486
+ break ;
487
+ case 4 :
488
+ LAUNCH_HGEMM_MMA_STAGES_CUTE_SWIZZLE_TN (4 , swizzle_stride);
489
+ break ;
490
+ default :
491
+ LAUNCH_HGEMM_MMA_STAGES_CUTE_SWIZZLE_TN (2 , swizzle_stride);
492
+ break ;
493
+ }
494
+ } else {
495
+ switch (stages) {
496
+ case 2 :
497
+ LAUNCH_HGEMM_MMA_STAGES_CUTE_NO_SWIZZLE_TN (2 )
498
+ break ;
499
+ case 3 :
500
+ LAUNCH_HGEMM_MMA_STAGES_CUTE_NO_SWIZZLE_TN (3 )
501
+ break ;
502
+ case 4 :
503
+ LAUNCH_HGEMM_MMA_STAGES_CUTE_NO_SWIZZLE_TN (4 )
504
+ break ;
505
+ default :
506
+ LAUNCH_HGEMM_MMA_STAGES_CUTE_NO_SWIZZLE_TN (2 )
507
+ break ;
508
+ }
455
509
}
456
-
457
510
}
458
511
#endif
0 commit comments