@@ -55,23 +55,25 @@ double tflops = SIZE*SIZE*SIZE*2.0 * 1E-12;
5555int totalErrors = 0 ;
5656std::mutex mtx;
5757
58+
5859#define BLOCK_SIZE 128
59- void dgemm (int device)
60+ template <class T >
61+ void xgemm_test (int device)
6062{
6163 XSetDevice (device);
6264
63- double * A;
64- double * B;
65- double * C;
66- const double alpha = 1.0 ;
67- const double beta = 0.0 ;
65+ T * A;
66+ T * B;
67+ T * C;
68+ const T alpha = 1.0 ;
69+ const T beta = 0.0 ;
6870
69- XMalloc ((void **)&A, sizeof (double )*SIZE*SIZE);
70- XMalloc ((void **)&B, sizeof (double )*SIZE*SIZE);
71- XMalloc ((void **)&C, sizeof (double )*SIZE*SIZE);
71+ XMalloc ((void **)&A, sizeof (T )*SIZE*SIZE);
72+ XMalloc ((void **)&B, sizeof (T )*SIZE*SIZE);
73+ XMalloc ((void **)&C, sizeof (T )*SIZE*SIZE);
7274
73- kernels::init_as_ones<double ><<<(SIZE*SIZE+BLOCK_SIZE-1 )/BLOCK_SIZE, BLOCK_SIZE>>> (A, SIZE*SIZE);
74- kernels::init_as_ones<double ><<<(SIZE*SIZE+BLOCK_SIZE-1 )/BLOCK_SIZE, BLOCK_SIZE>>> (B, SIZE*SIZE);
75+ kernels::init_as_ones<T ><<<(SIZE*SIZE+BLOCK_SIZE-1 )/BLOCK_SIZE, BLOCK_SIZE>>> (A, SIZE*SIZE);
76+ kernels::init_as_ones<T ><<<(SIZE*SIZE+BLOCK_SIZE-1 )/BLOCK_SIZE, BLOCK_SIZE>>> (B, SIZE*SIZE);
7577 XDeviceSynchronize ();
7678
7779 XStream_t stream;
@@ -81,12 +83,13 @@ void dgemm(int device)
8183 XblasSetStream (blas_handle, stream);
8284
8385 // Warmup call
84- XblasDgemm (blas_handle,
86+ // define either as XblasDgemm or XblasSgemm
87+ XBLAS_GEMM (blas_handle,
8588 XBLAS_OP_N, XBLAS_OP_N,
8689 SIZE, SIZE, SIZE,
8790 &alpha,
88- (const double *)A, SIZE,
89- (const double *)B, SIZE,
91+ (const T *)A, SIZE,
92+ (const T *)B, SIZE,
9093 &beta,
9194 C, SIZE);
9295 XDeviceSynchronize ();
@@ -96,12 +99,13 @@ void dgemm(int device)
9699 t.start ();
97100 for (int i = 0 ; i < REPEAT; i++)
98101 {
99- XblasDgemm (blas_handle,
102+ // define either as XblasDgemm or XblasSgemm
103+ XBLAS_GEMM (blas_handle,
100104 XBLAS_OP_N, XBLAS_OP_N,
101105 SIZE, SIZE, SIZE,
102106 &alpha,
103- (const double *)A, SIZE,
104- (const double *)B, SIZE,
107+ (const T *)A, SIZE,
108+ (const T *)B, SIZE,
105109 &beta,
106110 C, SIZE);
107111 }
@@ -116,7 +120,7 @@ void dgemm(int device)
116120 int * err, h_err = 0 ;
117121 XMalloc ((void **)&err, sizeof (int ));
118122 XMemcpy ( err, &h_err, sizeof (int ), XMemcpyHostToDevice);
119- kernels::verify<double ><<<(SIZE+BLOCK_SIZE-1 )/BLOCK_SIZE, BLOCK_SIZE>>> (C, SIZE*SIZE, err);
123+ kernels::verify<T ><<<(SIZE+BLOCK_SIZE-1 )/BLOCK_SIZE, BLOCK_SIZE>>> (C, SIZE*SIZE, err);
120124 XMemcpy (&h_err, err, sizeof (int ), XMemcpyDeviceToHost);
121125 {
122126 std::lock_guard<std::mutex> lg (mtx);
@@ -145,10 +149,11 @@ int main(int argc, char **argv)
145149 // Create vector of threads.
146150 std::vector<std::thread> threads;
147151
152+
148153 // Do the dgemm for all devices in the node.
149154 for (int device = 0 ; device < num_devices; device++)
150155 {
151- threads.push_back (std::thread (dgemm ,device));
156+ threads.push_back (std::thread (xgemm_test<GEMM_TYPE> ,device));
152157 }
153158
154159 // Join all threads
0 commit comments