5
5
包含以下内容:
6
6
7
7
- [X] softmax_f32_kernel (grid level memory fence)
8
- - [X] softmax_f32x4_kernel(grid level memory fence, float4向量化版本 )
8
+ - [X] softmax_f32x4_kernel(grid level memory fence)
9
9
- [X] softmax_f32_per_token_kernel(per token)
10
- - [X] softmax_f32x4_per_token_kernel(per token, float4向量化版本 )
10
+ - [X] softmax_f32x4_per_token_kernel(per token)
11
11
- [X] safe_softmax_f32_per_token_kernel(per token)
12
- - [X] safe_softmax_f32x4_per_token_kernel(per token, float4向量化版本)
12
+ - [X] safe_softmax_f32x4_per_token_kernel(per token)
13
+ - [X] safe_softmax_f16_f32_per_token_kernel(per token)
14
+ - [X] safe_softmax_f16x2_f32_per_token_kernel(per token)
15
+ - [X] safe_softmax_f16x8_pack_f32_per_token_kernel(per token)
13
16
- [X] PyTorch bindings
14
17
15
18
@@ -24,25 +27,84 @@ python3 softmax.py
24
27
输出:
25
28
26
29
``` bash
27
- --------------------------------------------------------------------------------
28
- out_f32: [1.909e-05, 0.00023536, 0.00010881], time:0.01697016ms
29
- out_f32x4: [1.909e-05, 0.00023536, 0.00010881], time:0.01716042ms
30
- out_f32_th: [1.909e-05, 0.00023536, 0.00010881], time:0.00715089ms
31
- --------------------------------------------------------------------------------
32
- out_f32(v2): [1.909e-05, 0.00023536, 0.00010881], time:0.01011539ms
33
- out_f32x4(v2): [1.909e-05, 0.00023536, 0.00010881], time:0.01006842ms
34
- out_f32_th(v2): [1.909e-05, 0.00023536, 0.00010881], time:0.00547409ms
35
- --------------------------------------------------------------------------------
36
- out_f32(per): [0.00569158, 0.00022239, 0.00137839], time:0.01047754ms
37
- out_f32x4(per): [0.00569158, 0.00022239, 0.00137839], time:0.01045704ms
38
- out_f32(safe): [0.00569158, 0.00022239, 0.00137839], time:0.01054454ms
39
- out_f32x4(safe): [0.00569158, 0.00022239, 0.00137839], time:0.01042986ms
40
- out_f32_th(per): [0.00569158, 0.00022239, 0.00137839], time:0.00741696ms
41
- --------------------------------------------------------------------------------
42
- out_f32(per v2): [0.00569158, 0.00022239, 0.00137839], time:0.00419974ms
43
- out_f32x4(per v2): [0.00569158, 0.00022239, 0.00137839], time:0.00316834ms
44
- out_f32(safe v2): [0.00569158, 0.00022239, 0.00137839], time:0.00603890ms
45
- out_f32x4(safe v2): [0.00569158, 0.00022239, 0.00137839], time:0.00319862ms
46
- out_f32_th(per v2): [0.00569158, 0.00022239, 0.00137839], time:0.00577068ms
47
- --------------------------------------------------------------------------------
48
- ```
30
+ ----------------------------------------------------------------------------------------------------
31
+ N=16384
32
+ ----------------------------------------------------------------------------------------------------
33
+ out_f32(fence): [' 5.912e-05 ' , ' 9.61e-05 ' , ' 4.271e-05 ' ], time:0.01040053ms
34
+ out_f32x4(fence): [' 5.912e-05 ' , ' 9.61e-05 ' , ' 4.271e-05 ' ], time:0.01053643ms
35
+ out_f32_th: [' 5.912e-05 ' , ' 9.61e-05 ' , ' 4.271e-05 ' ], time:0.00582504ms
36
+ ----------------------------------------------------------------------------------------------------
37
+ S=4096, H=256
38
+ ----------------------------------------------------------------------------------------------------
39
+ out_f32(per): [' 0.0015298 ' , ' 0.00619088 ' , ' 0.00529766 ' ], time:0.00627208ms
40
+ out_f32x4(per): [' 0.0015298 ' , ' 0.00619088 ' , ' 0.00529766 ' ], time:0.00394082ms
41
+ out_f32(safe): [' 0.0015298 ' , ' 0.00619088 ' , ' 0.00529766 ' ], time:0.00941491ms
42
+ out_f32x4(safe): [' 0.0015298 ' , ' 0.00619088 ' , ' 0.00529766 ' ], time:0.00413442ms
43
+ out_f32_th(per): [' 0.0015298 ' , ' 0.00619088 ' , ' 0.00529766 ' ], time:0.00602674ms
44
+ ----------------------------------------------------------------------------------------------------
45
+ out_f16f32(safe): [' 0.00152969 ' , ' 0.00619125 ' , ' 0.00529861 ' ], time:0.00912046ms
46
+ out_f16x2f32(safe): [' 0.00152969 ' , ' 0.00619125 ' , ' 0.00529861 ' ], time:0.00522232ms
47
+ out_f16x8packf32(safe): [' 0.00152969 ' , ' 0.00619125 ' , ' 0.00529861 ' ], time:0.00413895ms
48
+ out_f16_th(per): [' 0.00152969 ' , ' 0.00619125 ' , ' 0.00529861 ' ], time:0.00605321ms
49
+ ----------------------------------------------------------------------------------------------------
50
+ ----------------------------------------------------------------------------------------------------
51
+ S=4096, H=512
52
+ ----------------------------------------------------------------------------------------------------
53
+ out_f32(per): [' 0.00200376 ' , ' 0.00063461 ' , ' 0.00163568 ' ], time:0.01139641ms
54
+ out_f32x4(per): [' 0.00200376 ' , ' 0.00063461 ' , ' 0.00163568 ' ], time:0.00515914ms
55
+ out_f32(safe): [' 0.00200376 ' , ' 0.00063461 ' , ' 0.00163568 ' ], time:0.01834297ms
56
+ out_f32x4(safe): [' 0.00200376 ' , ' 0.00063461 ' , ' 0.00163568 ' ], time:0.00574923ms
57
+ out_f32_th(per): [' 0.00200376 ' , ' 0.00063461 ' , ' 0.00163568 ' ], time:0.00657558ms
58
+ ----------------------------------------------------------------------------------------------------
59
+ out_f16f32(safe): [' 0.00200462 ' , ' 0.00063467 ' , ' 0.00163555 ' ], time:0.01782560ms
60
+ out_f16x2f32(safe): [' 0.00200462 ' , ' 0.00063467 ' , ' 0.00163555 ' ], time:0.00919509ms
61
+ out_f16x8packf32(safe): [' 0.00200462 ' , ' 0.00063467 ' , ' 0.00163555 ' ], time:0.00415683ms
62
+ out_f16_th(per): [' 0.00200462 ' , ' 0.00063467 ' , ' 0.00163555 ' ], time:0.00634599ms
63
+ ----------------------------------------------------------------------------------------------------
64
+ ----------------------------------------------------------------------------------------------------
65
+ S=4096, H=1024
66
+ ----------------------------------------------------------------------------------------------------
67
+ out_f32(per): [' 0.0009461 ' , ' 0.00073918 ' , ' 0.00074397 ' ], time:0.03191805ms
68
+ out_f32x4(per): [' 0.0009461 ' , ' 0.00073918 ' , ' 0.00074397 ' ], time:0.00862813ms
69
+ out_f32(safe): [' 0.0009461 ' , ' 0.00073918 ' , ' 0.00074397 ' ], time:0.04873967ms
70
+ out_f32x4(safe): [' 0.0009461 ' , ' 0.00073918 ' , ' 0.00074397 ' ], time:0.01027441ms
71
+ out_f32_th(per): [' 0.0009461 ' , ' 0.00073918 ' , ' 0.00074397 ' ], time:0.01181388ms
72
+ ----------------------------------------------------------------------------------------------------
73
+ out_f16f32(safe): [' 0.00094604 ' , ' 0.0007391 ' , ' 0.00074387 ' ], time:0.04671884ms
74
+ out_f16x2f32(safe): [' 0.00094604 ' , ' 0.0007391 ' , ' 0.00074387 ' ], time:0.01810408ms
75
+ out_f16x8packf32(safe): [' 0.00094604 ' , ' 0.0007391 ' , ' 0.00074387 ' ], time:0.00601912ms
76
+ out_f16_th(per): [' 0.00094604 ' , ' 0.0007391 ' , ' 0.00074387 ' ], time:0.01047063ms
77
+ ----------------------------------------------------------------------------------------------------
78
+ ----------------------------------------------------------------------------------------------------
79
+ S=4096, H=2048
80
+ ----------------------------------------------------------------------------------------------------
81
+ out_f32x4(per): [' 9.216e-05 ' , ' 0.00045569 ' , ' 0.00013162 ' ], time:0.01605988ms
82
+ out_f32x4(safe): [' 9.216e-05 ' , ' 0.00045569 ' , ' 0.00013162 ' ], time:0.02089310ms
83
+ out_f32_th(per): [' 9.216e-05 ' , ' 0.00045569 ' , ' 0.00013162 ' ], time:0.06726241ms
84
+ ----------------------------------------------------------------------------------------------------
85
+ out_f16x2f32(safe): [' 9.215e-05 ' , ' 0.00045562 ' , ' 0.00013161 ' ], time:0.04824972ms
86
+ out_f16x8packf32(safe): [' 9.215e-05 ' , ' 0.00045562 ' , ' 0.00013161 ' ], time:0.01086283ms
87
+ out_f16_th(per): [' 9.215e-05 ' , ' 0.00045562 ' , ' 0.00013161 ' ], time:0.07232165ms
88
+ ----------------------------------------------------------------------------------------------------
89
+ ----------------------------------------------------------------------------------------------------
90
+ S=4096, H=4096
91
+ ----------------------------------------------------------------------------------------------------
92
+ out_f32x4(per): [' 0.00017665 ' , ' 0.00035685 ' , ' 0.00017236 ' ], time:0.18465948ms
93
+ out_f32x4(safe): [' 0.00017665 ' , ' 0.00035685 ' , ' 0.00017236 ' ], time:0.18565655ms
94
+ out_f32_th(per): [' 0.00017665 ' , ' 0.00035685 ' , ' 0.00017236 ' ], time:0.18744922ms
95
+ ----------------------------------------------------------------------------------------------------
96
+ out_f16x8packf32(safe): [' 0.00017667 ' , ' 0.00035691 ' , ' 0.00017238 ' ], time:0.02254891ms
97
+ out_f16_th(per): [' 0.00017667 ' , ' 0.00035691 ' , ' 0.00017238 ' ], time:0.08283138ms
98
+ ----------------------------------------------------------------------------------------------------
99
+ ----------------------------------------------------------------------------------------------------
100
+ S=4096, H=8192
101
+ ----------------------------------------------------------------------------------------------------
102
+ out_f16x8packf32(safe): [' 4.166e-05 ' , ' 3.767e-05 ' , ' 1.562e-05 ' ], time:0.19313049ms
103
+ out_f16_th(per): [' 4.166e-05 ' , ' 3.767e-05 ' , ' 1.562e-05 ' ], time:0.19356799ms
104
+ ----------------------------------------------------------------------------------------------------
105
+ S=8192, H=8192
106
+ ----------------------------------------------------------------------------------------------------
107
+ out_f16x8packf32(safe): [' 4.208e-05 ' , ' 0.00015438 ' , ' 7.409e-05 ' ], time:0.39828229ms
108
+ out_f16_th(per): [' 4.208e-05 ' , ' 0.00015438 ' , ' 7.409e-05 ' ], time:0.40599036ms
109
+ ----------------------------------------------------------------------------------------------------
110
+ ```
0 commit comments