Skip to content

Commit f30fd33

Browse files
authored
[SGEMM] test bank conflicts free with smem offset (#56)
* Update README.md * Update README.md * Update sgemm.cu * Update sgemm.py * Update README.md
1 parent da0b939 commit f30fd33

File tree

4 files changed

+171
-88
lines changed

4 files changed

+171
-88
lines changed

hgemm/README.md

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,27 @@
1313
- [X] hgemm_t_8x8_sliced_k_f16x8_pack_bcf_kernel(bank conflicts reduce, pack)
1414
- [X] PyTorch bindings
1515

16+
## 共享内存 Bank Conflicts
17+
18+
含义:在访问shared memory时,因多个线程读写同一个Bank中的不同数据地址时,导致shared memory 并发读写 退化 成顺序读写的现象叫做Bank Conflict;
19+
20+
![](https://github.com/PaddleJitLab/CUDATutorial/blob/develop/docs/09_optimize_reduce/02_bank_conflict/images/ef322be7c3e5b6b9be69d2b90e88083f50569a58a97129f348e483b946ab4edf.png)
21+
22+
SM调度单位为一个warp(一个warp内32个Thread),shared_memory 可以 被一个warp中的所有(32个)线程进行访问,shared_memory 映射到大小相等的32个Bank上,Bank的数据读取带宽为32bit / cycle (4 bytes),因此,主要需要考虑一个Warp内32线程的访问共享内存时的bank冲突。
23+
对于多个线程读取同一个Bank数据时(不同地址),硬件把内存读写请求,拆分成 conflict-free requests,进行顺序读写,此时将会触发多次内存事务。特别地,当一个warp中的所有线程读写同一个地址时,会触发broadcast机制,此时不会退化成顺序读写。上面提到触发broadcast机制的条件是all threads acess same address,但在翻阅cuda-c-programming-guide以及最新版本的[NVProfGuide](https://docs.nvidia.com/nsight-compute/ProfilingGuide/index.html) 时,发现只要是多个thread 读写就会触发broadcast(不需要All)。
24+
25+
- 多个线程读同一个数据时,仅有一个线程读,然后broadcast到其他线程
26+
- 多个线程写同一个数据时,仅会有一个线程写成功
27+
28+
[Using Shared Memory in CUDA C/C++](https://developer.nvidia.com/blog/using-shared-memory-cuda-cc/)中指出,我们还可以通过 `cudaDeviceSetSharedMemConfig()` 函数设置默认Bank Size(默认为4 bytes)来避免bank conflicts,可设置为cudaSharedMemBankSizeFourByte或者cudaSharedMemBankSizeEightByte。对于某些场景来说,设置cudaSharedMemBankSizeEightByte或许更合格式。
29+
30+
## 参考文献
31+
32+
- [CUDA编程概念】一、什么是bank conflict?](https://zhuanlan.zhihu.com/p/659142274)
33+
- [解决 bank conflict](https://github.com/PaddleJitLab/CUDATutorial/blob/develop/docs/09_optimize_reduce/02_bank_conflict/README.md)
34+
- [Bank Conflict free 的几种方式](https://zhuanlan.zhihu.com/p/722286440)
35+
- [Using Shared Memory in CUDA C/C++](https://developer.nvidia.com/blog/using-shared-memory-cuda-cc/)
36+
1637
## 测试
1738

1839
```bash

sgemm/README.md

Lines changed: 66 additions & 72 deletions
Original file line numberDiff line numberDiff line change
@@ -22,76 +22,70 @@ python3 sgemm.py
2222
输出:
2323

2424
```bash
25-
-------------------------------------------------------------------------------------
26-
M=2048, N=2048, K=1024
27-
out_f32: [-23.44512749, 105.22006226, -72.40318298], time:2.581863ms
28-
out_f32(sk): [-23.44512749, 105.22006226, -72.40318298], time:1.837885ms
29-
out_f32x4(t8x8sk): [-23.44512749, 105.22006226, -72.40318298], time:0.325584ms
30-
out_f32x4(t8x8bcf): [-23.44512749, 105.22006226, -72.40318298], time:0.298755ms
31-
out_f32x4(t8x8dbuf): [-23.44512749, 105.22006226, -72.40318298], time:0.229251ms
32-
out_f32_th: [-23.44515038, 105.22006226, -72.40312958], time:0.255888ms
33-
-------------------------------------------------------------------------------------
34-
-------------------------------------------------------------------------------------
35-
M=2048, N=2048, K=2048
36-
out_f32: [4.73375559, -2.49913216, 111.71539307], time:5.155475ms
37-
out_f32(sk): [4.73375559, -2.49913216, 111.71539307], time:3.653073ms
38-
out_f32x4(t8x8sk): [4.73375559, -2.49913216, 111.71539307], time:0.635004ms
39-
out_f32x4(t8x8bcf): [4.73375559, -2.49913216, 111.71539307], time:0.593204ms
40-
out_f32x4(t8x8dbuf): [4.73375559, -2.49913216, 111.71539307], time:0.460200ms
41-
out_f32_th: [4.73375702, -2.49916267, 111.71534729], time:0.467465ms
42-
-------------------------------------------------------------------------------------
43-
-------------------------------------------------------------------------------------
44-
M=2048, N=4096, K=1024
45-
out_f32: [27.58790588, 18.39359474, -23.69882774], time:5.127516ms
46-
out_f32(sk): [27.58790588, 18.39359474, -23.69882774], time:3.652875ms
47-
out_f32x4(t8x8sk): [27.58790588, 18.39359474, -23.69882774], time:0.626333ms
48-
out_f32x4(t8x8bcf): [27.58790588, 18.39359474, -23.69882774], time:0.549185ms
49-
out_f32x4(t8x8dbuf): [27.58790588, 18.39359474, -23.69882774], time:0.463538ms
50-
out_f32_th: [27.58790588, 18.39359474, -23.69882774], time:0.555634ms
51-
-------------------------------------------------------------------------------------
52-
-------------------------------------------------------------------------------------
53-
M=2048, N=4096, K=2048
54-
out_f32: [54.19274139, -0.29313943, 26.92167664], time:10.221355ms
55-
out_f32(sk): [54.19274139, -0.29313943, 26.92167664], time:7.268925ms
56-
out_f32x4(t8x8sk): [54.19274139, -0.29313943, 26.92167664], time:1.249781ms
57-
out_f32x4(t8x8bcf): [54.19274139, -0.29313943, 26.92167664], time:1.119103ms
58-
out_f32x4(t8x8dbuf): [54.19274139, -0.29313943, 26.92167664], time:0.960808ms
59-
out_f32_th: [54.19275284, -0.29314613, 26.92167473], time:0.920537ms
60-
-------------------------------------------------------------------------------------
61-
-------------------------------------------------------------------------------------
62-
M=4096, N=2048, K=1024
63-
out_f32: [-37.67934418, 12.49935532, 40.71273804], time:5.120614ms
64-
out_f32(sk): [-37.67934418, 12.49935532, 40.71273804], time:3.652627ms
65-
out_f32x4(t8x8sk): [-37.67934418, 12.49935532, 40.71273804], time:0.624588ms
66-
out_f32x4(t8x8bcf): [-37.67934418, 12.49935532, 40.71273804], time:0.545461ms
67-
out_f32x4(t8x8dbuf): [-37.67934418, 12.49935532, 40.71273804], time:0.462778ms
68-
out_f32_th: [-37.67934418, 12.49935532, 40.71273804], time:0.560777ms
69-
-------------------------------------------------------------------------------------
70-
-------------------------------------------------------------------------------------
71-
M=4096, N=2048, K=2048
72-
out_f32: [-15.01755524, -0.44903478, 72.23948669], time:10.213506ms
73-
out_f32(sk): [-15.01755524, -0.44903478, 72.23948669], time:7.269592ms
74-
out_f32x4(t8x8sk): [-15.01755524, -0.44903478, 72.23948669], time:1.242898ms
75-
out_f32x4(t8x8bcf): [-15.01755524, -0.44903478, 72.23948669], time:1.099443ms
76-
out_f32x4(t8x8dbuf): [-15.01755524, -0.44903478, 72.23948669], time:0.941424ms
77-
out_f32_th: [-15.01752663, -0.44904327, 72.23952484], time:0.940223ms
78-
-------------------------------------------------------------------------------------
79-
-------------------------------------------------------------------------------------
80-
M=4096, N=4096, K=1024
81-
out_f32: [-5.76778412, 22.12718964, 17.76623344], time:10.221822ms
82-
out_f32(sk): [-5.76778412, 22.12718964, 17.76623344], time:7.308133ms
83-
out_f32x4(t8x8sk): [-5.76778412, 22.12718964, 17.76623344], time:1.263077ms
84-
out_f32x4(t8x8bcf): [-5.76778412, 22.12718964, 17.76623344], time:1.134577ms
85-
out_f32x4(t8x8dbuf): [-5.76778412, 22.12718964, 17.76623344], time:1.009488ms
86-
out_f32_th: [-5.76778412, 22.12718964, 17.76623344], time:0.926571ms
87-
-------------------------------------------------------------------------------------
88-
-------------------------------------------------------------------------------------
89-
M=4096, N=4096, K=2048
90-
out_f32: [35.152565, 56.02351761, 29.87486458], time:20.362103ms
91-
out_f32(sk): [35.152565, 56.02351761, 29.87486458], time:14.596984ms
92-
out_f32x4(t8x8sk): [35.152565, 56.02351761, 29.87486458], time:2.558391ms
93-
out_f32x4(t8x8bcf): [35.152565, 56.02351761, 29.87486458], time:2.313538ms
94-
out_f32x4(t8x8dbuf): [35.152565, 56.02351761, 29.87486458], time:2.144170ms
95-
out_f32_th: [35.152565, 56.02351761, 29.87486458], time:1.896987ms
96-
-------------------------------------------------------------------------------------
25+
----------------------------------------------------------------------------------------------------
26+
M=2048, N=2048, K=1024
27+
out_f32: ['-41.69404602', '-15.22974205', '12.31010342 '], time:2.583222ms
28+
out_f32(sk): ['-41.69404602', '-15.22974205', '12.31010342 '], time:1.836123ms
29+
out_f32x4(t8x8sk): ['-41.69404602', '-15.22974205', '12.31010342 '], time:0.324936ms
30+
out_f32x4(t8x8bcf): ['-41.69404602', '-15.22974205', '12.31010342 '], time:0.290537ms
31+
out_f32x4(t8x8bcf+offset): ['-41.69404602', '-15.22974205', '12.31010342 '], time:0.289106ms
32+
out_f32x4(t8x8dbuf): ['-41.69404602', '-15.22974205', '12.31010342 '], time:0.229044ms
33+
out_f32x4(t8x8dbuf+offset): ['-41.69404602', '-15.22974205', '12.31010342 '], time:0.230970ms
34+
out_f32_th: ['-41.69403076', '-15.229743 ', '12.31009007 '], time:0.255721ms
35+
----------------------------------------------------------------------------------------------------
36+
----------------------------------------------------------------------------------------------------
37+
M=2048, N=2048, K=2048
38+
out_f32: ['-11.50634861', '-30.57016182', '14.03067684 '], time:5.152175ms
39+
out_f32(sk): ['-11.50634861', '-30.57016182', '14.03067684 '], time:3.652353ms
40+
out_f32x4(t8x8sk): ['-11.50634861', '-30.57016182', '14.03067684 '], time:0.639246ms
41+
out_f32x4(t8x8bcf): ['-11.50634861', '-30.57016182', '14.03067684 '], time:0.576742ms
42+
out_f32x4(t8x8bcf+offset): ['-11.50634861', '-30.57016182', '14.03067684 '], time:0.575581ms
43+
out_f32x4(t8x8dbuf): ['-11.50634861', '-30.57016182', '14.03067684 '], time:0.460470ms
44+
out_f32x4(t8x8dbuf+offset): ['-11.50634861', '-30.57016182', '14.03067684 '], time:0.465369ms
45+
out_f32_th: ['-11.50632 ', '-30.57013321', '14.03067398 '], time:0.465064ms
46+
----------------------------------------------------------------------------------------------------
47+
----------------------------------------------------------------------------------------------------
48+
M=2048, N=4096, K=1024
49+
out_f32: ['35.35253143 ', '44.40952682 ', '-10.71832466'], time:5.122924ms
50+
out_f32(sk): ['35.35253143 ', '44.40952682 ', '-10.71832466'], time:3.653028ms
51+
out_f32x4(t8x8sk): ['35.35253143 ', '44.40952682 ', '-10.71832466'], time:0.625312ms
52+
out_f32x4(t8x8bcf): ['35.35253143 ', '44.40952682 ', '-10.71832466'], time:0.534370ms
53+
out_f32x4(t8x8bcf+offset): ['35.35253143 ', '44.40952682 ', '-10.71832466'], time:0.530348ms
54+
out_f32x4(t8x8dbuf): ['35.35253143 ', '44.40952682 ', '-10.71832466'], time:0.462132ms
55+
out_f32x4(t8x8dbuf+offset): ['35.35253143 ', '44.40952682 ', '-10.71832466'], time:0.464492ms
56+
out_f32_th: ['35.35253143 ', '44.40952682 ', '-10.71832466'], time:0.557373ms
57+
----------------------------------------------------------------------------------------------------
58+
----------------------------------------------------------------------------------------------------
59+
M=2048, N=4096, K=2048
60+
out_f32: ['61.41757584 ', '107.04826355', '37.28448868 '], time:10.218813ms
61+
out_f32(sk): ['61.41757584 ', '107.04826355', '37.28448868 '], time:7.268655ms
62+
out_f32x4(t8x8sk): ['61.41757584 ', '107.04826355', '37.28448868 '], time:1.237755ms
63+
out_f32x4(t8x8bcf): ['61.41757584 ', '107.04826355', '37.28448868 '], time:1.065564ms
64+
out_f32x4(t8x8bcf+offset): ['61.41757584 ', '107.04826355', '37.28448868 '], time:1.053824ms
65+
out_f32x4(t8x8dbuf): ['61.41757584 ', '107.04826355', '37.28448868 '], time:0.935848ms
66+
out_f32x4(t8x8dbuf+offset): ['61.41757584 ', '107.04826355', '37.28448868 '], time:0.967648ms
67+
out_f32_th: ['61.41755676 ', '107.04829407', '37.28450775 '], time:0.921094ms
68+
----------------------------------------------------------------------------------------------------
69+
----------------------------------------------------------------------------------------------------
70+
M=4096, N=2048, K=1024
71+
out_f32: ['69.17631531 ', '2.35151434 ', '14.92191601 '], time:5.120900ms
72+
out_f32(sk): ['69.17631531 ', '2.35151434 ', '14.92191601 '], time:3.651984ms
73+
out_f32x4(t8x8sk): ['69.17631531 ', '2.35151434 ', '14.92191601 '], time:0.622756ms
74+
out_f32x4(t8x8bcf): ['69.17631531 ', '2.35151434 ', '14.92191601 '], time:0.526509ms
75+
out_f32x4(t8x8bcf+offset): ['69.17631531 ', '2.35151434 ', '14.92191601 '], time:0.529506ms
76+
out_f32x4(t8x8dbuf): ['69.17631531 ', '2.35151434 ', '14.92191601 '], time:0.451362ms
77+
out_f32x4(t8x8dbuf+offset): ['69.17631531 ', '2.35151434 ', '14.92191601 '], time:0.462964ms
78+
out_f32_th: ['69.17631531 ', '2.35151434 ', '14.92191601 '], time:0.552487ms
79+
----------------------------------------------------------------------------------------------------
80+
----------------------------------------------------------------------------------------------------
81+
M=4096, N=2048, K=2048
82+
out_f32: ['62.51137161 ', '-45.17026138', '61.54212952 '], time:10.213661ms
83+
out_f32(sk): ['62.51137161 ', '-45.17026138', '61.54212952 '], time:7.267971ms
84+
out_f32x4(t8x8sk): ['62.51137161 ', '-45.17026138', '61.54212952 '], time:1.244769ms
85+
out_f32x4(t8x8bcf): ['62.51137161 ', '-45.17026138', '61.54212952 '], time:1.076307ms
86+
out_f32x4(t8x8bcf+offset): ['62.51137161 ', '-45.17026138', '61.54212952 '], time:1.074743ms
87+
out_f32x4(t8x8dbuf): ['62.51137161 ', '-45.17026138', '61.54212952 '], time:0.948534ms
88+
out_f32x4(t8x8dbuf+offset): ['62.51137161 ', '-45.17026138', '61.54212952 '], time:0.963700ms
89+
out_f32_th: ['62.51136398 ', '-45.17026138', '61.54217911 '], time:0.916274ms
90+
----------------------------------------------------------------------------------------------------
9791
```

sgemm/sgemm.cu

Lines changed: 64 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -159,7 +159,7 @@ __global__ void sgemm_t_8x8_sliced_k_f32x4_kernel(float* a, float* b, float* c,
159159
}
160160
}
161161

162-
template<const int BM=128, const int BN=128, const int BK=8, const int TM=8, const int TN=8>
162+
template<const int BM=128, const int BN=128, const int BK=8, const int TM=8, const int TN=8, const int OFFSET=0>
163163
__global__ void sgemm_t_8x8_sliced_k_f32x4_bcf_kernel(
164164
float* a, float* b, float* c, const int M, const int N, const int K) {
165165

@@ -169,8 +169,8 @@ __global__ void sgemm_t_8x8_sliced_k_f32x4_bcf_kernel(
169169
const int ty = threadIdx.y;
170170
const int tid = ty * blockDim.x + tx;
171171

172-
__shared__ float s_a[BK][BM];
173-
__shared__ float s_b[BK][BN];
172+
__shared__ float s_a[BK][BM + OFFSET];
173+
__shared__ float s_b[BK][BN + OFFSET];
174174
// __shared__ float s_a[BK][BM + 4];
175175
// __shared__ float s_b[BK][BN + 4];
176176

@@ -334,7 +334,7 @@ __global__ void sgemm_t_8x8_sliced_k_f32x4_bcf_kernel(
334334
}
335335
}
336336

337-
template<const int BM=128, const int BN=128, const int BK=8, const int TM=8, const int TN=8>
337+
template<const int BM=128, const int BN=128, const int BK=8, const int TM=8, const int TN=8, const int OFFSET=0>
338338
__global__ void sgemm_t_8x8_sliced_k_f32x4_bcf_dbuf_kernel(
339339
float* a, float* b, float* c, const int M, const int N, const int K) {
340340

@@ -344,8 +344,8 @@ __global__ void sgemm_t_8x8_sliced_k_f32x4_bcf_dbuf_kernel(
344344
const int ty = threadIdx.y;
345345
const int tid = ty * blockDim.x + tx;
346346

347-
__shared__ float s_a[2][BK][BM];
348-
__shared__ float s_b[2][BK][BN];
347+
__shared__ float s_a[2][BK][BM + OFFSET];
348+
__shared__ float s_b[2][BK][BN + OFFSET];
349349

350350
float r_load_a[TM/2];
351351
float r_load_b[TN/2];
@@ -592,6 +592,34 @@ void sgemm_t_8x8_sliced_k_f32x4_bcf(torch::Tensor a, torch::Tensor b, torch::Ten
592592
);
593593
}
594594

595+
void sgemm_t_8x8_sliced_k_f32x4_bcf_offset(torch::Tensor a, torch::Tensor b, torch::Tensor c) {
596+
CHECK_TORCH_TENSOR_DTYPE(a, torch::kFloat32)
597+
CHECK_TORCH_TENSOR_DTYPE(b, torch::kFloat32)
598+
CHECK_TORCH_TENSOR_DTYPE(c, torch::kFloat32)
599+
const int M = a.size(0);
600+
const int K = a.size(1);
601+
const int N = b.size(1);
602+
CHECK_TORCH_TENSOR_SHAPE(a, M, K)
603+
CHECK_TORCH_TENSOR_SHAPE(b, K, N)
604+
CHECK_TORCH_TENSOR_SHAPE(c, M, N)
605+
constexpr int BM = 128;
606+
constexpr int BN = 128;
607+
constexpr int BK = 8;
608+
constexpr int TM = 8;
609+
constexpr int TN = 8;
610+
constexpr int OFFSET = 4;
611+
612+
dim3 block(BN/TN, BM/TM);
613+
dim3 grid((N + BN - 1) / BN, (M + BM - 1) / BM);
614+
615+
sgemm_t_8x8_sliced_k_f32x4_bcf_kernel<BM, BN, BK, TM, TN, OFFSET><<<grid, block>>>(
616+
reinterpret_cast<float*>(a.data_ptr()),
617+
reinterpret_cast<float*>(b.data_ptr()),
618+
reinterpret_cast<float*>(c.data_ptr()),
619+
M, N, K
620+
);
621+
}
622+
595623
void sgemm_t_8x8_sliced_k_f32x4_bcf_dbuf(torch::Tensor a, torch::Tensor b, torch::Tensor c) {
596624
CHECK_TORCH_TENSOR_DTYPE(a, torch::kFloat32)
597625
CHECK_TORCH_TENSOR_DTYPE(b, torch::kFloat32)
@@ -619,10 +647,40 @@ void sgemm_t_8x8_sliced_k_f32x4_bcf_dbuf(torch::Tensor a, torch::Tensor b, torch
619647
);
620648
}
621649

650+
void sgemm_t_8x8_sliced_k_f32x4_bcf_dbuf_offset(torch::Tensor a, torch::Tensor b, torch::Tensor c) {
651+
CHECK_TORCH_TENSOR_DTYPE(a, torch::kFloat32)
652+
CHECK_TORCH_TENSOR_DTYPE(b, torch::kFloat32)
653+
CHECK_TORCH_TENSOR_DTYPE(c, torch::kFloat32)
654+
const int M = a.size(0);
655+
const int K = a.size(1);
656+
const int N = b.size(1);
657+
CHECK_TORCH_TENSOR_SHAPE(a, M, K)
658+
CHECK_TORCH_TENSOR_SHAPE(b, K, N)
659+
CHECK_TORCH_TENSOR_SHAPE(c, M, N)
660+
constexpr int BM = 128;
661+
constexpr int BN = 128;
662+
constexpr int BK = 8;
663+
constexpr int TM = 8;
664+
constexpr int TN = 8;
665+
constexpr int OFFSET = 4;
666+
667+
dim3 block(BN/TN, BM/TM);
668+
dim3 grid((N + BN - 1) / BN, (M + BM - 1) / BM);
669+
670+
sgemm_t_8x8_sliced_k_f32x4_bcf_dbuf_kernel<BM, BN, BK, TM, TN, OFFSET><<<grid, block>>>(
671+
reinterpret_cast<float*>(a.data_ptr()),
672+
reinterpret_cast<float*>(b.data_ptr()),
673+
reinterpret_cast<float*>(c.data_ptr()),
674+
M, N, K
675+
);
676+
}
677+
622678
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
623679
TORCH_BINDING_COMMON_EXTENSION(sgemm_naive_f32)
624680
TORCH_BINDING_COMMON_EXTENSION(sgemm_sliced_k_f32)
625681
TORCH_BINDING_COMMON_EXTENSION(sgemm_t_8x8_sliced_k_f32x4)
626682
TORCH_BINDING_COMMON_EXTENSION(sgemm_t_8x8_sliced_k_f32x4_bcf)
683+
TORCH_BINDING_COMMON_EXTENSION(sgemm_t_8x8_sliced_k_f32x4_bcf_offset)
627684
TORCH_BINDING_COMMON_EXTENSION(sgemm_t_8x8_sliced_k_f32x4_bcf_dbuf)
685+
TORCH_BINDING_COMMON_EXTENSION(sgemm_t_8x8_sliced_k_f32x4_bcf_dbuf_offset)
628686
}

0 commit comments

Comments
 (0)