Skip to content

Commit 3acd5e2

Browse files
authored
[HEGMM][Bugfix] fix HGEMM Stage cp.async error (#75)
* [HGEMM] update HGEMM Stage kernels * [HGEMM] update HGEMM Stage kernels * [HGEMM] update HGEMM Stage kernels
1 parent 11d7072 commit 3acd5e2

File tree

7 files changed

+405
-339
lines changed

7 files changed

+405
-339
lines changed

README.md

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -91,6 +91,7 @@
9191
| ✔️ [safe_softmax_f16x2_f32](./softmax/softmax.cu)|f16|f32|[link](./softmax/)|⭐️⭐️|
9292
| ✔️ [safe_softmax_f16x8_pack_f32](./softmax/softmax.cu)|f16|f32|[link](./softmax/)|⭐️⭐️|
9393
| ✔️ [online_safe_softmax_f32](./softmax/softmax.cu)|f32|f32|[link](./softmax/)|⭐️⭐️|
94+
| ✔️ [online_safe_softmax_f32x4_pack](./softmax/softmax.cu)|f32|f32|[link](./softmax/)|⭐️⭐️|
9495
| ✔️ [layer_norm_f32](./layer-norm/layer_norm.cu)|f32|f32|[link](./layer-norm/)|⭐️⭐️|
9596
| ✔️ [layer_norm_f32x4](./layer-norm/layer_norm.cu)|f32|f32|[link](./layer-norm/)|⭐️⭐️|
9697
| ✔️ [layer_norm_f16_f16](./layer-norm/layer_norm.cu)|f16|f16|[link](./layer-norm/)|⭐️⭐️|
@@ -131,9 +132,9 @@
131132
| ✔️ [hgemm_wmma_m16n16k16...async*](./hgemm/hgemm_wmma.cu)|f16|f16|[link](./hgemm/)|⭐️⭐️⭐️|
132133
| ✔️ [hgemm_wmma_m16n16k16...offset*](./hgemm/hgemm_wmma.cu)|f16|f16|[link](./hgemm/)|⭐️⭐️⭐️|
133134
| ✔️ [hgemm_wmma_m16n16k16...dbuf*](./hgemm/hgemm_wmma.cu)|f16|f16|[link](./hgemm/)|⭐️⭐️⭐️|
134-
| ✔️ [hgemm_wmma_m32n8k16...dbuf*](./hgemm/hgemm_wmma.cu)|f16|f16|[link](./hgemm/)|⭐️⭐️⭐️|
135135
| ✔️ [hgemm_wmma_m16n16k16...rbuf*](./hgemm/hgemm_wmma.cu)|f16|f16|[link](./hgemm/)|⭐️⭐️⭐️|
136-
| ✔️ [hgemm_wmma_m16n16k16...stage3/4*](./hgemm/hgemm_wmma_stage.cu)|f16|f16|[link](./hgemm/)|⭐️⭐️⭐️|
136+
| ✔️ [hgemm_wmma_m16n16k16...stage2/3/4*](./hgemm/hgemm_wmma_stage.cu)|f16|f16|[link](./hgemm/)|⭐️⭐️⭐️|
137+
| ✔️ [hgemm_wmma_m32n8k16...dbuf*](./hgemm/hgemm_wmma.cu)|f16|f16|[link](./hgemm/)|⭐️⭐️⭐️|
137138
| ✔️ [sgemv_k32_f32](./sgemv/sgemv.cu)|f32|f32|[link](./sgemv/)|⭐️⭐️⭐️|
138139
| ✔️ [sgemv_k128_f32x4](./sgemv/sgemv.cu)|f32|f32|[link](./sgemv/)|⭐️⭐️⭐️|
139140
| ✔️ [sgemv_k16_f32](./sgemv/sgemv.cu)|f32|f32|[link](./sgemv/)|⭐️⭐️⭐️|

hgemm/README.md

Lines changed: 182 additions & 3 deletions
Large diffs are not rendered by default.

hgemm/hgemm.cu

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1229,10 +1229,14 @@ void hgemm_wmma_m32n8k16_mma2x4_warp2x4_dbuf_async_offset(torch::Tensor a, torch
12291229
void hgemm_wmma_m16n16k16_mma4x2_warp2x4x2_rbuf_async(torch::Tensor a, torch::Tensor b, torch::Tensor c);
12301230
void hgemm_wmma_m16n16k16_mma4x2_warp2x4x2_rbuf_async_offset(torch::Tensor a, torch::Tensor b, torch::Tensor c);
12311231
// from hgemm_wmma_stage.cu
1232+
void hgemm_wmma_m16n16k16_mma4x2_warp2x4_stage2(torch::Tensor a, torch::Tensor b, torch::Tensor c);
1233+
void hgemm_wmma_m16n16k16_mma4x2_warp2x4_stage2_offset(torch::Tensor a, torch::Tensor b, torch::Tensor c);
12321234
void hgemm_wmma_m16n16k16_mma4x2_warp2x4_stage3(torch::Tensor a, torch::Tensor b, torch::Tensor c);
12331235
void hgemm_wmma_m16n16k16_mma4x2_warp2x4_stage3_offset(torch::Tensor a, torch::Tensor b, torch::Tensor c);
12341236
void hgemm_wmma_m16n16k16_mma4x2_warp2x4_stage4(torch::Tensor a, torch::Tensor b, torch::Tensor c);
12351237
void hgemm_wmma_m16n16k16_mma4x2_warp2x4_stage4_offset(torch::Tensor a, torch::Tensor b, torch::Tensor c);
1238+
// from hgemm_cublas.cu
1239+
void hgemm_cublas_tensor_op(torch::Tensor a, torch::Tensor b, torch::Tensor c);
12361240

12371241

12381242
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
@@ -1272,8 +1276,11 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
12721276
TORCH_BINDING_COMMON_EXTENSION(hgemm_wmma_m32n8k16_mma2x4_warp2x4_dbuf_async_offset)
12731277
TORCH_BINDING_COMMON_EXTENSION(hgemm_wmma_m16n16k16_mma4x2_warp2x4x2_rbuf_async)
12741278
TORCH_BINDING_COMMON_EXTENSION(hgemm_wmma_m16n16k16_mma4x2_warp2x4x2_rbuf_async_offset)
1279+
TORCH_BINDING_COMMON_EXTENSION(hgemm_wmma_m16n16k16_mma4x2_warp2x4_stage2)
1280+
TORCH_BINDING_COMMON_EXTENSION(hgemm_wmma_m16n16k16_mma4x2_warp2x4_stage2_offset)
12751281
TORCH_BINDING_COMMON_EXTENSION(hgemm_wmma_m16n16k16_mma4x2_warp2x4_stage3)
12761282
TORCH_BINDING_COMMON_EXTENSION(hgemm_wmma_m16n16k16_mma4x2_warp2x4_stage3_offset)
12771283
TORCH_BINDING_COMMON_EXTENSION(hgemm_wmma_m16n16k16_mma4x2_warp2x4_stage4)
12781284
TORCH_BINDING_COMMON_EXTENSION(hgemm_wmma_m16n16k16_mma4x2_warp2x4_stage4_offset)
1285+
TORCH_BINDING_COMMON_EXTENSION(hgemm_cublas_tensor_op)
12791286
}

hgemm/hgemm.py

Lines changed: 10 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
# Load the CUDA kernel as a python module
1010
lib = load(name='hgemm_lib',
1111
sources=['hgemm.cu', 'hgemm_async.cu', 'hgemm_wmma.cu',
12-
'hgemm_wmma_stage.cu'],
12+
'hgemm_wmma_stage.cu', 'hgemm_cublas.cu'],
1313
extra_cuda_cflags=[
1414
"-O3",
1515
"-U__CUDA_NO_HALF_OPERATORS__",
@@ -98,7 +98,7 @@ def run_benchmark(perf_func: callable,
9898
a, b, "f16x8pack(bcf+offset)", c)
9999
run_benchmark(lib.hgemm_t_8x8_sliced_k_f16x8_pack_bcf_dbuf,
100100
a, b, "f16x8pack(bcf+dbuf)", c)
101-
print("-" * 57 + "Async" + "-" * 58)
101+
print("-" * 58 + "Async" + "-" * 57)
102102
run_benchmark(lib.hgemm_t_8x8_sliced_k16_f16x8_pack_dbuf,
103103
a, b, "f16x8pack(k16+dbuf)", c)
104104
run_benchmark(lib.hgemm_t_8x8_sliced_k16_f16x8_pack_dbuf_offset,
@@ -138,6 +138,8 @@ def run_benchmark(perf_func: callable,
138138
a, b, "f16wmma(mma4x2+warp2x4+dbuf)", c)
139139
run_benchmark(lib.hgemm_wmma_m32n8k16_mma2x4_warp2x4_dbuf_async,
140140
a, b, "f16wmma(m32n8k16+mma2x4+warp2x4+dbuf)", c)
141+
run_benchmark(lib.hgemm_wmma_m16n16k16_mma4x2_warp2x4_stage2,
142+
a, b, "f16wmma(mma2x4+warp2x4+stage2)", c)
141143
run_benchmark(lib.hgemm_wmma_m16n16k16_mma4x2_warp2x4_stage3,
142144
a, b, "f16wmma(mma2x4+warp2x4+stage3)", c)
143145
run_benchmark(lib.hgemm_wmma_m16n16k16_mma4x2_warp2x4_stage4,
@@ -150,13 +152,15 @@ def run_benchmark(perf_func: callable,
150152
a, b, "f16wmma(mma4x4+warp2x2x2+dbuf+offset)", c)
151153
run_benchmark(lib.hgemm_wmma_m32n8k16_mma2x4_warp2x4_dbuf_async_offset,
152154
a, b, "f16wmma(m32n8k16+mma2x4+warp2x4+dbuf+offset)", c)
153-
run_benchmark(lib.hgemm_wmma_m16n16k16_mma4x2_warp2x4_stage4_offset,
154-
a, b, "f16wmma(mma4x2+warp2x4+stage4+offset)", c)
155155
run_benchmark(lib.hgemm_wmma_m16n16k16_mma4x2_warp2x4_dbuf_async_offset,
156156
a, b, "f16wmma(mma4x2+warp2x4+dbuf+offset)", c)
157+
run_benchmark(lib.hgemm_wmma_m16n16k16_mma4x2_warp2x4_stage4_offset,
158+
a, b, "f16wmma(mma4x2+warp2x4+stage4+offset)", c)
159+
run_benchmark(lib.hgemm_wmma_m16n16k16_mma4x2_warp2x4_stage2_offset,
160+
a, b, "f16wmma(mma4x2+warp2x4+stage2+offset)", c)
157161
run_benchmark(lib.hgemm_wmma_m16n16k16_mma4x2_warp2x4_stage3_offset,
158162
a, b, "f16wmma(mma4x2+warp2x4+stage3+offset)", c)
159-
run_benchmark(partial(torch.matmul, out=c),
160-
a, b, "f16_th")
163+
run_benchmark(partial(torch.matmul, out=c), a, b, "f16_th")
164+
run_benchmark(lib.hgemm_cublas_tensor_op, a, b, "f16(cublas)", c)
161165
print("-" * 120)
162166

hgemm/hgemm_cublas.cu

Lines changed: 76 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,76 @@
1+
#include <stdio.h>
2+
#include <stdlib.h>
3+
#include <float.h>
4+
#include <vector>
5+
#include <algorithm>
6+
#include <cuda_runtime.h>
7+
#include <cuda_fp16.h>
8+
#include <cuda_bf16.h>
9+
#include <cuda_fp8.h>
10+
#include <mma.h>
11+
12+
#include <torch/types.h>
13+
#include <torch/extension.h>
14+
15+
#include "cublas_v2.h"
16+
17+
18+
void cublas_tensor_op(half *A, half *B, half *C, size_t M,
19+
size_t N, size_t K) {
20+
21+
cublasHandle_t handle = nullptr;
22+
cublasCreate(&handle);
23+
cublasSetMathMode(handle, CUBLAS_TENSOR_OP_MATH);
24+
25+
static half alpha = 1.0;
26+
static half beta = 0.0;
27+
28+
cublasGemmEx(handle,
29+
CUBLAS_OP_N,
30+
CUBLAS_OP_N,
31+
N, M, K,
32+
&alpha,
33+
B, CUDA_R_16F, N,
34+
A, CUDA_R_16F, K,
35+
&beta,
36+
C, CUDA_R_16F, N,
37+
CUBLAS_COMPUTE_16F,
38+
CUBLAS_GEMM_DEFAULT_TENSOR_OP);
39+
}
40+
41+
// --------------------- PyTorch bindings for custom kernel -----------------------
42+
#define STRINGFY(str) #str
43+
#define TORCH_BINDING_COMMON_EXTENSION(func) \
44+
m.def(STRINGFY(func), &func, STRINGFY(func));
45+
46+
#define CHECK_TORCH_TENSOR_DTYPE(T, th_type) \
47+
if(((T).options().dtype() != (th_type))) { \
48+
std::cout << "Tensor Info:" << (T).options() << std::endl; \
49+
throw std::runtime_error("values must be "#th_type); \
50+
}
51+
52+
#define CHECK_TORCH_TENSOR_SHAPE(T, S0, S1) \
53+
if (((T).size(0) != (S0)) || ((T).size(1) != (S1))) { \
54+
throw std::runtime_error("Tensor size mismatch!"); \
55+
}
56+
57+
// cublas tensor op
58+
void hgemm_cublas_tensor_op(
59+
torch::Tensor a, torch::Tensor b, torch::Tensor c) {
60+
CHECK_TORCH_TENSOR_DTYPE(a, torch::kHalf)
61+
CHECK_TORCH_TENSOR_DTYPE(b, torch::kHalf)
62+
CHECK_TORCH_TENSOR_DTYPE(c, torch::kHalf)
63+
const int M = a.size(0);
64+
const int K = a.size(1);
65+
const int N = b.size(1);
66+
CHECK_TORCH_TENSOR_SHAPE(a, M, K)
67+
CHECK_TORCH_TENSOR_SHAPE(b, K, N)
68+
CHECK_TORCH_TENSOR_SHAPE(c, M, N)
69+
70+
cublas_tensor_op(
71+
reinterpret_cast<half*>(a.data_ptr()),
72+
reinterpret_cast<half*>(b.data_ptr()),
73+
reinterpret_cast<half*>(c.data_ptr()),
74+
M, N, K
75+
);
76+
}

0 commit comments

Comments
 (0)