Skip to content

Commit 54c761d

Browse files
authored
[RMSNorm][FP16] Pack f16x8 rmsnorm (#47)
* Update .gitmodules * Update README.md * Update rms_norm.cu * Update rms_norm.py * Update README.md * Update rms_norm.cu * Update rms_norm.py * Update README.md * Update README.md
1 parent 4667308 commit 54c761d

File tree

5 files changed

+351
-210
lines changed

5 files changed

+351
-210
lines changed

.gitmodules

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,3 @@
11
[submodule "third-party/cutlass"]
22
path = third-party/cutlass
3-
url = git@github.com:NVIDIA/cutlass.git
3+
url = https://github.com/NVIDIA/cutlass.git

README.md

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -88,8 +88,8 @@
8888
| ✔️ [rms_norm_f16x2_f16(per token)](./rms-norm/rms_norm.cu)|f16|f16|[link](./rms-norm/)|⭐️⭐️|
8989
| ✔️ [rms_norm_f16x8_f16(per token)](./rms-norm/rms_norm.cu)|f16|f16|[link](./rms-norm/)|⭐️⭐️|
9090
| ✔️ [rms_norm_f16x8_f32(per token)](./rms-norm/rms_norm.cu)|f16|f32|[link](./rms-norm/)|⭐️⭐️|
91-
| ✔️ [rms_norm_f16x16_f16(per token)](./rms-norm/rms_norm.cu)|f16|f16|[link](./rms-norm/)|⭐️⭐️|
92-
| ✔️ [rms_norm_f16x16_f32(per token)](./rms-norm/rms_norm.cu)|f16|f32|[link](./rms-norm/)|⭐️⭐️|
91+
| ✔️ [rms_norm_f16x8_pack_f16(per token)](./rms-norm/rms_norm.cu)|f16|f16|[link](./rms-norm/)|⭐️⭐️|
92+
| ✔️ [rms_norm_f16x8_pack_f32(per token)](./rms-norm/rms_norm.cu)|f16|f32|[link](./rms-norm/)|⭐️⭐️|
9393
| ✔️ [rms_norm_f16_f32(per token)](./rms-norm/rms_norm.cu)|f16|f32|[link](./rms-norm/)|⭐️⭐️|
9494
| ✔️ [sgemm_sliced_k_f32](./sgemm/sgemm.cu)|f32|f32|[link](./sgemm/)|⭐️⭐️⭐️|
9595
| ✔️ [sgemm_t_8x8_sliced_k_f32x4](./sgemm/sgemm.cu)|f32|f32|[link](./sgemm/)|⭐️⭐️⭐️|

rms-norm/README.md

Lines changed: 71 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -10,8 +10,8 @@
1010
- [X] rms_norm_f16x2_f16_kernel
1111
- [X] rms_norm_f16x8_f16_kernel
1212
- [X] rms_norm_f16x8_f32_kernel
13-
- [X] rms_norm_f16x16_f16_kernel
14-
- [X] rms_norm_f16x16_f32_kernel
13+
- [X] rms_norm_f16x8_pack_f16_kernel
14+
- [X] rms_norm_f16x8_pack_f32_kernel
1515
- [X] rms_norm_f16_f32_kernel
1616
- [X] PyTorch bindings
1717

@@ -26,18 +26,73 @@ python3 rms_norm.py
2626
输出:
2727

2828
```bash
29-
--------------------------------------------------------------------------------
30-
out_f32: [0.92419142, -0.08846965, 1.06359947], time:0.03389192ms
31-
out_f32x4: [0.92419147, -0.08846966, 1.06359959], time:0.00855207ms
32-
out_f32_th: [0.92419606, -0.08847010, 1.06360483], time:0.04171062ms
33-
--------------------------------------------------------------------------------
34-
out_f16f16: [0.92431641, -0.08843994, 1.06347656], time:0.03518176ms
35-
out_f16x2f16: [0.92431641, -0.08843994, 1.06347656], time:0.01200986ms
36-
out_f16x8f16: [0.92431641, -0.08843994, 1.06347656], time:0.00625682ms
37-
out_f16x8f32: [0.92431641, -0.08843994, 1.06347656], time:0.00625014ms
38-
out_f16x16f16: [0.92431641, -0.08843994, 1.06347656], time:0.02620339ms
39-
out_f16x16f32: [0.92431641, -0.08843994, 1.06347656], time:0.01505637ms
40-
out_f16f32: [0.92431641, -0.08843994, 1.06347656], time:0.03300810ms
41-
out_f16_th: [0.92431641, -0.08843994, 1.06347656], time:0.04187107ms
42-
--------------------------------------------------------------------------------
29+
-------------------------------------------------------------------------------------
30+
N=4096, K=512
31+
out_f32: ['0.04078517 ', '0.74503314 ', '0.87149841 '], time:0.01198173ms
32+
out_f32x4: ['0.04078517 ', '0.74503314 ', '0.87149841 '], time:0.00517488ms
33+
out_f32_th: ['0.04078539 ', '0.74503714 ', '0.87150306 '], time:0.04351616ms
34+
-------------------------------------------------------------------------------------
35+
out_f16f16: ['0.040802 ', '0.74511719 ', '0.87158203 '], time:0.01200986ms
36+
out_f16f32: ['0.040802 ', '0.74511719 ', '0.87109375 '], time:0.01180410ms
37+
out_f16x2f16: ['0.040802 ', '0.74511719 ', '0.87158203 '], time:0.00670171ms
38+
out_f16x8f16: ['0.040802 ', '0.74511719 ', '0.87158203 '], time:0.00411820ms
39+
out_f16x8f32: ['0.040802 ', '0.74511719 ', '0.87158203 '], time:0.00411677ms
40+
out_f16x8packf16: ['0.040802 ', '0.74511719 ', '0.87158203 '], time:0.00411630ms
41+
out_f16x8packf32: ['0.040802 ', '0.74511719 ', '0.87109375 '], time:0.00399137ms
42+
out_f16_th: ['0.040802 ', '0.74511719 ', '0.87158203 '], time:0.04383564ms
43+
-------------------------------------------------------------------------------------
44+
-------------------------------------------------------------------------------------
45+
N=4096, K=1024
46+
out_f32: ['-0.76329279 ', '-0.62111992 ', '-1.45531178 '], time:0.03398657ms
47+
out_f32x4: ['-0.76329279 ', '-0.62111992 ', '-1.45531178 '], time:0.00862885ms
48+
out_f32_th: ['-0.76329684 ', '-0.62112319 ', '-1.4553194 '], time:0.04355550ms
49+
-------------------------------------------------------------------------------------
50+
out_f16f16: ['-0.76318359 ', '-0.62109375 ', '-1.45507812 '], time:0.03526235ms
51+
out_f16f32: ['-0.76318359 ', '-0.62109375 ', '-1.45605469 '], time:0.03302288ms
52+
out_f16x2f16: ['-0.76318359 ', '-0.62109375 ', '-1.45507812 '], time:0.01215649ms
53+
out_f16x8f16: ['-0.76318359 ', '-0.62109375 ', '-1.45507812 '], time:0.00632071ms
54+
out_f16x8f32: ['-0.76318359 ', '-0.62109375 ', '-1.45507812 '], time:0.00631690ms
55+
out_f16x8packf16: ['-0.76318359 ', '-0.62109375 ', '-1.45507812 '], time:0.00528240ms
56+
out_f16x8packf32: ['-0.76318359 ', '-0.62109375 ', '-1.45605469 '], time:0.00519514ms
57+
out_f16_th: ['-0.76318359 ', '-0.62109375 ', '-1.45507812 '], time:0.04399920ms
58+
-------------------------------------------------------------------------------------
59+
-------------------------------------------------------------------------------------
60+
N=4096, K=2048
61+
out_f32x4: ['-0.17984088 ', '-1.76387513 ', '-0.32782754 '], time:0.01650691ms
62+
out_f32_th: ['-0.17984176 ', '-1.76388371 ', '-0.32782915 '], time:0.09451318ms
63+
-------------------------------------------------------------------------------------
64+
out_f16x2f16: ['-0.17980957 ', '-1.76367188 ', '-0.32788086 '], time:0.03497124ms
65+
out_f16x8f16: ['-0.17980957 ', '-1.76367188 ', '-0.32788086 '], time:0.01254177ms
66+
out_f16x8f32: ['-0.17980957 ', '-1.76367188 ', '-0.32788086 '], time:0.01253581ms
67+
out_f16x8packf16: ['-0.17980957 ', '-1.76367188 ', '-0.32788086 '], time:0.00903535ms
68+
out_f16x8packf32: ['-0.17980957 ', '-1.76367188 ', '-0.32788086 '], time:0.00894380ms
69+
out_f16_th: ['-0.17980957 ', '-1.76367188 ', '-0.32788086 '], time:0.04889655ms
70+
-------------------------------------------------------------------------------------
71+
-------------------------------------------------------------------------------------
72+
N=4096, K=4096
73+
out_f32x4: ['-1.14100003 ', '-0.71529448 ', '2.26544118 '], time:0.18783689ms
74+
out_f32_th: ['-1.14100587 ', '-0.71529812 ', '2.26545286 '], time:0.52556086ms
75+
-------------------------------------------------------------------------------------
76+
out_f16x8f16: ['-1.140625 ', '-0.71484375 ', '2.26367188 '], time:0.03605795ms
77+
out_f16x8f32: ['-1.140625 ', '-0.71484375 ', '2.26367188 '], time:0.03605533ms
78+
out_f16x8packf16: ['-1.140625 ', '-0.71484375 ', '2.26367188 '], time:0.01718473ms
79+
out_f16x8packf32: ['-1.140625 ', '-0.71533203 ', '2.26367188 '], time:0.01735568ms
80+
out_f16_th: ['-1.140625 ', '-0.71484375 ', '2.26367188 '], time:0.11150384ms
81+
-------------------------------------------------------------------------------------
82+
-------------------------------------------------------------------------------------
83+
N=4096, K=8192
84+
out_f16x8f16: ['-0.40844727 ', '-0.14294434 ', '-0.93359375 '], time:0.19292974ms
85+
out_f16x8f32: ['-0.40844727 ', '-0.14294434 ', '-0.93359375 '], time:0.19298863ms
86+
out_f16x8packf16: ['-0.40844727 ', '-0.14294434 ', '-0.93359375 '], time:0.18497562ms
87+
out_f16x8packf32: ['-0.40844727 ', '-0.14294434 ', '-0.93310547 '], time:0.18479729ms
88+
out_f16_th: ['-0.40844727 ', '-0.14294434 ', '-0.93359375 '], time:0.59557104ms
89+
-------------------------------------------------------------------------------------
90+
-------------------------------------------------------------------------------------
91+
N=8192, K=8192
92+
out_f16x8f16: ['-0.35253906 ', '-1.04101562 ', '0.17358398 '], time:0.38169765ms
93+
out_f16x8f32: ['-0.35253906 ', '-1.04101562 ', '0.17358398 '], time:0.38264203ms
94+
out_f16x8packf16: ['-0.35253906 ', '-1.04101562 ', '0.17358398 '], time:0.40794849ms
95+
out_f16x8packf32: ['-0.35229492 ', '-1.04003906 ', '0.17346191 '], time:0.40747380ms
96+
out_f16_th: ['-0.35229492 ', '-1.04003906 ', '0.17346191 '], time:1.35807014ms
97+
-------------------------------------------------------------------------------------
4398
```

0 commit comments

Comments
 (0)