Skip to content

Commit 4667308

Browse files
authored
[LayerNorm][FP16] Add pack support for f16x8 LD/ST (#46)
* Update README.md * Update README.md * Update README.md * Update README.md * Update README.md * Update layer_norm.cu * Update layer_norm.py * Update README.md
1 parent e28cb4d commit 4667308

File tree

4 files changed

+291
-34
lines changed

4 files changed

+291
-34
lines changed

README.md

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
<img src=https://img.shields.io/badge/License-GPLv3.0-turquoise.svg >
1010
</div>
1111

12-
🎉 **CUDA Learn Notes**: This repo aims to build a **Modern CUDA Learn Notes with PyTorch** for beginners, including **fp32, fp16/bf16, fp8/int8, Tensor/CUDA Cores**, flash_attn, sgemm, sgemv, hgemm, hgemv, warp/block reduce, dot prod, elementwise, sigmoid, relu, softmax, layernorm, rmsnorm, hist and some CUDA optimization techniques (pack LDST, warp gemv, sliced_k/split_k/pipeline gemm, bank conflicts free, MMA, etc).
12+
🎉 **CUDA Learn Notes**: This repo aims to build a **Modern CUDA Learn Notes with PyTorch** for **[Beginners]**, including **fp32, fp16/bf16, fp8/int8, Tensor/CUDA Cores**, flash_attn, sgemm, sgemv, hgemm, hgemv, warp/block reduce, dot prod, elementwise, sigmoid, relu, softmax, layernorm, rmsnorm, hist and some CUDA optimization techniques (pack LDST, warp gemv, sliced_k/split_k/pipeline gemm, bank conflicts free, MMA, etc).
1313

1414
<img width="1438" alt="image" src="https://github.com/user-attachments/assets/0c5e5125-586f-43fa-8e8b-e2c61c1afbbe">
1515

@@ -21,7 +21,7 @@
2121

2222
|📖 cuda kernel| 📖 elem dtype| 📖 acc dtype| 📖 docs | 📖 level |
2323
|:---|:---|:---|:---|:---|
24-
| ✔️ [nsys/ncu usage(timeline/ptx/sass)](./nvidia-nsight/)|/|/|[link](./nvidia-nsight/)|⭐️|
24+
| ✔️ [nsys/ncu(timeline/ptx/sass)](./nvidia-nsight/)|/|/|[link](./nvidia-nsight/)|⭐️|
2525
| ✔️ [elementwise_f32](./elementwise/elementwise.cu)|f32|/|[link](./elementwise/)|⭐️|
2626
| ✔️ [elementwise_f32x4](./elementwise/elementwise.cu)|f32|/|[link](./elementwise/)|⭐️|
2727
| ✔️ [elementwise_f16](./elementwise/elementwise.cu)|f16|/|[link](./elementwise/)|⭐️|
@@ -80,6 +80,7 @@
8080
| ✔️ [layer_norm_f16_f16(per token)](./layer-norm/layer_norm.cu)|f16|f16|[link](./layer-norm/)|⭐️⭐️|
8181
| ✔️ [layer_norm_f16x2_f16(per token)](./layer-norm/layer_norm.cu)|f16|f16|[link](./layer-norm/)|⭐️⭐️|
8282
| ✔️ [layer_norm_f16x8_f16(per token)](./layer-norm/layer_norm.cu)|f16|f16|[link](./layer-norm/)|⭐️⭐️|
83+
| ✔️ [layer_norm_f16x8_pack_f16(per token)](./layer-norm/layer_norm.cu)|f16|f16|[link](./layer-norm/)|⭐️⭐️|
8384
| ✔️ [layer_norm_f16_f32(per token)](./layer-norm/layer_norm.cu)|f16|f32|[link](./layer-norm/)|⭐️⭐️|
8485
| ✔️ [rms_norm_f32(per token)](./rms-norm/rms_norm.cu)|f32|f32|[link](./rms-norm/)|⭐️⭐️|
8586
| ✔️ [rms_norm_f32x4(per token)](./rms-norm/rms_norm.cu)|f32|f32|[link](./rms-norm/)|⭐️⭐️|

layer-norm/README.md

Lines changed: 64 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
- [X] layer_norm_f16_f16_kernel
1010
- [X] layer_norm_f16x2_f16_kernel
1111
- [X] layer_norm_f16x8_f16_kernel
12+
- [X] layer_norm_f16x8_pack_f16_kernel
1213
- [X] layer_norm_f16_f32_kernel
1314
- [X] PyTorch bindings
1415

@@ -23,15 +24,67 @@ python3 layer_norm.py
2324
输出:
2425

2526
```bash
26-
--------------------------------------------------------------------------------
27-
out_f32: [0.54253572, -0.13322251, 1.65217566], time:0.01894307ms
28-
out_f32x4: [0.54253572, -0.13322251, 1.65217566], time:0.00595951ms
29-
out_f32_th: [0.54200566, -0.13309236, 1.65056157], time:0.07212615ms
30-
--------------------------------------------------------------------------------
31-
out_f16f16: [0.54248047, -0.13330078, 1.65332031], time:0.01863098ms
32-
out_f16x2f16: [0.54248047, -0.13330078, 1.65332031], time:0.00949597ms
33-
out_f16x8f16: [0.54248047, -0.13317871, 1.65234375], time:0.00466394ms
34-
out_f16f32: [0.54248047, -0.13317871, 1.65234375], time:0.01892662ms
35-
out_f16_th: [0.54199219, -0.13305664, 1.65039062], time:0.07164359ms
36-
--------------------------------------------------------------------------------
27+
-------------------------------------------------------------------------------------
28+
N=4096, K=512
29+
-------------------------------------------------------------------------------------
30+
out_f32: ['-1.76292217 ', '0.04765211 ', '0.50859255 '], time:0.01897240ms
31+
out_f32x4: ['-1.76292217 ', '0.04765211 ', '0.50859255 '], time:0.00600266ms
32+
out_f32_th: ['-1.76119995 ', '0.04760556 ', '0.50809568 '], time:0.07085347ms
33+
-------------------------------------------------------------------------------------
34+
out_f16f16: ['-1.76367188 ', '0.04763794 ', '0.50878906 '], time:0.01869035ms
35+
out_f16f32: ['-1.76367188 ', '0.04766846 ', '0.50878906 '], time:0.01897883ms
36+
out_f16x2f16: ['-1.76367188 ', '0.04766846 ', '0.50878906 '], time:0.00951219ms
37+
out_f16x8f16: ['-1.76367188 ', '0.04766846 ', '0.50878906 '], time:0.00467825ms
38+
out_f16x8packf16: ['-1.76367188 ', '0.04763794 ', '0.50878906 '], time:0.00430202ms
39+
out_f16_th: ['-1.76171875 ', '0.04760742 ', '0.50830078 '], time:0.07009959ms
40+
-------------------------------------------------------------------------------------
41+
-------------------------------------------------------------------------------------
42+
N=4096, K=1024
43+
-------------------------------------------------------------------------------------
44+
out_f32: ['-0.65619785 ', '1.33576787 ', '-0.29172164 '], time:0.05123448ms
45+
out_f32x4: ['-0.65619785 ', '1.33576787 ', '-0.29172164 '], time:0.01073551ms
46+
out_f32_th: ['-0.65587735 ', '1.33511555 ', '-0.29157916 '], time:0.07034254ms
47+
-------------------------------------------------------------------------------------
48+
out_f16f16: ['-0.65576172 ', '1.3359375 ', '-0.29174805 '], time:0.05320668ms
49+
out_f16f32: ['-0.65576172 ', '1.3359375 ', '-0.29150391 '], time:0.05061388ms
50+
out_f16x2f16: ['-0.65576172 ', '1.3359375 ', '-0.29174805 '], time:0.01861978ms
51+
out_f16x8f16: ['-0.65576172 ', '1.3359375 ', '-0.29174805 '], time:0.00745845ms
52+
out_f16x8packf16: ['-0.65576172 ', '1.3359375 ', '-0.29174805 '], time:0.00648832ms
53+
out_f16_th: ['-0.65527344 ', '1.33398438 ', '-0.29150391 '], time:0.07068610ms
54+
-------------------------------------------------------------------------------------
55+
-------------------------------------------------------------------------------------
56+
N=4096, K=2048
57+
-------------------------------------------------------------------------------------
58+
out_f32x4: ['0.92044634 ', '0.37421227 ', '-2.49094558 '], time:0.02202415ms
59+
out_f32_th: ['0.92022169 ', '0.37412092 ', '-2.49033761 '], time:0.12026787ms
60+
-------------------------------------------------------------------------------------
61+
out_f16x2f16: ['0.92041016 ', '0.37426758 ', '-2.49023438 '], time:0.05346847ms
62+
out_f16x8f16: ['0.92041016 ', '0.37426758 ', '-2.49023438 '], time:0.01381087ms
63+
out_f16x8packf16: ['0.92041016 ', '0.37426758 ', '-2.49023438 '], time:0.01159072ms
64+
out_f16_th: ['0.92041016 ', '0.37426758 ', '-2.49023438 '], time:0.08454061ms
65+
-------------------------------------------------------------------------------------
66+
-------------------------------------------------------------------------------------
67+
N=4096, K=4096
68+
-------------------------------------------------------------------------------------
69+
out_f32x4: ['-2.05339074 ', '0.25924587 ', '0.42393678 '], time:0.18885875ms
70+
out_f32_th: ['-2.05314016 ', '0.25921422 ', '0.42388505 '], time:0.77834105ms
71+
-------------------------------------------------------------------------------------
72+
out_f16x8f16: ['-2.05273438 ', '0.2590332 ', '0.42382812 '], time:0.03327322ms
73+
out_f16x8packf16: ['-2.05273438 ', '0.2590332 ', '0.42382812 '], time:0.02402687ms
74+
out_f16_th: ['-2.05273438 ', '0.2590332 ', '0.42382812 '], time:0.17436218ms
75+
-------------------------------------------------------------------------------------
76+
-------------------------------------------------------------------------------------
77+
N=4096, K=8192
78+
-------------------------------------------------------------------------------------
79+
out_f16x8f16: ['-1.0234375 ', '-0.3371582 ', '-1.54882812 '], time:0.19311237ms
80+
out_f16x8packf16: ['-1.0234375 ', '-0.33691406 ', '-1.54882812 '], time:0.18668032ms
81+
out_f16_th: ['-1.0234375 ', '-0.33691406 ', '-1.54882812 '], time:0.84443021ms
82+
-------------------------------------------------------------------------------------
83+
-------------------------------------------------------------------------------------
84+
N=8192, K=8192
85+
-------------------------------------------------------------------------------------
86+
out_f16x8f16: ['-1.03320312 ', '0.41455078 ', '-0.49707031 '], time:0.38361049ms
87+
out_f16x8packf16: ['-1.03320312 ', '0.41455078 ', '-0.49707031 '], time:0.40809250ms
88+
out_f16_th: ['-1.03320312 ', '0.41455078 ', '-0.49707031 '], time:1.99517584ms
89+
-------------------------------------------------------------------------------------
3790
```

layer-norm/layer_norm.cu

Lines changed: 130 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,8 @@
1414
#define INT4(value) (reinterpret_cast<int4*>(&(value))[0])
1515
#define FLOAT4(value) (reinterpret_cast<float4*>(&(value))[0])
1616
#define HALF2(value) (reinterpret_cast<half2*>(&(value))[0])
17+
#define BFLOAT2(value) (reinterpret_cast<__nv_bfloat162*>(&(value))[0])
18+
#define LDST128BITS(value) (reinterpret_cast<float4*>(&(value))[0])
1719

1820
// -------------------------------------- FP32 --------------------------------------
1921
// Warp Reduce Sum
@@ -325,6 +327,55 @@ __global__ void layer_norm_f16_f32_kernel(half* x, half* y, float g, float b, in
325327
}
326328
}
327329

