Skip to content

Commit dfabac3

Browse files
authored
[HGEMV][Half] support hgemv k32/k128/f16 (#32)
* [Refactor][6/N] CUDA Learn Notes refactor Part-6 * [Refactor][6/N] CUDA Learn Notes refactor Part-6 * [HGEMM] Add slicked_k&t_8x8_sliced_k_f16x4 * [HGEMV] support hgemv k32/k128/f16
1 parent b417e20 commit dfabac3

File tree

4 files changed

+315
-3
lines changed

4 files changed

+315
-3
lines changed

README.md

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -82,9 +82,9 @@
8282
| ✔️ [sgemv_k32_f32_kernel](./sgemv/sgemv.cu)|f32|f32|[link](./sgemv/)|⭐️⭐️⭐️|
8383
| ✔️ [sgemv_k128_f32x4_kernel](./sgemv/sgemv.cu)|f32|f32|[link](./sgemv/)|⭐️⭐️⭐️|
8484
| ✔️ [sgemv_k16_f32_kernel](./sgemv/sgemv.cu)|f32|f32|[link](./sgemv/)|⭐️⭐️⭐️|
85-
| [hgemv_k32_f16_kernel](./hgemv)|f16|f16||⭐️⭐️⭐️|
86-
| [hgemv_k128_f16x2_kernel](./hgemv)|f16|f16||⭐️⭐️⭐️|
87-
| [hgemv_k16_f16_kernel](./hgemv)|f16|f16||⭐️⭐️⭐️|
85+
| ✔️ [hgemv_k32_f16_kernel](./hgemv)|f16|f16|[link](./hgemv/)|⭐️⭐️⭐️|
86+
| ✔️ [hgemv_k128_f16x4_kernel](./hgemv)|f16|f16|[link](./hgemv/)|⭐️⭐️⭐️|
87+
| ✔️ [hgemv_k16_f16_kernel](./hgemv)|f16|f16|[link](./hgemv/)|⭐️⭐️⭐️|
8888
| ✔️ [flash_attn_1_fwd_f32_kernel](./flash-attn/flash_attn_1_fwd_f32.cu)|f32|f32|[link](./flash-attn)|⭐️⭐️⭐️|
8989
|[flash_attn_2_fwd_f32_kernel](./flash-attn/flash_attn_2_fwd_f32.cu)|f32|f32|[link](./flash-attn)|⭐️⭐️⭐️|
9090
|[flash_attn_2_fwd_f16_kernel](./flash-attn/flash_attn_2_fwd_f32.cu)|f16|f32|[link](./flash-attn)|⭐️⭐️⭐️|

hgemv/README.md

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,31 @@
1+
# HGEMV
2+
3+
## 0x00 说明
4+
5+
包含以下内容:
6+
7+
- [X] hgemv_k32_f16_kernel
8+
- [X] hgemv_k128_f16x4_kernel
9+
- [X] hgemv_k16_f16_kernel
10+
- [X] PyTorch bindings
11+
12+
## 测试
13+
14+
```bash
15+
# 只测试Ada架构 不指定默认编译所有架构 耗时较长
16+
export TORCH_CUDA_ARCH_LIST=Ada
17+
python3 hgemv.py
18+
```
19+
20+
输出:
21+
22+
```bash
23+
--------------------------------------------------------------------------------
24+
out_k32f16: [-12.9843750, 11.6406250, -12.75], time:0.00751138ms
25+
out_k128f16x4: [-12.9765625, 11.6328125, -12.75], time:0.00726104ms
26+
out_f16_th: [-12.9765625, 11.6406250, -12.75], time:0.04772186ms
27+
--------------------------------------------------------------------------------
28+
out_k16f16: [5.85546875, 2.00390625, -1.79882812], time:0.02308726ms
29+
out_f16_th: [5.85156250, 2.00195312, -1.79882812], time:0.05379081ms
30+
--------------------------------------------------------------------------------
31+
```

hgemv/hgemv.cu

Lines changed: 205 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,205 @@
1+
#include <stdio.h>
2+
#include <stdlib.h>
3+
#include <float.h>
4+
#include <vector>
5+
#include <algorithm>
6+
#include <cuda_runtime.h>
7+
#include <cuda_fp16.h>
8+
#include <cuda_bf16.h>
9+
#include <cuda_fp8.h>
10+
#include <torch/types.h>
11+
#include <torch/extension.h>
12+
13+
#define WARP_SIZE 32
14+
#define INT4(value) (reinterpret_cast<int4*>(&(value))[0])
15+
#define FLOAT4(value) (reinterpret_cast<float4*>(&(value))[0])
16+
#define HALF2(value) (reinterpret_cast<half2*>(&(value))[0])
17+
#define BFLOAT2(value) (reinterpret_cast<__nv_bfloat162*>(&(value))[0])
18+
19+
// -------------------------------------- FP16 --------------------------------------
20+
// Warp Reduce Sum
21+
template<const int kWarpSize = WARP_SIZE>
22+
__device__ __forceinline__ half warp_reduce_sum_f16(half val) {
23+
#pragma unroll
24+
for (int mask = kWarpSize >> 1; mask >= 1; mask >>= 1) {
25+
val += __shfl_xor_sync(0xffffffff, val, mask);
26+
}
27+
return val;
28+
}
29+
30+
// HGEMV: Warp HGEMV K32
31+
// 假设K为32的倍数,每个warp负责一行
32+
// grid(M/4), block(32,4) blockDim.x=32=K, blockDim.y=4
33+
// a: MxK, x: Kx1, y: Mx1, compute: y = a * x
34+
__global__ void hgemv_k32_f16_kernel(half* a, half* x, half* y, int M, int K) {
35+
int tx = threadIdx.x; // 0~31
36+
int ty = threadIdx.y; // 0~4
37+
int bx = blockIdx.x; // 0~M/4
38+
int lane = tx % WARP_SIZE; // 0~31
39+
int m = bx * blockDim.y + ty; // (0~M/4) * 4 + (0~3)
40+
if (m < M) {
41+
half sum = 0.0f;
42+
int NUM_WARPS = (K + WARP_SIZE - 1) / WARP_SIZE;
43+
#pragma unroll
44+
for (int w = 0; w < NUM_WARPS; ++w) {
45+
// 若NUM_WARPS>=2,先将当前行的数据累加到第一个warp中
46+
int k = w * WARP_SIZE + lane;
47+
sum += a[m * K + k] * x[k];
48+
}
49+
sum = warp_reduce_sum_f16<WARP_SIZE>(sum);
50+
if (lane == 0) y[m] = sum;
51+
}
52+
}
53+
54+
// HGEMV: Warp HGEMV K128 + half2x2
55+
// 假设K为128的倍数 float4
56+
// grid(M/4), block(32,4) blockDim.x=32=K, blockDim.y=4
57+
// a: MxK, x: Kx1, y: Mx1, compute: y = a * x
58+
__global__ void hgemv_k128_f16x4_kernel(half* a, half* x, half* y, int M, int K) {
59+
// 每个线程负责4个元素,一个warp覆盖128个元素
60+
int tx = threadIdx.x; // 0~31
61+
int ty = threadIdx.y; // 0~3
62+
int bx = blockIdx.x; // 0~M/4
63+
int lane = tx % WARP_SIZE; // 0~31
64+
int m = blockDim.y * bx + ty; // (0~M/4) * 4 + (0~3)
65+
66+
if (m < M) {
67+
half sum = 0.0f;
68+
// process 4*WARP_SIZE elements per warp.
69+
int NUM_WARPS = (((K + WARP_SIZE - 1) / WARP_SIZE) + 4 - 1) / 4;
70+
#pragma unroll
71+
for (int w = 0; w < NUM_WARPS; ++w) {
72+
int k = (w * WARP_SIZE + lane) * 4;
73+
half2 reg_x_0 = HALF2(x[k + 0]);
74+
half2 reg_x_1 = HALF2(x[k + 2]);
75+
half2 reg_a_0 = HALF2(a[m * K + k + 0]);
76+
half2 reg_a_1 = HALF2(a[m * K + k + 2]);
77+
sum += (reg_x_0.x * reg_a_0.x + reg_x_0.y * reg_a_0.y
78+
+ reg_x_1.x * reg_a_1.x + reg_x_1.y * reg_a_1.y);
79+
}
80+
sum = warp_reduce_sum_f16<WARP_SIZE>(sum);
81+
if(lane == 0) y[m] = sum;
82+
}
83+
}
84+
85+
// HGEMV: Warp HGEMV K16
86+
// 假设K为16 < 32,每个warp负责2行,每行有16个元素
87+
// NUM_THREADS=128, NUM_WARPS=NUM_THREADS/WARP_SIZE;
88+
// NUM_ROWS=NUM_WARPS * ROW_PER_WARP, grid(M/NUM_ROWS), block(32,NUM_WARPS)
89+
// a: MxK, x: Kx1, y: Mx1, compute: y = a * x
90+
template<const int ROW_PER_WARP = 2>
91+
__global__ void hgemv_k16_f16_kernel(half* A, half* x, half* y, int M, int K) {
92+
constexpr int K_WARP_SIZE = (WARP_SIZE + ROW_PER_WARP - 1) / ROW_PER_WARP;
93+
int tx = threadIdx.x; // 0~31
94+
int ty = threadIdx.y; // 0~NUM_WARPS
95+
int bx = blockIdx.x; // 0~M/NUM_ROWS (NUM_ROWS=NUM_WARPS * ROW_PER_WARP)
96+
int lane = tx % WARP_SIZE; // 0~31
97+
int k = lane % K_WARP_SIZE; // 0~15
98+
// gloabl row of a: MxK and y:Mx1, blockDim.y=NUM_WARPS
99+
int m = (blockDim.y * bx + ty) * ROW_PER_WARP + lane / K_WARP_SIZE;
100+
if (m < M) {
101+
half sum = A[m * K + k] * x[k];
102+
sum = warp_reduce_sum_f16<K_WARP_SIZE>(sum);
103+
// 注意是k == 0,而不是lane == 0
104+
if(k == 0) y[m] = sum;
105+
}
106+
}
107+
108+
// --------------------- PyTorch bindings for custom kernel -----------------------
109+
#define STRINGFY(str) #str
110+
#define TORCH_BINDING_COMMON_EXTENSION(func) \
111+
m.def(STRINGFY(func), &func, STRINGFY(func));
112+
113+
#define CHECK_TORCH_TENSOR_DTYPE(T, th_type) \
114+
if(((T).options().dtype() != (th_type))) { \
115+
std::cout << "Tensor Info:" << (T).options() << std::endl; \
116+
throw std::runtime_error("values must be "#th_type); \
117+
}
118+
119+
#define CHECK_TORCH_TENSOR_SHAPE(T, S0, S1) \
120+
if (((T).size(0) != (S0)) || ((T).size(1) != (S1))) { \
121+
throw std::runtime_error("Tensor size mismatch!"); \
122+
}
123+
124+
#define ASSERT_K_IS_MULTIBLE_OF(V) \
125+
if (K % (V) != 0) { throw std::runtime_error("K must be multiple of "#V); }
126+
127+
#define ASSERT_K_IS_EQUAL_OF(V) \
128+
if (K != (V)) { throw std::runtime_error("K must be "#V);}
129+
130+
void hgemv_k32_f16(torch::Tensor a, torch::Tensor x, torch::Tensor y) {
131+
CHECK_TORCH_TENSOR_DTYPE(a, torch::kHalf)
132+
CHECK_TORCH_TENSOR_DTYPE(x, torch::kHalf)
133+
CHECK_TORCH_TENSOR_DTYPE(y, torch::kHalf)
134+
const int M = a.size(0);
135+
const int K = a.size(1);
136+
CHECK_TORCH_TENSOR_SHAPE(a, M, K)
137+
CHECK_TORCH_TENSOR_SHAPE(x, K, 1)
138+
CHECK_TORCH_TENSOR_SHAPE(y, M, 1)
139+
ASSERT_K_IS_MULTIBLE_OF(32)
140+
141+
dim3 block(32, 4);
142+
dim3 grid((M + 4 - 1) / 4);
143+
144+
hgemv_k32_f16_kernel<<<grid, block>>>(
145+
reinterpret_cast<half*>(a.data_ptr()),
146+
reinterpret_cast<half*>(x.data_ptr()),
147+
reinterpret_cast<half*>(y.data_ptr()),
148+
M, K
149+
);
150+
}
151+
152+
void hgemv_k128_f16x4(torch::Tensor a, torch::Tensor x, torch::Tensor y) {
153+
CHECK_TORCH_TENSOR_DTYPE(a, torch::kHalf)
154+
CHECK_TORCH_TENSOR_DTYPE(x, torch::kHalf)
155+
CHECK_TORCH_TENSOR_DTYPE(y, torch::kHalf)
156+
const int M = a.size(0);
157+
const int K = a.size(1);
158+
CHECK_TORCH_TENSOR_SHAPE(a, M, K)
159+
CHECK_TORCH_TENSOR_SHAPE(x, K, 1)
160+
CHECK_TORCH_TENSOR_SHAPE(y, M, 1)
161+
ASSERT_K_IS_MULTIBLE_OF(128)
162+
163+
dim3 block(32, 4);
164+
dim3 grid((M + 4 - 1) / 4);
165+
166+
hgemv_k128_f16x4_kernel<<<grid, block>>>(
167+
reinterpret_cast<half*>(a.data_ptr()),
168+
reinterpret_cast<half*>(x.data_ptr()),
169+
reinterpret_cast<half*>(y.data_ptr()),
170+
M, K
171+
);
172+
}
173+
174+
void hgemv_k16_f16(torch::Tensor a, torch::Tensor x, torch::Tensor y) {
175+
CHECK_TORCH_TENSOR_DTYPE(a, torch::kHalf)
176+
CHECK_TORCH_TENSOR_DTYPE(x, torch::kHalf)
177+
CHECK_TORCH_TENSOR_DTYPE(y, torch::kHalf)
178+
const int M = a.size(0);
179+
const int K = a.size(1);
180+
CHECK_TORCH_TENSOR_SHAPE(a, M, K)
181+
CHECK_TORCH_TENSOR_SHAPE(x, K, 1)
182+
CHECK_TORCH_TENSOR_SHAPE(y, M, 1)
183+
ASSERT_K_IS_EQUAL_OF(16)
184+
185+
constexpr int NUM_THREADS = 128;
186+
constexpr int ROW_PER_WARP = 2;
187+
constexpr int NUM_WARPS = NUM_THREADS / WARP_SIZE; // 4
188+
constexpr int NUM_ROWS = NUM_WARPS * ROW_PER_WARP; // 4 * 2 = 8
189+
190+
dim3 block(32, NUM_WARPS);
191+
dim3 grid((M + NUM_ROWS - 1) / NUM_ROWS);
192+
193+
hgemv_k16_f16_kernel<ROW_PER_WARP><<<grid, block>>>(
194+
reinterpret_cast<half*>(a.data_ptr()),
195+
reinterpret_cast<half*>(x.data_ptr()),
196+
reinterpret_cast<half*>(y.data_ptr()),
197+
M, K
198+
);
199+
}
200+
201+
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
202+
TORCH_BINDING_COMMON_EXTENSION(hgemv_k32_f16)
203+
TORCH_BINDING_COMMON_EXTENSION(hgemv_k128_f16x4)
204+
TORCH_BINDING_COMMON_EXTENSION(hgemv_k16_f16)
205+
}

hgemv/hgemv.py

Lines changed: 76 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,76 @@
1+
import torch
2+
import time
3+
from torch.utils.cpp_extension import load
4+
from functools import partial
5+
from typing import Optional
6+
7+
torch.set_grad_enabled(False)
8+
9+
# Load the CUDA kernel as a python module
10+
lib = load(name='hgemv_lib',
11+
sources=['hgemv.cu'],
12+
extra_cuda_cflags=[
13+
"-O3",
14+
"-U__CUDA_NO_HALF_OPERATORS__",
15+
"-U__CUDA_NO_HALF_CONVERSIONS__",
16+
"-U__CUDA_NO_HALF2_OPERATORS__",
17+
"-U__CUDA_NO_BFLOAT16_CONVERSIONS__",
18+
"--expt-relaxed-constexpr",
19+
"--expt-extended-lambda",
20+
"--use_fast_math"
21+
],
22+
extra_cflags=['-std=c++17'])
23+
24+
25+
def run_benchmark(perf_func: callable,
26+
a: torch.Tensor, b: torch.Tensor,
27+
tag: str, out: Optional[torch.Tensor] = None,
28+
warmup: int = 10, iters: int = 200,
29+
show_all: bool = False):
30+
if out is not None:
31+
out.fill_(0)
32+
if out is not None:
33+
for i in range(warmup):
34+
perf_func(a, b, out)
35+
else:
36+
for i in range(warmup):
37+
_ = perf_func(a, b)
38+
39+
torch.cuda.synchronize()
40+
start = time.time()
41+
# iters
42+
if out is not None:
43+
for i in range(iters):
44+
perf_func(a, b, out)
45+
else:
46+
for i in range(iters):
47+
out = perf_func(a, b)
48+
torch.cuda.synchronize()
49+
end = time.time()
50+
total_time = (end - start) * 1000 # ms
51+
mean_time = total_time / iters
52+
out_info = f"out_{tag}"
53+
out_val = out.flatten().detach().cpu().numpy().tolist()[:3]
54+
out_val = [round(v, 8) for v in out_val]
55+
print(f"{out_info:>13}: {out_val}, time:{mean_time:.8f}ms")
56+
if show_all: print(out)
57+
return out.clone(), mean_time
58+
59+
60+
print("-" * 80)
61+
M, N, K = 1024, 1, 128
62+
a = torch.randn((M, K)).cuda().half().contiguous()
63+
b = torch.randn((K, N)).cuda().half().contiguous()
64+
c = torch.randn((M, N)).cuda().half().contiguous()
65+
run_benchmark(lib.hgemv_k32_f16, a, b, "k32f16", c)
66+
run_benchmark(lib.hgemv_k128_f16x4, a, b, "k128f16x4", c)
67+
run_benchmark(partial(torch.matmul, out=c), a, b, "f16_th")
68+
print("-" * 80)
69+
70+
M, N, K = 1024, 1, 16
71+
a = torch.randn((M, K)).cuda().half().contiguous()
72+
b = torch.randn((K, N)).cuda().half().contiguous()
73+
c = torch.randn((M, N)).cuda().half().contiguous()
74+
run_benchmark(lib.hgemv_k16_f16, a, b, "k16f16", c)
75+
run_benchmark(partial(torch.matmul, out=c), a, b, "f16_th")
76+
print("-" * 80)

0 commit comments

Comments
 (0)