Skip to content

Commit cb869e2

Browse files
authored
[Docs] Add docs for HGEMM/SGEMM double buffers (#58)
* Update README.md * Update hgemm.cu * Update hgemm.cu * Update README.md * Update README.md * Update README.md * Update sgemm.cu * Update README.md * Update README.md * Update README.md
1 parent f1bb64b commit cb869e2

File tree

4 files changed

+240
-6
lines changed

4 files changed

+240
-6
lines changed

hgemm/README.md

Lines changed: 90 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,14 +26,103 @@ SM调度单位为一个warp(一个warp内32个Thread),shared_memory 可以
2626
- 多个线程读同一个数据时,仅有一个线程读,然后broadcast到其他线程
2727
- 多个线程写同一个数据时,仅会有一个线程写成功
2828

29-
[Using Shared Memory in CUDA C/C++](https://developer.nvidia.com/blog/using-shared-memory-cuda-cc/)中指出,我们还可以通过 `cudaDeviceSetSharedMemConfig()` 函数设置默认Bank Size(默认为4 bytes)来避免bank conflicts,可设置为cudaSharedMemBankSizeFourByte或者cudaSharedMemBankSizeEightByte。对于某些场景来说,设置cudaSharedMemBankSizeEightByte或许更合格式。
29+
NVIDIA的[文章](https://developer.nvidia.com/blog/using-shared-memory-cuda-cc/)中指出,我们还可以通过 `cudaDeviceSetSharedMemConfig()` 函数设置默认Bank Size(默认为4 bytes)来避免bank conflicts,可设置为cudaSharedMemBankSizeFourByte或者cudaSharedMemBankSizeEightByte。对于某些场景来说,设置cudaSharedMemBankSizeEightByte或许更加合适,比如使用double数据类型时。
30+
31+
```C
32+
cudaDeviceSetSharedMemConfig(cudaSharedMemBankSizeEightByte);
33+
```
34+
35+
## 双缓冲 Double Buffers
36+
37+
本仓库实现的HGEMM Double Buffers策略如下:1)主循环从bk = 1 开始,第一次数据加载在主循环之前,最后一次计算在主循环之后,这是pipeline 的特点决定的;2)由于计算和下一次访存使用的Shared Memory不同,因此主循环中每次循环只需要一次__syncthreads()即可,对比非double buffers版本,总共节省了 ((K + BK - 1) / BK) - 1 次block内的同步操作。比如,bk=1时,HFMA计算使用的是s_a[0]和s_b[0],因此,和s_a[1]和s_b[1]的加载是没有依赖关系的。HFMA计算,从global内存到s_a[1]和s_b[1]和HFMA计算可以并行。s_a[1]和s_b[1]用于加载下一块BK需要的数据到共享内存;3)由于GPU不能向CPU那样支持乱序执行,主循环中需要先将下一次循环计算需要的Gloabal Memory中的数据load 到寄存器,然后进行本次计算,之后再将load到寄存器中的数据写到Shared Memory,这样在LDG指令向Global Memory做load时,不会影响后续HFMA及其它运算指令的 launch 执行,也就达到了Double Buffers的目的。
38+
39+
```C
40+
// bk = 0 is loading here, buffer 0
41+
{
42+
int load_a_gmem_k = load_a_smem_k;
43+
int load_a_gmem_addr = load_a_gmem_m * K + load_a_gmem_k;
44+
int load_b_gmem_k = load_b_smem_k;
45+
int load_b_gmem_addr = load_b_gmem_k * N + load_b_gmem_n;
46+
LDST64BITS(r_load_a[0]) = LDST64BITS(a[load_a_gmem_addr]);
47+
LDST64BITS(r_load_b[0]) = LDST64BITS(b[load_b_gmem_addr]);
48+
49+
s_a[0][load_a_smem_k + 0][load_a_smem_m] = r_load_a[0];
50+
s_a[0][load_a_smem_k + 1][load_a_smem_m] = r_load_a[1];
51+
s_a[0][load_a_smem_k + 2][load_a_smem_m] = r_load_a[2];
52+
s_a[0][load_a_smem_k + 3][load_a_smem_m] = r_load_a[3];
53+
LDST64BITS(s_b[0][load_b_smem_k][load_b_smem_n]) = LDST64BITS(r_load_b[0]);
54+
}
55+
// Without this synchronization, accuracy may occasionally be abnormal.
56+
__syncthreads();
57+
58+
// bk start from 1,需要注意的是,虽然 bk 从 1 开始,但实际上 bk=1时,使用的是
59+
// 第0块BK中的数据(已经加载到共享内存s_a[0]和s_b[0]);bk=2时,实际计算的是第1块
60+
// BK中的数据。其余以此类推,这个循环结束后,剩下最后一块BK大小的数据需要计算。
61+
for (int bk = 1; bk < (K + BK - 1) / BK; bk++) {
62+
63+
int smem_sel = (bk - 1) & 1; // bk 1->0, bk 2->1, bk 3->0, ...
64+
int smem_sel_next = bk & 1; // bk 1->1, bk 2->0, bk 3->1, ...
65+
66+
int load_a_gmem_k = bk * BK + load_a_smem_k;
67+
int load_a_gmem_addr = load_a_gmem_m * K + load_a_gmem_k;
68+
int load_b_gmem_k = bk * BK + load_b_smem_k;
69+
int load_b_gmem_addr = load_b_gmem_k * N + load_b_gmem_n;
70+
LDST64BITS(r_load_a[0]) = LDST64BITS(a[load_a_gmem_addr]);
71+
LDST64BITS(r_load_b[0]) = LDST64BITS(b[load_b_gmem_addr]);
72+
73+
#pragma unroll
74+
for (int tk = 0; tk < BK; tk++) {
75+
LDST128BITS(r_comp_a[0]) = LDST128BITS(s_a[smem_sel][tk][ty * TM]);
76+
LDST128BITS(r_comp_b[0]) = LDST128BITS(s_b[smem_sel][tk][tx * TN]);
77+
78+
#pragma unroll
79+
for (int tm = 0; tm < TM; tm++) {
80+
#pragma unroll
81+
for (int tn = 0; tn < TN; tn++) {
82+
r_c[tm][tn] = __hfma(r_comp_a[tm], r_comp_b[tn], r_c[tm][tn]);
83+
}
84+
}
85+
}
86+
87+
// 对比非double buffers版本,此处不需要__syncthreads(),总共节省了
88+
// ((K + BK - 1) / BK) - 1 次block内的同步操作。比如,bk=1时,HFMA计算
89+
// 使用的是s_a[0]和s_b[0],因此,和s_a[1]和s_b[1]的加载是没有依赖关系的。
90+
// 从global内存到s_a[1]和s_b[1]和HFMA计算可以并行。s_a[1]和s_b[1]用于
91+
// 加载下一块BK需要的数据到共享内存。
92+
s_a[smem_sel_next][load_a_smem_k + 0][load_a_smem_m] = r_load_a[0];
93+
s_a[smem_sel_next][load_a_smem_k + 1][load_a_smem_m] = r_load_a[1];
94+
s_a[smem_sel_next][load_a_smem_k + 2][load_a_smem_m] = r_load_a[2];
95+
s_a[smem_sel_next][load_a_smem_k + 3][load_a_smem_m] = r_load_a[3];
96+
LDST128BITS(s_b[smem_sel_next][load_b_smem_k][load_b_smem_n]) = LDST128BITS(r_load_b[0]);
97+
98+
__syncthreads();
99+
}
100+
101+
// 计算剩下最后一块BK
102+
#pragma unroll
103+
for (int tk = 0; tk < BK; tk++) {
104+
LDST128BITS(r_comp_a[0]) = LDST128BITS(s_a[1][tk][ty * TM]);
105+
LDST128BITS(r_comp_b[0]) = LDST128BITS(s_b[1][tk][tx * TN]);
106+
107+
#pragma unroll
108+
for (int tm = 0; tm < TM; tm++) {
109+
#pragma unroll
110+
for (int tn = 0; tn < TN; tn++) {
111+
r_c[tm][tn] = __hfma(r_comp_a[tm], r_comp_b[tn], r_c[tm][tn]);
112+
}
113+
}
114+
}
115+
116+
```
117+
30118

