Skip to content

Commit d43c53d

Browse files
authored
[RELU][FP16] Add f16x8_pack kernel, boost 2.1x (#42)
* Update README.md * Update relu.cu * Update relu.py * Update README.md
1 parent 4be041f commit d43c53d

File tree

4 files changed

+80
-86
lines changed

4 files changed

+80
-86
lines changed

README.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,7 @@
4040
| ✔️ [relu_f16](./relu/relu.cu)|f16|/|[link](./relu/)|⭐️|
4141
| ✔️ [relu_f16x2](./relu/relu.cu)|f16|/|[link](./relu/)|⭐️|
4242
| ✔️ [relu_f16x8](./relu/relu.cu)|f16|/|[link](./relu/)|⭐️|
43+
| ✔️ [relu_f16x8_pack](./relu/relu.cu)|f16|/|[link](./relu/)|⭐️⭐️|
4344
| ✔️ [warp_reduce_f16/bf16/f32/f8/i8](./reduce/block_all_reduce.cu)|all|all|[link](./reduce/)|⭐️⭐️|
4445
| ✔️ [block_reduce_f32](./reduce/block_all_reduce.cu)|f32|f32|[link](./reduce/)|⭐️⭐️|
4546
| ✔️ [block_all_reduce_f32_f32](./reduce/block_all_reduce.cu)|f32|f32|[link](./reduce/)|⭐️⭐️|

relu/README.md

Lines changed: 9 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
- [X] relu_f16_kernel(fp16版本)
1010
- [X] relu_f16x2_kernel(fp16向量化版本)
1111
- [X] relu_f16x8_kernel(fp16向量化版本)
12+
- [X] relu_f16x8_pack_kernel(fp16向量化,pack版本)
1213
- [X] PyTorch bindings
1314

1415

@@ -24,22 +25,14 @@ python3 relu.py
2425

2526
```bash
2627
--------------------------------------------------------------------------------
27-
out_f32: [0.0, 0.0], time:0.01072860ms
28-
out_f32x4: [0.0, 0.0], time:0.01059222ms
29-
out_f32_th: [0.0, 0.0], time:0.00772071ms
28+
out_f32: [0.0, 0.23360847], time:0.18854451ms
29+
out_f32x4: [0.0, 0.23360847], time:0.18829441ms
30+
out_f32_th: [0.0, 0.23360847], time:0.20471048ms
3031
--------------------------------------------------------------------------------
31-
out_f16: [0.0, 0.0], time:0.01077199ms
32-
out_f16x2: [0.0, 0.0], time:0.01084924ms
33-
out_f16x8: [0.0, 0.0], time:0.01083326ms
34-
out_f16_th: [0.0, 0.0], time:0.00762105ms
35-
--------------------------------------------------------------------------------
36-
out_f32(v2): [0.0, 0.0], time:0.00346351ms
37-
out_f32x4(v2): [0.0, 0.0], time:0.00342798ms
38-
out_f32_th: [0.0, 0.0], time:0.01125073ms
39-
--------------------------------------------------------------------------------
40-
out_f16(v2): [0.0, 0.0], time:0.00343585ms
41-
out_f16x2(v2): [0.0, 0.0], time:0.00339842ms
42-
out_f16x8(v2): [0.0, 0.0], time:0.00347090ms
43-
out_f16_th: [0.0, 0.0], time:0.00776792ms
32+
out_f16: [0.0, 0.23364258], time:0.04058957ms
33+
out_f16x2: [0.0, 0.23364258], time:0.03622127ms
34+
out_f16x8: [0.0, 0.23364258], time:0.03658152ms
35+
out_f16x8pack: [0.0, 0.23364258], time:0.01454449ms
36+
out_f16_th: [0.0, 0.23364258], time:0.04748964ms
4437
--------------------------------------------------------------------------------
4538
```

relu/relu.cu

Lines changed: 55 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
#define FLOAT4(value) (reinterpret_cast<float4*>(&(value))[0])
1414
#define HALF2(value) (reinterpret_cast<half2*>(&(value))[0])
1515
#define BFLOAT2(value) (reinterpret_cast<__nv_bfloat162*>(&(value))[0])
16+
#define LDST128BITS(value) (reinterpret_cast<float4*>(&(value))[0])
1617

1718
// -------------------------------------- FP32 --------------------------------------
1819
// Relu x: N, y: N y=max(0,x)
@@ -81,6 +82,24 @@ __global__ void relu_f16x8_kernel(half* x, half* y, int N) {
8182
if ((idx + 6) < N) { HALF2(y[idx + 6]) = reg_y_3; }
8283
}
8384

85+
__global__ void relu_f16x8_pack_kernel(half* x, half* y, int N) {
86+
int idx = 8 * (blockIdx.x * blockDim.x + threadIdx.x);
87+
const half2 z2 = {__float2half(0.0f), __float2half(0.0f)};
88+
// temporary register(memory), .local space in ptx, addressable
89+
half pack_x[8], pack_y[8]; // 8x16 bits=128 bits.
90+
// reinterpret as float4 and load 128 bits in 1 memory issue.
91+
LDST128BITS(pack_x[0]) = LDST128BITS(x[idx]); // load 128 bits
92+
93+
#pragma unroll
94+
for (int i = 0; i < 8; i += 2) {
95+
// __hmax2 for half2 x 4
96+
HALF2(pack_y[i]) = __hmax2(HALF2(pack_x[i]), z2);
97+
}
98+
// reinterpret as float4 and store 128 bits in 1 memory issue.
99+
if ((idx + 7) < N) { LDST128BITS(y[idx]) = LDST128BITS(pack_y[0]); }
100+
}
101+
102+
84103
// --------------------- PyTorch bindings for custom kernel -----------------------
85104
#define STRINGFY(str) #str
86105
#define TORCH_BINDING_COMMON_EXTENSION(func) \
@@ -92,61 +111,54 @@ if(((T).options().dtype() != (th_type))) { \
92111
throw std::runtime_error("values must be "#th_type); \
93112
}
94113

95-
#define CHECK_TORCH_TENSOR_SHAPE(T, S0) \
96-
if (((T).size(0) != (S0))) { throw std::runtime_error("Tensor size mismatch!"); }
97-
98114
#define TORCH_BINDING_RELU(packed_type, th_type, element_type, n_elements) \
99-
torch::Tensor relu_##packed_type(torch::Tensor x) { \
100-
CHECK_TORCH_TENSOR_DTYPE(x, (th_type)) \
101-
auto options = torch::TensorOptions().dtype((th_type)).device( \
102-
torch::kCUDA, 0); \
103-
const int N = x.size(0); \
104-
auto y = torch::zeros({N}, options); \
105-
static const int NUM_THREADS_PER_BLOCK = 256 / (n_elements); \
106-
const int NUM_BLOCKS = (N + 256 - 1) / 256; \
107-
dim3 block(NUM_THREADS_PER_BLOCK); \
108-
dim3 grid(NUM_BLOCKS); \
109-
relu_##packed_type##_kernel<<<grid, block>>>( \
110-
reinterpret_cast<element_type*>(x.data_ptr()), \
111-
reinterpret_cast<element_type*>(y.data_ptr()), N); \
112-
return y; \
113-
}
114-
115-
#define TORCH_BINDING_RELU_V2(packed_type, th_type, element_type, n_elements) \
116-
void relu_##packed_type##_v2(torch::Tensor x, torch::Tensor y) { \
115+
void relu_##packed_type(torch::Tensor x, torch::Tensor y) { \
117116
CHECK_TORCH_TENSOR_DTYPE(x, (th_type)) \
118117
CHECK_TORCH_TENSOR_DTYPE(y, (th_type)) \
119-
const int N = x.size(0); \
120-
CHECK_TORCH_TENSOR_SHAPE(y, N) \
121-
static const int NUM_THREADS_PER_BLOCK = 256 / (n_elements); \
122-
const int NUM_BLOCKS = (N + 256 - 1) / 256; \
123-
dim3 block(NUM_THREADS_PER_BLOCK); \
124-
dim3 grid(NUM_BLOCKS); \
125-
relu_##packed_type##_kernel<<<grid, block>>>( \
118+
const int ndim = x.dim(); \
119+
if (ndim != 2) { \
120+
int N = 1; \
121+
for (int i = 0; i < ndim; ++i) { N *= x.size(i); } \
122+
dim3 block(256 / (n_elements)); \
123+
dim3 grid((N + 256 - 1) / 256); \
124+
relu_##packed_type##_kernel<<<grid, block>>>( \
126125
reinterpret_cast<element_type*>(x.data_ptr()), \
127126
reinterpret_cast<element_type*>(y.data_ptr()), N); \
127+
} else { \
128+
const int S = x.size(0); \
129+
const int K = x.size(1); \
130+
const int N = S * K; \
131+
if ((K/(n_elements)) <= 1024) { \
132+
dim3 block(K/(n_elements)); \
133+
dim3 grid(S); \
134+
relu_##packed_type##_kernel<<<grid, block>>>( \
135+
reinterpret_cast<element_type*>(x.data_ptr()), \
136+
reinterpret_cast<element_type*>(y.data_ptr()), N); \
137+
} else { \
138+
int N = 1; \
139+
for (int i = 0; i < ndim; ++i) { N *= x.size(i); } \
140+
dim3 block(256 / (n_elements)); \
141+
dim3 grid((N + 256 - 1) / 256); \
142+
relu_##packed_type##_kernel<<<grid, block>>>( \
143+
reinterpret_cast<element_type*>(x.data_ptr()), \
144+
reinterpret_cast<element_type*>(y.data_ptr()), N); \
145+
} \
146+
} \
128147
}
129148

