Skip to content

Commit 195158a

Browse files
authored
[SGEMM][Async] Add naive copy async SGEMM (#64)
* Create sgemm_async.cu * Update sgemm.cu * Update sgemm.py * Update README.md * Update softmax.cu * Update softmax.py * Update README.md * Update README.md * Update README.md
1 parent bbec7b5 commit 195158a

File tree

8 files changed

+384
-139
lines changed

8 files changed

+384
-139
lines changed

README.md

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -77,7 +77,7 @@
7777
| ✔️ [safe_softmax_f16_f32](./softmax/softmax.cu)|f16|f32|[link](./softmax/)|⭐️⭐️|
7878
| ✔️ [safe_softmax_f16x2_f32](./softmax/softmax.cu)|f16|f32|[link](./softmax/)|⭐️⭐️|
7979
| ✔️ [safe_softmax_f16x8_pack_f32](./softmax/softmax.cu)|f16|f32|[link](./softmax/)|⭐️⭐️|
80-
| ✔️ [online_softmax_f32](./softmax/softmax.cu)|f32|f32|[link](./softmax/)|⭐️⭐️|
80+
| ✔️ [online_safe_softmax_f32](./softmax/softmax.cu)|f32|f32|[link](./softmax/)|⭐️⭐️|
8181
| ✔️ [layer_norm_f32](./layer-norm/layer_norm.cu)|f32|f32|[link](./layer-norm/)|⭐️⭐️|
8282
| ✔️ [layer_norm_f32x4](./layer-norm/layer_norm.cu)|f32|f32|[link](./layer-norm/)|⭐️⭐️|
8383
| ✔️ [layer_norm_f16_f16](./layer-norm/layer_norm.cu)|f16|f16|[link](./layer-norm/)|⭐️⭐️|
@@ -100,6 +100,7 @@
100100
| ✔️ [sgemm_t_8x8_sliced_k_f32x4](./sgemm/sgemm.cu)|f32|f32|[link](./sgemm/)|⭐️⭐️⭐️|
101101
| ✔️ [sgemm_t_8x8_sliced_k...bcf](./sgemm/sgemm.cu)|f32|f32|[link](./sgemm/)|⭐️⭐️⭐️|
102102
| ✔️ [sgemm_t_8x8_sliced_k...dbuf](./sgemm/sgemm.cu)|f32|f32|[link](./sgemm/)|⭐️⭐️⭐️|
103+
| ✔️ [sgemm_t_8x8_sliced_k...async](./sgemm/sgemm_async.cu)|f32|f32|[link](./sgemm/)|⭐️⭐️⭐️|
103104
| ✔️ [hgemm_naive_f16](./hgemm/hgemm.cu)|f16|f16|[link](./hgemm/)|⭐️⭐️|
104105
| ✔️ [hgemm_sliced_k_f16](./hgemm/hgemm.cu)|f16|f16|[link](./hgemm/)|⭐️⭐️⭐️|
105106
| ✔️ [hgemm_t_8x8_sliced_k_f16x4](./hgemm/hgemm.cu)|f16|f16|[link](./hgemm/)|⭐️⭐️⭐️|

sgemm/README.md

Lines changed: 99 additions & 60 deletions
Large diffs are not rendered by default.

sgemm/sgemm.cu

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -692,6 +692,9 @@ void sgemm_t_8x8_sliced_k_f32x4_bcf_dbuf_offset(torch::Tensor a, torch::Tensor b
692692
);
693693
}
694694

695+
// from sgemm_async.cu
696+
void sgemm_t_8x8_sliced_k_f32x4_bcf_dbuf_async(torch::Tensor a, torch::Tensor b, torch::Tensor c);
697+
695698
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
696699
TORCH_BINDING_COMMON_EXTENSION(sgemm_naive_f32)
697700
TORCH_BINDING_COMMON_EXTENSION(sgemm_sliced_k_f32)
@@ -700,4 +703,5 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
700703
TORCH_BINDING_COMMON_EXTENSION(sgemm_t_8x8_sliced_k_f32x4_bcf_offset)
701704
TORCH_BINDING_COMMON_EXTENSION(sgemm_t_8x8_sliced_k_f32x4_bcf_dbuf)
702705
TORCH_BINDING_COMMON_EXTENSION(sgemm_t_8x8_sliced_k_f32x4_bcf_dbuf_offset)
706+
TORCH_BINDING_COMMON_EXTENSION(sgemm_t_8x8_sliced_k_f32x4_bcf_dbuf_async)
703707
}

sgemm/sgemm.py

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88

