Skip to content

Commit 5901796

Browse files
authored
[Softmax][FP16] Pack f16x8 softmax kernel (#49)
* Update README.md * Update softmax.cu * Update softmax.py * Update README.md * Update layer_norm.cu * Update README.md * Update rms_norm.cu
1 parent 93636df commit 5901796

File tree

6 files changed

+589
-266
lines changed

6 files changed

+589
-266
lines changed

README.md

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,15 +9,15 @@
99
<img src=https://img.shields.io/badge/License-GPLv3.0-turquoise.svg >
1010
</div>
1111

12-
🎉 **CUDA Learn Notes**: This repo aims to build a **Modern CUDA Learn Notes with PyTorch** for **[Beginners]**, including **fp32, fp16/bf16, fp8/int8, Tensor/CUDA Cores**, flash_attn, sgemm, sgemv, hgemm, hgemv, warp/block reduce, dot prod, elementwise, sigmoid, relu, softmax, layernorm, rmsnorm, hist and some CUDA optimization techniques (pack LDST, warp gemv, sliced_k/split_k/pipeline gemm, bank conflicts free, MMA, etc).
12+
🎉 **CUDA Learn Notes**: This repo aims to build a **Modern CUDA Learn Notes with PyTorch** for **[B]eginners**, including **fp32, fp16/bf16, fp8/int8, Tensor/CUDA Cores**, flash_attn, sgemm, sgemv, hgemm, hgemv, warp/block reduce, dot prod, elementwise, sigmoid, relu, softmax, layernorm, rmsnorm, hist and some CUDA optimization techniques (pack LDST, warp gemv, sliced_k/split_k/pipeline gemm, bank conflicts free, MMA, etc).
1313

1414
<img width="1438" alt="image" src="https://github.com/user-attachments/assets/0c5e5125-586f-43fa-8e8b-e2c61c1afbbe">
1515

1616
## 0x00 📖 CUDA Kernel目录 (面试常考题目)
1717
- / = not supported now.
1818
- ✔️ = known work and already supported now.
1919
- ❔ = in my plan, but not coming soon, maybe a few weeks later.
20-
- **workflow**: custom **CUDA** kernel impl -> **Torch** python binding -> Run tests.
20+
- **workflow**: custom **CUDA** kernel impl -> **PyTorch** python binding -> Run tests.
2121

2222
|📖 cuda kernel| 📖 elem dtype| 📖 acc dtype| 📖 docs | 📖 level |
2323
|:---|:---|:---|:---|:---|
@@ -75,6 +75,9 @@
7575
| ✔️ [softmax_f32x4(per token)](./softmax/softmax.cu)|f32|f32|[link](./softmax/)|⭐️⭐️|
7676
| ✔️ [safe_softmax_f32(per token)](./softmax/softmax.cu)|f32|f32|[link](./softmax/)|⭐️⭐️|
7777
| ✔️ [safe_softmax_f32x4(per token)](./softmax/softmax.cu)|f32|f32|[link](./softmax/)|⭐️⭐️|
78+
| ✔️ [safe_softmax_f16_f32(per token)](./softmax/softmax.cu)|f16|f32|[link](./softmax/)|⭐️⭐️|
79+
| ✔️ [safe_softmax_f16x2_f32(per token)](./softmax/softmax.cu)|f16|f32|[link](./softmax/)|⭐️⭐️|
80+
| ✔️ [safe_softmax_f16x8_pack_f32(per token)](./softmax/softmax.cu)|f16|f32|[link](./softmax/)|⭐️⭐️|
7881
| ✔️ [layer_norm_f32(per token)](./layer-norm/layer_norm.cu)|f32|f32|[link](./layer-norm/)|⭐️⭐️|
7982
| ✔️ [layer_norm_f32x4(per token)](./layer-norm/layer_norm.cu)|f32|f32|[link](./layer-norm/)|⭐️⭐️|
8083
| ✔️ [layer_norm_f16_f16(per token)](./layer-norm/layer_norm.cu)|f16|f16|[link](./layer-norm/)|⭐️⭐️|

layer-norm/layer_norm.cu

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -433,9 +433,12 @@ if(((T).options().dtype() != (th_type))) { \
433433
throw std::runtime_error("values must be "#th_type); \
434434
}
435435

436-
#define CHECK_TORCH_TENSOR_SHAPE(T1, T2) \
437-
if (((T2).size(0) != (T1).size(0)) || ((T2).size(1) != (T1).size(1))) { \
438-
throw std::runtime_error("Tensor size mismatch!"); \
436+
#define CHECK_TORCH_TENSOR_SHAPE(T1, T2) \
437+
assert((T1).dim() == (T2).dim()); \
438+
for (int i = 0; i < (T1).dim(); ++i) { \
439+
if ((T2).size(i) != (T1).size(i)) { \
440+
throw std::runtime_error("Tensor size mismatch!"); \
441+
} \
439442
}
440443

441444
// fp32

rms-norm/rms_norm.cu

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -382,9 +382,12 @@ if(((T).options().dtype() != (th_type))) { \
382382
throw std::runtime_error("values must be "#th_type); \
383383
}
384384

385-
#define CHECK_TORCH_TENSOR_SHAPE(T1, T2) \
386-
if (((T2).size(0) != (T1).size(0)) || ((T2).size(1) != (T1).size(1))) { \
387-
throw std::runtime_error("Tensor size mismatch!"); \
385+
#define CHECK_TORCH_TENSOR_SHAPE(T1, T2) \
386+
assert((T1).dim() == (T2).dim()); \
387+
for (int i = 0; i < (T1).dim(); ++i) { \
388+
if ((T2).size(i) != (T1).size(i)) { \
389+
throw std::runtime_error("Tensor size mismatch!"); \
390+
} \
388391
}
389392

390393
#define LANUCH_RMS_NORM_F32_KERNEL(K) \

softmax/README.md

Lines changed: 87 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -5,11 +5,14 @@
55
包含以下内容:
66

77
- [X] softmax_f32_kernel (grid level memory fence)
8-
- [X] softmax_f32x4_kernel(grid level memory fence, float4向量化版本)
8+
- [X] softmax_f32x4_kernel(grid level memory fence)
99
- [X] softmax_f32_per_token_kernel(per token)
10-
- [X] softmax_f32x4_per_token_kernel(per token, float4向量化版本)
10+
- [X] softmax_f32x4_per_token_kernel(per token)
1111
- [X] safe_softmax_f32_per_token_kernel(per token)
12-
- [X] safe_softmax_f32x4_per_token_kernel(per token, float4向量化版本)
12+
- [X] safe_softmax_f32x4_per_token_kernel(per token)
13+
- [X] safe_softmax_f16_f32_per_token_kernel(per token)
14+
- [X] safe_softmax_f16x2_f32_per_token_kernel(per token)
15+
- [X] safe_softmax_f16x8_pack_f32_per_token_kernel(per token)
1316
- [X] PyTorch bindings
1417

1518

@@ -24,25 +27,84 @@ python3 softmax.py
2427
输出:
2528

2629
```bash
27-
--------------------------------------------------------------------------------
28-
out_f32: [1.909e-05, 0.00023536, 0.00010881], time:0.01697016ms
29-
out_f32x4: [1.909e-05, 0.00023536, 0.00010881], time:0.01716042ms
30-
out_f32_th: [1.909e-05, 0.00023536, 0.00010881], time:0.00715089ms
31-
--------------------------------------------------------------------------------
32-
out_f32(v2): [1.909e-05, 0.00023536, 0.00010881], time:0.01011539ms
33-
out_f32x4(v2): [1.909e-05, 0.00023536, 0.00010881], time:0.01006842ms
34-
out_f32_th(v2): [1.909e-05, 0.00023536, 0.00010881], time:0.00547409ms
35-
--------------------------------------------------------------------------------
36-
out_f32(per): [0.00569158, 0.00022239, 0.00137839], time:0.01047754ms
37-
out_f32x4(per): [0.00569158, 0.00022239, 0.00137839], time:0.01045704ms
38-
out_f32(safe): [0.00569158, 0.00022239, 0.00137839], time:0.01054454ms
39-
out_f32x4(safe): [0.00569158, 0.00022239, 0.00137839], time:0.01042986ms
40-
out_f32_th(per): [0.00569158, 0.00022239, 0.00137839], time:0.00741696ms
41-
--------------------------------------------------------------------------------
42-
out_f32(per v2): [0.00569158, 0.00022239, 0.00137839], time:0.00419974ms
43-
out_f32x4(per v2): [0.00569158, 0.00022239, 0.00137839], time:0.00316834ms
44-
out_f32(safe v2): [0.00569158, 0.00022239, 0.00137839], time:0.00603890ms
45-
out_f32x4(safe v2): [0.00569158, 0.00022239, 0.00137839], time:0.00319862ms
46-
out_f32_th(per v2): [0.00569158, 0.00022239, 0.00137839], time:0.00577068ms
47-
--------------------------------------------------------------------------------
48-
```
30+
----------------------------------------------------------------------------------------------------
31+
N=16384
32+
----------------------------------------------------------------------------------------------------
33+
out_f32(fence): ['5.912e-05 ', '9.61e-05 ', '4.271e-05 '], time:0.01040053ms
34+
out_f32x4(fence): ['5.912e-05 ', '9.61e-05 ', '4.271e-05 '], time:0.01053643ms
35+
out_f32_th: ['5.912e-05 ', '9.61e-05 ', '4.271e-05 '], time:0.00582504ms
36+
----------------------------------------------------------------------------------------------------
37+
S=4096, H=256
38+
----------------------------------------------------------------------------------------------------
39+
out_f32(per): ['0.0015298 ', '0.00619088 ', '0.00529766 '], time:0.00627208ms
40+
out_f32x4(per): ['0.0015298 ', '0.00619088 ', '0.00529766 '], time:0.00394082ms
41+
out_f32(safe): ['0.0015298 ', '0.00619088 ', '0.00529766 '], time:0.00941491ms
42+
out_f32x4(safe): ['0.0015298 ', '0.00619088 ', '0.00529766 '], time:0.00413442ms
43+
out_f32_th(per): ['0.0015298 ', '0.00619088 ', '0.00529766 '], time:0.00602674ms
44+
----------------------------------------------------------------------------------------------------
45+
out_f16f32(safe): ['0.00152969 ', '0.00619125 ', '0.00529861 '], time:0.00912046ms
46+
out_f16x2f32(safe): ['0.00152969 ', '0.00619125 ', '0.00529861 '], time:0.00522232ms
47+
out_f16x8packf32(safe): ['0.00152969 ', '0.00619125 ', '0.00529861 '], time:0.00413895ms
48+
out_f16_th(per): ['0.00152969 ', '0.00619125 ', '0.00529861 '], time:0.00605321ms
49+
----------------------------------------------------------------------------------------------------
50+
----------------------------------------------------------------------------------------------------
51+
S=4096, H=512
52+
----------------------------------------------------------------------------------------------------
53+
out_f32(per): ['0.00200376 ', '0.00063461 ', '0.00163568 '], time:0.01139641ms
54+
out_f32x4(per): ['0.00200376 ', '0.00063461 ', '0.00163568 '], time:0.00515914ms
55+
out_f32(safe): ['0.00200376 ', '0.00063461 ', '0.00163568 '], time:0.01834297ms
56+
out_f32x4(safe): ['0.00200376 ', '0.00063461 ', '0.00163568 '], time:0.00574923ms
57+
out_f32_th(per): ['0.00200376 ', '0.00063461 ', '0.00163568 '], time:0.00657558ms
58+
----------------------------------------------------------------------------------------------------
59+
out_f16f32(safe): ['0.00200462 ', '0.00063467 ', '0.00163555 '], time:0.01782560ms
60+
out_f16x2f32(safe): ['0.00200462 ', '0.00063467 ', '0.00163555 '], time:0.00919509ms
61+
out_f16x8packf32(safe): ['0.00200462 ', '0.00063467 ', '0.00163555 '], time:0.00415683ms
62+
out_f16_th(per): ['0.00200462 ', '0.00063467 ', '0.00163555 '], time:0.00634599ms
63+
----------------------------------------------------------------------------------------------------
64+
----------------------------------------------------------------------------------------------------
65+
S=4096, H=1024
66+
----------------------------------------------------------------------------------------------------
67+
out_f32(per): ['0.0009461 ', '0.00073918 ', '0.00074397 '], time:0.03191805ms
68+
out_f32x4(per): ['0.0009461 ', '0.00073918 ', '0.00074397 '], time:0.00862813ms
69+
out_f32(safe): ['0.0009461 ', '0.00073918 ', '0.00074397 '], time:0.04873967ms
70+
out_f32x4(safe): ['0.0009461 ', '0.00073918 ', '0.00074397 '], time:0.01027441ms
71+
out_f32_th(per): ['0.0009461 ', '0.00073918 ', '0.00074397 '], time:0.01181388ms
72+
----------------------------------------------------------------------------------------------------
73+
out_f16f32(safe): ['0.00094604 ', '0.0007391 ', '0.00074387 '], time:0.04671884ms
74+
out_f16x2f32(safe): ['0.00094604 ', '0.0007391 ', '0.00074387 '], time:0.01810408ms
75+
out_f16x8packf32(safe): ['0.00094604 ', '0.0007391 ', '0.00074387 '], time:0.00601912ms
76+
out_f16_th(per): ['0.00094604 ', '0.0007391 ', '0.00074387 '], time:0.01047063ms
77+
----------------------------------------------------------------------------------------------------
78+
----------------------------------------------------------------------------------------------------
79+
S=4096, H=2048
80+
----------------------------------------------------------------------------------------------------
81+
out_f32x4(per): ['9.216e-05 ', '0.00045569 ', '0.00013162 '], time:0.01605988ms
82+
out_f32x4(safe): ['9.216e-05 ', '0.00045569 ', '0.00013162 '], time:0.02089310ms
83+
out_f32_th(per): ['9.216e-05 ', '0.00045569 ', '0.00013162 '], time:0.06726241ms
84+
----------------------------------------------------------------------------------------------------
85+
out_f16x2f32(safe): ['9.215e-05 ', '0.00045562 ', '0.00013161 '], time:0.04824972ms
86+
out_f16x8packf32(safe): ['9.215e-05 ', '0.00045562 ', '0.00013161 '], time:0.01086283ms
87+
out_f16_th(per): ['9.215e-05 ', '0.00045562 ', '0.00013161 '], time:0.07232165ms
88+
----------------------------------------------------------------------------------------------------
89+
----------------------------------------------------------------------------------------------------
90+
S=4096, H=4096
91+
----------------------------------------------------------------------------------------------------
92+
out_f32x4(per): ['0.00017665 ', '0.00035685 ', '0.00017236 '], time:0.18465948ms
93+
out_f32x4(safe): ['0.00017665 ', '0.00035685 ', '0.00017236 '], time:0.18565655ms
94+
out_f32_th(per): ['0.00017665 ', '0.00035685 ', '0.00017236 '], time:0.18744922ms
95+
----------------------------------------------------------------------------------------------------
96+
out_f16x8packf32(safe): ['0.00017667 ', '0.00035691 ', '0.00017238 '], time:0.02254891ms
97+
out_f16_th(per): ['0.00017667 ', '0.00035691 ', '0.00017238 '], time:0.08283138ms
98+
----------------------------------------------------------------------------------------------------
99+
----------------------------------------------------------------------------------------------------
100+
S=4096, H=8192
101+
----------------------------------------------------------------------------------------------------
102+
out_f16x8packf32(safe): ['4.166e-05 ', '3.767e-05 ', '1.562e-05 '], time:0.19313049ms
103+
out_f16_th(per): ['4.166e-05 ', '3.767e-05 ', '1.562e-05 '], time:0.19356799ms
104+
----------------------------------------------------------------------------------------------------
105+
S=8192, H=8192
106+
----------------------------------------------------------------------------------------------------
107+
out_f16x8packf32(safe): ['4.208e-05 ', '0.00015438 ', '7.409e-05 '], time:0.39828229ms
108+
out_f16_th(per): ['4.208e-05 ', '0.00015438 ', '7.409e-05 '], time:0.40599036ms
109+
----------------------------------------------------------------------------------------------------
110+
```

0 commit comments

Comments
 (0)