Skip to content

Commit e28cb4d

Browse files
authored
[DotProd][FP16] support f16x8_pack kernel (#45)
* Update dot_product.cu * Update dot_product.py * Update README.md * Update README.md
1 parent d96f83b commit e28cb4d

File tree

4 files changed

+119
-30
lines changed

4 files changed

+119
-30
lines changed

README.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -68,6 +68,7 @@
6868
| ✔️ [dot_product_f32x4](./dot-product/dot_product.cu)|f32|f32|[link](./dot-product/)|⭐️⭐️|
6969
| ✔️ [dot_product_f16_f32](./dot-product/dot_product.cu)|f16|f32|[link](./dot-product/)|⭐️⭐️|
7070
| ✔️ [dot_product_f16x2_f32](./dot-product/dot_product.cu)|f16|f32|[link](./dot-product/)|⭐️⭐️|
71+
| ✔️ [dot_product_f16x8_pack_f32](./dot-product/dot_product.cu)|f16|f32|[link](./dot-product/)|⭐️⭐️|
7172
| ✔️ [softmax_f32(memory fence)](./softmax/softmax.cu)|f32|f32|[link](./softmax/)|⭐️⭐️|
7273
| ✔️ [softmax_f32x4(memory fence)](./softmax/softmax.cu)|f32|f32|[link](./softmax/)|⭐️⭐️|
7374
| ✔️ [softmax_f32(per token)](./softmax/softmax.cu)|f32|f32|[link](./softmax/)|⭐️⭐️|

dot-product/README.md

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
- [X] dot_prod_f32x4_f32_kernel(float4向量化版本)
99
- [X] dot_prod_f16_f32_kernel(fp16版本,使用fp32 acc)
1010
- [X] dot_prod_f16x2_f32_kernel(fp16向量化版本,使用fp32 acc)
11+
- [X] dot_prod_f16x8_pack_f32_kernel(fp16向量化版本,使用fp32 acc, pack)
1112
- [X] PyTorch bindings
1213

1314
## 测试
@@ -22,12 +23,13 @@ python3 dot_product.py
2223

2324
```bash
2425
--------------------------------------------------------------------------------
25-
out_f32f32: -88.81410217 , time:0.01135945ms
26-
out_f32x4f32: -88.81417847 , time:0.01171017ms
27-
out_f32f32_th: -88.81379700 , time:0.01147819ms
26+
out_f32f32: -1534.59301758 , time:0.17350578ms
27+
out_f32x4f32: -1534.61364746 , time:0.18058038ms
28+
out_f32f32_th: -1534.61157227 , time:0.18307972ms
2829
--------------------------------------------------------------------------------
29-
out_f16f32: -88.62890625 , time:0.01113868ms
30-
out_f16x2f32: -88.65764618 , time:0.01108241ms
31-
out_f16f16_th: -88.75000000 , time:0.01112628ms
30+
out_f16f32: -1538.26318359 , time:0.10106802ms
31+
out_f16x2f32: -1537.58288574 , time:0.05217433ms
32+
out_f16x8packf32: -1536.44006348 , time:0.02096844ms
33+
out_f16f16_th: -1536.00000000 , time:0.02491832ms
3234
--------------------------------------------------------------------------------
3335
```

dot-product/dot_product.cu

Lines changed: 102 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
#define FLOAT4(value) (reinterpret_cast<float4*>(&(value))[0])
1616
#define HALF2(value) (reinterpret_cast<half2*>(&(value))[0])
1717
#define BFLOAT2(value) (reinterpret_cast<__nv_bfloat162*>(&(value))[0])
18+
#define LDST128BITS(value) (reinterpret_cast<float4*>(&(value))[0])
1819

1920
// -------------------------------------- FP32 --------------------------------------
2021
// Warp Reduce Sum
@@ -123,7 +124,7 @@ __global__ void dot_prod_f16_f32_kernel(half* a, half* b, float* y, int N) {
123124
if (tid == 0) atomicAdd(y, prod);
124125
}
125126

126-
template<const int NUM_THREADS = 256>
127+
template<const int NUM_THREADS = 256/2>
127128
__global__ void dot_prod_f16x2_f32_kernel(half* a, half* b, float* y, int N) {
128129
int tid = threadIdx.x;
129130
int idx = (blockIdx.x * NUM_THREADS + tid) * 2; // 2 half elements per thread
@@ -148,6 +149,38 @@ __global__ void dot_prod_f16x2_f32_kernel(half* a, half* b, float* y, int N) {
148149
if (tid == 0) atomicAdd(y, prod);
149150
}
150151

152+
template<const int NUM_THREADS = 256/8>
153+
__global__ void dot_prod_f16x8_pack_f32_kernel(half* a, half* b, float* y, int N) {
154+
int tid = threadIdx.x;
155+
int idx = (blockIdx.x * NUM_THREADS + tid) * 8; // 8 half elements per thread
156+
constexpr int NUM_WARPS = (NUM_THREADS + WARP_SIZE - 1) / WARP_SIZE;
157+
__shared__ float reduce_smem[NUM_WARPS];
158+
// temporary register(memory), .local space in ptx, addressable
159+
half pack_a[8], pack_b[8]; // 8x16 bits=128 bits.
160+
LDST128BITS(pack_a[0]) = LDST128BITS(a[idx]); // load 128 bits
161+
LDST128BITS(pack_b[0]) = LDST128BITS(b[idx]); // load 128 bits
162+
const half z = __float2half(0.0f);
163+
164+
half prod_f16 = z;
165+
#pragma unroll
166+
for (int i = 0; i < 8; i += 2) {
167+
half2 v = __hmul2(HALF2(pack_a[i]), HALF2(pack_b[i]));
168+
prod_f16 += (((idx + i ) < N) ? (v.x + v.y) : z);
169+
}
170+
171+
int warp = tid / WARP_SIZE;
172+
int lane = tid % WARP_SIZE;
173+
// perform warp sync reduce.
174+
float prod = warp_reduce_sum_f16_f32<WARP_SIZE>(prod_f16);
175+
// warp leaders store the data to shared memory.
176+
if (lane == 0) reduce_smem[warp] = prod;
177+
__syncthreads(); // make sure the data is in shared memory.
178+
// the first warp compute the final sum.
179+
prod = (lane < NUM_WARPS) ? reduce_smem[lane] : 0.0f;
180+
if (warp == 0) prod = warp_reduce_sum_f32<NUM_WARPS>(prod);
181+
if (tid == 0) atomicAdd(y, prod);
182+
}
183+
151184
// --------------------- PyTorch bindings for custom kernel -----------------------
152185
#define STRINGFY(str) #str
153186
#define TORCH_BINDING_COMMON_EXTENSION(func) \
@@ -159,8 +192,42 @@ if(((T).options().dtype() != (th_type))) { \
159192
throw std::runtime_error("values must be "#th_type); \
160193
}
161194

162-
#define CHECK_TORCH_TENSOR_SHAPE(T, S0) \
163-
if (((T).size(0) != (S0))) { throw std::runtime_error("Tensor size mismatch!"); }
195+
#define LANUCH_DOT_PROD_KERNEL(NT, packed_type, acc_type, element_type) \
196+
dot_prod_##packed_type##_##acc_type##_kernel<(NT)><<<grid, block>>>( \
197+
reinterpret_cast<element_type*>(a.data_ptr()), \
198+
reinterpret_cast<element_type*>(b.data_ptr()), \
199+
prod.data_ptr<float>(), N);
200+
201+
#define DISPATCH_DOT_PROD_KERNEL(K, packed_type, acc_type, element_type, n_elements) \
202+
const int NT = (K)/(n_elements); \
203+
dim3 block(NT); \
204+
dim3 grid((S)); \
205+
switch (NT) \
206+
{ \
207+
case 32: \
208+
LANUCH_DOT_PROD_KERNEL(32, packed_type, acc_type, element_type) \
209+
break; \
210+
case 64: \
211+
LANUCH_DOT_PROD_KERNEL(64, packed_type, acc_type, element_type) \
212+
break; \
213+
case 128: \
214+
LANUCH_DOT_PROD_KERNEL(128, packed_type, acc_type, element_type) \
215+
break; \
216+
case 256: \
217+
LANUCH_DOT_PROD_KERNEL(256, packed_type, acc_type, element_type) \
218+
break; \
219+
case 512: \
220+
LANUCH_DOT_PROD_KERNEL(512, packed_type, acc_type, element_type) \
221+
break; \
222+
case 1024: \
223+
LANUCH_DOT_PROD_KERNEL(1024, packed_type, acc_type, element_type) \
224+
break; \
225+
default: \
226+
throw std::runtime_error( \
227+
"only support (K)/(n_elements): 32/64/128/256/512/1024"); \
228+
break; \
229+
}
230+
164231

165232
#define TORCH_BINDING_DOT_PROD(packed_type, acc_type, th_type, element_type, n_elements) \
166233
torch::Tensor dot_prod_##packed_type##_##acc_type(torch::Tensor a, torch::Tensor b) { \
@@ -169,30 +236,49 @@ torch::Tensor dot_prod_##packed_type##_##acc_type(torch::Tensor a, torch::Tensor
169236
auto options = torch::TensorOptions().dtype(torch::kFloat32).device( \
170237
torch::kCUDA, 0); \
171238
auto prod = torch::zeros({1}, options); \
172-
const int N = a.size(0); \
173-
CHECK_TORCH_TENSOR_SHAPE(b, N) \
174-
static const int NUM_THREADS_PER_BLOCK = 256 / (n_elements); \
175-
const int NUM_BLOCKS = (N + 256 - 1) / 256; \
176-
dim3 block(NUM_THREADS_PER_BLOCK); \
177-
dim3 grid(NUM_BLOCKS); \
178-
dot_prod_##packed_type##_##acc_type##_kernel< \
179-
NUM_THREADS_PER_BLOCK><<<grid, block>>>( \
239+
const int ndim = a.dim(); \
240+
if (ndim != 2) { \
241+
int N = 1; \
242+
for (int i = 0; i < ndim; ++i) { N *= a.size(i); } \
243+
dim3 block(256); \
244+
dim3 grid(((N + 256 - 1) / 256) / (n_elements)); \
245+
dot_prod_##packed_type##_##acc_type##_kernel< \
246+
256 ><<<grid, block>>>( \
180247
reinterpret_cast<element_type*>(a.data_ptr()), \
181248
reinterpret_cast<element_type*>(b.data_ptr()), \
182249
prod.data_ptr<float>(), N); \
250+
} else { \
251+
const int S = a.size(0); \
252+
const int K = a.size(1); \
253+
const int N = S * K; \
254+
if ((K/(n_elements)) <= 1024) { \
255+
DISPATCH_DOT_PROD_KERNEL(K, packed_type, acc_type, element_type, n_elements) \
256+
} else { \
257+
int N = 1; \
258+
for (int i = 0; i < ndim; ++i) { N *= a.size(i); } \
259+
dim3 block(256); \
260+
dim3 grid(((N + 256 - 1) / 256) / (n_elements)); \
261+
dot_prod_##packed_type##_##acc_type##_kernel< \
262+
256 ><<<grid, block>>>( \
263+
reinterpret_cast<element_type*>(a.data_ptr()), \
264+
reinterpret_cast<element_type*>(b.data_ptr()), \
265+
prod.data_ptr<float>(), N); \
266+
} \
267+
} \
183268
return prod; \
184269
}
185270

186271
// packed_type, acc_type, th_type, element_type, n_elements_per_pack
187-
TORCH_BINDING_DOT_PROD(f32, f32, torch::kFloat32, float, 1)
188-
TORCH_BINDING_DOT_PROD(f32x4, f32, torch::kFloat32, float, 4)
189-
TORCH_BINDING_DOT_PROD(f16, f32, torch::kHalf, half, 1)
190-
TORCH_BINDING_DOT_PROD(f16x2, f32, torch::kHalf, half, 2)
191-
272+
TORCH_BINDING_DOT_PROD(f32, f32, torch::kFloat32, float, 1)
273+
TORCH_BINDING_DOT_PROD(f32x4, f32, torch::kFloat32, float, 4)
274+
TORCH_BINDING_DOT_PROD(f16, f32, torch::kHalf, half, 1)
275+
TORCH_BINDING_DOT_PROD(f16x2, f32, torch::kHalf, half, 2)
276+
TORCH_BINDING_DOT_PROD(f16x8_pack, f32, torch::kHalf, half, 8)
192277

193278
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
194279
TORCH_BINDING_COMMON_EXTENSION(dot_prod_f32_f32)
195280
TORCH_BINDING_COMMON_EXTENSION(dot_prod_f32x4_f32)
196281
TORCH_BINDING_COMMON_EXTENSION(dot_prod_f16_f32)
197282
TORCH_BINDING_COMMON_EXTENSION(dot_prod_f16x2_f32)
283+
TORCH_BINDING_COMMON_EXTENSION(dot_prod_f16x8_pack_f32)
198284
}

dot-product/dot_product.py

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -43,18 +43,18 @@ def run_benchmark(perf_func: callable, a: torch.Tensor, b: torch.Tensor, tag: st
4343

4444

4545
print("-" * 80)
46-
N_ELEMENTS = 256*92*16
47-
a = torch.randn((N_ELEMENTS)).cuda().float()
48-
b = torch.randn((N_ELEMENTS)).cuda().float()
46+
S, K = 4096, 4096
47+
a = torch.randn((S*K)).cuda().float()
48+
b = torch.randn((S*K)).cuda().float()
4949
run_benchmark(lib.dot_prod_f32_f32, a, b, "f32f32")
5050
run_benchmark(lib.dot_prod_f32x4_f32, a, b, "f32x4f32")
51-
run_benchmark(torch.dot, a, b , "f32f32_th")
51+
run_benchmark(torch.dot, a, b, "f32f32_th")
5252

5353
print("-" * 80)
5454
a_f16 = a.half()
5555
b_f16 = b.half()
56-
run_benchmark(lib.dot_prod_f16_f32, a_f16, b_f16, "f16f32")
57-
run_benchmark(lib.dot_prod_f16x2_f32, a_f16, b_f16, "f16x2f32")
58-
run_benchmark(torch.dot, a_f16, b_f16 , "f16f16_th")
59-
56+
run_benchmark(lib.dot_prod_f16_f32, a_f16, b_f16, "f16f32")
57+
run_benchmark(lib.dot_prod_f16x2_f32, a_f16, b_f16, "f16x2f32")
58+
run_benchmark(lib.dot_prod_f16x8_pack_f32, a_f16, b_f16, "f16x8packf32")
59+
run_benchmark(torch.dot, a_f16, b_f16, "f16f16_th")
6060
print("-" * 80)

0 commit comments

Comments
 (0)