Skip to content

Commit 11d7072

Browse files
authored
[Softmax] Add online softmax f32x4 pack kernel (#73)
* [Softmax] Add online softmax f32x4 pack kernel * [Softmax] Add online softmax f32x4 pack kernel
1 parent c061812 commit 11d7072

File tree

3 files changed

+141
-48
lines changed

3 files changed

+141
-48
lines changed

softmax/README.md

Lines changed: 52 additions & 48 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
- [X] safe_softmax_f16x2_f32_per_token_kernel(per token)
1515
- [X] safe_softmax_f16x8_pack_f32_per_token_kernel(per token)
1616
- [X] online_safe_softmax_f32_per_token_kernel(per token, online softmax)
17+
- [X] online_safe_softmax_f32x4_pack_per_token_kernel(per token, online softmax)
1718
- [X] PyTorch bindings
1819

1920

@@ -31,84 +32,87 @@ python3 softmax.py
3132
----------------------------------------------------------------------------------------------------
3233
N=16384
3334
----------------------------------------------------------------------------------------------------
34-
out_f32(fence): ['3.359e-05 ', '1.657e-05 ', '0.0001522 '], time:0.01000977ms
35-
out_f32x4(fence): ['3.359e-05 ', '1.657e-05 ', '0.0001522 '], time:0.01015735ms
36-
out_f32_th: ['3.359e-05 ', '1.657e-05 ', '0.0001522 '], time:0.00575948ms
35+
out_f32(fence): ['0.00011554 ', '1.172e-05 ', '3.789e-05 '], time:0.00707126ms
36+
out_f32x4(fence): ['0.00011554 ', '1.172e-05 ', '3.789e-05 '], time:0.00714874ms
37+
out_f32_th: ['0.00011554 ', '1.172e-05 ', '3.789e-05 '], time:0.00871110ms
3738
----------------------------------------------------------------------------------------------------
3839
S=4096, H=256
3940
----------------------------------------------------------------------------------------------------
40-
out_f32(per): ['0.00425925 ', '0.00819569 ', '0.00073704 '], time:0.00633717ms
41-
out_f32x4(per): ['0.00425925 ', '0.00819569 ', '0.00073704 '], time:0.00395060ms
42-
out_f32(safe): ['0.00425925 ', '0.00819569 ', '0.00073704 '], time:0.00937152ms
43-
out_f32(safe+online): ['0.00425925 ', '0.00819569 ', '0.00073704 '], time:0.00749898ms
44-
out_f32x4(safe): ['0.00425925 ', '0.00819569 ', '0.00073704 '], time:0.00413203ms
45-
out_f32_th(per): ['0.00425925 ', '0.00819569 ', '0.00073704 '], time:0.00574470ms
41+
out_f32(per): ['0.00489144 ', '0.00030952 ', '0.00112878 '], time:0.01259184ms
42+
out_f32x4(per): ['0.00489144 ', '0.00030952 ', '0.00112878 '], time:0.01004362ms
43+
out_f32(safe): ['0.00489144 ', '0.00030952 ', '0.00112878 '], time:0.01583433ms
44+
out_f32(safe+online): ['0.00489144 ', '0.00030952 ', '0.00112878 '], time:0.01357031ms
45+
out_f32x4(safe+online): ['0.00489145 ', '0.00030952 ', '0.00112878 '], time:0.01050377ms
46+
out_f32x4(safe): ['0.00489144 ', '0.00030952 ', '0.00112878 '], time:0.01027584ms
47+
out_f32_th(per): ['0.00489144 ', '0.00030952 ', '0.00112878 '], time:0.01042914ms
4648
----------------------------------------------------------------------------------------------------
47-
out_f16f32(safe): ['0.00426102 ', '0.00819397 ', '0.00073671 '], time:0.00907254ms
48-
out_f16x2f32(safe): ['0.00426102 ', '0.00819397 ', '0.00073671 '], time:0.00526237ms
49-
out_f16x8packf32(safe): ['0.00426102 ', '0.00819397 ', '0.00073671 '], time:0.00414038ms
50-
out_f16_th(per): ['0.00426102 ', '0.00819397 ', '0.00073671 '], time:0.00579095ms
49+
out_f16f32(safe): ['0.00489044 ', '0.00030971 ', '0.00112915 '], time:0.01418757ms
50+
out_f16x2f32(safe): ['0.00489044 ', '0.00030971 ', '0.00112915 '], time:0.00781608ms
51+
out_f16x8packf32(safe): ['0.00489044 ', '0.00030971 ', '0.00112915 '], time:0.00523329ms
52+
out_f16_th(per): ['0.00489044 ', '0.00030971 ', '0.00112915 '], time:0.00563836ms
5153
----------------------------------------------------------------------------------------------------
5254
----------------------------------------------------------------------------------------------------
5355
S=4096, H=512
5456
----------------------------------------------------------------------------------------------------
55-
out_f32(per): ['0.00203266 ', '7.054e-05 ', '0.00042398 '], time:0.01142383ms
56-
out_f32x4(per): ['0.00203266 ', '7.054e-05 ', '0.00042398 '], time:0.00514126ms
57-
out_f32(safe): ['0.00203266 ', '7.054e-05 ', '0.00042398 '], time:0.01835704ms
58-
out_f32(safe+online): ['0.00203266 ', '7.054e-05 ', '0.00042398 '], time:0.01364374ms
59-
out_f32x4(safe): ['0.00203266 ', '7.054e-05 ', '0.00042398 '], time:0.00578308ms
60-
out_f32_th(per): ['0.00203266 ', '7.054e-05 ', '0.00042398 '], time:0.00650859ms
57+
out_f32(per): ['0.00042486 ', '0.00308358 ', '0.00113099 '], time:0.02372313ms
58+
out_f32x4(per): ['0.00042486 ', '0.00308358 ', '0.00113099 '], time:0.02219534ms
59+
out_f32(safe): ['0.00042486 ', '0.00308358 ', '0.00113099 '], time:0.03100491ms
60+
out_f32(safe+online): ['0.00042486 ', '0.00308358 ', '0.00113099 '], time:0.02549100ms
61+
out_f32x4(safe+online): ['0.00042486 ', '0.00308358 ', '0.00113099 '], time:0.02228165ms
62+
out_f32x4(safe): ['0.00042486 ', '0.00308358 ', '0.00113099 '], time:0.02230835ms
63+
out_f32_th(per): ['0.00042486 ', '0.00308358 ', '0.00113099 '], time:0.02294350ms
6164
----------------------------------------------------------------------------------------------------
62-
out_f16f32(safe): ['0.00203323 ', '7.057e-05 ', '0.00042415 '], time:0.01780558ms
63-
out_f16x2f32(safe): ['0.00203323 ', '7.057e-05 ', '0.00042415 '], time:0.00920749ms
64-
out_f16x8packf32(safe): ['0.00203323 ', '7.057e-05 ', '0.00042415 '], time:0.00416279ms
65-
out_f16_th(per): ['0.00203323 ', '7.057e-05 ', '0.00042415 '], time:0.00592852ms
65+
out_f16f32(safe): ['0.00042486 ', '0.00308418 ', '0.00113106 '], time:0.02967048ms
66+
out_f16x2f32(safe): ['0.00042486 ', '0.00308418 ', '0.00113106 '], time:0.01563406ms
67+
out_f16x8packf32(safe): ['0.00042486 ', '0.00308418 ', '0.00113106 '], time:0.01033092ms
68+
out_f16_th(per): ['0.00042486 ', '0.00308418 ', '0.00113106 '], time:0.01410413ms
6669
----------------------------------------------------------------------------------------------------
6770
----------------------------------------------------------------------------------------------------
6871
S=4096, H=1024
6972
----------------------------------------------------------------------------------------------------
70-
out_f32(per): ['4.202e-05 ', '0.00064992 ', '0.00070006 '], time:0.03191423ms
71-
out_f32x4(per): ['4.202e-05 ', '0.00064992 ', '0.00070006 '], time:0.00858426ms
72-
out_f32(safe): ['4.202e-05 ', '0.00064992 ', '0.00070006 '], time:0.04868317ms
73-
out_f32(safe+online): ['4.202e-05 ', '0.00064992 ', '0.00070006 '], time:0.03698754ms
74-
out_f32x4(safe): ['4.202e-05 ', '0.00064992 ', '0.00070006 '], time:0.01025891ms
75-
out_f32_th(per): ['4.202e-05 ', '0.00064992 ', '0.00070006 '], time:0.01172018ms
73+
out_f32(per): ['0.00015042 ', '0.00127817 ', '0.00087939 '], time:0.06144118ms
74+
out_f32x4(per): ['0.00015042 ', '0.00127817 ', '0.00087939 '], time:0.04208207ms
75+
out_f32(safe): ['0.00015042 ', '0.00127817 ', '0.00087939 '], time:0.08846235ms
76+
out_f32(safe+online): ['0.00015042 ', '0.00127817 ', '0.00087939 '], time:0.06275535ms
77+
out_f32x4(safe+online): ['0.00015042 ', '0.00127817 ', '0.00087939 '], time:0.04195666ms
78+
out_f32x4(safe): ['0.00015042 ', '0.00127817 ', '0.00087939 '], time:0.04199767ms
79+
out_f32_th(per): ['0.00015042 ', '0.00127817 ', '0.00087939 '], time:0.04214501ms
7680
----------------------------------------------------------------------------------------------------
77-
out_f16f32(safe): ['4.202e-05 ', '0.00064993 ', '0.0007 '], time:0.04668665ms
78-
out_f16x2f32(safe): ['4.202e-05 ', '0.00064993 ', '0.0007 '], time:0.01805592ms
79-
out_f16x8packf32(safe): ['4.202e-05 ', '0.00064993 ', '0.0007 '], time:0.00600147ms
80-
out_f16_th(per): ['4.202e-05 ', '0.00064993 ', '0.0007 '], time:0.01042104ms
81+
out_f16f32(safe): ['0.00015044 ', '0.00127792 ', '0.00087929 '], time:0.07461023ms
82+
out_f16x2f32(safe): ['0.00015044 ', '0.00127792 ', '0.00087929 '], time:0.02805471ms
83+
out_f16x8packf32(safe): ['0.00015044 ', '0.00127792 ', '0.00087929 '], time:0.02210021ms
84+
out_f16_th(per): ['0.00015044 ', '0.00127792 ', '0.00087929 '], time:0.02429175ms
8185
----------------------------------------------------------------------------------------------------
8286
----------------------------------------------------------------------------------------------------
8387
S=4096, H=2048
8488
----------------------------------------------------------------------------------------------------
85-
out_f32x4(per): ['0.00068028 ', '0.00138677 ', '0.00012553 '], time:0.01602578ms
86-
out_f32x4(safe): ['0.00068028 ', '0.00138677 ', '0.00012553 '], time:0.02085137ms
87-
out_f32_th(per): ['0.00068028 ', '0.00138677 ', '0.00012553 '], time:0.06727862ms
89+
out_f32x4(per): ['0.00014777 ', '0.00018938 ', '9.769e-05 '], time:0.08160353ms
90+
out_f32x4(safe): ['0.00014777 ', '0.00018938 ', '9.769e-05 '], time:0.08181977ms
91+
out_f32_th(per): ['0.00014777 ', '0.00018938 ', '9.769e-05 '], time:0.10212374ms
8892
----------------------------------------------------------------------------------------------------
89-
out_f16x2f32(safe): ['0.00067997 ', '0.00138664 ', '0.00012553 '], time:0.04822373ms
90-
out_f16x8packf32(safe): ['0.00067997 ', '0.00138664 ', '0.00012553 '], time:0.01078343ms
91-
out_f16_th(per): ['0.00067997 ', '0.00138664 ', '0.00012553 '], time:0.07226229ms
93+
out_f16x2f32(safe): ['0.0001477 ', '0.00018942 ', '9.769e-05 '], time:0.07831120ms
94+
out_f16x8packf32(safe): ['0.0001477 ', '0.00018942 ', '9.769e-05 '], time:0.04206920ms
95+
out_f16_th(per): ['0.0001477 ', '0.00018942 ', '9.769e-05 '], time:0.05331278ms
9296
----------------------------------------------------------------------------------------------------
9397
----------------------------------------------------------------------------------------------------
9498
S=4096, H=4096
9599
----------------------------------------------------------------------------------------------------
96-
out_f32x4(per): ['3.5e-05 ', '8.788e-05 ', '0.00017372 '], time:0.18450212ms
97-
out_f32x4(safe): ['3.5e-05 ', '8.788e-05 ', '0.00017372 '], time:0.18548727ms
98-
out_f32_th(per): ['3.5e-05 ', '8.788e-05 ', '0.00017372 '], time:0.18735909ms
100+
out_f32x4(per): ['4.063e-05 ', '0.00038625 ', '0.00019391 '], time:0.16202784ms
101+
out_f32x4(safe): ['4.063e-05 ', '0.00038625 ', '0.00019391 '], time:0.16271973ms
102+
out_f32_th(per): ['4.063e-05 ', '0.00038625 ', '0.00019391 '], time:0.19028711ms
99103
----------------------------------------------------------------------------------------------------
100-
out_f16x8packf32(safe): ['3.499e-05 ', '8.792e-05 ', '0.00017369 '], time:0.02230954ms
101-
out_f16_th(per): ['3.499e-05 ', '8.792e-05 ', '0.00017369 '], time:0.08258724ms
104+
out_f16x8packf32(safe): ['4.065e-05 ', '0.00038624 ', '0.00019383 '], time:0.08193207ms
105+
out_f16_th(per): ['4.065e-05 ', '0.00038624 ', '0.00019383 '], time:0.10132599ms
102106
----------------------------------------------------------------------------------------------------
103107
----------------------------------------------------------------------------------------------------
104108
S=4096, H=8192
105109
----------------------------------------------------------------------------------------------------
106-
out_f16x8packf32(safe): ['8.47e-05 ', '0.00048876 ', '2.718e-05 '], time:0.19314885ms
107-
out_f16_th(per): ['8.47e-05 ', '0.00048876 ', '2.718e-05 '], time:0.19355965ms
110+
out_f16x8packf32(safe): ['0.00044656 ', '1.872e-05 ', '0.00054884 '], time:0.16337919ms
111+
out_f16_th(per): ['0.00044656 ', '1.872e-05 ', '0.00054884 '], time:0.18709970ms
108112
----------------------------------------------------------------------------------------------------
109113
S=8192, H=8192
110114
----------------------------------------------------------------------------------------------------
111-
out_f16x8packf32(safe): ['5.829e-05 ', '8.482e-05 ', '0.00021875 '], time:0.39851356ms
112-
out_f16_th(per): ['5.829e-05 ', '8.482e-05 ', '0.00021875 '], time:0.40570927ms
115+
out_f16x8packf32(safe): ['4.601e-05 ', '9.853e-05 ', '1.711e-05 '], time:0.32324409ms
116+
out_f16_th(per): ['4.601e-05 ', '9.853e-05 ', '1.711e-05 '], time:0.36632204ms
113117
----------------------------------------------------------------------------------------------------
114118
```

softmax/softmax.cu

Lines changed: 86 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -345,6 +345,50 @@ __global__ void online_safe_softmax_f32_per_token_kernel(const float* x, float*
345345
}
346346
}
347347

348+
template <const int NUM_THREADS = 256 / 4>
349+
__global__ void online_safe_softmax_f32x4_pack_per_token_kernel(float *x, float *y, int N)
350+
{
351+
// reference: https://arxiv.org/pdf/1805.02867 (Online normalizer calculation for softmax)
352+
int local_tid = threadIdx.x;
353+
int global_tid = (blockIdx.x * NUM_THREADS + local_tid) * 4;
354+
355+
const int WAPR_NUM = NUM_THREADS / WARP_SIZE;
356+
int warp_id = local_tid / WARP_SIZE;
357+
int lane_id = local_tid % WARP_SIZE;
358+
// compare local max value
359+
float4 val = FLOAT4((x)[global_tid]);
360+
float local_m = fmaxf(fmaxf(val.x, val.y), fmaxf(val.z, val.w));
361+
float local_d = __expf(val.x - local_m) + __expf(val.y - local_m) + __expf(val.z - local_m) + __expf(val.w - local_m);
362+
363+
364+
MD local_md = {local_m, local_d};
365+
MD res = warp_reduce_md_op<WARP_SIZE>(local_md);
366+
__shared__ MD shared[WAPR_NUM];
367+
368+
if (lane_id == 0) shared[warp_id] = res;
369+
__syncthreads();
370+
// do block reduce
371+
if (local_tid < WARP_SIZE)
372+
{
373+
MD block_res = shared[local_tid];
374+
block_res = warp_reduce_md_op<WAPR_NUM>(block_res);
375+
if (local_tid == 0) shared[0] = block_res;
376+
}
377+
__syncthreads();
378+
// write back
379+
MD final_res = shared[0];
380+
float d_total_inverse = __fdividef(1.0f, final_res.d);
381+
if (global_tid < N)
382+
{
383+
float4 reg_y;
384+
reg_y.x = __expf(val.x - final_res.m) * d_total_inverse;
385+
reg_y.y = __expf(val.y - final_res.m) * d_total_inverse;
386+
reg_y.z = __expf(val.z - final_res.m) * d_total_inverse;
387+
reg_y.w = __expf(val.w - final_res.m) * d_total_inverse;
388+
FLOAT4((y)[global_tid]) = reg_y;
389+
}
390+
}
391+
348392
// --------------------- PyTorch bindings for custom kernel -----------------------
349393
#define STRINGFY(str) #str
350394
#define TORCH_BINDING_COMMON_EXTENSION(func) \
@@ -531,6 +575,37 @@ online_safe_softmax_f32_per_token_kernel<(H)><<<grid, block>>>( \
531575
"only support H: 64/128/256/512/1024"); \
532576
break; \
533577
}
578+
579+
// online softmax per token
580+
#define LANUCH_ONLINE_SOFTMAX_F32X4_PACK_PER_TOKEN_KERNEL(H) \
581+
online_safe_softmax_f32x4_pack_per_token_kernel<(H/4)><<<grid, block>>>( \
582+
reinterpret_cast<float*>(x.data_ptr()), \
583+
reinterpret_cast<float*>(y.data_ptr()), \
584+
N);
585+
586+
#define DISPATCH_ONLINE_SOFTMAX_F32X4_PACK_PER_TOKEN_KERNEL(S, H) \
587+
dim3 block((H/4)); \
588+
dim3 grid((S)); \
589+
switch ((H)) \
590+
{ \
591+
case 128: \
592+
LANUCH_ONLINE_SOFTMAX_F32X4_PACK_PER_TOKEN_KERNEL(128) \
593+
break; \
594+
case 256: \
595+
LANUCH_ONLINE_SOFTMAX_F32X4_PACK_PER_TOKEN_KERNEL(256) \
596+
break; \
597+
case 512: \
598+
LANUCH_ONLINE_SOFTMAX_F32X4_PACK_PER_TOKEN_KERNEL(512) \
599+
break; \
600+
case 1024: \
601+
LANUCH_ONLINE_SOFTMAX_F32X4_PACK_PER_TOKEN_KERNEL(1024) \
602+
break; \
603+
default: \
604+
throw std::runtime_error( \
605+
"only support H: 128/256/512/1024; raise error if warp_num*4 > H"); \
606+
break; \
607+
}
608+
534609
#define LANUCH_SAFE_SOFTMAX_F32x4_PER_TOKEN_KERNEL(H) \
535610
safe_softmax_f32x4_per_token_kernel<(H)/4><<< \
536611
grid, block>>>( \
@@ -775,6 +850,16 @@ void online_safe_softmax_f32_per_token(torch::Tensor x, torch::Tensor y) {
775850
DISPATCH_ONLINE_SOFTMAX_F32_PER_TOKEN_KERNEL(S, H)
776851
}
777852

853+
void online_safe_softmax_f32x4_pack_per_token(torch::Tensor x, torch::Tensor y) {
854+
CHECK_TORCH_TENSOR_DTYPE(x, torch::kFloat32)
855+
CHECK_TORCH_TENSOR_DTYPE(y, torch::kFloat32)
856+
CHECK_TORCH_TENSOR_SHAPE(x, y)
857+
const int S = x.size(0);
858+
const int H = x.size(1);
859+
const int N = S * H;
860+
DISPATCH_ONLINE_SOFTMAX_F32X4_PACK_PER_TOKEN_KERNEL(S, H)
861+
}
862+
778863
// grid memory fence fp32
779864
TORCH_BINDING_SOFTMAX(f32, torch::kFloat32, float, 1)
780865
TORCH_BINDING_SOFTMAX(f32x4, torch::kFloat32, float, 4)
@@ -790,4 +875,5 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
790875
TORCH_BINDING_COMMON_EXTENSION(safe_softmax_f16x2_f32_per_token)
791876
TORCH_BINDING_COMMON_EXTENSION(safe_softmax_f16x8_pack_f32_per_token)
792877
TORCH_BINDING_COMMON_EXTENSION(online_safe_softmax_f32_per_token)
878+
TORCH_BINDING_COMMON_EXTENSION(online_safe_softmax_f32x4_pack_per_token)
793879
}

softmax/softmax.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -77,6 +77,7 @@ def run_benchmark(perf_func: callable, x: torch.Tensor,
7777
run_benchmark(lib.softmax_f32x4_per_token, x, "f32x4(per)", out)
7878
run_benchmark(lib.safe_softmax_f32_per_token, x, "f32(safe)", out)
7979
run_benchmark(lib.online_safe_softmax_f32_per_token, x, "f32(safe+online)", out)
80+
run_benchmark(lib.online_safe_softmax_f32x4_pack_per_token, x, "f32x4(safe+online)", out)
8081
run_benchmark(lib.safe_softmax_f32x4_per_token, x, "f32x4(safe)", out)
8182
run_benchmark(partial(torch.softmax, dim=1, out=out), x, "f32_th(per)")
8283

@@ -100,6 +101,7 @@ def run_benchmark(perf_func: callable, x: torch.Tensor,
100101
run_benchmark(lib.softmax_f32x4_per_token, x, "f32x4(per)", out)
101102
run_benchmark(lib.safe_softmax_f32_per_token, x, "f32(safe)", out)
102103
run_benchmark(lib.online_safe_softmax_f32_per_token, x, "f32(safe+online)", out)
104+
run_benchmark(lib.online_safe_softmax_f32x4_pack_per_token, x, "f32x4(safe+online)", out)
103105
run_benchmark(lib.safe_softmax_f32x4_per_token, x, "f32x4(safe)", out)
104106
run_benchmark(partial(torch.softmax, dim=1, out=out), x, "f32_th(per)")
105107

@@ -123,6 +125,7 @@ def run_benchmark(perf_func: callable, x: torch.Tensor,
123125
run_benchmark(lib.softmax_f32x4_per_token, x, "f32x4(per)", out)
124126
run_benchmark(lib.safe_softmax_f32_per_token, x, "f32(safe)", out)
125127
run_benchmark(lib.online_safe_softmax_f32_per_token, x, "f32(safe+online)", out)
128+
run_benchmark(lib.online_safe_softmax_f32x4_pack_per_token, x, "f32x4(safe+online)", out)
126129
run_benchmark(lib.safe_softmax_f32x4_per_token, x, "f32x4(safe)", out)
127130
run_benchmark(partial(torch.softmax, dim=1, out=out), x, "f32_th(per)")
128131

0 commit comments

Comments
 (0)