5757 * functions from its derived classes to map the suitable path of
5858 * execution - CPU or GPU.
5959 */
60- template <class ScalarType >
61- class CMatrixVectorProduct {
62- public:
63- virtual ~CMatrixVectorProduct () = 0 ;
64- virtual void operator ()(const CSysVector<ScalarType>& u, CSysVector<ScalarType>& v) const = 0;
65- };
66- template <class ScalarType >
67- CMatrixVectorProduct<ScalarType>::~CMatrixVectorProduct () {}
6860
6961/* !
7062 * \class CExecutionPath
7163 * \brief Dummy super class that holds the correct member functions in its child classes
7264 */
7365
7466template <class ScalarType >
75- class CExecutionPath {
76- public:
77- virtual void mat_vec_prod (const CSysVector<ScalarType>& u, CSysVector<ScalarType>& v, CGeometry* geometry,
78- const CConfig* config, const CSysMatrix<ScalarType>& matrix) = 0;
79- };
80-
81- /* !
82- * \class CCpuExecution
83- * \brief Derived class containing the CPU Matrix Vector Product Function
84- */
85- template <class ScalarType >
86- class CCpuExecution : public CExecutionPath <ScalarType> {
67+ class CMatrixVectorProduct {
8768 public:
88- void mat_vec_prod (const CSysVector<ScalarType>& u, CSysVector<ScalarType>& v, CGeometry* geometry,
89- const CConfig* config, const CSysMatrix<ScalarType>& matrix) override {
90- matrix.MatrixVectorProduct (u, v, geometry, config);
91- }
69+ virtual ~CMatrixVectorProduct () = 0 ;
70+ virtual void operator ()(const CSysVector<ScalarType>& u, CSysVector<ScalarType>& v) const = 0;
9271};
93-
94- /* !
95- * \class CGpuExecution
96- * \brief Derived class containing the GPU Matrix Vector Product Function
97- */
9872template <class ScalarType >
99- class CGpuExecution : public CExecutionPath <ScalarType> {
100- public:
101- void mat_vec_prod (const CSysVector<ScalarType>& u, CSysVector<ScalarType>& v, CGeometry* geometry,
102- const CConfig* config, const CSysMatrix<ScalarType>& matrix) override {
103- #ifdef HAVE_CUDA
104- matrix.GPUMatrixVectorProduct (u, v, geometry, config);
105- #else
106- SU2_MPI::Error (
107- " \n Error in launching Matrix-Vector Product Function\n ENABLE_CUDA is set to YES\n Please compile with CUDA "
108- " options enabled in Meson to access GPU Functions" ,
109- CURRENT_FUNCTION);
110- #endif
111- }
112- };
73+ CMatrixVectorProduct<ScalarType>::~CMatrixVectorProduct () {}
11374
11475/* !
11576 * \class CSysMatrixVectorProduct
@@ -122,7 +83,6 @@ class CSysMatrixVectorProduct final : public CMatrixVectorProduct<ScalarType> {
12283 const CSysMatrix<ScalarType>& matrix; /* !< \brief pointer to matrix that defines the product. */
12384 CGeometry* geometry; /* !< \brief geometry associated with the matrix. */
12485 const CConfig* config; /* !< \brief config of the problem. */
125- CExecutionPath<ScalarType>* exec; /* !< \brief interface that decides which path of execution to choose from. */
12686
12787 public:
12888 /* !
@@ -133,13 +93,7 @@ class CSysMatrixVectorProduct final : public CMatrixVectorProduct<ScalarType> {
13393 */
13494 inline CSysMatrixVectorProduct (const CSysMatrix<ScalarType>& matrix_ref, CGeometry* geometry_ref,
13595 const CConfig* config_ref)
136- : matrix(matrix_ref), geometry(geometry_ref), config(config_ref) {
137- if (config->GetCUDA ()) {
138- exec = new CGpuExecution<ScalarType>;
139- } else {
140- exec = new CCpuExecution<ScalarType>;
141- }
142- }
96+ : matrix(matrix_ref), geometry(geometry_ref), config(config_ref) {}
14397
14498 /* !
14599 * \note This class cannot be default constructed as that would leave us with invalid pointers.
@@ -152,6 +106,14 @@ class CSysMatrixVectorProduct final : public CMatrixVectorProduct<ScalarType> {
152106 * \param[out] v - CSysVector that is the result of the product
153107 */
154108 inline void operator ()(const CSysVector<ScalarType>& u, CSysVector<ScalarType>& v) const override {
155- exec->mat_vec_prod (u, v, geometry, config, matrix);
109+ #ifdef HAVE_CUDA
110+ if (config->GetCUDA ()) {
111+ matrix.GPUMatrixVectorProduct (u, v, geometry, config);
112+ } else {
113+ matrix.MatrixVectorProduct (u, v, geometry, config);
114+ }
115+ #else
116+ matrix.MatrixVectorProduct (u, v, geometry, config)
117+ #endif
156118 }
157119};
0 commit comments