130-
TORCH_BINDING_RELU(f32, torch::kFloat32, float, 1)
131-
TORCH_BINDING_RELU(f32x4, torch::kFloat32, float, 4)
132-
TORCH_BINDING_RELU(f16, torch::kHalf, half, 1)
133-
TORCH_BINDING_RELU(f16x2, torch::kHalf, half, 2)
134-
TORCH_BINDING_RELU(f16x8, torch::kHalf, half, 8)
135-
TORCH_BINDING_RELU_V2(f32, torch::kFloat32, float, 1)
136-
TORCH_BINDING_RELU_V2(f32x4, torch::kFloat32, float, 4)
137-
TORCH_BINDING_RELU_V2(f16, torch::kHalf, half, 1)
138-
TORCH_BINDING_RELU_V2(f16x2, torch::kHalf, half, 2)
139-
TORCH_BINDING_RELU_V2(f16x8, torch::kHalf, half, 8)
149+
150+
TORCH_BINDING_RELU(f32, torch::kFloat32, float, 1)
151+
TORCH_BINDING_RELU(f32x4, torch::kFloat32, float, 4)
152+
TORCH_BINDING_RELU(f16, torch::kHalf, half, 1)
153+
TORCH_BINDING_RELU(f16x2, torch::kHalf, half, 2)
154+
TORCH_BINDING_RELU(f16x8, torch::kHalf, half, 8)
155+
TORCH_BINDING_RELU(f16x8_pack, torch::kHalf, half, 8)
140156

