Skip to content

Commit b417e20

Browse files
authored
[HGEMM] Add slicked_k&t_8x8_sliced_k_f16x4 (#31)
* [Refactor][6/N] CUDA Learn Notes refactor Part-6 * [Refactor][6/N] CUDA Learn Notes refactor Part-6 * [HGEMM] Add slicked_k&t_8x8_sliced_k_f16x4
1 parent ab1f9bf commit b417e20

File tree

4 files changed

+327
-2
lines changed

4 files changed

+327
-2
lines changed

README.md

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -77,8 +77,8 @@
7777
| ✔️ [rms_norm_f16_f32_kernel(per token)](./rms-norm/rms_norm.cu)|f16|f32|[link](./rms-norm/)|⭐️⭐️|
7878
| ✔️ [sgemm_sliced_k_f32_kernel](./sgemm/sgemm.cu)|f32|f32|[link](./sgemm/)|⭐️⭐️⭐️|
7979
| ✔️ [sgemm_t_8x8_sliced_k_f32x4_kernel](./sgemm/sgemm.cu)|f32|f32|[link](./sgemm/)|⭐️⭐️⭐️|
80-
| [hgemm_sliced_k_f16_f32_kernel](./hgemm)|f16|f32||⭐️⭐️⭐️|
81-
| [hgemm_t_8x8_sliced_k_f16x2_f32_kernel](./hgemm)|f16|f32||⭐️⭐️⭐️|
80+
| ✔️ [hgemm_sliced_k_f16_kernel](./hgemm)|f16|f16|[link](./hgemm/)|⭐️⭐️⭐️|
81+
| ✔️ [hgemm_t_8x8_sliced_k_f16x4_kernel](./hgemm)|f16|f16|[link](./hgemm/)|⭐️⭐️⭐️|
8282
| ✔️ [sgemv_k32_f32_kernel](./sgemv/sgemv.cu)|f32|f32|[link](./sgemv/)|⭐️⭐️⭐️|
8383
| ✔️ [sgemv_k128_f32x4_kernel](./sgemv/sgemv.cu)|f32|f32|[link](./sgemv/)|⭐️⭐️⭐️|
8484
| ✔️ [sgemv_k16_f32_kernel](./sgemv/sgemv.cu)|f32|f32|[link](./sgemv/)|⭐️⭐️⭐️|

hgemm/README.md

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,27 @@
1+
# HGEMM
2+
3+
## 0x00 说明
4+
5+
包含以下内容:
6+
7+
- [X] hgemm_sliced_k_f16_kernel
8+
- [X] hgemm_t_8x8_sliced_k_f16x4_kernel
9+
- [X] PyTorch bindings
10+
11+
## 测试
12+
13+
```bash
14+
# 只测试Ada架构 不指定默认编译所有架构 耗时较长
15+
export TORCH_CUDA_ARCH_LIST=Ada
16+
python3 hgemm.py
17+
```
18+
19+
输出:
20+
21+
```bash
22+
--------------------------------------------------------------------------------
23+
out_f16(sk): [-1.08691406, 14.2890625, -0.57226562], time:0.38200140ms
24+
out_f16x4(t8x8sk): [-1.08691406, 14.2890625, -0.57226562], time:0.06475449ms
25+
out_f16_th: [-1.08398438, 14.3046875, -0.56152344], time:0.02875686ms
26+
--------------------------------------------------------------------------------
27+
```

hgemm/hgemm.cu

Lines changed: 229 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,229 @@
1+
#include <stdio.h>
2+
#include <stdlib.h>
3+
#include <float.h>
4+
#include <vector>
5+
#include <algorithm>
6+
#include <cuda_runtime.h>
7+
#include <cuda_fp16.h>
8+
#include <cuda_bf16.h>
9+
#include <cuda_fp8.h>
10+
#include <torch/types.h>
11+
#include <torch/extension.h>
12+
13+
#define WARP_SIZE 32
14+
#define INT4(value) (reinterpret_cast<int4*>(&(value))[0])
15+
#define FLOAT4(value) (reinterpret_cast<float4*>(&(value))[0])
16+
#define HALF2(value) (reinterpret_cast<half2*>(&(value))[0])
17+
#define BFLOAT2(value) (reinterpret_cast<__nv_bfloat162*>(&(value))[0])
18+
19+
// -------------------------------------- FP16 --------------------------------------
20+
// HGEMM: Block Tile + K Tile, with smem
21+
// Block Tile (BM, BN) + K Tile (BK=32)
22+
// grid((N + BN - 1) / BN, (M + BM - 1) / BM), block(BN, BM)
23+
// a: MxK, b: KxN, c: MxN, compute: c = a * b, all row major
24+
template<const int BM=32, const int BN=32, const int BK=32>
25+
__global__ void hgemm_sliced_k_f16_kernel(half* a, half* b, half* c, int M, int N, int K) {
26+
// [1] Block Tile: 32x32的block处理c上一块32x32的元素计算
27+
// [2] K Tile: 使用共享内存,并将K分块为BK大小的块
28+
__shared__ half s_a[BM][BK], s_b[BK][BN];
29+
30+
int bx = blockIdx.x;
31+
int by = blockIdx.y;
32+
int tx = threadIdx.x;
33+
int ty = threadIdx.y;
34+
int tid = threadIdx.y * blockDim.x + tx; // tid within the block
35+
// load values to shared memory, 32x32 threads working together
36+
// to fetch data along the row direction of a and b both for s_a
37+
// and s_b 32x32x4x2=8KB, we use 32x32 threads within block to
38+
// load 32x32 elements from global memory to shared memory, namely,
39+
// each thread will load 1 element.
40+
int load_smem_a_m = tid / 32; // 0~31, tid / 32, tid / BM, threadIdx.y
41+
int load_smem_a_k = tid % 32; // 0~31, tid % 32, tid % BK, threadIdx.x
42+
int load_smem_b_k = tid / 32; // 0~31, tid / 32, tid / BK, threadIdx.y
43+
int load_smem_b_n = tid % 32; // 0~31, tid % 32, tid % BN, threadIdx.x
44+
int load_gmem_a_m = by * BM + load_smem_a_m; // global row of a and c
45+
int load_gmem_b_n = bx * BN + load_smem_b_n; // global col of b and c
46+
// if (load_gmem_a_m >= M || load_gmem_b_n >= N) return;
47+
48+
half sum = __float2half(0.f);
49+
for (int bk = 0; bk < (K + BK - 1) / BK; ++bk) {
50+
int load_gmem_a_k = bk * BK + load_smem_a_k;
51+
int load_gmem_a_addr = load_gmem_a_m * K + load_gmem_a_k;
52+
s_a[load_smem_a_m][load_smem_a_k] = a[load_gmem_a_addr];
53+
int load_gmem_b_k = bk * BK + load_smem_b_k;
54+
int load_gmem_b_addr = load_gmem_b_k * N + load_gmem_b_n;
55+
s_b[load_smem_b_k][load_smem_b_n] = b[load_gmem_b_addr];
56+
__syncthreads();
57+
#pragma unroll
58+
for (int k = 0; k < BK; ++k) {
59+
int comp_smem_a_m = load_smem_a_m;
60+
int comp_smem_b_n = load_smem_b_n;
61+
sum += s_a[comp_smem_a_m][k] * s_b[k][comp_smem_b_n];
62+
}
63+
__syncthreads();
64+
}
65+
int store_gmem_c_m = load_gmem_a_m;
66+
int store_gmem_c_n = load_gmem_b_n;
67+
int store_gmem_c_addr = store_gmem_c_m * N + store_gmem_c_n;
68+
c[store_gmem_c_addr] = sum;
69+
}
70+
71+
// HGEMM: Block Tile + Thread Tile + K Tile + Vec4, with smem
72+
// BK:TILE_K=8 BM=BN=128
73+
// TM=TN=8 增加计算密度 BM/TM=16 BN/TN=16
74+
// dim3 blockDim(BN/TN, BM/TM);
75+
// dim3 gridDim((N + BN - 1) / BN, (M + BM - 1) / BM)
76+
template<const int BM=128, const int BN=128, const int BK=8, const int TM=8, const int TN=8>
77+
__global__ void hgemm_t_8x8_sliced_k_f16x4_kernel(half* a, half* b, half* c, int M, int N, int K) {
78+
// [1] Block Tile: 一个16x16的block处理C上大小为128X128的一个目标块
79+
// [2] Thread Tile: 每个thread负责计算TM*TN(8*8)个元素,增加计算密度
80+
// [3] K Tile: 将K分块,每块BK大小,迭代(K+BK-1/BK)次,
81+
// 每次计算TM*TN个元素各自的部分乘累加
82+
// [4] Vectorize: 减少load和store指令,使用half2
83+
84+
// 线程总数16x16=256,每个线程负责计算8x8的元素
85+
int bx = blockIdx.x;
86+
int by = blockIdx.y;
87+
int tx = threadIdx.x;
88+
int ty = threadIdx.y;
89+
int tid = threadIdx.y * blockDim.x + tx; // tid within the block
90+
__shared__ half s_a[BM][BK], s_b[BK][BN]; // 2*128*8*2=4KB
91+
92+
// 0. 先计算shared memory中的索引
93+
// tid和需要加载的smem s_a[BM][BK] 之间的索引关系 BM=128 BK=8 按行读取 A行主序
94+
// 对于s_a每行8个数据,每个线程读取4个,需要2个线程;总共128行,需要128x2刚好256线程
95+
int load_smem_a_m = tid / 2; // tid/2 (128/8)*(128/8)=256 threads per block, tid/2->[0,128), BM=128 0~127
96+
int load_smem_a_k = (tid % 2 == 0) ? 0 : 4; // (tid%2 == 0) ? 0 : 4, col of s_a 0,4
97+
// tid和需要加载的smem s_b[BK][BN] 之间的索引关系 BK=8 BN=128 按行读取 B行主序
98+
// 对于s_b每行128个数据,每个线程读4个数据,需要32个线程;总共8行,需要32x8=256个线程
99+
int load_smem_b_k = tid / 32; // tid/32, row of s_b 256/32=8 行 0~7
100+
int load_smem_b_n = (tid % 32) * 4; // (tid % 32) * 4, col of s_b 0,4,...,124
101+
// 1. 再计算全局内存中的索引
102+
// 要加载到s_a中的元素对应到A全局内存中的行数 每个block负责出C中大小为BM*BN的块
103+
int load_gmem_a_m = by * BM + load_smem_a_m; // global row of a and c
104+
int load_gmem_b_n = bx * BN + load_smem_b_n; // global col of b and c
105+
106+
half r_c[TM][TN] = {__float2half(0.0f)}; // 8x8
107+
// 2. 先对K进行分块,每块BK大小
108+
for (int bk = 0; bk < (K + BK - 1) / BK; ++bk) {
109+
// 加载数据到共享内存smem s_a BM*BK 128*8 vectorize float4
110+
int load_gmem_a_k = bk * BK + load_smem_a_k; // global col of a
111+
int load_gmem_a_addr = load_gmem_a_m * K + load_gmem_a_k;
112+
HALF2(s_a[load_smem_a_m][load_smem_a_k + 0]) = HALF2(a[load_gmem_a_addr + 0]);
113+
HALF2(s_a[load_smem_a_m][load_smem_a_k + 2]) = HALF2(a[load_gmem_a_addr + 2]);
114+
// 加载数据到共享内存smem s_b BK*BN 8*128 vectorize float4
115+
int load_gmem_b_k = bk * BK + load_smem_b_k; // global row of b
116+
int load_gmem_b_addr = load_gmem_b_k * N + load_gmem_b_n;
117+
HALF2(s_b[load_smem_b_k][load_smem_b_n + 0]) = HALF2(b[load_gmem_b_addr + 0]);
118+
HALF2(s_b[load_smem_b_k][load_smem_b_n + 2]) = HALF2(b[load_gmem_b_addr + 2]);
119+
__syncthreads();
120+
#pragma unroll
121+
for (int k = 0; k < BK; k++) {
122+
// 3. 每个线程负责计算BM*BN(12x128)中的TM*TN(8x8)个元素
123+
#pragma unroll
124+
for (int m = 0; m < TM; m++) {
125+
#pragma unroll
126+
for (int n = 0; n < TN; n++) {
127+
// k from 0~7,0 ~ BK, ty and tx range from 0 to 15, 16x8=128
128+
int comp_smem_a_m = ty * TM + m; // 128*8 128/TM(8)=16 M方向 16线程
129+
int comp_smem_b_n = tx * TN + n; // 8*128 128/TN(8)=16 N方向 16线程
130+
r_c[m][n] += s_a[comp_smem_a_m][k] * s_b[k][comp_smem_b_n];
131+
}
132+
}
133+
}
134+
__syncthreads();
135+
}
136+
137+
#pragma unroll
138+
for (int m = 0; m < TM; ++m) {
139+
int store_gmem_c_m = by * BM + ty * TM + m;
140+
#pragma unroll
141+
for (int n = 0; n < TN; n += 2) {
142+
int store_gmem_c_n = bx * BN + tx * TN + n;
143+
int store_gmem_c_addr = store_gmem_c_m * N + store_gmem_c_n;
144+
HALF2(c[store_gmem_c_addr]) = HALF2(r_c[m][n]);
145+
}
146+
}
147+
}
148+
149+
// --------------------- PyTorch bindings for custom kernel -----------------------
150+
#define STRINGFY(str) #str
151+
#define TORCH_BINDING_COMMON_EXTENSION(func) \
152+
m.def(STRINGFY(func), &func, STRINGFY(func));
153+
154+
#define CHECK_TORCH_TENSOR_DTYPE(T, th_type) \
155+
if(((T).options().dtype() != (th_type))) { \
156+
std::cout << "Tensor Info:" << (T).options() << std::endl; \
157+
throw std::runtime_error("values must be "#th_type); \
158+
}
159+
160+
#define CHECK_TORCH_TENSOR_SHAPE(T, S0, S1) \
161+
if (((T).size(0) != (S0)) || ((T).size(1) != (S1))) { \
162+
throw std::runtime_error("Tensor size mismatch!"); \
163+
}
164+
165+
// HGEMM: Block Tile + K Tile, with smem
166+
// Block Tile (BM, BN) + K Tile (BK=32)
167+
// grid((N + BN - 1) / BN, (M + BM - 1) / BM), block(BN, BM)
168+
// a: MxK, b: KxN, c: MxN, compute: c = a * b, all row major
169+
void hgemm_sliced_k_f16(torch::Tensor a, torch::Tensor b, torch::Tensor c) {
170+
CHECK_TORCH_TENSOR_DTYPE(a, torch::kHalf)
171+
CHECK_TORCH_TENSOR_DTYPE(b, torch::kHalf)
172+
CHECK_TORCH_TENSOR_DTYPE(c, torch::kHalf)
173+
const int M = a.size(0);
174+
const int K = a.size(1);
175+
const int N = b.size(1);
176+
CHECK_TORCH_TENSOR_SHAPE(a, M, K)
177+
CHECK_TORCH_TENSOR_SHAPE(b, K, N)
178+
CHECK_TORCH_TENSOR_SHAPE(c, M, N)
179+
constexpr int BM = 32;
180+
constexpr int BN = 32;
181+
constexpr int BK = 32;
182+
183+
dim3 block(BN, BM);
184+
dim3 grid((N + BN - 1) / BN, (M + BM - 1) / BM);
185+
186+
hgemm_sliced_k_f16_kernel<BM, BN, BK><<<grid, block>>>(
187+
reinterpret_cast<half*>(a.data_ptr()),
188+
reinterpret_cast<half*>(b.data_ptr()),
189+
reinterpret_cast<half*>(c.data_ptr()),
190+
M, N, K
191+
);
192+
}
193+
194+
// HGEMM: Block Tile + Thread Tile + K Tile + half2x2, with smem
195+
// BK:TILE_K=8 BM=BN=128
196+
// TM=TN=8 增加计算密度 BM/TM=16 BN/TN=16
197+
// dim3 blockDim(BN/TN, BM/TM);
198+
// dim3 gridDim((N + BN - 1) / BN, (M + BM - 1) / BM)
199+
void hgemm_t_8x8_sliced_k_f16x4(torch::Tensor a, torch::Tensor b, torch::Tensor c) {
200+
CHECK_TORCH_TENSOR_DTYPE(a, torch::kHalf)
201+
CHECK_TORCH_TENSOR_DTYPE(b, torch::kHalf)
202+
CHECK_TORCH_TENSOR_DTYPE(c, torch::kHalf)
203+
const int M = a.size(0);
204+
const int K = a.size(1);
205+
const int N = b.size(1);
206+
CHECK_TORCH_TENSOR_SHAPE(a, M, K)
207+
CHECK_TORCH_TENSOR_SHAPE(b, K, N)
208+
CHECK_TORCH_TENSOR_SHAPE(c, M, N)
209+
constexpr int BM = 128;
210+
constexpr int BN = 128;
211+
constexpr int BK = 8;
212+
constexpr int TM = 8;
213+
constexpr int TN = 8;
214+
215+
dim3 block(BN/TN, BM/TM);
216+
dim3 grid((N + BN - 1) / BN, (M + BM - 1) / BM);
217+
218+
hgemm_t_8x8_sliced_k_f16x4_kernel<BM, BN, BK, TM, TN><<<grid, block>>>(
219+
reinterpret_cast<half*>(a.data_ptr()),
220+
reinterpret_cast<half*>(b.data_ptr()),
221+
reinterpret_cast<half*>(c.data_ptr()),
222+
M, N, K
223+
);
224+
}
225+
226+
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
227+
TORCH_BINDING_COMMON_EXTENSION(hgemm_sliced_k_f16)
228+
TORCH_BINDING_COMMON_EXTENSION(hgemm_t_8x8_sliced_k_f16x4)
229+
}

hgemm/hgemm.py

Lines changed: 69 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,69 @@
1+
import torch
2+
import time
3+
from torch.utils.cpp_extension import load
4+
from functools import partial
5+
from typing import Optional
6+
7+
torch.set_grad_enabled(False)
8+
9+
# Load the CUDA kernel as a python module
10+
lib = load(name='hgemm_lib',
11+
sources=['hgemm.cu'],
12+
extra_cuda_cflags=[
13+
"-O3",
14+
"-U__CUDA_NO_HALF_OPERATORS__",
15+
"-U__CUDA_NO_HALF_CONVERSIONS__",
16+
"-U__CUDA_NO_HALF2_OPERATORS__",
17+
"-U__CUDA_NO_BFLOAT16_CONVERSIONS__",
18+
"--expt-relaxed-constexpr",
19+
"--expt-extended-lambda",
20+
"--use_fast_math"
21+
],
22+
extra_cflags=['-std=c++17'])
23+
24+
25+
def run_benchmark(perf_func: callable,
26+
a: torch.Tensor, b: torch.Tensor,
27+
tag: str, out: Optional[torch.Tensor] = None,
28+
warmup: int = 10, iters: int = 200,
29+
show_all: bool = False):
30+
if out is not None:
31+
out.fill_(0)
32+
if out is not None:
33+
for i in range(warmup):
34+
perf_func(a, b, out)
35+
else:
36+
for i in range(warmup):
37+
_ = perf_func(a, b)
38+
39+
torch.cuda.synchronize()
40+
start = time.time()
41+
# iters
42+
if out is not None:
43+
for i in range(iters):
44+
perf_func(a, b, out)
45+
else:
46+
for i in range(iters):
47+
out = perf_func(a, b)
48+
torch.cuda.synchronize()
49+
end = time.time()
50+
total_time = (end - start) * 1000 # ms
51+
mean_time = total_time / iters
52+
out_info = f"out_{tag}"
53+
out_val = out.flatten().detach().cpu().numpy().tolist()[:3]
54+
out_val = [round(v, 8) for v in out_val]
55+
print(f"{out_info:>17}: {out_val}, time:{mean_time:.8f}ms")
56+
if show_all: print(out)
57+
return out.clone(), mean_time
58+
59+
60+
print("-" * 80)
61+
M, N, K = 2048, 1024, 128
62+
a = torch.randn((M, K)).cuda().half().contiguous()
63+
b = torch.randn((K, N)).cuda().half().contiguous()
64+
c = torch.randn((M, N)).cuda().half().contiguous()
65+
run_benchmark(lib.hgemm_sliced_k_f16, a, b, "f16(sk)", c)
66+
run_benchmark(lib.hgemm_t_8x8_sliced_k_f16x4, a, b, "f16x4(t8x8sk)", c)
67+
run_benchmark(partial(torch.matmul, out=c), a, b, "f16_th")
68+
print("-" * 80)
69+

0 commit comments

Comments
 (0)