Skip to content

Commit 068e6fe

Browse files
authored
[Elementwise][Half] support f16x8_pack kernel, boost 1.1x (#40)
* Update elementwise.cu * Update elementwise.py * Update README.md * Update README.md
1 parent 8529e9d commit 068e6fe

File tree

4 files changed

+83
-94
lines changed

4 files changed

+83
-94
lines changed

README.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@
2626
| ✔️ [elementwise_f16](./elementwise/elementwise.cu)|f16|/|[link](./elementwise/)|⭐️|
2727
| ✔️ [elementwise_f16x2](./elementwise/elementwise.cu)|f16|/|[link](./elementwise/)|⭐️|
2828
| ✔️ [elementwise_f16x8](./elementwise/elementwise.cu)|f16|/|[link](./elementwise/)|⭐️|
29+
| ✔️ [elementwise_f16x8_pack](./elementwise/elementwise.cu)|f16|/|[link](./elementwise/)|⭐️⭐️|
2930
| ✔️ [histogram_i32](./histogram/histogram.cu)|i32|/|[link](./histogram/)|⭐️|
3031
| ✔️ [histogram_i32x4](./histogram/histogram.cu)|i32|/|[link](./histogram/)|⭐️|
3132
| ✔️ [sigmoid_f32](./sigmoid/sigmoid.cu)|f32|/|[link](./sigmoid/)|⭐️|

elementwise/README.md

Lines changed: 9 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
- [X] elementwise_add_f16_kernel(fp16版本)
1010
- [X] elementwise_add_f16x2_kernel(fp16向量化版本)
1111
- [X] elementwise_add_f16x8_kernel(fp16向量化版本)
12+
- [X] elementwise_add_f16x8_pack_kernel(fp16向量化版本, pack)
1213
- [X] PyTorch bindings
1314

1415

@@ -24,22 +25,14 @@ python3 elementwise.py
2425

2526
```bash
2627
--------------------------------------------------------------------------------
27-
out_f32: [-1.8014312982559204, 0.38691335916519165], time:0.01107502ms
28-
out_f32x4: [-1.8014312982559204, 0.38691335916519165], time:0.01091743ms
29-
out_f32_th: [-1.8014312982559204, 0.38691335916519165], time:0.00744152ms
28+
out_f32: [-1.53079593, 0.52963573], time:0.28430200ms
29+
out_f32x4: [-1.53079593, 0.52963573], time:0.29020834ms
30+
out_f32_th: [-1.53079593, 0.52963573], time:0.29701710ms
3031
--------------------------------------------------------------------------------
31-
out_f16: [-1.80078125, 0.38671875], time:0.01076937ms
32-
out_f16x2: [-1.80078125, 0.38671875], time:0.01071215ms
33-
out_f16x8: [-1.80078125, 0.38671875], time:0.01074862ms
34-
out_f16_th: [-1.80078125, 0.38671875], time:0.00737953ms
35-
--------------------------------------------------------------------------------
36-
out_f32(v2): [-1.8014312982559204, 0.38691335916519165], time:0.00359011ms
37-
out_f32x4(v2): [-1.8014312982559204, 0.38691335916519165], time:0.00357652ms
38-
out_f32_th: [-1.8014312982559204, 0.38691335916519165], time:0.00575542ms
39-
--------------------------------------------------------------------------------
40-
out_f16(v2): [-1.80078125, 0.38671875], time:0.00358772ms
41-
out_f16x2(v2): [-1.80078125, 0.38671875], time:0.00354576ms
42-
out_f16x8(v2): [-1.80078125, 0.38671875], time:0.00353265ms
43-
out_f16_th: [-1.80078125, 0.38671875], time:0.00590253ms
32+
out_f16: [-1.53027344, 0.52929688], time:0.05925465ms
33+
out_f16x2: [-1.53027344, 0.52929688], time:0.04892802ms
34+
out_f16x8: [-1.53027344, 0.52929688], time:0.04291439ms
35+
out_f16x8pack: [-1.53027344, 0.52929688], time:0.03846574ms
36+
out_f16_th: [-1.53027344, 0.52929688], time:0.04044223ms
4437
--------------------------------------------------------------------------------
4538
```

elementwise/elementwise.cu

Lines changed: 55 additions & 48 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
#define FLOAT4(value) (reinterpret_cast<float4*>(&(value))[0])
1616
#define HALF2(value) (reinterpret_cast<half2*>(&(value))[0])
1717
#define BFLOAT2(value) (reinterpret_cast<__nv_bfloat162*>(&(value))[0])
18+
#define LDST128BITS(value) (reinterpret_cast<float4*>(&(value))[0])
1819

1920
// -------------------------------------- FP32 --------------------------------------
2021
// ElementWise Add
@@ -95,6 +96,23 @@ __global__ void elementwise_add_f16x8_kernel(half* a, half* b, half* c, int N) {
9596
if ((idx + 6) < N) { HALF2(c[idx + 6]) = reg_c_3; }
9697
}
9798

99+
__global__ void elementwise_add_f16x8_pack_kernel(half* a, half* b, half* c, int N) {
100+
int idx = 8 * (blockIdx.x * blockDim.x + threadIdx.x);
101+
// temporary register(memory), .local space in ptx, addressable
102+
half pack_a[8], pack_b[8], pack_c[8]; // 8x16 bits=128 bits.
103+
// reinterpret as float4 and load 128 bits in 1 memory issue.
104+
LDST128BITS(pack_a[0]) = LDST128BITS(a[idx]); // load 128 bits
105+
LDST128BITS(pack_b[0]) = LDST128BITS(b[idx]); // load 128 bits
106+
107+
#pragma unroll
108+
for (int i = 0; i < 8; i += 2) {
109+
// __hadd2 for half2 x 4
110+
HALF2(pack_c[i]) = __hadd2(HALF2(pack_a[i]), HALF2(pack_b[i]));
111+
}
112+
// reinterpret as float4 and store 128 bits in 1 memory issue.
113+
if ((idx + 7) < N) { LDST128BITS(c[idx]) = LDST128BITS(pack_c[0]); }
114+
}
115+
98116

99117
// --------------------- PyTorch bindings for custom kernel -----------------------
100118
#define STRINGFY(str) #str
@@ -107,70 +125,59 @@ if(((T).options().dtype() != (th_type))) { \
107125
throw std::runtime_error("values must be "#th_type); \
108126
}
109127

110-
#define CHECK_TORCH_TENSOR_SHAPE(T, S0) \
111-
if (((T).size(0) != (S0))) { throw std::runtime_error("Tensor size mismatch!"); }
112-
113128
#define TORCH_BINDING_ELEM_ADD(packed_type, th_type, element_type, n_elements) \
114-
torch::Tensor elementwise_add_##packed_type(torch::Tensor a, torch::Tensor b) { \
115-
CHECK_TORCH_TENSOR_DTYPE(a, (th_type)) \
116-
CHECK_TORCH_TENSOR_DTYPE(b, (th_type)) \
117-
auto options = torch::TensorOptions().dtype((th_type)).device( \
118-
torch::kCUDA, 0); \
119-
const int N = a.size(0); \
120-
CHECK_TORCH_TENSOR_SHAPE(b, N) \
121-
auto c = torch::zeros({N}, options); \
122-
static const int NUM_THREADS_PER_BLOCK = 256 / (n_elements); \
123-
const int NUM_BLOCKS = (N + 256 - 1) / 256; \
124-
dim3 block(NUM_THREADS_PER_BLOCK); \
125-
dim3 grid(NUM_BLOCKS); \
126-
elementwise_add_##packed_type##_kernel<<<grid, block>>>( \
127-
reinterpret_cast<element_type*>(a.data_ptr()), \
128-
reinterpret_cast<element_type*>(b.data_ptr()), \
129-
reinterpret_cast<element_type*>(c.data_ptr()), N); \
130-
return c; \
131-
}
132-
133-
#define TORCH_BINDING_ELEM_ADD_V2(packed_type, th_type, element_type, n_elements)\
134-
void elementwise_add_##packed_type##_v2( \
129+
void elementwise_add_##packed_type( \
135130
torch::Tensor a, torch::Tensor b, torch::Tensor c) { \
136131
CHECK_TORCH_TENSOR_DTYPE(a, (th_type)) \
137132
CHECK_TORCH_TENSOR_DTYPE(b, (th_type)) \
138133
CHECK_TORCH_TENSOR_DTYPE(c, (th_type)) \
139-
const int N = a.size(0); \
140-
CHECK_TORCH_TENSOR_SHAPE(b, N) \
141-
CHECK_TORCH_TENSOR_SHAPE(c, N) \
142-
static const int NUM_THREADS_PER_BLOCK = 256 / (n_elements); \
143-
const int NUM_BLOCKS = (N + 256 - 1) / 256; \
144-
dim3 block(NUM_THREADS_PER_BLOCK); \
145-
dim3 grid(NUM_BLOCKS); \
146-
elementwise_add_##packed_type##_kernel<<<grid, block>>>( \
134+
const int ndim = a.dim(); \
135+
if (ndim != 2) { \
136+
int N = 1; \
137+
for (int i = 0; i < ndim; ++i) { N *= a.size(i); } \
138+
dim3 block(256 / (n_elements)); \
139+
dim3 grid((N + 256 - 1) / 256); \
140+
elementwise_add_##packed_type##_kernel<<<grid, block>>>( \
147141
reinterpret_cast<element_type*>(a.data_ptr()), \
148142
reinterpret_cast<element_type*>(b.data_ptr()), \
149143
reinterpret_cast<element_type*>(c.data_ptr()), N); \
144+
} else { \
145+
const int S = a.size(0); \
146+
const int K = a.size(1); \
147+
const int N = S * K; \
148+
if ((K/(n_elements)) <= 1024) { \
149+
dim3 block(K/(n_elements)); \
150+
dim3 grid(S); \
151+
elementwise_add_##packed_type##_kernel<<<grid, block>>>( \
152+
reinterpret_cast<element_type*>(a.data_ptr()), \
153+
reinterpret_cast<element_type*>(b.data_ptr()), \
154+
reinterpret_cast<element_type*>(c.data_ptr()), N); \
155+
} else { \
156+
int N = 1; \
157+
for (int i = 0; i < ndim; ++i) { N *= a.size(i); } \
158+
dim3 block(256 / (n_elements)); \
159+
dim3 grid((N + 256 - 1) / 256); \
160+
elementwise_add_##packed_type##_kernel<<<grid, block>>>( \
161+
reinterpret_cast<element_type*>(a.data_ptr()), \
162+
reinterpret_cast<element_type*>(b.data_ptr()), \
163+
reinterpret_cast<element_type*>(c.data_ptr()), N); \
164+
} \
165+
} \
150166
}
151167

