1414- [X] safe_softmax_f16x2_f32_per_token_kernel(per token)
1515- [X] safe_softmax_f16x8_pack_f32_per_token_kernel(per token)
1616- [X] online_safe_softmax_f32_per_token_kernel(per token, online softmax)
17+ - [X] online_safe_softmax_f32x4_pack_per_token_kernel(per token, online softmax)
1718- [X] PyTorch bindings
1819
1920
@@ -31,84 +32,87 @@ python3 softmax.py
3132----------------------------------------------------------------------------------------------------
3233 N=16384
3334----------------------------------------------------------------------------------------------------
34- out_f32(fence): [' 3.359e-05 ' , ' 1.657e -05 ' , ' 0.0001522 ' ], time:0.01000977ms
35- out_f32x4(fence): [' 3.359e-05 ' , ' 1.657e -05 ' , ' 0.0001522 ' ], time:0.01015735ms
36- out_f32_th: [' 3.359e-05 ' , ' 1.657e -05 ' , ' 0.0001522 ' ], time:0.00575948ms
35+ out_f32(fence): [' 0.00011554 ' , ' 1.172e -05 ' , ' 3.789e-05 ' ], time:0.00707126ms
36+ out_f32x4(fence): [' 0.00011554 ' , ' 1.172e -05 ' , ' 3.789e-05 ' ], time:0.00714874ms
37+ out_f32_th: [' 0.00011554 ' , ' 1.172e -05 ' , ' 3.789e-05 ' ], time:0.00871110ms
3738----------------------------------------------------------------------------------------------------
3839 S=4096, H=256
3940----------------------------------------------------------------------------------------------------
40- out_f32(per): [' 0.00425925 ' , ' 0.00819569 ' , ' 0.00073704 ' ], time:0.00633717ms
41- out_f32x4(per): [' 0.00425925 ' , ' 0.00819569 ' , ' 0.00073704 ' ], time:0.00395060ms
42- out_f32(safe): [' 0.00425925 ' , ' 0.00819569 ' , ' 0.00073704 ' ], time:0.00937152ms
43- out_f32(safe+online): [' 0.00425925 ' , ' 0.00819569 ' , ' 0.00073704 ' ], time:0.00749898ms
44- out_f32x4(safe): [' 0.00425925 ' , ' 0.00819569 ' , ' 0.00073704 ' ], time:0.00413203ms
45- out_f32_th(per): [' 0.00425925 ' , ' 0.00819569 ' , ' 0.00073704 ' ], time:0.00574470ms
41+ out_f32(per): [' 0.00489144 ' , ' 0.00030952 ' , ' 0.00112878 ' ], time:0.01259184ms
42+ out_f32x4(per): [' 0.00489144 ' , ' 0.00030952 ' , ' 0.00112878 ' ], time:0.01004362ms
43+ out_f32(safe): [' 0.00489144 ' , ' 0.00030952 ' , ' 0.00112878 ' ], time:0.01583433ms
44+ out_f32(safe+online): [' 0.00489144 ' , ' 0.00030952 ' , ' 0.00112878 ' ], time:0.01357031ms
45+ out_f32x4(safe+online): [' 0.00489145 ' , ' 0.00030952 ' , ' 0.00112878 ' ], time:0.01050377ms
46+ out_f32x4(safe): [' 0.00489144 ' , ' 0.00030952 ' , ' 0.00112878 ' ], time:0.01027584ms
47+ out_f32_th(per): [' 0.00489144 ' , ' 0.00030952 ' , ' 0.00112878 ' ], time:0.01042914ms
4648----------------------------------------------------------------------------------------------------
47- out_f16f32(safe): [' 0.00426102 ' , ' 0.00819397 ' , ' 0.00073671 ' ], time:0.00907254ms
48- out_f16x2f32(safe): [' 0.00426102 ' , ' 0.00819397 ' , ' 0.00073671 ' ], time:0.00526237ms
49- out_f16x8packf32(safe): [' 0.00426102 ' , ' 0.00819397 ' , ' 0.00073671 ' ], time:0.00414038ms
50- out_f16_th(per): [' 0.00426102 ' , ' 0.00819397 ' , ' 0.00073671 ' ], time:0.00579095ms
49+ out_f16f32(safe): [' 0.00489044 ' , ' 0.00030971 ' , ' 0.00112915 ' ], time:0.01418757ms
50+ out_f16x2f32(safe): [' 0.00489044 ' , ' 0.00030971 ' , ' 0.00112915 ' ], time:0.00781608ms
51+ out_f16x8packf32(safe): [' 0.00489044 ' , ' 0.00030971 ' , ' 0.00112915 ' ], time:0.00523329ms
52+ out_f16_th(per): [' 0.00489044 ' , ' 0.00030971 ' , ' 0.00112915 ' ], time:0.00563836ms
5153----------------------------------------------------------------------------------------------------
5254----------------------------------------------------------------------------------------------------
5355 S=4096, H=512
5456----------------------------------------------------------------------------------------------------
55- out_f32(per): [' 0.00203266 ' , ' 7.054e-05 ' , ' 0.00042398 ' ], time:0.01142383ms
56- out_f32x4(per): [' 0.00203266 ' , ' 7.054e-05 ' , ' 0.00042398 ' ], time:0.00514126ms
57- out_f32(safe): [' 0.00203266 ' , ' 7.054e-05 ' , ' 0.00042398 ' ], time:0.01835704ms
58- out_f32(safe+online): [' 0.00203266 ' , ' 7.054e-05 ' , ' 0.00042398 ' ], time:0.01364374ms
59- out_f32x4(safe): [' 0.00203266 ' , ' 7.054e-05 ' , ' 0.00042398 ' ], time:0.00578308ms
60- out_f32_th(per): [' 0.00203266 ' , ' 7.054e-05 ' , ' 0.00042398 ' ], time:0.00650859ms
57+ out_f32(per): [' 0.00042486 ' , ' 0.00308358 ' , ' 0.00113099 ' ], time:0.02372313ms
58+ out_f32x4(per): [' 0.00042486 ' , ' 0.00308358 ' , ' 0.00113099 ' ], time:0.02219534ms
59+ out_f32(safe): [' 0.00042486 ' , ' 0.00308358 ' , ' 0.00113099 ' ], time:0.03100491ms
60+ out_f32(safe+online): [' 0.00042486 ' , ' 0.00308358 ' , ' 0.00113099 ' ], time:0.02549100ms
61+ out_f32x4(safe+online): [' 0.00042486 ' , ' 0.00308358 ' , ' 0.00113099 ' ], time:0.02228165ms
62+ out_f32x4(safe): [' 0.00042486 ' , ' 0.00308358 ' , ' 0.00113099 ' ], time:0.02230835ms
63+ out_f32_th(per): [' 0.00042486 ' , ' 0.00308358 ' , ' 0.00113099 ' ], time:0.02294350ms
6164----------------------------------------------------------------------------------------------------
62- out_f16f32(safe): [' 0.00203323 ' , ' 7.057e-05 ' , ' 0.00042415 ' ], time:0.01780558ms
63- out_f16x2f32(safe): [' 0.00203323 ' , ' 7.057e-05 ' , ' 0.00042415 ' ], time:0.00920749ms
64- out_f16x8packf32(safe): [' 0.00203323 ' , ' 7.057e-05 ' , ' 0.00042415 ' ], time:0.00416279ms
65- out_f16_th(per): [' 0.00203323 ' , ' 7.057e-05 ' , ' 0.00042415 ' ], time:0.00592852ms
65+ out_f16f32(safe): [' 0.00042486 ' , ' 0.00308418 ' , ' 0.00113106 ' ], time:0.02967048ms
66+ out_f16x2f32(safe): [' 0.00042486 ' , ' 0.00308418 ' , ' 0.00113106 ' ], time:0.01563406ms
67+ out_f16x8packf32(safe): [' 0.00042486 ' , ' 0.00308418 ' , ' 0.00113106 ' ], time:0.01033092ms
68+ out_f16_th(per): [' 0.00042486 ' , ' 0.00308418 ' , ' 0.00113106 ' ], time:0.01410413ms
6669----------------------------------------------------------------------------------------------------
6770----------------------------------------------------------------------------------------------------
6871 S=4096, H=1024
6972----------------------------------------------------------------------------------------------------
70- out_f32(per): [' 4.202e-05 ' , ' 0.00064992 ' , ' 0.00070006 ' ], time:0.03191423ms
71- out_f32x4(per): [' 4.202e-05 ' , ' 0.00064992 ' , ' 0.00070006 ' ], time:0.00858426ms
72- out_f32(safe): [' 4.202e-05 ' , ' 0.00064992 ' , ' 0.00070006 ' ], time:0.04868317ms
73- out_f32(safe+online): [' 4.202e-05 ' , ' 0.00064992 ' , ' 0.00070006 ' ], time:0.03698754ms
74- out_f32x4(safe): [' 4.202e-05 ' , ' 0.00064992 ' , ' 0.00070006 ' ], time:0.01025891ms
75- out_f32_th(per): [' 4.202e-05 ' , ' 0.00064992 ' , ' 0.00070006 ' ], time:0.01172018ms
73+ out_f32(per): [' 0.00015042 ' , ' 0.00127817 ' , ' 0.00087939 ' ], time:0.06144118ms
74+ out_f32x4(per): [' 0.00015042 ' , ' 0.00127817 ' , ' 0.00087939 ' ], time:0.04208207ms
75+ out_f32(safe): [' 0.00015042 ' , ' 0.00127817 ' , ' 0.00087939 ' ], time:0.08846235ms
76+ out_f32(safe+online): [' 0.00015042 ' , ' 0.00127817 ' , ' 0.00087939 ' ], time:0.06275535ms
77+ out_f32x4(safe+online): [' 0.00015042 ' , ' 0.00127817 ' , ' 0.00087939 ' ], time:0.04195666ms
78+ out_f32x4(safe): [' 0.00015042 ' , ' 0.00127817 ' , ' 0.00087939 ' ], time:0.04199767ms
79+ out_f32_th(per): [' 0.00015042 ' , ' 0.00127817 ' , ' 0.00087939 ' ], time:0.04214501ms
7680----------------------------------------------------------------------------------------------------
77- out_f16f32(safe): [' 4.202e-05 ' , ' 0.00064993 ' , ' 0.0007 ' ], time:0.04668665ms
78- out_f16x2f32(safe): [' 4.202e-05 ' , ' 0.00064993 ' , ' 0.0007 ' ], time:0.01805592ms
79- out_f16x8packf32(safe): [' 4.202e-05 ' , ' 0.00064993 ' , ' 0.0007 ' ], time:0.00600147ms
80- out_f16_th(per): [' 4.202e-05 ' , ' 0.00064993 ' , ' 0.0007 ' ], time:0.01042104ms
81+ out_f16f32(safe): [' 0.00015044 ' , ' 0.00127792 ' , ' 0.00087929 ' ], time:0.07461023ms
82+ out_f16x2f32(safe): [' 0.00015044 ' , ' 0.00127792 ' , ' 0.00087929 ' ], time:0.02805471ms
83+ out_f16x8packf32(safe): [' 0.00015044 ' , ' 0.00127792 ' , ' 0.00087929 ' ], time:0.02210021ms
84+ out_f16_th(per): [' 0.00015044 ' , ' 0.00127792 ' , ' 0.00087929 ' ], time:0.02429175ms
8185----------------------------------------------------------------------------------------------------
8286----------------------------------------------------------------------------------------------------
8387 S=4096, H=2048
8488----------------------------------------------------------------------------------------------------
85- out_f32x4(per): [' 0.00068028 ' , ' 0.00138677 ' , ' 0.00012553 ' ], time:0.01602578ms
86- out_f32x4(safe): [' 0.00068028 ' , ' 0.00138677 ' , ' 0.00012553 ' ], time:0.02085137ms
87- out_f32_th(per): [' 0.00068028 ' , ' 0.00138677 ' , ' 0.00012553 ' ], time:0.06727862ms
89+ out_f32x4(per): [' 0.00014777 ' , ' 0.00018938 ' , ' 9.769e-05 ' ], time:0.08160353ms
90+ out_f32x4(safe): [' 0.00014777 ' , ' 0.00018938 ' , ' 9.769e-05 ' ], time:0.08181977ms
91+ out_f32_th(per): [' 0.00014777 ' , ' 0.00018938 ' , ' 9.769e-05 ' ], time:0.10212374ms
8892----------------------------------------------------------------------------------------------------
89- out_f16x2f32(safe): [' 0.00067997 ' , ' 0.00138664 ' , ' 0.00012553 ' ], time:0.04822373ms
90- out_f16x8packf32(safe): [' 0.00067997 ' , ' 0.00138664 ' , ' 0.00012553 ' ], time:0.01078343ms
91- out_f16_th(per): [' 0.00067997 ' , ' 0.00138664 ' , ' 0.00012553 ' ], time:0.07226229ms
93+ out_f16x2f32(safe): [' 0.0001477 ' , ' 0.00018942 ' , ' 9.769e-05 ' ], time:0.07831120ms
94+ out_f16x8packf32(safe): [' 0.0001477 ' , ' 0.00018942 ' , ' 9.769e-05 ' ], time:0.04206920ms
95+ out_f16_th(per): [' 0.0001477 ' , ' 0.00018942 ' , ' 9.769e-05 ' ], time:0.05331278ms
9296----------------------------------------------------------------------------------------------------
9397----------------------------------------------------------------------------------------------------
9498 S=4096, H=4096
9599----------------------------------------------------------------------------------------------------
96- out_f32x4(per): [' 3.5e -05 ' , ' 8.788e-05 ' , ' 0.00017372 ' ], time:0.18450212ms
97- out_f32x4(safe): [' 3.5e -05 ' , ' 8.788e-05 ' , ' 0.00017372 ' ], time:0.18548727ms
98- out_f32_th(per): [' 3.5e -05 ' , ' 8.788e-05 ' , ' 0.00017372 ' ], time:0.18735909ms
100+ out_f32x4(per): [' 4.063e -05 ' , ' 0.00038625 ' , ' 0.00019391 ' ], time:0.16202784ms
101+ out_f32x4(safe): [' 4.063e -05 ' , ' 0.00038625 ' , ' 0.00019391 ' ], time:0.16271973ms
102+ out_f32_th(per): [' 4.063e -05 ' , ' 0.00038625 ' , ' 0.00019391 ' ], time:0.19028711ms
99103----------------------------------------------------------------------------------------------------
100- out_f16x8packf32(safe): [' 3.499e -05 ' , ' 8.792e-05 ' , ' 0.00017369 ' ], time:0.02230954ms
101- out_f16_th(per): [' 3.499e -05 ' , ' 8.792e-05 ' , ' 0.00017369 ' ], time:0.08258724ms
104+ out_f16x8packf32(safe): [' 4.065e -05 ' , ' 0.00038624 ' , ' 0.00019383 ' ], time:0.08193207ms
105+ out_f16_th(per): [' 4.065e -05 ' , ' 0.00038624 ' , ' 0.00019383 ' ], time:0.10132599ms
102106----------------------------------------------------------------------------------------------------
103107----------------------------------------------------------------------------------------------------
104108 S=4096, H=8192
105109----------------------------------------------------------------------------------------------------
106- out_f16x8packf32(safe): [' 8.47e-05 ' , ' 0.00048876 ' , ' 2.718e-05 ' ], time:0.19314885ms
107- out_f16_th(per): [' 8.47e-05 ' , ' 0.00048876 ' , ' 2.718e-05 ' ], time:0.19355965ms
110+ out_f16x8packf32(safe): [' 0.00044656 ' , ' 1.872e-05 ' , ' 0.00054884 ' ], time:0.16337919ms
111+ out_f16_th(per): [' 0.00044656 ' , ' 1.872e-05 ' , ' 0.00054884 ' ], time:0.18709970ms
108112----------------------------------------------------------------------------------------------------
109113 S=8192, H=8192
110114----------------------------------------------------------------------------------------------------
111- out_f16x8packf32(safe): [' 5.829e -05 ' , ' 8.482e -05 ' , ' 0.00021875 ' ], time:0.39851356ms
112- out_f16_th(per): [' 5.829e -05 ' , ' 8.482e -05 ' , ' 0.00021875 ' ], time:0.40570927ms
115+ out_f16x8packf32(safe): [' 4.601e -05 ' , ' 9.853e -05 ' , ' 1.711e-05 ' ], time:0.32324409ms
116+ out_f16_th(per): [' 4.601e -05 ' , ' 9.853e -05 ' , ' 1.711e-05 ' ], time:0.36632204ms
113117----------------------------------------------------------------------------------------------------
114118```
0 commit comments