141157
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
142158
TORCH_BINDING_COMMON_EXTENSION(relu_f32)
143159
TORCH_BINDING_COMMON_EXTENSION(relu_f32x4)
144160
TORCH_BINDING_COMMON_EXTENSION(relu_f16)
145161
TORCH_BINDING_COMMON_EXTENSION(relu_f16x2)
146162
TORCH_BINDING_COMMON_EXTENSION(relu_f16x8)
147-
TORCH_BINDING_COMMON_EXTENSION(relu_f32_v2)
148-
TORCH_BINDING_COMMON_EXTENSION(relu_f32x4_v2)
149-
TORCH_BINDING_COMMON_EXTENSION(relu_f16_v2)
150-
TORCH_BINDING_COMMON_EXTENSION(relu_f16x2_v2)
151-
TORCH_BINDING_COMMON_EXTENSION(relu_f16x8_v2)
163+
TORCH_BINDING_COMMON_EXTENSION(relu_f16x8_pack)
152164
}

relu/relu.py

Lines changed: 15 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -49,39 +49,27 @@ def run_benchmark(perf_func: callable, x: torch.Tensor, tag: str,
4949
total_time = (end - start) * 1000 # ms
5050
mean_time = total_time / iters
5151
out_info = f"out_{tag}"
52-
out_val = out.detach().cpu().numpy().tolist()[:2]
52+
out_val = out.flatten().detach().cpu().numpy().tolist()[:2]
5353
out_val = [round(v, 8) for v in out_val]
54-
print(f"{out_info:>15}: {out_val}, time:{mean_time:.8f}ms")
54+
print(f"{out_info:>18}: {out_val}, time:{mean_time:.8f}ms")
5555
if show_all: print(out)
5656
return out, mean_time
5757

5858

5959
print("-" * 80)
60-
N_ELEMENTS = 256*256*4
61-
x = torch.randn((N_ELEMENTS)).cuda().float()
62-
run_benchmark(lib.relu_f32, x, "f32")
63-
run_benchmark(lib.relu_f32x4, x, "f32x4")
64-
run_benchmark(torch.relu, x , "f32_th")
60+
S, K = 4096, 4096
61+
x = torch.randn((S, K)).cuda().float().contiguous()
62+
y = torch.zeros_like(x).cuda().float().contiguous()
63+
run_benchmark(lib.relu_f32, x, "f32", y)
64+
run_benchmark(lib.relu_f32x4, x, "f32x4", y)
65+
run_benchmark(torch.relu, x, "f32_th")
6566

6667
print("-" * 80)
67-
x_f16 = x.half()
68-
run_benchmark(lib.relu_f16, x_f16, "f16")
69-
run_benchmark(lib.relu_f16x2, x_f16, "f16x2")
70-
run_benchmark(lib.relu_f16x8, x_f16, "f16x8")
71-
run_benchmark(torch.relu, x_f16 , "f16_th")
72-
73-
print("-" * 80)
74-
# v2: no copy of y Tensor
75-
y = torch.zeros_like(x).cuda().float()
76-
run_benchmark(lib.relu_f32_v2, x, "f32(v2)", y)
77-
run_benchmark(lib.relu_f32x4_v2, x, "f32x4(v2)", y)
78-
run_benchmark(torch.relu, x , "f32_th")
79-
80-
print("-" * 80)
81-
# v2: no copy of y Tensor
82-
y_f16 = torch.zeros_like(x_f16).cuda().half()
83-
run_benchmark(lib.relu_f16_v2, x_f16, "f16(v2)", y_f16)
84-
run_benchmark(lib.relu_f16x2_v2, x_f16, "f16x2(v2)", y_f16)
85-
run_benchmark(lib.relu_f16x8_v2, x_f16, "f16x8(v2)", y_f16)
86-
run_benchmark(torch.relu, x_f16 , "f16_th")
68+
x_f16 = x.half().contiguous()
69+
y_f16 = y.half().contiguous()
70+
run_benchmark(lib.relu_f16, x_f16, "f16", y_f16)
71+
run_benchmark(lib.relu_f16x2, x_f16, "f16x2", y_f16)
72+
run_benchmark(lib.relu_f16x8, x_f16, "f16x8", y_f16)
73+
run_benchmark(lib.relu_f16x8_pack, x_f16, "f16x8pack", y_f16)
74+
run_benchmark(torch.relu, x_f16, "f16_th")
8775
print("-" * 80)

0 commit comments

Comments
 (0)