152168

153-
TORCH_BINDING_ELEM_ADD(f32, torch::kFloat32, float, 1)
154-
TORCH_BINDING_ELEM_ADD(f32x4, torch::kFloat32, float, 4)
155-
TORCH_BINDING_ELEM_ADD(f16, torch::kHalf, half, 1)
156-
TORCH_BINDING_ELEM_ADD(f16x2, torch::kHalf, half, 2)
157-
TORCH_BINDING_ELEM_ADD(f16x8, torch::kHalf, half, 8)
158-
// v2: no copy of c Tensor
159-
TORCH_BINDING_ELEM_ADD_V2(f32, torch::kFloat32, float, 1)
160-
TORCH_BINDING_ELEM_ADD_V2(f32x4, torch::kFloat32, float, 4)
161-
TORCH_BINDING_ELEM_ADD_V2(f16, torch::kHalf, half, 1)
162-
TORCH_BINDING_ELEM_ADD_V2(f16x2, torch::kHalf, half, 2)
163-
TORCH_BINDING_ELEM_ADD_V2(f16x8, torch::kHalf, half, 8)
169+
TORCH_BINDING_ELEM_ADD(f32, torch::kFloat32, float, 1)
170+
TORCH_BINDING_ELEM_ADD(f32x4, torch::kFloat32, float, 4)
171+
TORCH_BINDING_ELEM_ADD(f16, torch::kHalf, half, 1)
172+
TORCH_BINDING_ELEM_ADD(f16x2, torch::kHalf, half, 2)
173+
TORCH_BINDING_ELEM_ADD(f16x8, torch::kHalf, half, 8)
174+
TORCH_BINDING_ELEM_ADD(f16x8_pack, torch::kHalf, half, 8)
164175

