10
10
- [X] rms_norm_f16x2_f16_kernel
11
11
- [X] rms_norm_f16x8_f16_kernel
12
12
- [X] rms_norm_f16x8_f32_kernel
13
- - [X] rms_norm_f16x16_f16_kernel
14
- - [X] rms_norm_f16x16_f32_kernel
13
+ - [X] rms_norm_f16x8_pack_f16_kernel
14
+ - [X] rms_norm_f16x8_pack_f32_kernel
15
15
- [X] rms_norm_f16_f32_kernel
16
16
- [X] PyTorch bindings
17
17
@@ -26,18 +26,73 @@ python3 rms_norm.py
26
26
输出:
27
27
28
28
``` bash
29
- --------------------------------------------------------------------------------
30
- out_f32: [0.92419142, -0.08846965, 1.06359947], time:0.03389192ms
31
- out_f32x4: [0.92419147, -0.08846966, 1.06359959], time:0.00855207ms
32
- out_f32_th: [0.92419606, -0.08847010, 1.06360483], time:0.04171062ms
33
- --------------------------------------------------------------------------------
34
- out_f16f16: [0.92431641, -0.08843994, 1.06347656], time:0.03518176ms
35
- out_f16x2f16: [0.92431641, -0.08843994, 1.06347656], time:0.01200986ms
36
- out_f16x8f16: [0.92431641, -0.08843994, 1.06347656], time:0.00625682ms
37
- out_f16x8f32: [0.92431641, -0.08843994, 1.06347656], time:0.00625014ms
38
- out_f16x16f16: [0.92431641, -0.08843994, 1.06347656], time:0.02620339ms
39
- out_f16x16f32: [0.92431641, -0.08843994, 1.06347656], time:0.01505637ms
40
- out_f16f32: [0.92431641, -0.08843994, 1.06347656], time:0.03300810ms
41
- out_f16_th: [0.92431641, -0.08843994, 1.06347656], time:0.04187107ms
42
- --------------------------------------------------------------------------------
29
+ -------------------------------------------------------------------------------------
30
+ N=4096, K=512
31
+ out_f32: [' 0.04078517 ' , ' 0.74503314 ' , ' 0.87149841 ' ], time:0.01198173ms
32
+ out_f32x4: [' 0.04078517 ' , ' 0.74503314 ' , ' 0.87149841 ' ], time:0.00517488ms
33
+ out_f32_th: [' 0.04078539 ' , ' 0.74503714 ' , ' 0.87150306 ' ], time:0.04351616ms
34
+ -------------------------------------------------------------------------------------
35
+ out_f16f16: [' 0.040802 ' , ' 0.74511719 ' , ' 0.87158203 ' ], time:0.01200986ms
36
+ out_f16f32: [' 0.040802 ' , ' 0.74511719 ' , ' 0.87109375 ' ], time:0.01180410ms
37
+ out_f16x2f16: [' 0.040802 ' , ' 0.74511719 ' , ' 0.87158203 ' ], time:0.00670171ms
38
+ out_f16x8f16: [' 0.040802 ' , ' 0.74511719 ' , ' 0.87158203 ' ], time:0.00411820ms
39
+ out_f16x8f32: [' 0.040802 ' , ' 0.74511719 ' , ' 0.87158203 ' ], time:0.00411677ms
40
+ out_f16x8packf16: [' 0.040802 ' , ' 0.74511719 ' , ' 0.87158203 ' ], time:0.00411630ms
41
+ out_f16x8packf32: [' 0.040802 ' , ' 0.74511719 ' , ' 0.87109375 ' ], time:0.00399137ms
42
+ out_f16_th: [' 0.040802 ' , ' 0.74511719 ' , ' 0.87158203 ' ], time:0.04383564ms
43
+ -------------------------------------------------------------------------------------
44
+ -------------------------------------------------------------------------------------
45
+ N=4096, K=1024
46
+ out_f32: [' -0.76329279 ' , ' -0.62111992 ' , ' -1.45531178 ' ], time:0.03398657ms
47
+ out_f32x4: [' -0.76329279 ' , ' -0.62111992 ' , ' -1.45531178 ' ], time:0.00862885ms
48
+ out_f32_th: [' -0.76329684 ' , ' -0.62112319 ' , ' -1.4553194 ' ], time:0.04355550ms
49
+ -------------------------------------------------------------------------------------
50
+ out_f16f16: [' -0.76318359 ' , ' -0.62109375 ' , ' -1.45507812 ' ], time:0.03526235ms
51
+ out_f16f32: [' -0.76318359 ' , ' -0.62109375 ' , ' -1.45605469 ' ], time:0.03302288ms
52
+ out_f16x2f16: [' -0.76318359 ' , ' -0.62109375 ' , ' -1.45507812 ' ], time:0.01215649ms
53
+ out_f16x8f16: [' -0.76318359 ' , ' -0.62109375 ' , ' -1.45507812 ' ], time:0.00632071ms
54
+ out_f16x8f32: [' -0.76318359 ' , ' -0.62109375 ' , ' -1.45507812 ' ], time:0.00631690ms
55
+ out_f16x8packf16: [' -0.76318359 ' , ' -0.62109375 ' , ' -1.45507812 ' ], time:0.00528240ms
56
+ out_f16x8packf32: [' -0.76318359 ' , ' -0.62109375 ' , ' -1.45605469 ' ], time:0.00519514ms
57
+ out_f16_th: [' -0.76318359 ' , ' -0.62109375 ' , ' -1.45507812 ' ], time:0.04399920ms
58
+ -------------------------------------------------------------------------------------
59
+ -------------------------------------------------------------------------------------
60
+ N=4096, K=2048
61
+ out_f32x4: [' -0.17984088 ' , ' -1.76387513 ' , ' -0.32782754 ' ], time:0.01650691ms
62
+ out_f32_th: [' -0.17984176 ' , ' -1.76388371 ' , ' -0.32782915 ' ], time:0.09451318ms
63
+ -------------------------------------------------------------------------------------
64
+ out_f16x2f16: [' -0.17980957 ' , ' -1.76367188 ' , ' -0.32788086 ' ], time:0.03497124ms
65
+ out_f16x8f16: [' -0.17980957 ' , ' -1.76367188 ' , ' -0.32788086 ' ], time:0.01254177ms
66
+ out_f16x8f32: [' -0.17980957 ' , ' -1.76367188 ' , ' -0.32788086 ' ], time:0.01253581ms
67
+ out_f16x8packf16: [' -0.17980957 ' , ' -1.76367188 ' , ' -0.32788086 ' ], time:0.00903535ms
68
+ out_f16x8packf32: [' -0.17980957 ' , ' -1.76367188 ' , ' -0.32788086 ' ], time:0.00894380ms
69
+ out_f16_th: [' -0.17980957 ' , ' -1.76367188 ' , ' -0.32788086 ' ], time:0.04889655ms
70
+ -------------------------------------------------------------------------------------
71
+ -------------------------------------------------------------------------------------
72
+ N=4096, K=4096
73
+ out_f32x4: [' -1.14100003 ' , ' -0.71529448 ' , ' 2.26544118 ' ], time:0.18783689ms
74
+ out_f32_th: [' -1.14100587 ' , ' -0.71529812 ' , ' 2.26545286 ' ], time:0.52556086ms
75
+ -------------------------------------------------------------------------------------
76
+ out_f16x8f16: [' -1.140625 ' , ' -0.71484375 ' , ' 2.26367188 ' ], time:0.03605795ms
77
+ out_f16x8f32: [' -1.140625 ' , ' -0.71484375 ' , ' 2.26367188 ' ], time:0.03605533ms
78
+ out_f16x8packf16: [' -1.140625 ' , ' -0.71484375 ' , ' 2.26367188 ' ], time:0.01718473ms
79
+ out_f16x8packf32: [' -1.140625 ' , ' -0.71533203 ' , ' 2.26367188 ' ], time:0.01735568ms
80
+ out_f16_th: [' -1.140625 ' , ' -0.71484375 ' , ' 2.26367188 ' ], time:0.11150384ms
81
+ -------------------------------------------------------------------------------------
82
+ -------------------------------------------------------------------------------------
83
+ N=4096, K=8192
84
+ out_f16x8f16: [' -0.40844727 ' , ' -0.14294434 ' , ' -0.93359375 ' ], time:0.19292974ms
85
+ out_f16x8f32: [' -0.40844727 ' , ' -0.14294434 ' , ' -0.93359375 ' ], time:0.19298863ms
86
+ out_f16x8packf16: [' -0.40844727 ' , ' -0.14294434 ' , ' -0.93359375 ' ], time:0.18497562ms
87
+ out_f16x8packf32: [' -0.40844727 ' , ' -0.14294434 ' , ' -0.93310547 ' ], time:0.18479729ms
88
+ out_f16_th: [' -0.40844727 ' , ' -0.14294434 ' , ' -0.93359375 ' ], time:0.59557104ms
89
+ -------------------------------------------------------------------------------------
90
+ -------------------------------------------------------------------------------------
91
+ N=8192, K=8192
92
+ out_f16x8f16: [' -0.35253906 ' , ' -1.04101562 ' , ' 0.17358398 ' ], time:0.38169765ms
93
+ out_f16x8f32: [' -0.35253906 ' , ' -1.04101562 ' , ' 0.17358398 ' ], time:0.38264203ms
94
+ out_f16x8packf16: [' -0.35253906 ' , ' -1.04101562 ' , ' 0.17358398 ' ], time:0.40794849ms
95
+ out_f16x8packf32: [' -0.35229492 ' , ' -1.04003906 ' , ' 0.17346191 ' ], time:0.40747380ms
96
+ out_f16_th: [' -0.35229492 ' , ' -1.04003906 ' , ' 0.17346191 ' ], time:1.35807014ms
97
+ -------------------------------------------------------------------------------------
43
98
```
0 commit comments