Skip to content

Commit a0daf10

Browse files
authored
[HGEMM] Add HGEMM MMA Col Major Kernel (#104)
* Update and rename hgemm_mma_stage_col_major.cu to hgemm_mma_stage_tn.cu * Update hgemm_mma_stage_tn.cu * Update hgemm.py * Update hgemm.cu * Update hgemm_cublas.cu * Update hgemm_cublas.cu * Update hgemm_mma_stage_tn.cu * Update hgemm.py * Update hgemm.py * Update hgemm.py
1 parent 1492631 commit a0daf10

File tree

5 files changed

+523
-70
lines changed

5 files changed

+523
-70
lines changed

hgemm/hgemm.cu

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -999,7 +999,8 @@ void hgemm_t_8x8_sliced_k32_f16x8_pack_dbuf_async(torch::Tensor a, torch::Tensor
999999
void hgemm_t_16x8_sliced_k32_f16x8_pack_dbuf(torch::Tensor a, torch::Tensor b, torch::Tensor c);
10001000
void hgemm_t_16x8_sliced_k32_f16x8_pack_dbuf_async(torch::Tensor a, torch::Tensor b, torch::Tensor c);
10011001
// from hgemm_cublas.cu
1002-
void hgemm_cublas_tensor_op_row_major(torch::Tensor a, torch::Tensor b, torch::Tensor c);
1002+
void hgemm_cublas_tensor_op_nn(torch::Tensor a, torch::Tensor b, torch::Tensor c);
1003+
void hgemm_cublas_tensor_op_tn(torch::Tensor a, torch::Tensor b, torch::Tensor c);
10031004
// from hgemm_wmma.cu
10041005
void hgemm_wmma_m16n16k16_naive(torch::Tensor a, torch::Tensor b, torch::Tensor c);
10051006
void hgemm_wmma_m16n16k16_mma4x2(torch::Tensor a, torch::Tensor b, torch::Tensor c);
@@ -1018,6 +1019,9 @@ void hgemm_mma_m16n8k16_mma2x4_warp4x4(torch::Tensor a, torch::Tensor b, torch::
10181019
void hgemm_mma_m16n8k16_mma2x4_warp4x4_stages(torch::Tensor a, torch::Tensor b, torch::Tensor c, int stages, bool swizzle, int swizzle_stride);
10191020
void hgemm_mma_m16n8k16_mma2x4_warp4x4_stages_dsmem(torch::Tensor a, torch::Tensor b, torch::Tensor c, int stages, bool swizzle, int swizzle_stride);
10201021
void hgemm_mma_m16n8k16_mma2x4_warp4x4x2_stages_dsmem(torch::Tensor a, torch::Tensor b, torch::Tensor c, int stages, bool swizzle, int swizzle_stride);
1022+
// from hgemm_mma_stage_tn.cu
1023+
void hgemm_mma_m16n8k16_mma2x4_warp4x4_stages_dsmem_tn(torch::Tensor a, torch::Tensor b, torch::Tensor c, int stages, bool swizzle, int swizzle_stride);
1024+
10211025

10221026
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
10231027
// CUDA Cores FP16
@@ -1037,7 +1041,8 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
10371041
TORCH_BINDING_COMMON_EXTENSION(hgemm_t_16x8_sliced_k32_f16x8_pack_dbuf)
10381042
TORCH_BINDING_COMMON_EXTENSION(hgemm_t_16x8_sliced_k32_f16x8_pack_dbuf_async)
10391043
// cuBLAS Tensor Cores
1040-
TORCH_BINDING_COMMON_EXTENSION(hgemm_cublas_tensor_op_row_major)
1044+
TORCH_BINDING_COMMON_EXTENSION(hgemm_cublas_tensor_op_nn)
1045+
TORCH_BINDING_COMMON_EXTENSION(hgemm_cublas_tensor_op_tn)
10411046
// WMMA API Tensor Cores
10421047
TORCH_BINDING_COMMON_EXTENSION(hgemm_wmma_m16n16k16_naive)
10431048
TORCH_BINDING_COMMON_EXTENSION(hgemm_wmma_m16n16k16_mma4x2)
@@ -1056,5 +1061,7 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
10561061
TORCH_BINDING_COMMON_EXTENSION(hgemm_mma_m16n8k16_mma2x4_warp4x4_stages)
10571062
TORCH_BINDING_COMMON_EXTENSION(hgemm_mma_m16n8k16_mma2x4_warp4x4_stages_dsmem)
10581063
TORCH_BINDING_COMMON_EXTENSION(hgemm_mma_m16n8k16_mma2x4_warp4x4x2_stages_dsmem)
1064+
// TN: A row major MxK, B col major NxK, C row major MxN
1065+
TORCH_BINDING_COMMON_EXTENSION(hgemm_mma_m16n8k16_mma2x4_warp4x4_stages_dsmem_tn)
10591066
}
10601067

hgemm/hgemm.py

Lines changed: 17 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -16,13 +16,15 @@ def get_args():
1616
parser.add_argument("--iters", "--i", type=int, default=10, help="Benchmark iters")
1717
parser.add_argument("--show-all", "--show", action="store_true", help="Show all matrix values ")
1818
parser.add_argument("--enable-mma", "--mma", action="store_true", help="Enable MMA kernel tests")
19+
parser.add_argument("--enable-mma-tn", "--mma-tn", action="store_true", help="Enable TN MMA kernel tests")
1920
parser.add_argument("--enable-wmma", "--wmma", action="store_true", help="Enable WMMA kernel tests")
2021
parser.add_argument("--enable-cuda", "--cuda", action="store_true", help="Enable CUDA kernel tests")
2122
parser.add_argument("--enable-mma-all", "--mma-all", action="store_true", help="Enable all MMA kernel tests")
2223
parser.add_argument("--enable-wmma-all", "--wmma-all", action="store_true", help="Enable all WMMA kernel tests")
2324
parser.add_argument("--enable-cuda-all", "--cuda-all", action="store_true", help="Enable all CUDA kernel tests")
2425
parser.add_argument("--enable-torch", "--torch", action="store_true", help="Enable torch matmul")
2526
parser.add_argument("--disable-cublas", "--no-cublas", action="store_true", help="Disable cublas hgemm")
27+
parser.add_argument("--disable-cublas-tn", "--no-cublas-tn", action="store_true", help="Disable cublas TN hgemm")
2628
parser.add_argument("--sleep-duration", "--sleep", type=float, default=0.1, help="Sleep duration")
2729
parser.add_argument("--swizzle-factor", "--swizzle", type=float, default=0.25, help="Swizzle factor")
2830
return parser.parse_args()
@@ -35,7 +37,8 @@ def get_args():
3537
lib = load(name='hgemm_lib',
3638
sources=['hgemm.cu', 'hgemm_async.cu', 'hgemm_wmma.cu',
3739
'hgemm_wmma_stage.cu', 'hgemm_cublas.cu',
38-
'hgemm_mma.cu', 'hgemm_mma_stage.cu'],
40+
'hgemm_mma.cu', 'hgemm_mma_stage.cu',
41+
'hgemm_mma_stage_tn.cu'],
3942
extra_cuda_cflags=[
4043
"-O3",
4144
"-U__CUDA_NO_HALF_OPERATORS__",
@@ -65,6 +68,8 @@ def run_benchmark(perf_func: callable,
6568
M = a.size(0)
6669
K = a.size(1)
6770
N = b.size(1)
71+
if 'tn' in tag:
72+
N = b.size(0)
6873
if swizzle:
6974
# make swizzle stride as N/4 or N/2 and multiples of 256
7075
swizzle_stride = int((int(N * args.swizzle_factor) // 256) * 256)
@@ -217,8 +222,17 @@ def run_benchmark(perf_func: callable,
217222
run_benchmark(lib.hgemm_mma_m16n8k16_mma2x4_warp4x4x2_stages_dsmem, a, b, "(mma2x4+warp4x4x2+stage4+dsmem+swizzle)", c, stages=4, swizzle=True)
218223
run_benchmark(lib.hgemm_mma_m16n8k16_mma2x4_warp4x4x2_stages_dsmem, a, b, "(mma2x4+warp4x4x2+stage3+dsmem+swizzle)", c, stages=3, swizzle=True)
219224
run_benchmark(lib.hgemm_mma_m16n8k16_mma2x4_warp4x4x2_stages_dsmem, a, b, "(mma2x4+warp4x4x2+stage2+dsmem+swizzle)", c, stages=2, swizzle=True)
220-
if not args.disable_cublas:
221-
run_benchmark(lib.hgemm_cublas_tensor_op_row_major, a, b, "(cublas)", c)
225+
if (not args.disable_cublas) and any((
226+
args.enable_mma, args.enable_mma_all, args.enable_wmma, args.enable_wmma_all,
227+
args.enable_cuda, args.enable_cuda_all, args.enable_torch)):
228+
run_benchmark(lib.hgemm_cublas_tensor_op_nn, a, b, "(cublas)", c)
229+
if args.enable_mma_tn:
230+
run_benchmark(lib.hgemm_mma_m16n8k16_mma2x4_warp4x4_stages_dsmem_tn, a, b.transpose(1, 0), "tn(mma2x4+warp4x4+stage3+dsmem)", c, stages=3)
231+
run_benchmark(lib.hgemm_mma_m16n8k16_mma2x4_warp4x4_stages_dsmem_tn, a, b.transpose(1, 0), "tn(mma2x4+warp4x4+stage2+dsmem)", c, stages=2)
232+
run_benchmark(lib.hgemm_mma_m16n8k16_mma2x4_warp4x4_stages_dsmem_tn, a, b.transpose(1, 0), "tn(mma2x4+warp4x4+stage3+dsmem+swizzle)", c, stages=3, swizzle=True)
233+
run_benchmark(lib.hgemm_mma_m16n8k16_mma2x4_warp4x4_stages_dsmem_tn, a, b.transpose(1, 0), "tn(mma2x4+warp4x4+stage2+dsmem+swizzle)", c, stages=2, swizzle=True)
234+
if not args.disable_cublas_tn:
235+
run_benchmark(lib.hgemm_cublas_tensor_op_tn, a, b.transpose(1, 0), "tn(cublas)", c)
222236
if args.enable_torch:
223237
run_benchmark(partial(torch.matmul, out=c), a, b, "(torch)")
224238
torch.cuda.synchronize()

hgemm/hgemm_cublas.cu

Lines changed: 49 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -14,8 +14,8 @@
1414

1515
#include "cublas_v2.h"
1616

17-
void cublas_tensor_op_row_major(half *A, half *B, half *C, size_t M,
18-
size_t N, size_t K) {
17+
// NN: A/B/C All row major
18+
void cublas_tensor_op_nn(half *A, half *B, half *C, size_t M, size_t N, size_t K) {
1919

2020
static cublasHandle_t handle = nullptr;
2121
cublasCreate(&handle);
@@ -36,11 +36,33 @@ void cublas_tensor_op_row_major(half *A, half *B, half *C, size_t M,
3636
CUBLAS_COMPUTE_16F,
3737
CUBLAS_GEMM_DEFAULT_TENSOR_OP);
3838

39-
// why this line will make cublas slow down?
4039
// cublasDestroy(handle);
4140
}
4241

43-
// TODO: add cublas_tensor_op_col_major
42+
// TN: A row major MxK, B col major NxK, C row major MxN
43+
void cublas_tensor_op_tn(half *A, half *B, half *C, size_t M, size_t N, size_t K) {
44+
45+
static cublasHandle_t handle = nullptr;
46+
cublasCreate(&handle);
47+
cublasSetMathMode(handle, CUBLAS_TENSOR_OP_MATH);
48+
49+
static half alpha = 1.0;
50+
static half beta = 0.0;
51+
52+
cublasGemmEx(handle,
53+
CUBLAS_OP_T,
54+
CUBLAS_OP_N,
55+
N, M, K,
56+
&alpha,
57+
B, CUDA_R_16F, K,
58+
A, CUDA_R_16F, K,
59+
&beta,
60+
C, CUDA_R_16F, N,
61+
CUBLAS_COMPUTE_16F,
62+
CUBLAS_GEMM_DEFAULT_TENSOR_OP);
63+
64+
// cublasDestroy(handle);
65+
}
4466

4567
// --------------------- PyTorch bindings for custom kernel -----------------------
4668
#define STRINGFY(str) #str
@@ -58,8 +80,8 @@ if (((T).size(0) != (S0)) || ((T).size(1) != (S1))) { \
5880
throw std::runtime_error("Tensor size mismatch!"); \
5981
}
6082

61-
// cublas tensor op with row major B matrix
62-
void hgemm_cublas_tensor_op_row_major(
83+
// NN: A/B/C All row major
84+
void hgemm_cublas_tensor_op_nn(
6385
torch::Tensor a, torch::Tensor b, torch::Tensor c) {
6486
CHECK_TORCH_TENSOR_DTYPE(a, torch::kHalf)
6587
CHECK_TORCH_TENSOR_DTYPE(b, torch::kHalf)
@@ -71,12 +93,31 @@ void hgemm_cublas_tensor_op_row_major(
7193
CHECK_TORCH_TENSOR_SHAPE(b, K, N)
7294
CHECK_TORCH_TENSOR_SHAPE(c, M, N)
7395

74-
cublas_tensor_op_row_major(
96+
cublas_tensor_op_nn(
7597
reinterpret_cast<half*>(a.data_ptr()),
7698
reinterpret_cast<half*>(b.data_ptr()),
7799
reinterpret_cast<half*>(c.data_ptr()),
78100
M, N, K
79101
);
80102
}
81103

82-
// TODO: add cublas_tensor_op_col_major
104+
// TN: A row major MxK, B col major NxK, C row major MxN
105+
void hgemm_cublas_tensor_op_tn(
106+
torch::Tensor a, torch::Tensor b, torch::Tensor c) {
107+
CHECK_TORCH_TENSOR_DTYPE(a, torch::kHalf)
108+
CHECK_TORCH_TENSOR_DTYPE(b, torch::kHalf)
109+
CHECK_TORCH_TENSOR_DTYPE(c, torch::kHalf)
110+
const int M = a.size(0);
111+
const int K = a.size(1);
112+
const int N = b.size(0);
113+
CHECK_TORCH_TENSOR_SHAPE(a, M, K)
114+
CHECK_TORCH_TENSOR_SHAPE(b, N, K)
115+
CHECK_TORCH_TENSOR_SHAPE(c, M, N)
116+
117+
cublas_tensor_op_tn(
118+
reinterpret_cast<half*>(a.data_ptr()),
119+
reinterpret_cast<half*>(b.data_ptr()),
120+
reinterpret_cast<half*>(c.data_ptr()),
121+
M, N, K
122+
);
123+
}

hgemm/hgemm_mma_stage_col_major.cu

Lines changed: 0 additions & 57 deletions
This file was deleted.

0 commit comments

Comments
 (0)