330+
template<const int NUM_THREADS=256>
331+
__global__ void layer_norm_f16x8_pack_f16_kernel(half* x, half* y, float g, float b, int N, int K) {
332+
int tid = threadIdx.x; // 0..K-1
333+
int bid = blockIdx.x; // 0..N-1
334+
int idx = (bid * blockDim.x + threadIdx.x) * 8;
335+
const half epsilon = __float2half(1e-5f);
336+
const half g_ = __float2half(g);
337+
const half b_ = __float2half(b);
338+
const half K_ = __int2half_rn(K);
339+
const half z_ = __float2half(0.0f);
340+
341+
__shared__ half s_mean; // shared within block
342+
__shared__ half s_variance; // shared within block
343+
// temporary register(memory), .local space in ptx, addressable
344+
half pack_x[8], pack_y[8]; // 8x16 bits=128 bits.
345+
// reinterpret as float4 and load 128 bits in 1 memory issue.
346+
LDST128BITS(pack_x[0]) = LDST128BITS(x[idx]); // load 128 bits
347+
348+
half value = z_;
349+
#pragma unroll
350+
for (int i = 0; i < 8; ++i) {
351+
value += ((idx + i) < N * K ? pack_x[i] : z_);
352+
}
353+
half sum = block_reduce_sum_f16_f16<NUM_THREADS>(value);
354+
if (tid == 0) s_mean = sum / K_;
355+
// wait for s_mean in shared memory to be ready for all threads
356+
__syncthreads();
357+
358+
half variance = z_;
359+
#pragma unroll
360+
for (int i = 0; i < 8; ++i) {
361+
half v_hat = pack_x[i] - s_mean;
362+
variance += ((idx + i) < N * K ? v_hat * v_hat : z_);
363+
}
364+
variance = block_reduce_sum_f16_f16<NUM_THREADS>(variance);
365+
if (tid == 0) s_variance = hrsqrt(variance / (K_ + epsilon));
366+
// wait for s_variance in shared memory to be ready for all threads
367+
__syncthreads();
368+
369+
#pragma unroll
370+
for (int i = 0; i < 8; ++i) {
371+
// TODO: use __hfma2, __hsub2, __hmul2 here
372+
pack_y[i] = __hfma((pack_x[i] - s_mean) * s_variance, g_, b_);
373+
}
374+
// reinterpret as float4 and store 128 bits in 1 memory issue.
375+
if ((idx + 7) < N * K) { LDST128BITS(y[idx]) = LDST128BITS(pack_y[0]); }
376+
// TODO: support non 8-multiple K here
377+
}
378+
328379
// --------------------- PyTorch bindings for custom kernel -----------------------
329380
#define STRINGFY(str) #str
330381
#define TORCH_BINDING_COMMON_EXTENSION(func) \
@@ -350,7 +401,7 @@ layer_norm_f32_kernel<(K)><<<grid, block>>>( \
350401

351402
#define DISPATCH_LAYER_NORM_F32_KERNEL(N, K) \
352403
dim3 block((K)); \
353-
dim3 grid((N)); \
404+
dim3 grid((N)); \
354405
switch ((K)) \
355406
{ \
356407
case 64: \
@@ -382,7 +433,7 @@ layer_norm_f32x4_kernel<(K)/4><<<grid, block>>>( \
382433

383434
#define DISPATCH_LAYER_NORM_F32x4_KERNEL(N, K) \
384435
dim3 block((K)/4); \
385-
dim3 grid((N)); \
436+
dim3 grid((N)); \
386437
switch ((K)) \
387438
{ \
388439
case 64: \
@@ -400,9 +451,15 @@ layer_norm_f32x4_kernel<(K)/4><<<grid, block>>>( \
400451
case 1024: \
401452
LANUCH_LAYER_NORM_F32x4_KERNEL(1024) \
402453
break; \
454+
case 2048: \
455+
LANUCH_LAYER_NORM_F32x4_KERNEL(2048) \
456+
break; \
457+
case 4096: \
458+
LANUCH_LAYER_NORM_F32x4_KERNEL(4096) \
459+
break; \
403460
default: \
404461
throw std::runtime_error( \
405-
"only support K: 64/128/256/512/1024"); \
462+
"only support K: 64/128/.../1024*4"); \
406463
break; \
407464
}
408465

@@ -433,7 +490,7 @@ layer_norm_f16_f16_kernel<(K)><<<grid, block>>>( \
433490

434491
#define DISPATCH_LAYER_NORM_F16F16_KERNEL(N, K) \
435492
dim3 block((K)); \
436-
dim3 grid((N)); \
493+
dim3 grid((N)); \
437494
switch ((K)) \
438495
{ \
439496
case 64: \
@@ -465,7 +522,7 @@ layer_norm_f16_f32_kernel<(K)><<<grid, block>>>( \
465522

466523
#define DISPATCH_LAYER_NORM_F16F32_KERNEL(N, K) \
467524
dim3 block((K)); \
468-
dim3 grid((N)); \
525+
dim3 grid((N)); \
469526
switch ((K)) \
470527
{ \
471528
case 64: \
@@ -497,7 +554,7 @@ layer_norm_f16x2_f16_kernel<(K)/2><<<grid, block>>>( \
497554

498555
#define DISPATCH_LAYER_NORM_F16x2F16_KERNEL(N, K) \
499556
dim3 block((K)/2); \
500-
dim3 grid((N)); \
557+
dim3 grid((N)); \
501558
switch ((K)) \
502559
{ \
503560
case 64: \
@@ -515,9 +572,12 @@ layer_norm_f16x2_f16_kernel<(K)/2><<<grid, block>>>( \
515572
case 1024: \
516573
LANUCH_LAYER_NORM_F16x2F16_KERNEL(1024) \
517574
break; \
575+
case 2048: \
576+
LANUCH_LAYER_NORM_F16x2F16_KERNEL(2048) \
577+
break; \
518578
default: \
519579
throw std::runtime_error( \
520-
"only support K: 64/128/256/512/1024"); \
580+
"only support K: 64/128/.../1024*2"); \
521581
break; \
522582
}
523583

@@ -529,7 +589,7 @@ layer_norm_f16x8_f16_kernel<(K)/8><<<grid, block>>>( \
529589

530590
#define DISPATCH_LAYER_NORM_F16x8F16_KERNEL(N, K) \
531591
dim3 block((K)/8); \
532-
dim3 grid((N)); \
592+
dim3 grid((N)); \
533593
switch ((K)) \
534594
{ \
535595
case 64: \
@@ -547,12 +607,62 @@ layer_norm_f16x8_f16_kernel<(K)/8><<<grid, block>>>( \
547607
case 1024: \
548608
LANUCH_LAYER_NORM_F16x8F16_KERNEL(1024) \
549609
break; \
610+
case 2048: \
611+
LANUCH_LAYER_NORM_F16x8F16_KERNEL(2048) \
612+
break; \
613+
case 4096: \
614+
LANUCH_LAYER_NORM_F16x8F16_KERNEL(4096) \
615+
break; \
616+
case 8192: \
617+
LANUCH_LAYER_NORM_F16x8F16_KERNEL(8192) \
618+
break; \
550619
default: \
551620
throw std::runtime_error( \
552-
"only support K: 64/128/256/512/1024"); \
621+
"only support K: 64/128/.../1024*8"); \
553622
break; \
554623
}
555624

625+
#define LANUCH_LAYER_NORM_F16x8_PACK_F16_KERNEL(K) \
626+
layer_norm_f16x8_pack_f16_kernel<(K)/8><<<grid, block>>>( \
627+
reinterpret_cast<half*>(x.data_ptr()), \
628+
reinterpret_cast<half*>(y.data_ptr()), \
629+
g, b, N, (K));
630+
631+
#define DISPATCH_LAYER_NORM_F16x8_PACK_F16_KERNEL(N, K) \
632+
dim3 block((K)/8); \
633+
dim3 grid((N)); \
634+
switch ((K)) \
635+
{ \
636+
case 64: \
637+
LANUCH_LAYER_NORM_F16x8_PACK_F16_KERNEL(64) \
638+
break; \
639+
case 128: \
640+
LANUCH_LAYER_NORM_F16x8_PACK_F16_KERNEL(128) \
641+
break; \
642+
case 256: \
643+
LANUCH_LAYER_NORM_F16x8_PACK_F16_KERNEL(256) \
644+
break; \
645+
case 512: \
646+
LANUCH_LAYER_NORM_F16x8_PACK_F16_KERNEL(512) \
647+
break; \
648+
case 1024: \
649+
LANUCH_LAYER_NORM_F16x8_PACK_F16_KERNEL(1024) \
650+
break; \
651+
case 2048: \
652+
LANUCH_LAYER_NORM_F16x8_PACK_F16_KERNEL(2048) \
653+
break; \
654+
case 4096: \
655+
LANUCH_LAYER_NORM_F16x8_PACK_F16_KERNEL(4096) \
656+
break; \
657+
case 8192: \
658+
LANUCH_LAYER_NORM_F16x8_PACK_F16_KERNEL(8192) \
659+
break; \
660+
default: \
661+
throw std::runtime_error( \
662+
"only support K: 64/128/.../1024*8"); \
663+
break; \
664+
}
665+
556666
void layer_norm_f16_f16(torch::Tensor x, torch::Tensor y, float g, float b) {
557667
CHECK_TORCH_TENSOR_DTYPE(x, torch::kHalf)
558668
CHECK_TORCH_TENSOR_DTYPE(y, torch::kHalf)
@@ -580,6 +690,16 @@ void layer_norm_f16x8_f16(torch::Tensor x, torch::Tensor y, float g, float b) {
580690
DISPATCH_LAYER_NORM_F16x8F16_KERNEL(N, K)
581691
}
582692

693+
void layer_norm_f16x8_pack_f16(torch::Tensor x, torch::Tensor y, float g, float b) {
694+
CHECK_TORCH_TENSOR_DTYPE(x, torch::kHalf)
695+
CHECK_TORCH_TENSOR_DTYPE(y, torch::kHalf)
696+
CHECK_TORCH_TENSOR_SHAPE(x, y)
697+
const int N = x.size(0);
698+
const int K = x.size(1);
699+
DISPATCH_LAYER_NORM_F16x8_PACK_F16_KERNEL(N, K)
700+
}
701+
702+
583703
void layer_norm_f16_f32(torch::Tensor x, torch::Tensor y, float g, float b) {
584704
CHECK_TORCH_TENSOR_DTYPE(x, torch::kHalf)
585705
CHECK_TORCH_TENSOR_DTYPE(y, torch::kHalf)
@@ -595,6 +715,7 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
595715
TORCH_BINDING_COMMON_EXTENSION(layer_norm_f16_f16)
596716
TORCH_BINDING_COMMON_EXTENSION(layer_norm_f16x2_f16)
597717
TORCH_BINDING_COMMON_EXTENSION(layer_norm_f16x8_f16)
718+
TORCH_BINDING_COMMON_EXTENSION(layer_norm_f16x8_pack_f16)
598719
TORCH_BINDING_COMMON_EXTENSION(layer_norm_f16_f32)
599720
}
600721

0 commit comments

Comments
 (0)