Skip to content

Commit bf283f2

Browse files
authored
[Reduce][Kernel] Pack f16/bf16x8 & fp8/i8x16 LD/ST (#43)
* Update README.md * Update block_all_reduce.cu * Update block_all_reduce.py * Update README.md * Update block_all_reduce.cu * Update README.md * Update block_all_reduce.cu * Update block_all_reduce.py * Update README.md * Delete fuse-multihead-attention directory * Create elementwise.cu * Create relu.cu * Create .gitignore * Create README.md * Update README.md
1 parent d43c53d commit bf283f2

File tree

9 files changed

+655
-129
lines changed

9 files changed

+655
-129
lines changed

README.md

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,11 +5,11 @@
55
<img src=https://img.shields.io/github/watchers/DefTruth/cuda-learn-note?color=9cc >
66
<img src=https://img.shields.io/github/forks/DefTruth/cuda-learn-note.svg?style=social >
77
<img src=https://img.shields.io/github/stars/DefTruth/cuda-learn-note.svg?style=social >
8-
<img src=https://img.shields.io/badge/Release-v2.3-brightgreen.svg >
8+
<img src=https://img.shields.io/badge/Release-v2.4-brightgreen.svg >
99
<img src=https://img.shields.io/badge/License-GPLv3.0-turquoise.svg >
1010
</div>
1111

12-
📖 **CUDA Learn Notes with PyTorch**: **fp32fp16/bf16fp8/int8**flash_attnsgemmsgemvwarp/block reducedot prodelementwise、softmax、layernorm、rmsnorm、hist etc. 👉News: Most of my time now is focused on **LLM/VLM/Diffusion** Inference. Please check 📖[Awesome-LLM-Inference](https://github.com/DefTruth/Awesome-LLM-Inference) ![](https://img.shields.io/github/stars/DefTruth/Awesome-LLM-Inference.svg?style=social), 📖[Awesome-SD-Inference](https://github.com/DefTruth/Awesome-SD-Inference) ![](https://img.shields.io/github/stars/DefTruth/Awesome-SD-Inference.svg?style=social) and 📖[CUDA-Learn-Notes](https://github.com/DefTruth/CUDA-Learn-Notes) ![](https://img.shields.io/github/stars/DefTruth/CUDA-Learn-Notes.svg?style=social) for more details.
12+
🎉 **CUDA Learn Notes**: This repo aims to build a **Modern CUDA Learn Notes with PyTorch** for beginners, including **fp32, fp16/bf16, fp8/int8, Tensor/CUDA Cores**, flash_attn, sgemm, sgemv, hgemm, hgemv, warp/block reduce, dot prod, elementwise, sigmoid, relu, softmax, layernorm, rmsnorm, hist and some CUDA optimization techniques (pack LDST, warp gemv, sliced_k/split_k/pipeline gemm, bank conflicts free, MMA, etc).
1313

1414
<img width="1438" alt="image" src="https://github.com/user-attachments/assets/0c5e5125-586f-43fa-8e8b-e2c61c1afbbe">
1515

@@ -49,13 +49,20 @@
4949
| ✔️ [block_all_reduce_f16_f32](./reduce/block_all_reduce.cu)|f16|f32|[link](./reduce/)|⭐️⭐️|
5050
| ✔️ [block_all_reduce_f16x2_f16](./reduce/block_all_reduce.cu)|f16|f16|[link](./reduce/)|⭐️⭐️|
5151
| ✔️ [block_all_reduce_f16x2_f32](./reduce/block_all_reduce.cu)|f16|f32|[link](./reduce/)|⭐️⭐️|
52+
| ✔️ [block_all_reduce_f16x8_pack_f16](./reduce/block_all_reduce.cu)|f16|f16|[link](./reduce/)|⭐️⭐️|
53+
| ✔️ [block_all_reduce_f16x8_pack_f32](./reduce/block_all_reduce.cu)|f16|f32|[link](./reduce/)|⭐️⭐️|
5254
| ✔️ [block_all_reduce_bf16_bf16](./reduce/block_all_reduce.cu)|bf16|bf16|[link](./reduce/)|⭐️⭐️|
5355
| ✔️ [block_all_reduce_bf16_f32](./reduce/block_all_reduce.cu)|bf16|f32|[link](./reduce/)|⭐️⭐️|
5456
| ✔️ [block_all_reduce_bf16x2_bf16](./reduce/block_all_reduce.cu)|bf16|bf16|[link](./reduce/)|⭐️⭐️|
5557
| ✔️ [block_all_reduce_bf16x2_f32](./reduce/block_all_reduce.cu)|bf16|f32|[link](./reduce/)|⭐️⭐️|
58+
| ✔️ [block_all_reduce_bf16x8_pack_bf16](./reduce/block_all_reduce.cu)|bf16|bf16|[link](./reduce/)|⭐️⭐️|
59+
| ✔️ [block_all_reduce_bf16x8_pack_f32](./reduce/block_all_reduce.cu)|bf16|f32|[link](./reduce/)|⭐️⭐️|
5660
| ✔️ [block_all_reduce_fp8_e4m3_f16](./reduce/block_all_reduce.cu)|fp8_e4m3|f16|[link](./reduce/)|⭐️⭐️|
5761
| ✔️ [block_all_reduce_fp8_e5m2_f16](./reduce/block_all_reduce.cu)|fp8_e5m2|f16|[link](./reduce/)|⭐️⭐️|
62+
| ✔️ [block_all_reduce_fp8_e4m3x16_pack_f16](./reduce/block_all_reduce.cu)|fp8_e4m3|f16|[link](./reduce/)|⭐️⭐️|
63+
| ✔️ [block_all_reduce_fp8_e5m2x16_pack_f16](./reduce/block_all_reduce.cu)|fp8_e5m2|f16|[link](./reduce/)|⭐️⭐️|
5864
| ✔️ [block_all_reduce_i8_i32](./reduce/block_all_reduce.cu)|i8|i32|[link](./reduce/)|⭐️⭐️|
65+
| ✔️ [block_all_reduce_i8x16_pack_i32](./reduce/block_all_reduce.cu)|i8|i32|[link](./reduce/)|⭐️⭐️|
5966
| ✔️ [dot_product_f32](./dot-product/dot_product.cu)|f32|f32|[link](./dot-product/)|⭐️⭐️|
6067
| ✔️ [dot_product_f32x4](./dot-product/dot_product.cu)|f32|f32|[link](./dot-product/)|⭐️⭐️|
6168
| ✔️ [dot_product_f16_f32](./dot-product/dot_product.cu)|f16|f32|[link](./dot-product/)|⭐️⭐️|

fuse-multihead-attention/.gitignore

Lines changed: 0 additions & 20 deletions
This file was deleted.

nvidia-nsight/.gitignore

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,18 @@
1+
*.i
2+
*.ii
3+
*.gpu
4+
*.ptx
5+
*.cubin
6+
*.fatbin
7+
*.so
8+
*.a
9+
*.dylib
10+
*.dll
11+
*.lib
12+
.DS_Store
13+
build
14+
*.whl
15+
tmp
16+
*.nsys*
17+
*.profile*
18+
*.cubin

nvidia-nsight/README.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,2 @@
1+
# NVIDIA Nsight System
2+

nvidia-nsight/elementwise.cu

Lines changed: 114 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,114 @@
1+
#include <stdio.h>
2+
#include <stdlib.h>
3+
#include <float.h>
4+
#include <vector>
5+
#include <algorithm>
6+
#include <cuda_runtime.h>
7+
#include <cuda_fp16.h>
8+
#include <cuda_bf16.h>
9+
#include <cuda_fp8.h>
10+
11+
#define WARP_SIZE 32
12+
#define INT4(value) (reinterpret_cast<int4*>(&(value))[0])
13+
#define FLOAT4(value) (reinterpret_cast<float4*>(&(value))[0])
14+
#define HALF2(value) (reinterpret_cast<half2*>(&(value))[0])
15+
#define BFLOAT2(value) (reinterpret_cast<__nv_bfloat162*>(&(value))[0])
16+
#define LDST128BITS(value) (reinterpret_cast<float4*>(&(value))[0])
17+
18+
// -------------------------------------- FP32 --------------------------------------
19+
// ElementWise Add
20+
// grid(N/256), block(256)
21+
// a: Nx1, b: Nx1, c: Nx1, c = elementwise_add(a, b)
22+
__global__ void elementwise_add_f32_kernel(float* a, float* b, float* c, int N) {
23+
int idx = blockIdx.x * blockDim.x + threadIdx.x;
24+
if (idx < N) c[idx] = a[idx] + b[idx];
25+
}
26+
27+
// ElementWise Add + Vec4
28+
// grid(N/256), block(256/4)
29+
// a: Nx1, b: Nx1, c: Nx1, c = elementwise_add(a, b)
30+
__global__ void elementwise_add_f32x4_kernel(float* a, float* b, float* c, int N) {
31+
int idx = 4 * (blockIdx.x * blockDim.x + threadIdx.x);
32+
if (idx < N) {
33+
float4 reg_a = FLOAT4(a[idx]);
34+
float4 reg_b = FLOAT4(b[idx]);
35+
float4 reg_c;
36+
reg_c.x = reg_a.x + reg_b.x;
37+
reg_c.y = reg_a.y + reg_b.y;
38+
reg_c.z = reg_a.z + reg_b.z;
39+
reg_c.w = reg_a.w + reg_b.w;
40+
FLOAT4(c[idx]) = reg_c;
41+
}
42+
}
43+
44+
// -------------------------------------- FP16 --------------------------------------
45+
// ElementWise Add
46+
// grid(N/256), block(256)
47+
// a: Nx1, b: Nx1, c: Nx1, c = elementwise_add(a, b)
48+
__global__ void elementwise_add_f16_kernel(half* a, half* b, half* c, int N) {
49+
int idx = blockIdx.x * blockDim.x + threadIdx.x;
50+
if (idx < N) c[idx] = __hadd(a[idx], b[idx]);
51+
}
52+
53+
// a: Nx1, b: Nx1, c: Nx1, c = elementwise_add(a, b)
54+
__global__ void elementwise_add_f16x2_kernel(half* a, half* b, half* c, int N) {
55+
int idx = 2 * (blockIdx.x * blockDim.x + threadIdx.x);
56+
if (idx < N) {
57+
half2 reg_a = HALF2(a[idx]);
58+
half2 reg_b = HALF2(b[idx]);
59+
half2 reg_c;
60+
reg_c.x = __hadd(reg_a.x, reg_b.x);
61+
reg_c.y = __hadd(reg_a.y, reg_b.y);
62+
HALF2(c[idx]) = reg_c;
63+
}
64+
}
65+
66+
__global__ void elementwise_add_f16x8_kernel(half* a, half* b, half* c, int N) {
67+
int idx = 8 * (blockIdx.x * blockDim.x + threadIdx.x);
68+
// manual unroll and improve L2 cache hit rate.
69+
// Only L2 cache: load 32 bytes in 1 memory issue (default)
70+
// Enable L1 cache: load 128 bytes in 1 memory issue (-Xptxas -dlcm=ca)
71+
// why try fp16x8 within 1 threads? ref: https://zhuanlan.zhihu.com/p/641639133
72+
// 0. first, tid_0 load 32 bytes in 1 memory issue and cache data into L2 cache.
73+
// 1. then, tid_1,...,tid_3 hit L2 cache and load data from L2 cache directly.
74+
half2 reg_a_0 = HALF2(a[idx + 0]);
75+
half2 reg_a_1 = HALF2(a[idx + 2]);
76+
half2 reg_a_2 = HALF2(a[idx + 4]);
77+
half2 reg_a_3 = HALF2(a[idx + 6]);
78+
half2 reg_b_0 = HALF2(b[idx + 0]);
79+
half2 reg_b_1 = HALF2(b[idx + 2]);
80+
half2 reg_b_2 = HALF2(b[idx + 4]);
81+
half2 reg_b_3 = HALF2(b[idx + 6]);
82+
half2 reg_c_0, reg_c_1, reg_c_2, reg_c_3;
83+
reg_c_0.x = __hadd(reg_a_0.x, reg_b_0.x);
84+
reg_c_0.y = __hadd(reg_a_0.y, reg_b_0.y);
85+
reg_c_1.x = __hadd(reg_a_1.x, reg_b_1.x);
86+
reg_c_1.y = __hadd(reg_a_1.y, reg_b_1.y);
87+
reg_c_2.x = __hadd(reg_a_2.x, reg_b_2.x);
88+
reg_c_2.y = __hadd(reg_a_2.y, reg_b_2.y);
89+
reg_c_3.x = __hadd(reg_a_3.x, reg_b_3.x);
90+
reg_c_3.y = __hadd(reg_a_3.y, reg_b_3.y);
91+
if ((idx + 0) < N) { HALF2(c[idx + 0]) = reg_c_0; }
92+
if ((idx + 2) < N) { HALF2(c[idx + 2]) = reg_c_1; }
93+
if ((idx + 4) < N) { HALF2(c[idx + 4]) = reg_c_2; }
94+
if ((idx + 6) < N) { HALF2(c[idx + 6]) = reg_c_3; }
95+
}
96+
97+
__global__ void elementwise_add_f16x8_pack_kernel(half* a, half* b, half* c, int N) {
98+
int idx = 8 * (blockIdx.x * blockDim.x + threadIdx.x);
99+
// temporary register(memory), .local space in ptx, addressable
100+
half pack_a[8], pack_b[8], pack_c[8]; // 8x16 bits=128 bits.
101+
// reinterpret as float4 and load 128 bits in 1 memory issue.
102+
LDST128BITS(pack_a[0]) = LDST128BITS(a[idx]); // load 128 bits
103+
LDST128BITS(pack_b[0]) = LDST128BITS(b[idx]); // load 128 bits
104+
105+
#pragma unroll
106+
for (int i = 0; i < 8; i += 2) {
107+
// __hadd2 for half2 x 4
108+
HALF2(pack_c[i]) = __hadd2(HALF2(pack_a[i]), HALF2(pack_b[i]));
109+
}
110+
// reinterpret as float4 and store 128 bits in 1 memory issue.
111+
if ((idx + 7) < N) { LDST128BITS(c[idx]) = LDST128BITS(pack_c[0]); }
112+
}
113+
114+

nvidia-nsight/relu.cu

Lines changed: 99 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,99 @@
1+
#include <stdio.h>
2+
#include <stdlib.h>
3+
#include <float.h>
4+
#include <vector>
5+
#include <algorithm>
6+
#include <cuda_runtime.h>
7+
#include <cuda_fp16.h>
8+
9+
#define WARP_SIZE 32
10+
#define INT4(value) (reinterpret_cast<int4*>(&(value))[0])
11+
#define FLOAT4(value) (reinterpret_cast<float4*>(&(value))[0])
12+
#define HALF2(value) (reinterpret_cast<half2*>(&(value))[0])
13+
#define BFLOAT2(value) (reinterpret_cast<__nv_bfloat162*>(&(value))[0])
14+
#define LDST128BITS(value) (reinterpret_cast<float4*>(&(value))[0])
15+
16+
// -------------------------------------- FP32 --------------------------------------
17+
// Relu x: N, y: N y=max(0,x)
18+
// grid(N/256), block(K=256)
19+
__global__ void relu_f32_kernel(float* x, float* y, int N) {
20+
int idx = blockIdx.x * blockDim.x + threadIdx.x;
21+
if (idx < N) y[idx] = fmaxf(0.0f, x[idx]);
22+
}
23+
24+
// Relu x: N, y: N y=max(0,x) Vec4
25+
// grid(N/256/4), block(256/4)
26+
__global__ void relu_f32x4_kernel(float* x, float* y, int N) {
27+
int idx = (blockIdx.x * blockDim.x + threadIdx.x) * 4;
28+
if (idx < N) {
29+
float4 reg_x = FLOAT4(x[idx]);
30+
float4 reg_y;
31+
reg_y.x = fmaxf(0.0f, reg_x.x);
32+
reg_y.y = fmaxf(0.0f, reg_x.y);
33+
reg_y.z = fmaxf(0.0f, reg_x.z);
34+
reg_y.w = fmaxf(0.0f, reg_x.w);
35+
FLOAT4(y[idx]) = reg_y;
36+
}
37+
}
38+
39+
// -------------------------------------- FP16 --------------------------------------
40+
__global__ void relu_f16_kernel(half* x, half* y, int N) {
41+
int idx = blockIdx.x * blockDim.x + threadIdx.x;
42+
if (idx < N) y[idx] = __hmax(__float2half(0.0f), x[idx]);
43+
}
44+
45+
__global__ void relu_f16x2_kernel(half* x, half* y, int N) {
46+
int idx = 2 * (blockIdx.x * blockDim.x + threadIdx.x);
47+
if (idx < N) {
48+
half2 reg_x = HALF2(x[idx]);
49+
half2 reg_y = HALF2(y[idx]);
50+
reg_y.x = __hmax(__float2half(0.0f), reg_x.x);
51+
reg_y.y = __hmax(__float2half(0.0f), reg_x.y);
52+
HALF2(y[idx]) = reg_y;
53+
}
54+
}
55+
56+
__global__ void relu_f16x8_kernel(half* x, half* y, int N) {
57+
int idx = 8 * (blockIdx.x * blockDim.x + threadIdx.x);
58+
// manual unroll and improve L2 cache hit rate.
59+
// Only L2 cache: load 32 bytes in 1 memory issue (default)
60+
// Enable L1 cache: load 128 bytes in 1 memory issue (-Xptxas -dlcm=ca)
61+
// why try fp16x8 within 1 threads? ref: https://zhuanlan.zhihu.com/p/641639133
62+
// 0. first, tid_0 load 32 bytes in 1 memory issue and cache data into L2 cache.
63+
// 1. then, tid_1,...,tid_3 hit L2 cache and load data from L2 cache directly.
64+
half2 reg_x_0 = HALF2(x[idx + 0]);
65+
half2 reg_x_1 = HALF2(x[idx + 2]);
66+
half2 reg_x_2 = HALF2(x[idx + 4]);
67+
half2 reg_x_3 = HALF2(x[idx + 6]);
68+
half2 reg_y_0, reg_y_1, reg_y_2, reg_y_3;
69+
reg_y_0.x = __hmax(__float2half(0.0f), reg_x_0.x);
70+
reg_y_0.y = __hmax(__float2half(0.0f), reg_x_0.y);
71+
reg_y_1.x = __hmax(__float2half(0.0f), reg_x_1.x);
72+
reg_y_1.y = __hmax(__float2half(0.0f), reg_x_1.y);
73+
reg_y_2.x = __hmax(__float2half(0.0f), reg_x_2.x);
74+
reg_y_2.y = __hmax(__float2half(0.0f), reg_x_2.y);
75+
reg_y_3.x = __hmax(__float2half(0.0f), reg_x_3.x);
76+
reg_y_3.y = __hmax(__float2half(0.0f), reg_x_3.y);
77+
if ((idx + 0) < N) { HALF2(y[idx + 0]) = reg_y_0; }
78+
if ((idx + 2) < N) { HALF2(y[idx + 2]) = reg_y_1; }
79+
if ((idx + 4) < N) { HALF2(y[idx + 4]) = reg_y_2; }
80+
if ((idx + 6) < N) { HALF2(y[idx + 6]) = reg_y_3; }
81+
}
82+
83+
__global__ void relu_f16x8_pack_kernel(half* x, half* y, int N) {
84+
int idx = 8 * (blockIdx.x * blockDim.x + threadIdx.x);
85+
const half2 z2 = {__float2half(0.0f), __float2half(0.0f)};
86+
// temporary register(memory), .local space in ptx, addressable
87+
half pack_x[8], pack_y[8]; // 8x16 bits=128 bits.
88+
// reinterpret as float4 and load 128 bits in 1 memory issue.
89+
LDST128BITS(pack_x[0]) = LDST128BITS(x[idx]); // load 128 bits
90+
91+
#pragma unroll
92+
for (int i = 0; i < 8; i += 2) {
93+
// __hmax2 for half2 x 4
94+
HALF2(pack_y[i]) = __hmax2(HALF2(pack_x[i]), z2);
95+
}
96+
// reinterpret as float4 and store 128 bits in 1 memory issue.
97+
if ((idx + 7) < N) { LDST128BITS(y[idx]) = LDST128BITS(pack_y[0]); }
98+
}
99+

0 commit comments

Comments
 (0)