Skip to content

Commit 82b2898

Browse files
authored
[HGEMM] CuTe HGEMM with Thread Block Swizzle (#140)
* Update hgemm_mma_stage_tn.cu * Update hgemm_mma_stage_tn_cute.cu * Update hgemm.py * Update hgemm.py * Update hgemm_mma_stage_tn_cute.cu * Update utils.h * Update hgemm.py * Update hgemm.py
1 parent 2655b52 commit 82b2898

File tree

4 files changed

+242
-68
lines changed

4 files changed

+242
-68
lines changed

hgemm/hgemm.py

Lines changed: 18 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -108,18 +108,17 @@ def get_build_cuda_cflags():
108108
# spill stores: 指的是在执行过程中,数据因为寄存器不足而被存储到了栈上。
109109
# spill loads: 则是指将之前溢出到栈上的数据重新加载回寄存器。
110110
# diag 177: variable was declared but never referenced
111-
extra_cuda_cflags=[
112-
"-O2",
113-
"-U__CUDA_NO_HALF_OPERATORS__",
114-
"-U__CUDA_NO_HALF_CONVERSIONS__",
115-
"-U__CUDA_NO_HALF2_OPERATORS__",
116-
"-U__CUDA_NO_BFLOAT16_CONVERSIONS__",
117-
"--expt-relaxed-constexpr",
118-
"--expt-extended-lambda",
119-
"--use_fast_math",
120-
"-diag-suppress 177",
121-
"-Xptxas -v",
122-
]
111+
extra_cuda_cflags = []
112+
extra_cuda_cflags.append("-O3")
113+
extra_cuda_cflags.append("-U__CUDA_NO_HALF_OPERATORS__")
114+
extra_cuda_cflags.append("-U__CUDA_NO_HALF_CONVERSIONS__")
115+
extra_cuda_cflags.append("-U__CUDA_NO_HALF2_OPERATORS__")
116+
extra_cuda_cflags.append("-U__CUDA_NO_BFLOAT16_CONVERSIONS__")
117+
extra_cuda_cflags.append("--expt-relaxed-constexpr")
118+
extra_cuda_cflags.append("--expt-extended-lambda")
119+
extra_cuda_cflags.append("--use_fast_math")
120+
extra_cuda_cflags.append("-diag-suppress 177")
121+
extra_cuda_cflags.append("-Xptxas -v")
123122
# extra cuda flags for cute hgemm
124123
project_dir = get_project_dir()
125124
extra_cuda_cflags.append('-DENBLE_CUTE_HGEMM')
@@ -217,7 +216,9 @@ def run_benchmark(perf_func: callable,
217216
total_time = (end - start) * 1000 # ms
218217
mean_time = total_time / iters
219218
out_info = f"{tag}"
220-
out_val = out.flatten()[:2].detach().cpu().numpy().tolist()
219+
out_val_first = out.flatten()[:2].detach().cpu().numpy().tolist()
220+
out_val_last = out.flatten()[-2:].detach().cpu().numpy().tolist()
221+
out_val = [out_val_first[0], out_val_last[-1]]
221222
out_val = [round(v, 8) for v in out_val]
222223
out_val = [f"{v:<12}"[:10] for v in out_val]
223224
TFLOPS = (2 * M * N * K) * 1e-9 / (mean_time)
@@ -372,7 +373,7 @@ def row2col(x: torch.Tensor):
372373
C = torch.randn((MAX_M, MAX_N), dtype=torch.half).cuda()
373374
torch.cuda.synchronize()
374375
end = time.time()
375-
pretty_print_line(f"pre allocate for fast profiling done, time: {(end - start) * 1000} ms")
376+
pretty_print_line(f"pre allocate for fast profiling done, time: {(end - start)} s")
376377

377378
PERF_COUNT = 0
378379
for (M, N, K) in zip(Ms, Ns, Ks):
@@ -464,6 +465,9 @@ def row2col(x: torch.Tensor):
464465
run_benchmark(lib.hgemm_mma_stages_tn_cute, a, b_col_major, "tn(cute+stage4+swizzle<smem>)", c, stages=4)
465466
run_benchmark(lib.hgemm_mma_stages_tn_cute, a, b_col_major, "tn(cute+stage3+swizzle<smem>)", c, stages=3)
466467
run_benchmark(lib.hgemm_mma_stages_tn_cute, a, b_col_major, "tn(cute+stage2+swizzle<smem>)", c, stages=2)
468+
run_benchmark(lib.hgemm_mma_stages_tn_cute, a, b_col_major, "tn(cute+stage4+swizzle<smem+block>)", c, stages=4, swizzle=True)
469+
run_benchmark(lib.hgemm_mma_stages_tn_cute, a, b_col_major, "tn(cute+stage3+swizzle<smem+block>)", c, stages=3, swizzle=True)
470+
run_benchmark(lib.hgemm_mma_stages_tn_cute, a, b_col_major, "tn(cute+stage2+swizzle<smem+block>)", c, stages=2, swizzle=True)
467471
# TN layout: cublas
468472
if not args.disable_cublas_tn and any((args.enable_mma_tn, args.enable_cute_tn)):
469473
run_benchmark(lib.hgemm_cublas_tensor_op_tn, a, b_col_major, "tn(cublas)", c)

hgemm/hgemm_mma_stage_tn.cu

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -72,7 +72,6 @@ __global__ void __launch_bounds__(256)
7272
hgemm_mma_m16n8k16_mma2x4_warp4x4_stages_dsmem_tn_kernel(
7373
half* A, half* B, half* C, int M, int N, int K) {
7474
// BLOCK_SWIZZLE 0/1 control use block swizzle or not.
75-
// COLLECTIVE_STORE true/false control use stmatrix or not.
7675
const int bx = ((int) BLOCK_SWIZZLE) * blockIdx.z * gridDim.x + blockIdx.x;
7776
const int by = blockIdx.y;
7877
const int NUM_K_TILES = div_ceil(K, MMA_K);
@@ -98,6 +97,7 @@ hgemm_mma_m16n8k16_mma2x4_warp4x4_stages_dsmem_tn_kernel(
9897
int load_smem_b_k = (tid % 2 == 0) ? 0 : 8; // col 0,8
9998
int load_gmem_a_m = by * BM + load_smem_a_m; // global row of c
10099
int load_gmem_b_n = bx * BN + load_smem_b_n; // global col of c
100+
if (load_gmem_a_m >= M || load_gmem_b_n >= N) return;
101101

102102
uint32_t RC[WARP_TILE_M][WARP_TILE_N][2];
103103
#pragma unroll

hgemm/hgemm_mma_stage_tn_cute.cu

Lines changed: 106 additions & 53 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
#include <cute/tensor.hpp>
55
#include <float.h>
66

7-
// TODO: thread block swizzle, cute hgemm nn
7+
// BlockSwizzle: means apply thread block swizzle across N dim
88
template <
99
typename T,
1010
int BM,
@@ -21,8 +21,9 @@ template <
2121
typename S2RCopyAtomB,
2222
typename R2SCopyAtomC,
2323
typename S2GCopyAtomC,
24-
typename S2GCopyC>
25-
__global__ void hgemm_mma_stages_tn_cute_kernel(
24+
typename S2GCopyC,
25+
const bool BlockSwizzle>
26+
__global__ void hgemm_mma_stages_block_swizzle_tn_cute_kernel(
2627
const T *Aptr, const T *Bptr, T *Dptr, int m, int n, int k) {
2728
using namespace cute;
2829
// Initilize shared memory
@@ -33,9 +34,12 @@ __global__ void hgemm_mma_stages_tn_cute_kernel(
3334

3435
// Initilize thread block
3536
int idx = threadIdx.x;
36-
int ix = blockIdx.x;
37+
// BlockSwizzle 0/1 control use block swizzle or not.
38+
int ix = ((int) BlockSwizzle) * blockIdx.z * gridDim.x + blockIdx.x;
3739
int iy = blockIdx.y;
3840

41+
if (iy * BM >= m || ix * BN >= n) return;
42+
3943
// use Tensor notation to represent device pointer + dimension
4044
Tensor A = make_tensor(make_gmem_ptr(Aptr), make_shape(m, k), make_stride(k, Int<1>{}));
4145
Tensor B = make_tensor(make_gmem_ptr(Bptr), make_shape(n, k), make_stride(k, Int<1>{}));
@@ -131,18 +135,19 @@ __global__ void hgemm_mma_stages_tn_cute_kernel(
131135
}
132136

133137
// shm -> reg s[itile][ik + 1] -> r[ik + 1]
134-
cute::copy(s2r_tiled_copy_a, tAsA(_, _, ik_next, ismem_read), // tAsA: (CPY, CPY_M, CPY_K, kStage)
135-
tCrA_view(_, _, ik_next)); // tCrA_view: (CPY, CPY_M, CPY_K)
136-
cute::copy(s2r_tiled_copy_b, tBsB(_, _, ik_next, ismem_read), // tBsB: (CPY, CPY_M, CPY_K, kStage)
137-
tCrB_view(_, _, ik_next)); // tCrB_view: (CPY, CPY_M, CPY_K)
138+
// tAsA: (CPY, CPY_M, CPY_K, kStage), tCrA_view: (CPY, CPY_M, CPY_K)
139+
cute::copy(s2r_tiled_copy_a, tAsA(_, _, ik_next, ismem_read),
140+
tCrA_view(_, _, ik_next));
141+
// tBsB: (CPY, CPY_M, CPY_K, kStage), tCrB_view: (CPY, CPY_M, CPY_K)
142+
cute::copy(s2r_tiled_copy_b, tBsB(_, _, ik_next, ismem_read),
143+
tCrB_view(_, _, ik_next));
138144

139145
if (ik == 0) {
140146
if (itile_to_read < ntile) {
141147
cute::copy(g2s_tiled_copy_a, tAgA_copy(_, _, _, itile_to_read),
142-
tAsA_copy(_, _, _, ismem_write));
148+
tAsA_copy(_, _, _, ismem_write));
143149
cute::copy(g2s_tiled_copy_b, tBgB_copy(_, _, _, itile_to_read),
144-
tBsB_copy(_, _, _, ismem_write));
145-
150+
tBsB_copy(_, _, _, ismem_write));
146151
++itile_to_read;
147152
ismem_write = (ismem_write + 1) % kStage;
148153
}
@@ -195,28 +200,39 @@ __global__ void hgemm_mma_stages_tn_cute_kernel(
195200
} // end for
196201
}
197202

198-
template <typename T, const int K_STAGE = 2>
199-
void launch_hgemm_mma_stages_tn_cute(const T *a, const T *b, T *c, int M, int N, int K) {
203+
// For torch binding, need dynamic block swizzle stride
204+
template <typename T, const int Stages = 2, const bool BlockSwizzle = false>
205+
void launch_hgemm_mma_stages_block_swizzle_tn_cute(const T *a,
206+
const T *b,
207+
T *c,
208+
int M,
209+
int N,
210+
int K,
211+
int swizzle_stride) {
212+
// block swizzle_stride: 1024/2048/..., etc.
200213
using namespace cute;
201214

202215
auto BM = Int<128>{};
203216
auto BN = Int<256>{};
204217
auto BK = Int<32>{};
205-
auto KStage = Int<K_STAGE>{}; // default 2
218+
auto KStage = Int<Stages>{}; // default 2
206219
auto kSmemLayoutCBatch = Int<4>{};
207220

208221
// Define the smem layouts
209222
using SmemLayoutAtom = decltype(
210223
composition(
211224
Swizzle<3, 3, 3>{},
212-
make_layout(make_shape(Int<8>{}, Int<BK>{}), make_stride(Int<BK>{}, Int<1>{}))
225+
make_layout(make_shape(Int<8>{}, Int<BK>{}),
226+
make_stride(Int<BK>{}, Int<1>{}))
213227
)
214228
);
215229
using SmemLayoutA = decltype(
216-
tile_to_shape(SmemLayoutAtom{}, make_shape(Int<BM>{}, Int<BK>{}, Int<KStage>{}))
230+
tile_to_shape(SmemLayoutAtom{},
231+
make_shape(Int<BM>{}, Int<BK>{}, Int<KStage>{}))
217232
);
218233
using SmemLayoutB = decltype(
219-
tile_to_shape(SmemLayoutAtom{}, make_shape(Int<BN>{}, Int<BK>{}, Int<KStage>{}))
234+
tile_to_shape(SmemLayoutAtom{},
235+
make_shape(Int<BN>{}, Int<BK>{}, Int<KStage>{}))
220236
); // (m,n) -> smem_idx
221237

222238
// mma
@@ -259,7 +275,8 @@ void launch_hgemm_mma_stages_tn_cute(const T *a, const T *b, T *c, int M, int N,
259275
using SmemLayoutAtomC = decltype(
260276
composition(
261277
Swizzle<3, 3, 3>{},
262-
make_layout(make_shape(Int<kMmaPM>{}, Int<kMmaPN>{}), make_stride(Int<kMmaPN>{}, Int<1>{})))
278+
make_layout(make_shape(Int<kMmaPM>{}, Int<kMmaPN>{}),
279+
make_stride(Int<kMmaPN>{}, Int<1>{})))
263280
);
264281
using SmemLayoutC = decltype(
265282
tile_to_shape(
@@ -279,17 +296,20 @@ void launch_hgemm_mma_stages_tn_cute(const T *a, const T *b, T *c, int M, int N,
279296
using S2GCopyC = decltype(
280297
make_tiled_copy(
281298
S2GCopyAtomC{},
282-
make_layout(make_shape(Int<32>{}, Int<4>{}), make_stride(Int<4>{}, Int<1>{})),
299+
make_layout(make_shape(Int<32>{}, Int<4>{}),
300+
make_stride(Int<4>{}, Int<1>{})),
283301
make_layout(make_shape(Int<1>{}, Int<8>{}))
284302
)
285303
);
286304

287305
int BX = (N + BN - 1) / BN;
288306
int BY = (M + BM - 1) / BM;
307+
// NOTE: Apply thread block swizzle across N dim.
308+
int BZ = BlockSwizzle ? (N + (swizzle_stride) - 1) / (swizzle_stride) : 1;
309+
BX = BlockSwizzle ? (BX + BZ - 1) / BZ : BX;
289310

290-
// TODO: thread block swizzle
291311
dim3 block(size(MMA{}));
292-
dim3 grid(BX, BY);
312+
dim3 grid(BX, BY, BZ);
293313

294314
// C_shm is shared with A_shm and B_shm
295315
static constexpr int shm_size_AB =
@@ -301,7 +321,7 @@ void launch_hgemm_mma_stages_tn_cute(const T *a, const T *b, T *c, int M, int N,
301321
int shm_size = kShmSize;
302322

303323
cudaFuncSetAttribute(
304-
hgemm_mma_stages_tn_cute_kernel<
324+
hgemm_mma_stages_block_swizzle_tn_cute_kernel<
305325
T,
306326
BM, BN, BK,
307327
KStage,
@@ -315,13 +335,14 @@ void launch_hgemm_mma_stages_tn_cute(const T *a, const T *b, T *c, int M, int N,
315335
S2RCopyAtomB,
316336
R2SCopyAtomC,
317337
S2GCopyAtomC,
318-
S2GCopyC
338+
S2GCopyC,
339+
BlockSwizzle
319340
>,
320341
cudaFuncAttributeMaxDynamicSharedMemorySize,
321342
shm_size
322343
);
323344

324-
hgemm_mma_stages_tn_cute_kernel<
345+
hgemm_mma_stages_block_swizzle_tn_cute_kernel<
325346
T,
326347
BM, BN, BK,
327348
KStage,
@@ -335,7 +356,8 @@ void launch_hgemm_mma_stages_tn_cute(const T *a, const T *b, T *c, int M, int N,
335356
S2RCopyAtomB,
336357
R2SCopyAtomC,
337358
S2GCopyAtomC,
338-
S2GCopyC
359+
S2GCopyC,
360+
BlockSwizzle
339361
><<<grid, block, shm_size>>>(a, b, c, M, N, K);
340362
}
341363

@@ -360,12 +382,14 @@ int main() {
360382
}
361383

362384
const int outer_repeat = 10, inner_repeat = 1;
385+
const int thread_block_swizzle_stride = 2048; // thread block swizzle stride
363386

364-
printf("ALGO = CuTe HGEMM TN STAGES=2\n");
387+
printf("ALGO = CuTe HGEMM, TN, STAGES=2, SMEM SWIZZLE=<3, 3, 3>, BLOCK SWIZZLE=2048\n");
365388
for (int j = 0; j < 5; j++) {
366389
int M = M_list[j], N = N_list[j], K = K_list[j];
367-
float max_error = gemm_error_check_tn<T>(
368-
launch_hgemm_mma_stages_tn_cute, M, N, K);
390+
float max_error = gemm_error_check_tn_swizzle<T>(
391+
launch_hgemm_mma_stages_block_swizzle_tn_cute<T, 2, true>,
392+
M, N, K, thread_block_swizzle_stride);
369393
printf("M N K = %6d %6d %6d, ", M, N, K);
370394
printf("Max Error = %f\n", max_error);
371395
}
@@ -378,8 +402,9 @@ int main() {
378402
double total_sec = 0.0;
379403

380404
for (int k = 0; k < outer_repeat; k++) {
381-
double this_sec = perf_gemm<T>(
382-
launch_hgemm_mma_stages_tn_cute, M, N, K, inner_repeat);
405+
double this_sec = perf_gemm_swizzle<T>(
406+
launch_hgemm_mma_stages_block_swizzle_tn_cute<T, 2, true>,
407+
M, N, K, thread_block_swizzle_stride, inner_repeat);
383408
max_sec = max(max_sec, this_sec);
384409
min_sec = min(min_sec, this_sec);
385410
total_sec += this_sec;
@@ -395,8 +420,10 @@ int main() {
395420

396421
return 0;
397422
}
398-
// build torch python binding
423+
399424
#else
425+
// build torch python binding
426+
400427
#include <torch/types.h>
401428
#include <torch/extension.h>
402429
// --------------------- PyTorch bindings for custom kernel -----------------------
@@ -415,20 +442,29 @@ if (((T).size(0) != (S0)) || ((T).size(1) != (S1))) { \
415442
throw std::runtime_error("Tensor size mismatch!"); \
416443
}
417444

418-
#define LAUNCH_HGEMM_MMA_STAGES_CUTE_TN(stages) \
419-
launch_hgemm_mma_stages_tn_cute<half, (stages)>( \
420-
reinterpret_cast<half*>(a.data_ptr()), \
421-
reinterpret_cast<half*>(b.data_ptr()), \
422-
reinterpret_cast<half*>(c.data_ptr()), \
423-
M, N, K \
445+
#define LAUNCH_HGEMM_MMA_STAGES_CUTE_NO_SWIZZLE_TN(stages) \
446+
launch_hgemm_mma_stages_block_swizzle_tn_cute< \
447+
half, (stages), false>( \
448+
reinterpret_cast<half*>(a.data_ptr()), \
449+
reinterpret_cast<half*>(b.data_ptr()), \
450+
reinterpret_cast<half*>(c.data_ptr()), \
451+
M, N, K, 2048 \
452+
);
453+
454+
#define LAUNCH_HGEMM_MMA_STAGES_CUTE_SWIZZLE_TN(stages, stride) \
455+
launch_hgemm_mma_stages_block_swizzle_tn_cute< \
456+
half, (stages), true>( \
457+
reinterpret_cast<half*>(a.data_ptr()), \
458+
reinterpret_cast<half*>(b.data_ptr()), \
459+
reinterpret_cast<half*>(c.data_ptr()), \
460+
M, N, K, (stride) \
424461
);
425462

426463

427-
// TODO: support thread block swizzle
464+
// Multi stages CuTe HGEMM with smem and block swizzle.
428465
void hgemm_mma_stages_tn_cute(
429466
torch::Tensor a, torch::Tensor b, torch::Tensor c,
430467
int stages, bool swizzle, int swizzle_stride) {
431-
// swizzle, swizzle_stride unused now
432468
CHECK_TORCH_TENSOR_DTYPE(a, torch::kHalf)
433469
CHECK_TORCH_TENSOR_DTYPE(b, torch::kHalf)
434470
CHECK_TORCH_TENSOR_DTYPE(c, torch::kHalf)
@@ -439,20 +475,37 @@ void hgemm_mma_stages_tn_cute(
439475
CHECK_TORCH_TENSOR_SHAPE(b, K, N)
440476
CHECK_TORCH_TENSOR_SHAPE(c, M, N)
441477

442-
switch (stages) {
443-
case 2:
444-
LAUNCH_HGEMM_MMA_STAGES_CUTE_TN(2)
445-
break;
446-
case 3:
447-
LAUNCH_HGEMM_MMA_STAGES_CUTE_TN(3)
448-
break;
449-
case 4:
450-
LAUNCH_HGEMM_MMA_STAGES_CUTE_TN(4)
451-
break;
452-
default:
453-
LAUNCH_HGEMM_MMA_STAGES_CUTE_TN(2)
454-
break;
478+
if (swizzle) {
479+
switch (stages)
480+
{
481+
case 2:
482+
LAUNCH_HGEMM_MMA_STAGES_CUTE_SWIZZLE_TN(2, swizzle_stride);
483+
break;
484+
case 3:
485+
LAUNCH_HGEMM_MMA_STAGES_CUTE_SWIZZLE_TN(3, swizzle_stride);
486+
break;
487+
case 4:
488+
LAUNCH_HGEMM_MMA_STAGES_CUTE_SWIZZLE_TN(4, swizzle_stride);
489+
break;
490+
default:
491+
LAUNCH_HGEMM_MMA_STAGES_CUTE_SWIZZLE_TN(2, swizzle_stride);
492+
break;
493+
}
494+
} else {
495+
switch (stages) {
496+
case 2:
497+
LAUNCH_HGEMM_MMA_STAGES_CUTE_NO_SWIZZLE_TN(2)
498+
break;
499+
case 3:
500+
LAUNCH_HGEMM_MMA_STAGES_CUTE_NO_SWIZZLE_TN(3)
501+
break;
502+
case 4:
503+
LAUNCH_HGEMM_MMA_STAGES_CUTE_NO_SWIZZLE_TN(4)
504+
break;
505+
default:
506+
LAUNCH_HGEMM_MMA_STAGES_CUTE_NO_SWIZZLE_TN(2)
507+
break;
508+
}
455509
}
456-
457510
}
458511
#endif

0 commit comments

Comments
 (0)