Skip to content

Commit 5663da6

Browse files
committed
Merge branch 'master' of https://github.com/eth-cscs/reframe into bugfix/var_options_type
2 parents f278859 + b04eae5 commit 5663da6

File tree

3 files changed

+35
-21
lines changed

3 files changed

+35
-21
lines changed
Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,2 +1,6 @@
1+
all: dgemm sgemm
12
dgemm:
2-
nvcc [email protected] -o [email protected] ${CXXFLAGS} -lnvidia-ml -lcublas -std=c++14
3+
nvcc xgemm.cu -o [email protected] -DGEMM_TYPE=double -DXBLAS_GEMM=XblasDgemm ${CXXFLAGS} -lnvidia-ml -lcublas -std=c++14
4+
sgemm:
5+
nvcc xgemm.cu -o [email protected] -DGEMM_TYPE=float -DXBLAS_GEMM=XblasSgemm ${CXXFLAGS} -lnvidia-ml -lcublas -std=c++14
6+
Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,10 @@
11
ROCM_ROOT?=/opt/rocm
22
RSMI_ROOT?=/opt/rocm/rocm_smi
33

4+
all: dgemm sgemm
5+
46
dgemm:
5-
hipcc -O3 [email protected] -o [email protected] -DTARGET_HIP ${CXXFLAGS} -std=c++14 -I${ROCM_ROOT} -I${RSMI_ROOT}/include -lnuma -lrocm_smi64 -lrocblas
7+
hipcc -O3 xgemm.cu -o [email protected] -DTARGET_HIP -DGEMM_TYPE=double -DXBLAS_GEMM=XblasDgemm ${CXXFLAGS} -std=c++14 -I${ROCM_ROOT} -I${RSMI_ROOT}/include -lnuma -lrocm_smi64 -lrocblas
8+
9+
sgemm:
10+
hipcc -O3 xgemm.cu -o [email protected] -DTARGET_HIP -DGEMM_TYPE=float -DXBLAS_GEMM=XblasSgemm ${CXXFLAGS} -std=c++14 -I${ROCM_ROOT} -I${RSMI_ROOT}/include -lnuma -lrocm_smi64 -lrocblas

hpctestlib/microbenchmarks/gpu/dgemm/src/dgemm.cu renamed to hpctestlib/microbenchmarks/gpu/dgemm/src/xgemm.cu

Lines changed: 24 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -55,23 +55,25 @@ double tflops = SIZE*SIZE*SIZE*2.0 * 1E-12;
5555
int totalErrors = 0;
5656
std::mutex mtx;
5757

58+
5859
#define BLOCK_SIZE 128
59-
void dgemm(int device)
60+
template<class T>
61+
void xgemm_test(int device)
6062
{
6163
XSetDevice(device);
6264

63-
double * A;
64-
double * B;
65-
double * C;
66-
const double alpha = 1.0;
67-
const double beta = 0.0;
65+
T * A;
66+
T * B;
67+
T * C;
68+
const T alpha = 1.0;
69+
const T beta = 0.0;
6870

69-
XMalloc((void**)&A, sizeof(double)*SIZE*SIZE);
70-
XMalloc((void**)&B, sizeof(double)*SIZE*SIZE);
71-
XMalloc((void**)&C, sizeof(double)*SIZE*SIZE);
71+
XMalloc((void**)&A, sizeof(T)*SIZE*SIZE);
72+
XMalloc((void**)&B, sizeof(T)*SIZE*SIZE);
73+
XMalloc((void**)&C, sizeof(T)*SIZE*SIZE);
7274

73-
kernels::init_as_ones<double><<<(SIZE*SIZE+BLOCK_SIZE-1)/BLOCK_SIZE, BLOCK_SIZE>>>(A, SIZE*SIZE);
74-
kernels::init_as_ones<double><<<(SIZE*SIZE+BLOCK_SIZE-1)/BLOCK_SIZE, BLOCK_SIZE>>>(B, SIZE*SIZE);
75+
kernels::init_as_ones<T><<<(SIZE*SIZE+BLOCK_SIZE-1)/BLOCK_SIZE, BLOCK_SIZE>>>(A, SIZE*SIZE);
76+
kernels::init_as_ones<T><<<(SIZE*SIZE+BLOCK_SIZE-1)/BLOCK_SIZE, BLOCK_SIZE>>>(B, SIZE*SIZE);
7577
XDeviceSynchronize();
7678

7779
XStream_t stream;
@@ -81,12 +83,13 @@ void dgemm(int device)
8183
XblasSetStream(blas_handle, stream);
8284

8385
// Warmup call
84-
XblasDgemm(blas_handle,
86+
// define either as XblasDgemm or XblasSgemm
87+
XBLAS_GEMM(blas_handle,
8588
XBLAS_OP_N, XBLAS_OP_N,
8689
SIZE, SIZE, SIZE,
8790
&alpha,
88-
(const double*)A, SIZE,
89-
(const double*)B, SIZE,
91+
(const T*)A, SIZE,
92+
(const T*)B, SIZE,
9093
&beta,
9194
C, SIZE);
9295
XDeviceSynchronize();
@@ -96,12 +99,13 @@ void dgemm(int device)
9699
t.start();
97100
for (int i = 0; i < REPEAT; i++)
98101
{
99-
XblasDgemm(blas_handle,
102+
// define either as XblasDgemm or XblasSgemm
103+
XBLAS_GEMM(blas_handle,
100104
XBLAS_OP_N, XBLAS_OP_N,
101105
SIZE, SIZE, SIZE,
102106
&alpha,
103-
(const double*)A, SIZE,
104-
(const double*)B, SIZE,
107+
(const T*)A, SIZE,
108+
(const T*)B, SIZE,
105109
&beta,
106110
C, SIZE);
107111
}
@@ -116,7 +120,7 @@ void dgemm(int device)
116120
int * err, h_err = 0;
117121
XMalloc((void**)&err, sizeof(int));
118122
XMemcpy( err, &h_err, sizeof(int), XMemcpyHostToDevice);
119-
kernels::verify<double><<<(SIZE+BLOCK_SIZE-1)/BLOCK_SIZE, BLOCK_SIZE>>>(C, SIZE*SIZE, err);
123+
kernels::verify<T><<<(SIZE+BLOCK_SIZE-1)/BLOCK_SIZE, BLOCK_SIZE>>>(C, SIZE*SIZE, err);
120124
XMemcpy(&h_err, err, sizeof(int), XMemcpyDeviceToHost);
121125
{
122126
std::lock_guard<std::mutex> lg(mtx);
@@ -145,10 +149,11 @@ int main(int argc, char **argv)
145149
// Create vector of threads.
146150
std::vector<std::thread> threads;
147151

152+
148153
// Do the dgemm for all devices in the node.
149154
for (int device = 0; device < num_devices; device++)
150155
{
151-
threads.push_back(std::thread(dgemm,device));
156+
threads.push_back(std::thread(xgemm_test<GEMM_TYPE>,device));
152157
}
153158

154159
// Join all threads

0 commit comments

Comments
 (0)