165176
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
166177
TORCH_BINDING_COMMON_EXTENSION(elementwise_add_f32)
167178
TORCH_BINDING_COMMON_EXTENSION(elementwise_add_f32x4)
168179
TORCH_BINDING_COMMON_EXTENSION(elementwise_add_f16)
169180
TORCH_BINDING_COMMON_EXTENSION(elementwise_add_f16x2)
170181
TORCH_BINDING_COMMON_EXTENSION(elementwise_add_f16x8)
171-
TORCH_BINDING_COMMON_EXTENSION(elementwise_add_f32_v2)
172-
TORCH_BINDING_COMMON_EXTENSION(elementwise_add_f32x4_v2)
173-
TORCH_BINDING_COMMON_EXTENSION(elementwise_add_f16_v2)
174-
TORCH_BINDING_COMMON_EXTENSION(elementwise_add_f16x2_v2)
175-
TORCH_BINDING_COMMON_EXTENSION(elementwise_add_f16x8_v2)
182+
TORCH_BINDING_COMMON_EXTENSION(elementwise_add_f16x8_pack)
176183
}

elementwise/elementwise.py

Lines changed: 18 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -49,41 +49,29 @@ def run_benchmark(perf_func: callable, a: torch.Tensor, b: torch.Tensor, tag: st
4949
total_time = (end - start) * 1000 # ms
5050
mean_time = total_time / iters
5151
out_info = f"out_{tag}"
52-
out_val = out.detach().cpu().numpy().tolist()[:2]
53-
print(f"{out_info:>14}: {out_val}, time:{mean_time:.8f}ms")
52+
out_val = out.flatten().detach().cpu().numpy().tolist()[:2]
53+
out_val = [round(v, 8) for v in out_val]
54+
print(f"{out_info:>18}: {out_val}, time:{mean_time:.8f}ms")
5455
if show_all: print(out)
5556
return out, mean_time
5657

5758

5859
print("-" * 80)
59-
N_ELEMENTS = 256*92*4
60-
a = torch.randn((N_ELEMENTS)).cuda().float()
61-
b = torch.randn((N_ELEMENTS)).cuda().float()
62-
run_benchmark(lib.elementwise_add_f32, a, b, "f32")
63-
run_benchmark(lib.elementwise_add_f32x4, a, b, "f32x4")
64-
run_benchmark(torch.add, a, b, "f32_th")
60+
S, K = 4096, 4096
61+
a = torch.randn((S, K)).cuda().float().contiguous()
62+
b = torch.randn((S, K)).cuda().float().contiguous()
63+
c = torch.zeros_like(a).cuda().float().contiguous()
64+
run_benchmark(lib.elementwise_add_f32, a, b, "f32", c)
65+
run_benchmark(lib.elementwise_add_f32x4, a, b, "f32x4", c)
66+
run_benchmark(partial(torch.add, out=c), a, b, "f32_th")
6567

6668
print("-" * 80)
67-
a_f16 = a.half()
68-
b_f16 = b.half()
69-
run_benchmark(lib.elementwise_add_f16, a_f16, b_f16, "f16")
70-
run_benchmark(lib.elementwise_add_f16x2, a_f16, b_f16, "f16x2")
71-
run_benchmark(lib.elementwise_add_f16x8, a_f16, b_f16, "f16x8")
72-
run_benchmark(torch.add, a_f16, b_f16, "f16_th")
73-
74-
print("-" * 80)
75-
# v2: no copy of c Tensor
76-
c = torch.zeros_like(a).cuda().float()
77-
run_benchmark(lib.elementwise_add_f32_v2, a, b, "f32(v2)", c)
78-
run_benchmark(lib.elementwise_add_f32x4_v2, a, b, "f32x4(v2)", c)
79-
run_benchmark(partial(torch.add, out=c), a, b, "f32_th")
80-
81-
print("-" * 80)
82-
# v2: no copy of c Tensor
83-
c_f16 = torch.zeros_like(a_f16).cuda().half()
84-
run_benchmark(lib.elementwise_add_f16_v2, a_f16, b_f16, "f16(v2)", c_f16)
85-
run_benchmark(lib.elementwise_add_f16x2_v2, a_f16, b_f16, "f16x2(v2)", c_f16)
86-
run_benchmark(lib.elementwise_add_f16x8_v2, a_f16, b_f16, "f16x8(v2)", c_f16)
87-
run_benchmark(partial(torch.add, out=c_f16), a_f16, b_f16, "f16_th")
88-
69+
a_f16 = a.half().contiguous()
70+
b_f16 = b.half().contiguous()
71+
c_f16 = c.half().contiguous()
72+
run_benchmark(lib.elementwise_add_f16, a_f16, b_f16, "f16", c_f16)
73+
run_benchmark(lib.elementwise_add_f16x2, a_f16, b_f16, "f16x2", c_f16)
74+
run_benchmark(lib.elementwise_add_f16x8, a_f16, b_f16, "f16x8", c_f16)
75+
run_benchmark(lib.elementwise_add_f16x8_pack, a_f16, b_f16, "f16x8pack", c_f16)
76+
run_benchmark(partial(torch.add, out=c_f16), a_f16, b_f16, "f16_th")
8977
print("-" * 80)

0 commit comments

Comments
 (0)