31119
## 参考文献
32120

33121
- [CUDA编程概念】一、什么是bank conflict?](https://zhuanlan.zhihu.com/p/659142274)
34122
- [解决 bank conflict](https://github.com/PaddleJitLab/CUDATutorial/blob/develop/docs/09_optimize_reduce/02_bank_conflict/README.md)
35123
- [Bank Conflict free 的几种方式](https://zhuanlan.zhihu.com/p/722286440)
36124
- [Using Shared Memory in CUDA C/C++](https://developer.nvidia.com/blog/using-shared-memory-cuda-cc/)
125+
- [CUDA(三):通用矩阵乘法:从入门到熟练](https://zhuanlan.zhihu.com/p/657632577)
37126

38127
## 测试
39128

hgemm/hgemm.cu

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -762,7 +762,9 @@ __global__ void hgemm_t_8x8_sliced_k_f16x8_pack_bcf_dbuf_kernel(
762762
// Without this synchronization, accuracy may occasionally be abnormal.
763763
__syncthreads();
764764

765-
// bk start from 1
765+
// bk start from 1,需要注意的是,虽然 bk 从 1 开始,但实际上 bk=1时,使用的是
766+
// 第0块BK中的数据(已经加载到共享内存s_a[0]和s_b[0]);bk=2时,实际计算的是第1块
767+
// BK中的数据。其余以此类推,这个循环结束后,剩下最后一块BK大小的数据需要计算。
766768
for (int bk = 1; bk < (K + BK - 1) / BK; bk++) {
767769

768770
int smem_sel = (bk - 1) & 1; // bk 1->0, bk 2->1, bk 3->0, ...
@@ -789,6 +791,11 @@ __global__ void hgemm_t_8x8_sliced_k_f16x8_pack_bcf_dbuf_kernel(
789791
}
790792
}
791793

794+
// 对比非double buffers版本,此处不需要__syncthreads(),总共节省了
795+
// ((K + BK - 1) / BK) - 1 次block内的同步操作。比如,bk=1时,HFMA计算
796+
// 使用的是s_a[0]和s_b[0],因此,和s_a[1]和s_b[1]的加载是没有依赖关系的。
797+
// 从global内存到s_a[1]和s_b[1]和HFMA计算可以并行。s_a[1]和s_b[1]用于
798+
// 加载下一块BK需要的数据到共享内存。
792799
s_a[smem_sel_next][load_a_smem_k + 0][load_a_smem_m] = r_load_a[0];
793800
s_a[smem_sel_next][load_a_smem_k + 1][load_a_smem_m] = r_load_a[1];
794801
s_a[smem_sel_next][load_a_smem_k + 2][load_a_smem_m] = r_load_a[2];
@@ -798,7 +805,7 @@ __global__ void hgemm_t_8x8_sliced_k_f16x8_pack_bcf_dbuf_kernel(
798805
__syncthreads();
799806
}
800807

801-
// buffer 1
808+
// 计算剩下最后一块BK
802809
#pragma unroll
803810
for (int tk = 0; tk < BK; tk++) {
804811
LDST128BITS(r_comp_a[0]) = LDST128BITS(s_a[1][tk][ty * TM]);
@@ -1165,6 +1172,7 @@ void hgemm_t_8x8_sliced_k_f16x8_pack_bcf_dbuf(torch::Tensor a, torch::Tensor b,
11651172
constexpr int BK = 8;
11661173
constexpr int TM = 8;
11671174
constexpr int TN = 8;
1175+
// cudaDeviceSetSharedMemConfig(cudaSharedMemBankSizeEightByte);
11681176

11691177
dim3 block(BN/TN, BM/TM);
11701178
dim3 grid((N + BN - 1) / BN, (M + BM - 1) / BM);

sgemm/README.md

Lines changed: 121 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,14 +11,134 @@
1111
- [X] sgemm_t_8x8_sliced_k_f32x4_bcf_dbuf_kernel (bank conflicts free, double buffers)
1212
- [X] PyTorch bindings
1313

14+
## 共享内存 Bank Conflicts
15+
16+
含义:在访问shared memory时,因多个线程读写同一个Bank中的不同数据地址时,导致shared memory 并发读写 退化 成顺序读写的现象叫做Bank Conflict;
17+
18+
![](https://github.com/PaddleJitLab/CUDATutorial/blob/develop/docs/09_optimize_reduce/02_bank_conflict/images/ef322be7c3e5b6b9be69d2b90e88083f50569a58a97129f348e483b946ab4edf.png)
19+
20+
SM调度单位为一个warp(一个warp内32个Thread),shared_memory 可以 被一个warp中的所有(32个)线程进行访问,shared_memory 映射到大小相等的32个Bank上,Bank的数据读取带宽为32bit / cycle (4 bytes),因此,主要需要考虑一个Warp内32线程的访问共享内存时的bank冲突。
21+
对于多个线程读取同一个Bank数据时(不同地址),硬件把内存读写请求,拆分成 conflict-free requests,进行顺序读写,此时将会触发多次内存事务。特别地,当一个warp中的所有线程读写同一个地址时,会触发broadcast机制,此时不会退化成顺序读写。上面提到触发broadcast机制的条件是all threads acess same address,但在翻阅cuda-c-programming-guide以及最新版本的[NVProfGuide](https://docs.nvidia.com/nsight-compute/ProfilingGuide/index.html) 时,发现只要是多个thread 读写就会触发broadcast(不需要All)。
22+
23+
- 多个线程读同一个数据时,仅有一个线程读,然后broadcast到其他线程
24+
- 多个线程写同一个数据时,仅会有一个线程写成功
25+
26+
NVIDIA的[文章](https://developer.nvidia.com/blog/using-shared-memory-cuda-cc/)中指出,我们还可以通过 `cudaDeviceSetSharedMemConfig()` 函数设置默认Bank Size(默认为4 bytes)来避免bank conflicts,可设置为cudaSharedMemBankSizeFourByte或者cudaSharedMemBankSizeEightByte。对于某些场景来说,设置cudaSharedMemBankSizeEightByte或许更加合适,比如使用double数据类型时。
27+
28+
```C
29+
cudaDeviceSetSharedMemConfig(cudaSharedMemBankSizeEightByte);
30+
```
31+
32+
## 双缓冲 Double Buffers
33+
34+
本仓库实现的SGEMM Double Buffers策略如下:1)主循环从bk = 1 开始,第一次数据加载在主循环之前,最后一次计算在主循环之后,这是pipeline 的特点决定的;2)由于计算和下一次访存使用的Shared Memory不同,因此主循环中每次循环只需要一次__syncthreads()即可,对比非double buffers版本,总共节省了 ((K + BK - 1) / BK) - 1 次block内的同步操作。比如,bk=1时,FFMA计算使用的是s_a[0]和s_b[0],因此,和s_a[1]和s_b[1]的加载是没有依赖关系的。FFMA计算,从global内存到s_a[1]和s_b[1]和HFMA计算可以并行。s_a[1]和s_b[1]用于加载下一块BK需要的数据到共享内存;3)由于GPU不能向CPU那样支持乱序执行,主循环中需要先将下一次循环计算需要的Gloabal Memory中的数据load 到寄存器,然后进行本次计算,之后再将load到寄存器中的数据写到Shared Memory,这样在LDG指令向Global Memory做load时,不会影响后续HFMA及其它运算指令的 launch 执行,也就达到了Double Buffers的目的。
35+
36+
```C
37+
// 1)主循环从bk = 1 开始,第一次数据加载在主循环之前,最后一次计算在主循环之后,这是pipeline 的特点决定的;
38+
// 2)由于计算和下一次访存使用的Shared Memory不同,因此主循环中每次循环只需要一次__syncthreads()即可
39+
// 3)由于GPU不能向CPU那样支持乱序执行,主循环中需要先将下一次循环计算需要的Gloabal Memory中的数据load
40+
// 到寄存器,然后进行本次计算,之后再将load到寄存器中的数据写到Shared Memory,这样在LDG指令向Global
41+
// Memory做load时,不会影响后续FFMA及其它运算指令的 launch 执行,也就达到了Double Buffering的目的。
42+
43+
// bk = 0 is loading here, buffer 0
44+
45+
{
46+
int load_a_gmem_k = load_a_smem_k;
47+
int load_a_gmem_addr = load_a_gmem_m * K + load_a_gmem_k;
48+
int load_b_gmem_k = load_b_smem_k;
49+
int load_b_gmem_addr = load_b_gmem_k * N + load_b_gmem_n;
50+
FLOAT4(r_load_a[0]) = FLOAT4(a[load_a_gmem_addr]);
51+
FLOAT4(r_load_b[0]) = FLOAT4(b[load_b_gmem_addr]);
52+
53+
s_a[0][load_a_smem_k + 0][load_a_smem_m] = r_load_a[0];
54+
s_a[0][load_a_smem_k + 1][load_a_smem_m] = r_load_a[1];
55+
s_a[0][load_a_smem_k + 2][load_a_smem_m] = r_load_a[2];
56+
s_a[0][load_a_smem_k + 3][load_a_smem_m] = r_load_a[3];
57+
FLOAT4(s_b[0][load_b_smem_k][load_b_smem_n]) = FLOAT4(r_load_b[0]);
58+
}
59+
// Without this synchronization, accuracy may occasionally be abnormal.
60+
__syncthreads();
61+
62+
// bk start from 1,需要注意的是,虽然 bk 从 1 开始,但实际上 bk=1时,使用的是
63+
// 第0块BK中的数据(已经加载到共享内存s_a[0]和s_b[0]);bk=2时,实际计算的是第1块
64+
// BK中的数据。其余以此类推,这个循环结束后,剩下最后一块BK大小的数据需要计算。
65+
for (int bk = 1; bk < (K + BK - 1) / BK; bk++) {
66+
67+
int smem_sel = (bk - 1) & 1;
68+
int smem_sel_next = bk & 1;
69+
70+
int load_a_gmem_k = bk * BK + load_a_smem_k;
71+
int load_a_gmem_addr = load_a_gmem_m * K + load_a_gmem_k;
72+
int load_b_gmem_k = bk * BK + load_b_smem_k;
73+
int load_b_gmem_addr = load_b_gmem_k * N + load_b_gmem_n;
74+
FLOAT4(r_load_a[0]) = FLOAT4(a[load_a_gmem_addr]);
75+
FLOAT4(r_load_b[0]) = FLOAT4(b[load_b_gmem_addr]);
76+
77+
#pragma unroll
78+
for (int tk = 0; tk < BK; tk++) {
79+
FLOAT4(r_comp_a[0]) = FLOAT4(s_a[smem_sel][tk][ty * TM / 2 ]);
80+
FLOAT4(r_comp_a[4]) = FLOAT4(s_a[smem_sel][tk][ty * TM / 2 + BM / 2]);
81+
FLOAT4(r_comp_b[0]) = FLOAT4(s_b[smem_sel][tk][tx * TN / 2 ]);
82+
FLOAT4(r_comp_b[4]) = FLOAT4(s_b[smem_sel][tk][tx * TN / 2 + BN / 2]);
83+
84+
#pragma unroll
85+
for (int tm = 0; tm < TM; tm++) {
86+
#pragma unroll
87+
for (int tn = 0; tn < TN; tn++) {
88+
// r_c[tm][tn] += r_comp_a[tm] * r_comp_b[tn];
89+
r_c[tm][tn] = __fmaf_rn(r_comp_a[tm], r_comp_b[tn], r_c[tm][tn]);
90+
}
91+
}
92+
}
93+
94+
// 对比非double buffers版本,此处不需要__syncthreads(),总共节省了
95+
// ((K + BK - 1) / BK) - 1 次block内的同步操作。比如,bk=1时,HFMA计算
96+
// 使用的是s_a[0]和s_b[0],因此,和s_a[1]和s_b[1]的加载是没有依赖关系的。
97+
// 从global内存到s_a[1]和s_b[1]和HFMA计算可以并行。s_a[1]和s_b[1]用于
98+
// 加载下一块BK需要的数据到共享内存。
99+
s_a[smem_sel_next][load_a_smem_k + 0][load_a_smem_m] = r_load_a[0];
100+
s_a[smem_sel_next][load_a_smem_k + 1][load_a_smem_m] = r_load_a[1];
101+
s_a[smem_sel_next][load_a_smem_k + 2][load_a_smem_m] = r_load_a[2];
102+
s_a[smem_sel_next][load_a_smem_k + 3][load_a_smem_m] = r_load_a[3];
103+
FLOAT4(s_b[smem_sel_next][load_b_smem_k][load_b_smem_n]) = FLOAT4(r_load_b[0]);
104+
105+
__syncthreads();
106+
}
107+
108+
// 计算剩下最后一块BK
109+
#pragma unroll
110+
for (int tk = 0; tk < BK; tk++) {
111+
FLOAT4(r_comp_a[0]) = FLOAT4(s_a[1][tk][ty * TM / 2 ]);
112+
FLOAT4(r_comp_a[4]) = FLOAT4(s_a[1][tk][ty * TM / 2 + BM / 2]);
113+
FLOAT4(r_comp_b[0]) = FLOAT4(s_b[1][tk][tx * TN / 2 ]);
114+
FLOAT4(r_comp_b[4]) = FLOAT4(s_b[1][tk][tx * TN / 2 + BN / 2]);
115+
116+
#pragma unroll
117+
for (int tm = 0; tm < TM; tm++) {
118+
#pragma unroll
119+
for (int tn = 0; tn < TN; tn++) {
120+
// r_c[tm][tn] += r_comp_a[tm] * r_comp_b[tn];
121+
r_c[tm][tn] = __fmaf_rn(r_comp_a[tm], r_comp_b[tn], r_c[tm][tn]);
122+
}
123+
}
124+
}
125+
```
126+
127+
## 参考文献
128+
129+
- [CUDA编程概念】一、什么是bank conflict?](https://zhuanlan.zhihu.com/p/659142274)
130+
- [解决 bank conflict](https://github.com/PaddleJitLab/CUDATutorial/blob/develop/docs/09_optimize_reduce/02_bank_conflict/README.md)
131+
- [Bank Conflict free 的几种方式](https://zhuanlan.zhihu.com/p/722286440)
132+
- [Using Shared Memory in CUDA C/C++](https://developer.nvidia.com/blog/using-shared-memory-cuda-cc/)
133+
- [CUDA(三):通用矩阵乘法:从入门到熟练](https://zhuanlan.zhihu.com/p/657632577)
134+
14135
## 测试
15136

16137
```bash
17138
# 只测试Ada架构 不指定默认编译所有架构 耗时较长: Volta, Ampere, Ada, Hopper, ...
18139
export TORCH_CUDA_ARCH_LIST=Ada
19140
python3 sgemm.py
20141
```
21-
22142
输出:
23143

24144
```bash

sgemm/sgemm.cu

Lines changed: 19 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -373,6 +373,14 @@ __global__ void sgemm_t_8x8_sliced_k_f32x4_bcf_dbuf_kernel(
373373
int load_a_gmem_m = by * BM + load_a_smem_m;
374374
int load_b_gmem_n = bx * BN + load_b_smem_n;
375375

376+
// 1)主循环从bk = 1 开始,第一次数据加载在主循环之前,最后一次计算在主循环之后,这是pipeline 的特点决定的;
377+
// 2)由于计算和下一次访存使用的Shared Memory不同,因此主循环中每次循环只需要一次__syncthreads()即可
378+
// 3)由于GPU不能向CPU那样支持乱序执行,主循环中需要先将下一次循环计算需要的Gloabal Memory中的数据load
379+
// 到寄存器,然后进行本次计算,之后再将load到寄存器中的数据写到Shared Memory,这样在LDG指令向Global
380+
// Memory做load时,不会影响后续FFMA及其它运算指令的 launch 执行,也就达到了Double Buffering的目的。
381+
382+
// bk = 0 is loading here, buffer 0
383+
376384
{
377385
int load_a_gmem_k = load_a_smem_k;
378386
int load_a_gmem_addr = load_a_gmem_m * K + load_a_gmem_k;
@@ -390,6 +398,9 @@ __global__ void sgemm_t_8x8_sliced_k_f32x4_bcf_dbuf_kernel(
390398
// Without this synchronization, accuracy may occasionally be abnormal.
391399
__syncthreads();
392400

401+
// bk start from 1,需要注意的是,虽然 bk 从 1 开始,但实际上 bk=1时,使用的是
402+
// 第0块BK中的数据(已经加载到共享内存s_a[0]和s_b[0]);bk=2时,实际计算的是第1块
403+
// BK中的数据。其余以此类推,这个循环结束后,剩下最后一块BK大小的数据需要计算。
393404
for (int bk = 1; bk < (K + BK - 1) / BK; bk++) {
394405

395406
int smem_sel = (bk - 1) & 1;
@@ -418,7 +429,12 @@ __global__ void sgemm_t_8x8_sliced_k_f32x4_bcf_dbuf_kernel(
418429
}
419430
}
420431
}
421-
432+
433+
// 对比非double buffers版本,此处不需要__syncthreads(),总共节省了
434+
// ((K + BK - 1) / BK) - 1 次block内的同步操作。比如,bk=1时,HFMA计算
435+
// 使用的是s_a[0]和s_b[0],因此,和s_a[1]和s_b[1]的加载是没有依赖关系的。
436+
// 从global内存到s_a[1]和s_b[1]和HFMA计算可以并行。s_a[1]和s_b[1]用于
437+
// 加载下一块BK需要的数据到共享内存。
422438
s_a[smem_sel_next][load_a_smem_k + 0][load_a_smem_m] = r_load_a[0];
423439
s_a[smem_sel_next][load_a_smem_k + 1][load_a_smem_m] = r_load_a[1];
424440
s_a[smem_sel_next][load_a_smem_k + 2][load_a_smem_m] = r_load_a[2];
@@ -427,7 +443,8 @@ __global__ void sgemm_t_8x8_sliced_k_f32x4_bcf_dbuf_kernel(
427443

428444
__syncthreads();
429445
}
430-
446+
447+
// 计算剩下最后一块BK
431448
#pragma unroll
432449
for (int tk = 0; tk < BK; tk++) {
433450
FLOAT4(r_comp_a[0]) = FLOAT4(s_a[1][tk][ty * TM / 2 ]);

0 commit comments

Comments
 (0)