Skip to content

Commit 91f7b10

Browse files
authored
[HGEMM] fix cublas hgemm handle error (#138)
* Update hgemm_cublas.cu * Update hgemm_mma_stage_tn_cute.cu * Update hgemm_mma_stage_tn_cute.cu * Update utils.h * Update utils.h * Update hgemm_mma_stage_tn_cute.cu * Update hgemm_cublas.cu * Update hgemm_mma_stage_tn_cute.cu
1 parent ed1d100 commit 91f7b10

File tree

3 files changed

+42
-103
lines changed

3 files changed

+42
-103
lines changed

hgemm/hgemm_cublas.cu

Lines changed: 31 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -63,7 +63,27 @@ void cublas_tensor_op_tn(half *A, half *B, half *C, size_t M, size_t N, size_t
6363
// build cpp binary
6464
#ifndef NO_CUBLAS_HGEMM_BIN
6565

66-
float perf_cublas(int M, int N, int K, int repeat) {
66+
// pass the cuBLAS handle from outside to avoid error.
67+
void cublas_tensor_op_tn_v2(cublasHandle_t handle,
68+
half *A, half *B, half *C,
69+
size_t M, size_t N, size_t K) {
70+
half alpha = 1.0;
71+
half beta = 0.0;
72+
73+
cublasGemmEx(handle,
74+
CUBLAS_OP_T,
75+
CUBLAS_OP_N,
76+
N, M, K,
77+
&alpha,
78+
B, CUDA_R_16F, K,
79+
A, CUDA_R_16F, K,
80+
&beta,
81+
C, CUDA_R_16F, N,
82+
CUBLAS_COMPUTE_16F,
83+
CUBLAS_GEMM_DEFAULT_TENSOR_OP);
84+
}
85+
86+
float perf_cublas_tn(int M, int N, int K, int repeat) {
6787
size_t size_a = M * K * sizeof(half);
6888
size_t size_b = K * N * sizeof(half);
6989
size_t size_c = M * N * sizeof(half);
@@ -74,9 +94,13 @@ float perf_cublas(int M, int N, int K, int repeat) {
7494
cudaMalloc(&d_b, size_b);
7595
cudaMalloc(&d_c, size_c);
7696

97+
cublasHandle_t handle = nullptr;
98+
cublasCreate(&handle);
99+
cublasSetMathMode(handle, CUBLAS_TENSOR_OP_MATH);
100+
77101
// warmup
78102
for (int i = 0; i < 10; ++i) {
79-
cublas_tensor_op_tn(d_a, d_b, d_c, M, N, K);
103+
cublas_tensor_op_tn_v2(handle, d_a, d_b, d_c, M, N, K);
80104
}
81105
cudaDeviceSynchronize();
82106

@@ -86,7 +110,7 @@ float perf_cublas(int M, int N, int K, int repeat) {
86110
cudaEventRecord(start);
87111

88112
for (int i = 0; i < repeat; i++) {
89-
cublas_tensor_op_tn(d_a, d_b, d_c, M, N, K);
113+
cublas_tensor_op_tn_v2(handle, d_a, d_b, d_c, M, N, K);
90114
}
91115

92116
cudaEventRecord(end);
@@ -102,12 +126,13 @@ float perf_cublas(int M, int N, int K, int repeat) {
102126
cudaFree(d_c);
103127
cudaEventDestroy(start);
104128
cudaEventDestroy(end);
129+
cublasDestroy(handle);
105130

106131
return sec;
107132
}
108133

109134
int main(int argc, char *argv[]) {
110-
const int test_num = 50;
135+
const int test_num = 64;
111136
int M_list[test_num];
112137
int N_list[test_num];
113138
int K_list[test_num];
@@ -120,7 +145,7 @@ int main(int argc, char *argv[]) {
120145

121146
const int outer_repeat = 10, inner_repeat = 1;
122147

123-
printf("\nalgo = Cublas TN\n");
148+
printf("ALGO = cuBLAS CUBLAS_GEMM_DEFAULT_TENSOR_OP TN\n");
124149

125150
for (int j = 0; j < test_num; j++) {
126151
int M = M_list[j], N = N_list[j], K = K_list[j];
@@ -130,7 +155,7 @@ int main(int argc, char *argv[]) {
130155
double total_sec = 0.0;
131156

132157
for (int k = 0; k < outer_repeat; k++) {
133-
double this_sec = perf_cublas(M, N, K, inner_repeat);
158+
double this_sec = perf_cublas_tn(M, N, K, inner_repeat);
134159
max_sec = max(max_sec, this_sec);
135160
min_sec = min(min_sec, this_sec);
136161
total_sec += this_sec;

hgemm/hgemm_mma_stage_tn_cute.cu

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,6 @@
33
#include <stdlib.h>
44
#include <cute/tensor.hpp>
55
#include <float.h>
6-
// modifide from: https://github.com/weishengying/cute_gemm/blob/main/gemm_4/gemm.cu
76

87
// TODO: thread block swizzle, cute hgemm nn
98
template <
@@ -349,7 +348,7 @@ int main() {
349348
using T = cute::half_t;
350349
using namespace cute;
351350

352-
const int test_num = 50;
351+
const int test_num = 64;
353352
int M_list[test_num];
354353
int N_list[test_num];
355354
int K_list[test_num];
@@ -362,10 +361,10 @@ int main() {
362361

363362
const int outer_repeat = 10, inner_repeat = 1;
364363

365-
printf("\nalgo = CuTe HGEMM Stages 2\n");
364+
printf("ALGO = CuTe HGEMM Stages 2\n");
366365
for (int j = 0; j < 5; j++) {
367366
int M = M_list[j], N = N_list[j], K = K_list[j];
368-
float max_error = gemm_error_check_v2<T>(
367+
float max_error = gemm_error_check<T>(
369368
launch_hgemm_mma_stages_tn_cute, M, N, K);
370369
printf("M N K = %6d %6d %6d, ", M, N, K);
371370
printf("Max Error = %f\n", max_error);

hgemm/utils.h

Lines changed: 8 additions & 93 deletions
Original file line numberDiff line numberDiff line change
@@ -2,38 +2,6 @@
22
#include <cstdlib>
33
#include <cuda.h>
44
#include <cublas_v2.h>
5-
// modified from: https://github.com/weishengying/cute_gemm/blob/main/utils.h
6-
7-
#define OFFSET(row_idx, col_idx, stride_0, stride_1) \
8-
row_idx*stride_0 + col_idx*stride_1
9-
10-
#define PRINT(name, content) \
11-
print(name); \
12-
print(" : "); \
13-
print(content); \
14-
print("\n");
15-
16-
#define PRINTTENSOR(name, content) \
17-
print(name); \
18-
print(" : "); \
19-
print_tensor(content); \
20-
print("\n");
21-
22-
template<class T>
23-
void cpu_hgemm(const T* A, const T* B, T* C,
24-
const int M, const int N, const int K) {
25-
// A(M,K):(K,1) B(K,N):(1,K)
26-
for(int m = 0; m < M; m++) {
27-
for(int n = 0; n < N; n++) {
28-
float tmp = 0.0;
29-
for(int k = 0; k < K; k++) {
30-
tmp += float(A[OFFSET(m, k, K, 1)]) * float(B[OFFSET(k, n, 1, K)]);
31-
}
32-
C[OFFSET(m, n, N, 1)] = T(tmp);
33-
}
34-
}
35-
return;
36-
}
375

386
template <typename T>
397
float perf_gemm(
@@ -89,59 +57,6 @@ float gemm_error_check(
8957
size_t size_b = K * N * sizeof(T);
9058
size_t size_c = M * N * sizeof(T);
9159

92-
T *h_a, *h_b, *d_a, *d_b;
93-
T *h_c, *d_c, *h_d_c;
94-
95-
h_a = (T *)malloc(size_a);
96-
h_b = (T *)malloc(size_b);
97-
h_c = (T *)malloc(size_c);
98-
cudaMalloc(&d_a, size_a);
99-
cudaMalloc(&d_b, size_b);
100-
cudaMalloc(&d_c, size_c);
101-
102-
h_d_c = (T *)malloc(size_c);
103-
104-
srand(time(0));
105-
for (int i = 0; i < M * K; i++)
106-
h_a[i] = (T)((rand() % 200 - 100) * 0.01); // -1 ~ 1
107-
for (int i = 0; i < K * N; i++)
108-
h_b[i] = (T)((rand() % 200 - 100) * 0.01);
109-
110-
cpu_hgemm(h_a, h_b, h_c, M, N, K);
111-
112-
cudaMemcpy(d_a, h_a, size_a, cudaMemcpyHostToDevice);
113-
cudaMemcpy(d_b, h_b, size_b, cudaMemcpyHostToDevice);
114-
115-
gpu_hgemm(d_a, d_b, d_c, M, N, K);
116-
117-
cudaMemcpy(h_d_c, d_c, size_c, cudaMemcpyDeviceToHost);
118-
119-
float max_error = 0.0;
120-
for (int i = 0; i < M * N; i++) {
121-
float this_error = abs((float)h_d_c[i] - (float)h_c[i]);
122-
max_error = max(max_error, this_error);
123-
}
124-
125-
free(h_a);
126-
free(h_b);
127-
free(h_c);
128-
cudaFree(d_a);
129-
cudaFree(d_b);
130-
cudaFree(d_c);
131-
free(h_d_c);
132-
133-
return max_error;
134-
}
135-
136-
template <typename T>
137-
float gemm_error_check_v2(
138-
void (*gpu_hgemm) (const T *, const T *, T *, int, int, int),
139-
int M, int N, int K) {
140-
141-
size_t size_a = M * K * sizeof(T);
142-
size_t size_b = K * N * sizeof(T);
143-
size_t size_c = M * N * sizeof(T);
144-
14560
T *h_a, *h_b, *h_c, *h_c_ref;
14661
T *d_a, *d_b, *d_c, *d_c_ref;
14762

@@ -170,14 +85,14 @@ float gemm_error_check_v2(
17085
cudaMemcpy(d_b, h_b, size_b, cudaMemcpyHostToDevice);
17186

17287
cublasHgemm(handle,
173-
CUBLAS_OP_T,
174-
CUBLAS_OP_N,
175-
N, M, K,
176-
&alpha,
177-
(half *)d_b, K,
178-
(half *)d_a, K,
179-
&beta,
180-
(half *)d_c_ref, N);
88+
CUBLAS_OP_T,
89+
CUBLAS_OP_N,
90+
N, M, K,
91+
&alpha,
92+
(half *)d_b, K,
93+
(half *)d_a, K,
94+
&beta,
95+
(half *)d_c_ref, N);
18196

18297
gpu_hgemm(d_a, d_b, d_c, M, N, K);
18398

0 commit comments

Comments
 (0)