Skip to content

Commit 0c29631

Browse files
authored
[HGEMM] update HGEMM benchmark option (#95)
* update hgemm benchmark option * update hgemm benchmark option * update hgemm benchmark option
1 parent ce095b5 commit 0c29631

File tree

4 files changed

+18
-13
lines changed

4 files changed

+18
-13
lines changed

hgemm/README.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -238,12 +238,14 @@ export TORCH_CUDA_ARCH_LIST=Ada
238238
python3 hgemm.py # default, test some wmma kernels for all MNK
239239
python3 hgemm.py --wmma # test all wmma kernels for all MNK
240240
python3 hgemm.py --M 16384 --N 16384 --K 8192 --i 10 --wmma # test all wmma kernels for specific MNK
241+
python3 hgemm.py --wmma --no-default # test all wmma kernels, but exclude the default part.
241242
```
242243

243244
输出:
244245

245246
- NVIDIA L20
246247
```bash
248+
python3 hgemm.py
247249
----------------------------------------------------------------------------------------------------------------------------------
248250
M=4096, N=4096, K=2048
249251
f16x8pack(t8x8+dbuf): ['1.59863281', '-1.5263671'], time:1.404404ms, swizzle: NOOP, TFLOPS: 48.93 (+0.00%)

hgemm/hgemm.cu

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1211,7 +1211,7 @@ void hgemm_t_8x8_sliced_k32_f16x8_pack_dbuf_async(torch::Tensor a, torch::Tensor
12111211
void hgemm_t_16x8_sliced_k32_f16x8_pack_dbuf(torch::Tensor a, torch::Tensor b, torch::Tensor c);
12121212
void hgemm_t_16x8_sliced_k32_f16x8_pack_dbuf_async(torch::Tensor a, torch::Tensor b, torch::Tensor c);
12131213
// from hgemm_cublas.cu
1214-
void hgemm_cublas_tensor_op(torch::Tensor a, torch::Tensor b, torch::Tensor c);
1214+
void hgemm_cublas_tensor_op_row_major(torch::Tensor a, torch::Tensor b, torch::Tensor c);
12151215
// from hgemm_wmma.cu
12161216
void hgemm_wmma_m16n16k16_naive(torch::Tensor a, torch::Tensor b, torch::Tensor c);
12171217
void hgemm_wmma_m16n16k16_mma4x2(torch::Tensor a, torch::Tensor b, torch::Tensor c);
@@ -1266,7 +1266,7 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
12661266
TORCH_BINDING_COMMON_EXTENSION(hgemm_t_16x8_sliced_k32_f16x8_pack_dbuf)
12671267
TORCH_BINDING_COMMON_EXTENSION(hgemm_t_16x8_sliced_k32_f16x8_pack_dbuf_async)
12681268
// cuBLAS Tensor Cores
1269-
TORCH_BINDING_COMMON_EXTENSION(hgemm_cublas_tensor_op)
1269+
TORCH_BINDING_COMMON_EXTENSION(hgemm_cublas_tensor_op_row_major)
12701270
// WMMA API Tensor Cores
12711271
TORCH_BINDING_COMMON_EXTENSION(hgemm_wmma_m16n16k16_naive)
12721272
TORCH_BINDING_COMMON_EXTENSION(hgemm_wmma_m16n16k16_mma4x2)

hgemm/hgemm.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -18,8 +18,8 @@ def get_args():
1818
parser.add_argument("--enable-wmma-all", "--wmma", action="store_true", help="Enable all WMMA kernel tests")
1919
parser.add_argument("--enable-cuda-all", "--cuda", action="store_true", help="Enable all CUDA kernel tests")
2020
parser.add_argument("--enable-torch", "--torch", action="store_true", help="Enable torch matmul")
21-
parser.add_argument("--enable-cublas", "--cublas", action="store_true", default=True, help="Enable cublas hgemm")
22-
parser.add_argument("--disable-default", "--no-default", action="store_true", default=False, help="Disable default tests")
21+
parser.add_argument("--disable-cublas", "--no-cublas", action="store_true", help="Disable cublas hgemm")
22+
parser.add_argument("--disable-default", "--no-default", action="store_true", help="Disable default tests")
2323
return parser.parse_args()
2424

2525
args = get_args()
@@ -205,8 +205,8 @@ def run_benchmark(perf_func: callable,
205205
if args.enable_mma_all: # more mma kernel tests.
206206
print("-" * 68 + "MMA" + "-" * 59)
207207
pass
208-
if args.enable_cublas:
209-
run_benchmark(lib.hgemm_cublas_tensor_op, a, b, "(cublas)", c)
208+
if not args.disable_cublas:
209+
run_benchmark(lib.hgemm_cublas_tensor_op_row_major, a, b, "(cublas)", c)
210210
if args.enable_torch:
211211
run_benchmark(partial(torch.matmul, out=c), a, b, "(torch)")
212212
torch.cuda.synchronize()

hgemm/hgemm_cublas.cu

Lines changed: 10 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -14,11 +14,10 @@
1414

1515
#include "cublas_v2.h"
1616

17+
void cublas_tensor_op_row_major(half *A, half *B, half *C, size_t M,
18+
size_t N, size_t K) {
1719

18-
void cublas_tensor_op(half *A, half *B, half *C, size_t M,
19-
size_t N, size_t K) {
20-
21-
cublasHandle_t handle = nullptr;
20+
static cublasHandle_t handle = nullptr;
2221
cublasCreate(&handle);
2322
cublasSetMathMode(handle, CUBLAS_TENSOR_OP_MATH);
2423

@@ -41,6 +40,8 @@ void cublas_tensor_op(half *A, half *B, half *C, size_t M,
4140
// cublasDestroy(handle);
4241
}
4342

43+
// TODO: add cublas_tensor_op_col_major
44+
4445
// --------------------- PyTorch bindings for custom kernel -----------------------
4546
#define STRINGFY(str) #str
4647
#define TORCH_BINDING_COMMON_EXTENSION(func) \
@@ -57,8 +58,8 @@ if (((T).size(0) != (S0)) || ((T).size(1) != (S1))) { \
5758
throw std::runtime_error("Tensor size mismatch!"); \
5859
}
5960

60-
// cublas tensor op
61-
void hgemm_cublas_tensor_op(
61+
// cublas tensor op with row major B matrix
62+
void hgemm_cublas_tensor_op_row_major(
6263
torch::Tensor a, torch::Tensor b, torch::Tensor c) {
6364
CHECK_TORCH_TENSOR_DTYPE(a, torch::kHalf)
6465
CHECK_TORCH_TENSOR_DTYPE(b, torch::kHalf)
@@ -70,10 +71,12 @@ void hgemm_cublas_tensor_op(
7071
CHECK_TORCH_TENSOR_SHAPE(b, K, N)
7172
CHECK_TORCH_TENSOR_SHAPE(c, M, N)
7273

73-
cublas_tensor_op(
74+
cublas_tensor_op_row_major(
7475
reinterpret_cast<half*>(a.data_ptr()),
7576
reinterpret_cast<half*>(b.data_ptr()),
7677
reinterpret_cast<half*>(c.data_ptr()),
7778
M, N, K
7879
);
7980
}
81+
82+
// TODO: add cublas_tensor_op_col_major

0 commit comments

Comments
 (0)