99
# Load the CUDA kernel as a python module
1010
lib = load(name='sgemm_lib',
11-
sources=['sgemm.cu'],
11+
sources=['sgemm.cu', 'sgemm_async.cu'],
1212
extra_cuda_cflags=[
1313
"-O3",
1414
"-U__CUDA_NO_HALF_OPERATORS__",
@@ -53,7 +53,7 @@ def run_benchmark(perf_func: callable,
5353
out_val = out.flatten().detach().cpu().numpy().tolist()[:3]
5454
out_val = [round(v, 8) for v in out_val]
5555
out_val = [f"{v:<12}" for v in out_val]
56-
print(f"{out_info:>27}: {out_val}, time:{mean_time:.6f}ms")
56+
print(f"{out_info:>32}: {out_val}, time:{mean_time:.6f}ms")
5757
if show_all: print(out)
5858
return out.clone(), mean_time
5959

@@ -63,7 +63,7 @@ def run_benchmark(perf_func: callable,
6363
Ks = [1024, 2048]
6464
MNKs = [(M, N, K) for M in Ms for N in Ns for K in Ks]
6565
for (M, N, K) in MNKs:
66-
print("-" * 100)
66+
print("-" * 110)
6767
print(" " * 45 + f"M={M}, N={N}, K={K}")
6868
a = torch.randn((M, K)).cuda().float().contiguous()
6969
b = torch.randn((K, N)).cuda().float().contiguous()
@@ -82,6 +82,9 @@ def run_benchmark(perf_func: callable,
8282
a, b, "f32x4(t8x8dbuf)", c)
8383
run_benchmark(lib.sgemm_t_8x8_sliced_k_f32x4_bcf_dbuf_offset,
8484
a, b, "f32x4(t8x8dbuf+offset)", c)
85+
print("-" * 52 + "Async" + "-" * 53)
86+
run_benchmark(lib.sgemm_t_8x8_sliced_k_f32x4_bcf_dbuf_async,
87+
a, b, "f32x4(t8x8dbuf+async)", c)
8588
run_benchmark(partial(torch.matmul, out=c),
8689
a, b, "f32_th")
87-
print("-" * 100)
90+
print("-" * 110)

sgemm/sgemm_async.cu

Lines changed: 201 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,201 @@
1+
#include <stdio.h>
2+
#include <stdlib.h>
3+
#include <float.h>
4+
#include <vector>
5+
#include <algorithm>
6+
#include <cuda_runtime.h>
7+
#include <cuda_fp16.h>
8+
#include <cuda_bf16.h>
9+
#include <cuda_fp8.h>
10+
#include <torch/types.h>
11+
#include <torch/extension.h>
12+
13+
#define INT4(value) (reinterpret_cast<int4*>(&(value))[0])
14+
#define FLOAT4(value) (reinterpret_cast<float4*>(&(value))[0])
15+
#define HALF2(value) (reinterpret_cast<half2*>(&(value))[0])
16+
#define BFLOAT2(value) (reinterpret_cast<__nv_bfloat162*>(&(value))[0])
17+
#define LDST64BITS(value) (reinterpret_cast<float2*>(&(value))[0])
18+
#define LDST128BITS(value) (reinterpret_cast<float4*>(&(value))[0])
19+
#define CP_ASYNC_COMMIT_GROUP() asm volatile("cp.async.commit_group;\n" ::)
20+
#define CP_ASYNC_WAIT_ALL() asm volatile("cp.async.wait_all;\n" ::)
21+
#define CP_ASYNC_WAIT_GROUP(n) asm volatile("cp.async.wait_group %0;\n" ::"n"(n))
22+
// ca(cache all, L1 + L2): support 4, 8, 16 bytes, cg(cache global, L2): only support 16 bytes.
23+
#define CP_ASYNC_CA(dst, src, bytes) asm volatile("cp.async.ca.shared.global.L2::128B [%0], [%1], %2;\n" ::"r"(dst), "l"(src), "n"(bytes))
24+
#define CP_ASYNC_CG(dst, src, bytes) asm volatile("cp.async.cg.shared.global.L2::128B [%0], [%1], %2;\n" ::"r"(dst), "l"(src), "n"(bytes))
25+
26+
27+
template<const int BM=128, const int BN=128, const int BK=8,
28+
const int TM=8, const int TN=8, const int OFFSET=0>
29+
__global__ void sgemm_t_8x8_sliced_k_f32x4_bcf_dbuf_async_kernel(
30+
float* a, float* b, float* c, const int M, const int N, const int K) {
31+
32+
const int bx = blockIdx.x;
33+
const int by = blockIdx.y;
34+
const int tx = threadIdx.x;
35+
const int ty = threadIdx.y;
36+
const int tid = ty * blockDim.x + tx;
37+
38+
__shared__ float s_a[2][BK][BM + OFFSET];
39+
__shared__ float s_b[2][BK][BN + OFFSET];
40+
41+
float r_comp_a[TM];
42+
float r_comp_b[TN];
43+
float r_c[TM][TN] = {0.0};
44+
45+
int load_a_smem_m = tid / 2; // tid / 2,(0,1,2,...,128)
46+
int load_a_smem_k = (tid & 1) << 2; // (0,4)
47+
int load_b_smem_k = tid / 32; // 0~8
48+
int load_b_smem_n = (tid & 31) << 2; // (0,4,8,12,...,124)
49+
int load_a_gmem_m = by * BM + load_a_smem_m;
50+
int load_b_gmem_n = bx * BN + load_b_smem_n;
51+
52+
{
53+
int load_a_gmem_k = load_a_smem_k;
54+
int load_a_gmem_addr = load_a_gmem_m * K + load_a_gmem_k;
55+
int load_b_gmem_k = load_b_smem_k;
56+
int load_b_gmem_addr = load_b_gmem_k * N + load_b_gmem_n;
57+
58+
uint32_t load_b_smem_ptr = __cvta_generic_to_shared(
59+
&s_b[0][load_b_smem_k][load_b_smem_n]);
60+
// 1 cp.async issue, 16 bytes = 4 float.
61+
CP_ASYNC_CA(load_b_smem_ptr, &b[load_b_gmem_addr], 16);
62+
CP_ASYNC_COMMIT_GROUP();
63+
64+
#pragma unroll
65+
for (int i = 0; i < 4; ++i) {
66+
// 4 cp.async issues, 4 bytes = 1 float.
67+
uint32_t load_a_smem_ptr = __cvta_generic_to_shared(
68+
&s_a[0][load_a_smem_k + i][load_a_smem_m]);
69+
CP_ASYNC_CA(load_a_smem_ptr, &a[load_a_gmem_addr + i], 4);
70+
}
71+
CP_ASYNC_COMMIT_GROUP();
72+
CP_ASYNC_WAIT_GROUP(0);
73+
}
74+
__syncthreads();
75+
76+
for (int bk = 1; bk < (K + BK - 1) / BK; bk++) {
77+
78+
int smem_sel = (bk - 1) & 1;
79+
int smem_sel_next = bk & 1;
80+
81+
int load_a_gmem_k = bk * BK + load_a_smem_k;
82+
int load_a_gmem_addr = load_a_gmem_m * K + load_a_gmem_k;
83+
int load_b_gmem_k = bk * BK + load_b_smem_k;
84+
int load_b_gmem_addr = load_b_gmem_k * N + load_b_gmem_n;
85+
86+
uint32_t load_b_smem_ptr = __cvta_generic_to_shared(
87+
&s_b[smem_sel_next][load_b_smem_k][load_b_smem_n]);
88+
// 1 cp.async issue, 16 bytes = 4 float.
89+
CP_ASYNC_CA(load_b_smem_ptr, &b[load_b_gmem_addr], 16);
90+
CP_ASYNC_COMMIT_GROUP();
91+
92+
#pragma unroll
93+
for (int i = 0; i < 4; ++i) {
94+
// 4 cp.async issues, 4 bytes = 1 float.
95+
uint32_t load_a_smem_ptr = __cvta_generic_to_shared(
96+
&s_a[smem_sel_next][load_a_smem_k + i][load_a_smem_m]);
97+
CP_ASYNC_CA(load_a_smem_ptr, &a[load_a_gmem_addr + i], 4);
98+
}
99+
CP_ASYNC_COMMIT_GROUP();
100+
101+
#pragma unroll
102+
for (int tk = 0; tk < BK; tk++) {
103+
FLOAT4(r_comp_a[0]) = FLOAT4(s_a[smem_sel][tk][ty * TM / 2 ]);
104+
FLOAT4(r_comp_a[4]) = FLOAT4(s_a[smem_sel][tk][ty * TM / 2 + BM / 2]);
105+
FLOAT4(r_comp_b[0]) = FLOAT4(s_b[smem_sel][tk][tx * TN / 2 ]);
106+
FLOAT4(r_comp_b[4]) = FLOAT4(s_b[smem_sel][tk][tx * TN / 2 + BN / 2]);
107+
108+
#pragma unroll
109+
for (int tm = 0; tm < TM; tm++) {
110+
#pragma unroll
111+
for (int tn = 0; tn < TN; tn++) {
112+
r_c[tm][tn] = __fmaf_rn(r_comp_a[tm], r_comp_b[tn], r_c[tm][tn]);
113+
}
114+
}
115+
}
116+
117+
CP_ASYNC_WAIT_GROUP(0);
118+
__syncthreads();
119+
}
120+
121+
#pragma unroll
122+
for (int tk = 0; tk < BK; tk++) {
123+
FLOAT4(r_comp_a[0]) = FLOAT4(s_a[1][tk][ty * TM / 2 ]);
124+
FLOAT4(r_comp_a[4]) = FLOAT4(s_a[1][tk][ty * TM / 2 + BM / 2]);
125+
FLOAT4(r_comp_b[0]) = FLOAT4(s_b[1][tk][tx * TN / 2 ]);
126+
FLOAT4(r_comp_b[4]) = FLOAT4(s_b[1][tk][tx * TN / 2 + BN / 2]);
127+
128+
#pragma unroll
129+
for (int tm = 0; tm < TM; tm++) {
130+
#pragma unroll
131+
for (int tn = 0; tn < TN; tn++) {
132+
r_c[tm][tn] = __fmaf_rn(r_comp_a[tm], r_comp_b[tn], r_c[tm][tn]);
133+
}
134+
}
135+
}
136+
137+
#pragma unroll
138+
for (int i = 0; i < TM / 2; i++) {
139+
int store_c_gmem_m = by * BM + ty * TM / 2 + i;
140+
int store_c_gmem_n = bx * BN + tx * TN / 2;
141+
int store_c_gmem_addr = store_c_gmem_m * N + store_c_gmem_n;
142+
FLOAT4(c[store_c_gmem_addr]) = FLOAT4(r_c[i][0]);
143+
FLOAT4(c[store_c_gmem_addr + BN / 2]) = FLOAT4(r_c[i][4]);
144+
}
145+
#pragma unroll
146+
for (int i = 0; i < TM / 2; i++) {
147+
int store_c_gmem_m = by * BM + BM / 2 + ty * TM / 2 + i;
148+
int store_c_gmem_n = bx * BN + tx * TN / 2;
149+
int store_c_gmem_addr = store_c_gmem_m * N + store_c_gmem_n;
150+
FLOAT4(c[store_c_gmem_addr]) = FLOAT4(r_c[i + TM / 2][0]);
151+
FLOAT4(c[store_c_gmem_addr + BN / 2]) = FLOAT4(r_c[i + TM / 2][4]);
152+
}
153+
}
154+
155+
// TODO: sgemm_t_8x8_sliced_k16_f32x4_bcf_dbuf_{async}_kernel
156+
157+
// --------------------- PyTorch bindings for custom kernel -----------------------
158+
#define STRINGFY(str) #str
159+
#define TORCH_BINDING_COMMON_EXTENSION(func) \
160+
m.def(STRINGFY(func), &func, STRINGFY(func));
161+
162+
#define CHECK_TORCH_TENSOR_DTYPE(T, th_type) \
163+
if(((T).options().dtype() != (th_type))) { \
164+
std::cout << "Tensor Info:" << (T).options() << std::endl; \
165+
throw std::runtime_error("values must be "#th_type); \
166+
}
167+
168+
#define CHECK_TORCH_TENSOR_SHAPE(T, S0, S1) \
169+
if (((T).size(0) != (S0)) || ((T).size(1) != (S1))) { \
170+
throw std::runtime_error("Tensor size mismatch!"); \
171+
}
172+
173+
174+
void sgemm_t_8x8_sliced_k_f32x4_bcf_dbuf_async(
175+
torch::Tensor a, torch::Tensor b, torch::Tensor c) {
176+
CHECK_TORCH_TENSOR_DTYPE(a, torch::kFloat32)
177+
CHECK_TORCH_TENSOR_DTYPE(b, torch::kFloat32)
178+
CHECK_TORCH_TENSOR_DTYPE(c, torch::kFloat32)
179+
const int M = a.size(0);
180+
const int K = a.size(1);
181+
const int N = b.size(1);
182+
CHECK_TORCH_TENSOR_SHAPE(a, M, K)
183+
CHECK_TORCH_TENSOR_SHAPE(b, K, N)
184+
CHECK_TORCH_TENSOR_SHAPE(c, M, N)
185+
constexpr int BM = 128;
186+
constexpr int BN = 128;
187+
constexpr int BK = 8;
188+
constexpr int TM = 8;
189+
constexpr int TN = 8;
190+
191+
dim3 block(BN/TN, BM/TM);
192+
dim3 grid((N + BN - 1) / BN, (M + BM - 1) / BM);
193+
194+
sgemm_t_8x8_sliced_k_f32x4_bcf_dbuf_async_kernel<
195+
BM, BN, BK, TM, TN><<<grid, block>>>(
196+
reinterpret_cast<float*>(a.data_ptr()),
197+
reinterpret_cast<float*>(b.data_ptr()),
198+
reinterpret_cast<float*>(c.data_ptr()),
199+
M, N, K
200+
);
201+
}

0 commit comments

Comments
 (0)