1616#define HALF2 (value ) (reinterpret_cast <half2*>(&(value))[0 ])
1717#define BFLOAT2 (value ) (reinterpret_cast <__nv_bfloat162*>(&(value))[0 ])
1818#define LDST128BITS (value ) (reinterpret_cast <float4 *>(&(value))[0 ])
19+ // DS required for Online Softmax
20+ struct __align__ (8 ) MD
21+ {
22+ float m;
23+ float d;
24+ };
1925
2026// -------------------------------------- FP32 --------------------------------------
27+ // Warp Reduce for Online Softmax
28+
29+ template <const int kWarpSize = WARP_SIZE >
30+ __device__ __forceinline__ MD warp_reduce_md_op (MD value) {
31+ unsigned int mask = 0xffffffff ;
32+ #pragma unroll
33+ for (int stride = kWarpSize >> 1 ; stride >= 1 ; stride >>= 1 ) {
34+ MD other;
35+ other.m = __shfl_xor_sync (mask, value.m , stride);
36+ other.d = __shfl_xor_sync (mask, value.d , stride);
37+
38+ bool value_bigger = (value.m > other.m );
39+ MD bigger_m = value_bigger ? value : other;
40+ MD smaller_m = value_bigger ? other : value;
41+
42+ value.d = bigger_m.d + smaller_m.d * __expf (smaller_m.m - bigger_m.m );
43+ value.m = bigger_m.m ;
44+ }
45+ return value;
46+ }
47+
2148// Warp Reduce Sum
2249template <const int kWarpSize = WARP_SIZE>
2350__device__ __forceinline__ float warp_reduce_sum_f32 (float val) {
@@ -289,6 +316,40 @@ __global__ void safe_softmax_f16x8_pack_f32_per_token_kernel(half* x, half* y, i
289316 // TODO: support non 8-multiple K here
290317}
291318
319+ template <const int NUM_THREADS = 256 >
320+ __global__ void online_softmax_f32_per_token_kernel (const float * x, float * y, int N) {
321+
322+ int local_tid = threadIdx .x ;
323+ int global_tid = blockIdx .x * NUM_THREADS + threadIdx .x ;
324+ const int WAPR_NUM = NUM_THREADS / WARP_SIZE;
325+ int warp_id = local_tid / WARP_SIZE;
326+ int lane_id = local_tid % WARP_SIZE;
327+ MD val;
328+ val.m = global_tid < N ? x[global_tid] : -FLT_MAX;
329+ val.d = global_tid < N ? 1 .0f : 0 .0f ;
330+
331+ __shared__ MD shared[ WAPR_NUM ];
332+ MD res = warp_reduce_md_op<WARP_SIZE>(val);
333+
334+ if (lane_id == 0 ) shared[warp_id] = res;
335+ __syncthreads ();
336+
337+ if (local_tid < WARP_SIZE) {
338+ MD block_res = shared[local_tid];
339+ block_res = warp_reduce_md_op<WAPR_NUM>(block_res);
340+ if (local_tid == 0 ) {
341+ shared[0 ] = block_res;
342+ }
343+ }
344+ __syncthreads ();
345+
346+ MD final_res = shared[0 ];
347+ float d_total_inverse = __fdividef (1 .0f , final_res.d );
348+ if (global_tid < N) {
349+ y[global_tid] = __expf (x[global_tid] - final_res.m ) * d_total_inverse;
350+ }
351+ }
352+
292353// --------------------- PyTorch bindings for custom kernel -----------------------
293354#define STRINGFY (str ) #str
294355#define TORCH_BINDING_COMMON_EXTENSION (func ) \
@@ -440,6 +501,41 @@ safe_softmax_f32_per_token_kernel<(H)><<<grid, block>>>( \
440501 break ; \
441502 }
442503
504+ // online softmax per token
505+ #define LANUCH_ONLINE_SOFTMAX_F32_PER_TOKEN_KERNEL (H ) \
506+ online_softmax_f32_per_token_kernel<(H)><<<grid, block>>> ( \
507+ reinterpret_cast <float *>(x.data_ptr()), \
508+ reinterpret_cast <float *>(y.data_ptr()), \
509+ N);
510+
511+ #define DISPATCH_ONLINE_SOFTMAX_F32_PER_TOKEN_KERNEL (S, H ) \
512+ dim3 block ((H)); \
513+ dim3 grid ((S)); \
514+ switch ((H)) \
515+ { \
516+ case 32 : \
517+ LANUCH_ONLINE_SOFTMAX_F32_PER_TOKEN_KERNEL (32 ) \
518+ break ; \
519+ case 64 : \
520+ LANUCH_ONLINE_SOFTMAX_F32_PER_TOKEN_KERNEL (64 ) \
521+ break ; \
522+ case 128 : \
523+ LANUCH_ONLINE_SOFTMAX_F32_PER_TOKEN_KERNEL (128 ) \
524+ break ; \
525+ case 256 : \
526+ LANUCH_ONLINE_SOFTMAX_F32_PER_TOKEN_KERNEL (256 ) \
527+ break ; \
528+ case 512 : \
529+ LANUCH_ONLINE_SOFTMAX_F32_PER_TOKEN_KERNEL (512 ) \
530+ break ; \
531+ case 1024 : \
532+ LANUCH_ONLINE_SOFTMAX_F32_PER_TOKEN_KERNEL (1024 ) \
533+ break ; \
534+ default : \
535+ throw std::runtime_error ( \
536+ " only support H: 64/128/256/512/1024" ); \
537+ break ; \
538+ }
443539#define LANUCH_SAFE_SOFTMAX_F32x4_PER_TOKEN_KERNEL (H ) \
444540safe_softmax_f32x4_per_token_kernel<(H)/4 ><<< \
445541 grid, block>>> ( \
@@ -674,6 +770,16 @@ void safe_softmax_f16x8_pack_f32_per_token(torch::Tensor x, torch::Tensor y) {
674770 DISPATCH_SATE_SOFTMAX_F16x8_PACK_F32_PER_TOKEN_KERNEL (S, H)
675771}
676772
773+ void online_softmax_f32_per_token (torch::Tensor x, torch::Tensor y) {
774+ CHECK_TORCH_TENSOR_DTYPE (x, torch::kFloat32 )
775+ CHECK_TORCH_TENSOR_DTYPE (y, torch::kFloat32 )
776+ CHECK_TORCH_TENSOR_SHAPE (x, y)
777+ const int S = x.size (0 ); // seqlens
778+ const int H = x.size (1 ); // head size/kv_len
779+ const int N = S * H;
780+ DISPATCH_ONLINE_SOFTMAX_F32_PER_TOKEN_KERNEL (S, H)
781+ }
782+
677783// grid memory fence fp32
678784TORCH_BINDING_SOFTMAX (f32 , torch::kFloat32 , float , 1 )
679785TORCH_BINDING_SOFTMAX(f32x4, torch::kFloat32 , float , 4 )
@@ -688,4 +794,5 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
688794 TORCH_BINDING_COMMON_EXTENSION (safe_softmax_f16_f32_per_token)
689795 TORCH_BINDING_COMMON_EXTENSION (safe_softmax_f16x2_f32_per_token)
690796 TORCH_BINDING_COMMON_EXTENSION (safe_softmax_f16x8_pack_f32_per_token)
797+ TORCH_BINDING_COMMON_EXTENSION (online_softmax_f32_per_token)
691798}
0 commit comments