Skip to content

Commit 60d4ad2

Browse files
authored
[HGEMM] manually init/destroy cublas handle (#144)
* Update README.md * Update README.md * Update hgemm.cu * Update hgemm_async.cu * Update hgemm_mma.cu * Update hgemm_wmma.cu * Update hgemm_wmma_stage.cu * Update hgemm.cu * Update hgemm_cublas.cu * Update hgemm.py * Update README.md * Update README.md
1 parent 48af93d commit 60d4ad2

File tree

9 files changed

+71
-73
lines changed

9 files changed

+71
-73
lines changed

README.md

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -21,8 +21,8 @@
2121
<div id="hgemm-sgemm"></div>
2222

2323
<div align='left'>
24-
<img src='https://github.com/user-attachments/assets/89bac543-7272-44cd-b616-54df8ca23a91' height="225px" width="403px">
25-
<img src='https://github.com/user-attachments/assets/d8d7380b-4271-41f6-964a-ac3fa81f7f4c' height="225px" width="403px">
24+
<img src='https://github.com/user-attachments/assets/71927ac9-72b3-4ce9-b0e2-788b5885bc99' height="225px" width="403px">
25+
<img src='https://github.com/user-attachments/assets/05ef4f5e-d999-48ea-b58e-782cffb24e85' height="225px" width="403px">
2626
</div>
2727

2828
Currently, on NVIDIA L20, RTX 4090 and RTX 3090 Laptop, compared with cuBLAS's default Tensor Cores math algorithm `CUBLAS_GEMM_DEFAULT_TENSOR_OP`, the `HGEMM (WMMA/MMA)` implemented in this repo (`blue`🔵) can achieve `95%~99%` of its (`orange`🟠) performance. Please check [hgemm benchmark](./hgemm) for more details.
@@ -44,6 +44,8 @@ Currently, on NVIDIA L20, RTX 4090 and RTX 3090 Laptop, compared with cuBLAS's d
4444
<!---
4545
![NVIDIA_L20_NN+TN](https://github.com/user-attachments/assets/89bac543-7272-44cd-b616-54df8ca23a91)
4646
![NVIDIA_GeForce_RTX_4090_NN+TN](https://github.com/user-attachments/assets/d8d7380b-4271-41f6-964a-ac3fa81f7f4c)
47+
![NVIDIA_L20_NN+TN+v2](https://github.com/user-attachments/assets/71927ac9-72b3-4ce9-b0e2-788b5885bc99)
48+
![NVIDIA_GeForce_RTX_4090_NN+TN+v4](https://github.com/user-attachments/assets/05ef4f5e-d999-48ea-b58e-782cffb24e85)
4749
4850
<div align='left'>
4951
<img src='https://github.com/user-attachments/assets/89bac543-7272-44cd-b616-54df8ca23a91' width="805px">

hgemm/README.md

Lines changed: 9 additions & 56 deletions
Original file line numberDiff line numberDiff line change
@@ -82,16 +82,6 @@ make
8282
# NVIDIA L20
8383
ALGO = MMA16816 HGEMM NN MMA=2x4 WARP=4x4x2 STAGES=2 BLOCK SWIZZLE=2048
8484
M N K = 12544 12544 12544, Time = 0.03445555 0.03446098 0.03447399 s, AVG Performance = 114.5541 Tflops
85-
M N K = 12800 12800 12800, Time = 0.03651175 0.03652291 0.03653325 s, AVG Performance = 114.8404 Tflops
86-
M N K = 13056 13056 13056, Time = 0.03893658 0.03893934 0.03894375 s, AVG Performance = 114.3067 Tflops
87-
M N K = 13312 13312 13312, Time = 0.04108800 0.04109589 0.04111155 s, AVG Performance = 114.8052 Tflops
88-
M N K = 13568 13568 13568, Time = 0.04365005 0.04365251 0.04365619 s, AVG Performance = 114.4375 Tflops
89-
M N K = 13824 13824 13824, Time = 0.04591821 0.04593121 0.04594585 s, AVG Performance = 115.0332 Tflops
90-
M N K = 14080 14080 14080, Time = 0.04861338 0.04861614 0.04862054 s, AVG Performance = 114.8306 Tflops
91-
M N K = 14336 14336 14336, Time = 0.05134848 0.05135278 0.05136691 s, AVG Performance = 114.7493 Tflops
92-
M N K = 14592 14592 14592, Time = 0.05417882 0.05418947 0.05421568 s, AVG Performance = 114.6726 Tflops
93-
M N K = 14848 14848 14848, Time = 0.05706547 0.05706916 0.05707469 s, AVG Performance = 114.7182 Tflops
94-
M N K = 15104 15104 15104, Time = 0.06001767 0.06002084 0.06002586 s, AVG Performance = 114.8164 Tflops
9585
M N K = 15360 15360 15360, Time = 0.06307226 0.06307789 0.06308864 s, AVG Performance = 114.9017 Tflops
9686
M N K = 15616 15616 15616, Time = 0.06612480 0.06612798 0.06613094 s, AVG Performance = 115.1739 Tflops
9787
M N K = 15872 15872 15872, Time = 0.06969549 0.06970215 0.06971290 s, AVG Performance = 114.7305 Tflops
@@ -102,16 +92,6 @@ M N K = 16384 16384 16384, Time = 0.07663001 0.07663534 0.07664947 s, A
10292
# NVIDIA L20
10393
ALGO = CuTe HGEMM, TN, STAGES=2, SMEM SWIZZLE=<3, 3, 3>, BLOCK SWIZZLE=2048
10494
M N K = 12544 12544 12544, Time = 0.03413504 0.03414354 0.03415450 s, AVG Performance = 115.6191 Tflops
105-
M N K = 12800 12800 12800, Time = 0.03615642 0.03616481 0.03617178 s, AVG Performance = 115.9775 Tflops
106-
M N K = 13056 13056 13056, Time = 0.03821158 0.03821455 0.03821671 s, AVG Performance = 116.4747 Tflops
107-
M N K = 13312 13312 13312, Time = 0.04033536 0.04033894 0.04034560 s, AVG Performance = 116.9595 Tflops
108-
M N K = 13568 13568 13568, Time = 0.04318720 0.04319130 0.04319949 s, AVG Performance = 115.6595 Tflops
109-
M N K = 13824 13824 13824, Time = 0.04541542 0.04541942 0.04542157 s, AVG Performance = 116.3294 Tflops
110-
M N K = 14080 14080 14080, Time = 0.04770918 0.04772137 0.04772761 s, AVG Performance = 116.9836 Tflops
111-
M N K = 14336 14336 14336, Time = 0.05077402 0.05077955 0.05078426 s, AVG Performance = 116.0447 Tflops
112-
M N K = 14592 14592 14592, Time = 0.05324902 0.05326633 0.05327872 s, AVG Performance = 116.6599 Tflops
113-
M N K = 14848 14848 14848, Time = 0.05638758 0.05640591 0.05643162 s, AVG Performance = 116.0671 Tflops
114-
M N K = 15104 15104 15104, Time = 0.05892505 0.05893622 0.05894246 s, AVG Performance = 116.9294 Tflops
11595
M N K = 15360 15360 15360, Time = 0.06227354 0.06228111 0.06228992 s, AVG Performance = 116.3717 Tflops
11696
M N K = 15616 15616 15616, Time = 0.06492467 0.06493727 0.06496666 s, AVG Performance = 117.2858 Tflops
11797
M N K = 15872 15872 15872, Time = 0.06843085 0.06843873 0.06844723 s, AVG Performance = 116.8485 Tflops
@@ -122,16 +102,6 @@ M N K = 16384 16384 16384, Time = 0.07564493 0.07565752 0.07567462 s, A
122102
# NVIDIA L20
123103
ALGO = cuBLAS CUBLAS_GEMM_DEFAULT_TENSOR_OP TN
124104
M N K = 12544 12544 12544, Time = 0.03472691 0.03472968 0.03473408 s, AVG Performance = 113.6678 Tflops
125-
M N K = 12800 12800 12800, Time = 0.03687321 0.03687834 0.03688038 s, AVG Performance = 113.7335 Tflops
126-
M N K = 13056 13056 13056, Time = 0.03909427 0.03910103 0.03910963 s, AVG Performance = 113.8341 Tflops
127-
M N K = 13312 13312 13312, Time = 0.04140135 0.04141281 0.04148429 s, AVG Performance = 113.9266 Tflops
128-
M N K = 13568 13568 13568, Time = 0.04382720 0.04383375 0.04384461 s, AVG Performance = 113.9643 Tflops
129-
M N K = 13824 13824 13824, Time = 0.04629504 0.04630118 0.04630733 s, AVG Performance = 114.1140 Tflops
130-
M N K = 14080 14080 14080, Time = 0.04889805 0.04891136 0.04898202 s, AVG Performance = 114.1375 Tflops
131-
M N K = 14336 14336 14336, Time = 0.05156966 0.05157878 0.05158503 s, AVG Performance = 114.2465 Tflops
132-
M N K = 14592 14592 14592, Time = 0.05437849 0.05439980 0.05445734 s, AVG Performance = 114.2292 Tflops
133-
M N K = 14848 14848 14848, Time = 0.05723853 0.05725573 0.05730202 s, AVG Performance = 114.3444 Tflops
134-
M N K = 15104 15104 15104, Time = 0.06022963 0.06024274 0.06032179 s, AVG Performance = 114.3935 Tflops
135105
M N K = 15360 15360 15360, Time = 0.06332416 0.06333143 0.06334157 s, AVG Performance = 114.4417 Tflops
136106
M N K = 15616 15616 15616, Time = 0.06649446 0.06650184 0.06651699 s, AVG Performance = 114.5264 Tflops
137107
M N K = 15872 15872 15872, Time = 0.06977024 0.06977659 0.06978355 s, AVG Performance = 114.6081 Tflops
@@ -151,50 +121,33 @@ M N K = 16384 16384 16384, Time = 0.07668429 0.07669371 0.07670784 s, A
151121
<!---
152122
![L20](https://github.com/user-attachments/assets/a0039200-cd9e-4ae6-be13-422fff75dd2b)
153123
![L20](./NVIDIA_L20.png)
124+
![NVIDIA_L20_NN+TN](https://github.com/user-attachments/assets/89bac543-7272-44cd-b616-54df8ca23a91)
154125
155126
--->
156-
![NVIDIA_L20_NN+TN](https://github.com/user-attachments/assets/89bac543-7272-44cd-b616-54df8ca23a91)
157127

128+
![NVIDIA_L20_NN+TN+v2](https://github.com/user-attachments/assets/71927ac9-72b3-4ce9-b0e2-788b5885bc99)
158129

159130
- WMMA: Up to 113.76 TFLOPS, 113.83/119.5=95.25% TFLOPS utilization, 113.83/116.25=97.91% cuBLAS performance.
160131
- MMA: Up to 115.12 TFLOPS, 115.12/119.5=96.33% TFLOPS utilization, 115.12/116.25=99.03% cuBLAS performance.
161-
162-
```bash
163-
python3 hgemm.py --M 16384 --N 16384 --K 8192 --mma-all --wmma-all --cuda-all
164-
----------------------------------------------------------------------------------------------------------------------------------
165-
M=16384, N=16384, K=8192, Warmup=2, Iters=10, 1/1
166-
----------------------------------------------------------------------------------------------------------------------------------
167-
(naive): ['-236.75 ', '176.0 '], time:1835.537ms, swizzle: NOOP, TFLOPS: 2.40 (+0.00%)
168-
(f16x8pack+t8x8+bcf): ['-236.75 ', '176.0 '], time:99.63080ms, swizzle: NOOP, TFLOPS: 44.14 (+1742.34%)
169-
(f16x8pack+t8x8+k16+dbuf): ['-236.75 ', '176.0 '], time:98.20067ms, swizzle: NOOP, TFLOPS: 44.79 (+1.46%)
170-
--------------------------------------------------------------------WMMA----------------------------------------------------------
171-
(wmma4x2+warp2x4): ['-234.0 ', '181.0 '], time:55.99505ms, swizzle: NOOP, TFLOPS: 78.54 (+75.37%)
172-
(wmma4x2+warp2x4+stage3): ['-234.0 ', '181.0 '], time:49.62856ms, swizzle: NOOP, TFLOPS: 88.62 (+12.83%)
173-
(wmma4x2+warp2x4+stage3+dsmem): ['-234.0 ', '181.0 '], time:49.62389ms, swizzle: NOOP, TFLOPS: 88.63 (+0.01%)
174-
(wmma4x2+warp2x4+stage3+swizzle): ['-234.0 ', '181.0 '], time:39.11254ms, swizzle: 4096, TFLOPS: 112.45(+26.87%)
175-
(wmma4x2+warp2x4+stage2+swizzle): ['-234.0 ', '181.0 '], time:38.63754ms, swizzle: 4096, TFLOPS: 113.83(+1.23%)
176-
--------------------------------------------------------------------MMA-----------------------------------------------------------
177-
(mma2x4+warp4x4+stage2+swizzle): ['-234.0 ', '181.0 '], time:38.40544ms, swizzle: 4096, TFLOPS: 114.52(+0.60%)
178-
(mma2x4+warp4x4+stage2+dsmem+swizzle): ['-234.0 ', '181.0 '], time:38.20540ms, swizzle: 4096, TFLOPS: 115.12(+0.52%)
179-
(cublas): ['-234.0 ', '181.0 '], time:37.83144ms, swizzle: NOOP, TFLOPS: 116.25(+0.99%)
180-
----------------------------------------------------------------------------------------------------------------------------------
181-
```
132+
182133
全量MNK测试命令(提示: 每个MNK单独测试的性能数据更准确)
183134
```bash
184-
python3 hgemm.py --cute-tn --mma --plot --dir tmp --tag NN+TN --i 20 --wmma-all
135+
python3 hgemm.py --cute-tn --mma --plot
185136
```
186137

187138
### NVIDIA GeForce RTX 4090
188139
在NVIDIA RTX 4090上(FP16 Tensor Cores算力为330 TFLOPS),WMMA(m16n16k16)性能表现比MMA(m16n8k16)要更好,大分部MNK下,本仓库的实现能达到cuBLAS 95%~99%的性能,某些case能超过cuBLAS。就本仓库的实现而言,在RTX 4090上,大规模矩阵乘(MNK>=8192),WMMA表现更优,小规模矩阵乘,MMA表现更优。
189140

190141
<!---
191142
![4090](https://github.com/user-attachments/assets/c7d65fe5-9fb9-49a8-b962-a6c09bcc030a)
143+
![NVIDIA_GeForce_RTX_4090_NN+TN](https://github.com/user-attachments/assets/d8d7380b-4271-41f6-964a-ac3fa81f7f4c)
144+
192145
--->
193146

194-
![NVIDIA_GeForce_RTX_4090_NN+TN](https://github.com/user-attachments/assets/d8d7380b-4271-41f6-964a-ac3fa81f7f4c)
147+
![NVIDIA_GeForce_RTX_4090_NN+TN+v4](https://github.com/user-attachments/assets/05ef4f5e-d999-48ea-b58e-782cffb24e85)
195148

196149
```bash
197-
python3 hgemm.py --cute-tn --mma --plot --dir tmp --tag NN+TN --i 20 --wmma-all
150+
python3 hgemm.py --cute-tn --mma --wmma-all --plot
198151
```
199152

200153
### NVIDIA GeForce RTX 3080 Laptop
@@ -204,7 +157,7 @@ python3 hgemm.py --cute-tn --mma --plot --dir tmp --tag NN+TN --i 20 --wmma-all
204157
![](./NVIDIA_GeForce_RTX_3080_Laptop_GPU_WSL2.png)
205158

206159
```bash
207-
python3 hgemm.py --wmma-all --plot --dir tmp
160+
python3 hgemm.py --wmma-all --plot
208161
```
209162

210163

hgemm/hgemm.cu

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -62,7 +62,7 @@ __global__ void hgemm_sliced_k_f16_kernel(half* a, half* b, half* c, int M, int
6262
int load_smem_b_n = tid % 32; // 0~31, tid % 32, tid % BN, threadIdx.x
6363
int load_gmem_a_m = by * BM + load_smem_a_m; // global row of a and c
6464
int load_gmem_b_n = bx * BN + load_smem_b_n; // global col of b and c
65-
// if (load_gmem_a_m >= M || load_gmem_b_n >= N) return;
65+
if (load_gmem_a_m >= M || load_gmem_b_n >= N) return;
6666

6767
half sum = __float2half(0.f);
6868
for (int bk = 0; bk < (K + BK - 1) / BK; ++bk) {
@@ -121,6 +121,7 @@ __global__ void hgemm_t_8x8_sliced_k_f16x4_kernel(half* a, half* b, half* c, int
121121
// 要加载到s_a中的元素对应到A全局内存中的行数 每个block负责出C中大小为BM*BN的块
122122
int load_gmem_a_m = by * BM + load_smem_a_m; // global row of a and c
123123
int load_gmem_b_n = bx * BN + load_smem_b_n; // global col of b and c
124+
if (load_gmem_a_m >= M || load_gmem_b_n >= N) return;
124125

125126
half r_c[TM][TN] = {__float2half(0.0f)}; // 8x8
126127
// 2. 先对K进行分块,每块BK大小
@@ -195,6 +196,7 @@ __global__ void hgemm_t_8x8_sliced_k_f16x4_pack_kernel(
195196
// 要加载到s_a中的元素对应到A全局内存中的行数 每个block负责出C中大小为BM*BN的块
196197
int load_gmem_a_m = by * BM + load_smem_a_m; // global row of a and c
197198
int load_gmem_b_n = bx * BN + load_smem_b_n; // global col of b and c
199+
if (load_gmem_a_m >= M || load_gmem_b_n >= N) return;
198200

199201
half r_c[TM][TN] = {__float2half(0.0f)}; // 8x8
200202
// 2. 先对K进行分块,每块BK大小
@@ -279,6 +281,7 @@ __global__ void hgemm_t_8x8_sliced_k_f16x4_bcf_kernel(
279281

280282
int load_a_gmem_m = by * BM + load_a_smem_m;
281283
int load_b_gmem_n = bx * BN + load_b_smem_n;
284+
if (load_a_gmem_m >= M || load_b_gmem_n >= N) return;
282285

283286
for (int bk = 0; bk < (K + BK - 1) / BK; bk++) {
284287

@@ -388,6 +391,7 @@ __global__ void hgemm_t_8x8_sliced_k_f16x4_pack_bcf_kernel(
388391

389392
int load_a_gmem_m = by * BM + load_a_smem_m;
390393
int load_b_gmem_n = bx * BN + load_b_smem_n;
394+
if (load_a_gmem_m >= M || load_b_gmem_n >= N) return;
391395

392396
for (int bk = 0; bk < (K + BK - 1) / BK; bk++) {
393397

@@ -561,6 +565,7 @@ __global__ void hgemm_t_8x8_sliced_k_f16x8_pack_bcf_kernel(
561565

562566
int load_a_gmem_m = by * BM + load_a_smem_m;
563567
int load_b_gmem_n = bx * BN + load_b_smem_n;
568+
if (load_a_gmem_m >= M || load_b_gmem_n >= N) return;
564569

565570
for (int bk = 0; bk < (K + BK - 1) / BK; bk++) {
566571

@@ -666,6 +671,7 @@ __global__ void hgemm_t_8x8_sliced_k_f16x8_pack_bcf_dbuf_kernel(
666671

667672
int load_a_gmem_m = by * BM + load_a_smem_m;
668673
int load_b_gmem_n = bx * BN + load_b_smem_n;
674+
if (load_a_gmem_m >= M || load_b_gmem_n >= N) return;
669675

670676
// 1)主循环从bk = 1 开始,第一次数据加载在主循环之前,最后一次计算在主循环之后,这是pipeline 的特点决定的;
671677
// 2)由于计算和下一次访存使用的Shared Memory不同,因此主循环中每次循环只需要一次__syncthreads()即可
@@ -999,6 +1005,8 @@ void hgemm_t_8x8_sliced_k32_f16x8_pack_dbuf_async(torch::Tensor a, torch::Tensor
9991005
void hgemm_t_16x8_sliced_k32_f16x8_pack_dbuf(torch::Tensor a, torch::Tensor b, torch::Tensor c);
10001006
void hgemm_t_16x8_sliced_k32_f16x8_pack_dbuf_async(torch::Tensor a, torch::Tensor b, torch::Tensor c);
10011007
// from hgemm_cublas.cu
1008+
void init_cublas_handle();
1009+
void destroy_cublas_handle();
10021010
void hgemm_cublas_tensor_op_nn(torch::Tensor a, torch::Tensor b, torch::Tensor c);
10031011
void hgemm_cublas_tensor_op_tn(torch::Tensor a, torch::Tensor b, torch::Tensor c);
10041012
// from hgemm_wmma.cu
@@ -1046,6 +1054,8 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
10461054
TORCH_BINDING_COMMON_EXTENSION(hgemm_t_16x8_sliced_k32_f16x8_pack_dbuf)
10471055
TORCH_BINDING_COMMON_EXTENSION(hgemm_t_16x8_sliced_k32_f16x8_pack_dbuf_async)
10481056
// cuBLAS Tensor Cores
1057+
TORCH_BINDING_COMMON_EXTENSION(init_cublas_handle)
1058+
TORCH_BINDING_COMMON_EXTENSION(destroy_cublas_handle)
10491059
TORCH_BINDING_COMMON_EXTENSION(hgemm_cublas_tensor_op_nn)
10501060
TORCH_BINDING_COMMON_EXTENSION(hgemm_cublas_tensor_op_tn)
10511061
// WMMA API Tensor Cores

hgemm/hgemm.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -478,12 +478,16 @@ def row2col(x: torch.Tensor):
478478
run_benchmark(lib.hgemm_mma_stages_tn_cute, a, b_col_major, "tn(cute+stage2+swizzle<smem+block>)", c, stages=2, swizzle=True)
479479
# TN layout: cublas
480480
if not args.disable_cublas_tn and any((args.enable_mma_tn, args.enable_cute_tn)):
481+
lib.init_cublas_handle()
481482
run_benchmark(lib.hgemm_cublas_tensor_op_tn, a, b_col_major, "tn(cublas)", c)
483+
lib.destroy_cublas_handle()
482484
# NN layout: cublas/torch
483485
if (not args.disable_cublas) and any((
484486
args.enable_mma, args.enable_mma_all, args.enable_wmma, args.enable_wmma_all,
485487
args.enable_cuda, args.enable_cuda_all, args.enable_torch)):
488+
lib.init_cublas_handle()
486489
run_benchmark(lib.hgemm_cublas_tensor_op_nn, a, b, "(cublas)", c)
490+
lib.destroy_cublas_handle()
487491
if args.enable_torch:
488492
run_benchmark(partial(torch.matmul, out=c), a, b, "(torch)")
489493
torch.cuda.synchronize()

hgemm/hgemm_async.cu

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,7 @@ __global__ void hgemm_t_8x8_sliced_k16_f16x8_pack_dbuf_kernel(
5454
// 要加载到s_a中的元素对应到A全局内存中的行数 每个block负责出C中大小为BM*BN的块
5555
int load_gmem_a_m = by * BM + load_smem_a_m; // global row of a and c
5656
int load_gmem_b_n = bx * BN + load_smem_b_n; // global col of b and c
57+
if (load_gmem_a_m >= M || load_gmem_b_n >= N) return;
5758

5859
// bk = 0 is loading here, buffer 0
5960
{
@@ -156,6 +157,7 @@ __global__ void hgemm_t_8x8_sliced_k16_f16x8_pack_dbuf_async_kernel(
156157
int load_smem_b_n = (tid % 16) * 8; // col 0,8,...,120
157158
int load_gmem_a_m = by * BM + load_smem_a_m; // global row of a and c
158159
int load_gmem_b_n = bx * BN + load_smem_b_n; // global col of b and c
160+
if (load_gmem_a_m >= M || load_gmem_b_n >= N) return;
159161

160162
// bk = 0 is loading here, buffer 0
161163
{
@@ -269,6 +271,7 @@ __global__ void hgemm_t_8x8_sliced_k32_f16x8_pack_dbuf_kernel(
269271
// 要加载到s_a中的元素对应到A全局内存中的行数 每个block负责出C中大小为BM*BN的块
270272
int load_gmem_a_m = by * BM + load_smem_a_m; // global row of a and c
271273
int load_gmem_b_n = bx * BN + load_smem_b_n; // global col of b and c
274+
if (load_gmem_a_m >= M || load_gmem_b_n >= N) return;
272275

273276
// bk = 0 is loading here, buffer 0
274277
{
@@ -371,6 +374,7 @@ __global__ void hgemm_t_8x8_sliced_k32_f16x8_pack_dbuf_async_kernel(
371374
int load_smem_b_n = (tid % 8) * 16; // col 0,16,...,
372375
int load_gmem_a_m = by * BM + load_smem_a_m; // global row of a and c
373376
int load_gmem_b_n = bx * BN + load_smem_b_n; // global col of b and c
377+
if (load_gmem_a_m >= M || load_gmem_b_n >= N) return;
374378

375379
// bk = 0 is loading here, buffer 0
376380
{
@@ -498,6 +502,7 @@ __global__ void hgemm_t_16x8_sliced_k32_f16x8_pack_dbuf_kernel(
498502
// 要加载到s_a中的元素对应到A全局内存中的行数 每个block负责出C中大小为BM*BN的块
499503
int load_gmem_a_m = by * BM + load_smem_a_m; // global row of a and c
500504
int load_gmem_b_n = bx * BN + load_smem_b_n; // global col of b and c
505+
if (load_gmem_a_m >= M || load_gmem_b_n >= N) return;
501506

502507
// bk = 0 is loading here, buffer 0
503508
{
@@ -611,6 +616,7 @@ __global__ void hgemm_t_16x8_sliced_k32_f16x8_pack_dbuf_async_kernel(
611616
// 要加载到s_a中的元素对应到A全局内存中的行数 每个block负责出C中大小为BM*BN的块
612617
int load_gmem_a_m = by * BM + load_smem_a_m; // global row of a and c
613618
int load_gmem_b_n = bx * BN + load_smem_b_n; // global col of b and c
619+
if (load_gmem_a_m >= M || load_gmem_b_n >= N) return;
614620

615621
// bk = 0 is loading here, buffer 0
616622
{

0 commit comments

Comments
 (0)