|
55 | 55 | /*! \brief Gets the value for a (row, column) pair. */ \ |
56 | 56 | passivedouble operator()(unsigned long row, unsigned long col) const { return Get(row, col); } \ |
57 | 57 | \ |
| 58 | + /*! \brief Gets the values for a row of the matrix. */ \ |
| 59 | + std::vector<passivedouble> operator()(unsigned long row) const { return Get(row); } \ |
| 60 | + \ |
58 | 61 | /*! \brief Gets the value for a (row, column) pair. */ \ |
59 | 62 | passivedouble Get(unsigned long row, unsigned long col) const { return SU2_TYPE::GetValue(Access(row, col)); } \ |
60 | 63 | \ |
@@ -163,3 +166,77 @@ class CPyWrapperMarkerMatrixView { |
163 | 166 | /*--- Use the macro to generate the interface. ---*/ |
164 | 167 | PY_WRAPPER_MATRIX_INTERFACE |
165 | 168 | }; |
| 169 | + |
| 170 | +/*! |
| 171 | + * \class CPyWrapper3DMatrixView |
| 172 | + * \ingroup PySU2 |
| 173 | + * \brief This class wraps C3DDoubleMatrix for the python wrapper matrix interface. |
| 174 | + * It is generaly used to wrap access to solver gradients defined for the entire volume. |
| 175 | + */ |
| 176 | +class CPyWrapper3DMatrixView { |
| 177 | + protected: |
| 178 | + static_assert(su2activematrix::IsRowMajor, ""); |
| 179 | + su2double* data_ = nullptr; |
| 180 | + unsigned long rows_ = 0, cols_ = 0, dims_ = 0; |
| 181 | + std::string name_; |
| 182 | + bool read_only_ = false; |
| 183 | + |
| 184 | + /*--- Define the functions required by the interface macro. ---*/ |
| 185 | + inline const su2double& Access(unsigned long row, unsigned long col, unsigned long dim) const { |
| 186 | + if (row > rows_ || col > cols_ || dim > dims_) SU2_MPI::Error(name_ + " out of bounds", "CPyWrapper3DMatrixView"); |
| 187 | + return data_[row * (cols_ * dims_) + col * dims_ + dim]; |
| 188 | + } |
| 189 | + inline su2double& Access(unsigned long row, unsigned long col, unsigned long dim) { |
| 190 | + if (read_only_) SU2_MPI::Error(name_ + " is read-only", "CPyWrapper3DMatrixView"); |
| 191 | + const auto& const_me = *this; |
| 192 | + return const_cast<su2double&>(const_me.Access(row, col, dim)); |
| 193 | + } |
| 194 | + |
| 195 | + public: |
| 196 | + CPyWrapper3DMatrixView() = default; |
| 197 | + |
| 198 | + /*! |
| 199 | + * \brief Construct the view of the matrix. |
| 200 | + * \note "name" should be set to the variable name being returned to give better information to users. |
| 201 | + * \note "read_only" can be set to true to prevent the data from being modified. |
| 202 | + */ |
| 203 | + CPyWrapper3DMatrixView(C3DDoubleMatrix& mat, const std::string& name, bool read_only) |
| 204 | + : data_(mat.data()), |
| 205 | + rows_(mat.length()), |
| 206 | + cols_(mat.rows()), |
| 207 | + dims_(mat.cols()), |
| 208 | + name_(name), |
| 209 | + read_only_(read_only) {} |
| 210 | + |
| 211 | + /*! \brief Returns the shape of the matrix. */ |
| 212 | + std::vector<unsigned long> Shape() const { return {rows_, cols_, dims_}; } |
| 213 | + |
| 214 | + /*! \brief Returns whether the data is read-only [true] or if it can be modified [false]. */ |
| 215 | + bool IsReadOnly() const { return read_only_; } |
| 216 | + |
| 217 | + /*! \brief Gets the value for a (row, column, dimension) triplet. */ |
| 218 | + passivedouble operator()(unsigned long row, unsigned long col, unsigned long dim) const { return Get(row, col, dim); } |
| 219 | + |
| 220 | + /*! \brief Gets the values for a row and column of the matrix. */ |
| 221 | + std::vector<passivedouble> operator()(unsigned long row, unsigned long col) const { return Get(row, col); } |
| 222 | + |
| 223 | + /*! \brief Gets the value for a (row, column, dimension) triplet. */ |
| 224 | + passivedouble Get(unsigned long row, unsigned long col, unsigned long dim) const { |
| 225 | + return SU2_TYPE::GetValue(Access(row, col, dim)); |
| 226 | + } |
| 227 | + |
| 228 | + /*! \brief Gets the values for a row and column of the matrix. */ |
| 229 | + std::vector<passivedouble> Get(unsigned long row, unsigned long col) const { |
| 230 | + std::vector<passivedouble> vals(dims_); |
| 231 | + for (unsigned long j = 0; j < dims_; ++j) vals[j] = Get(row, col, j); |
| 232 | + return vals; |
| 233 | + } |
| 234 | + /*! \brief Sets the value for a (row, column, dimension) triplet. This clears derivative information. */ |
| 235 | + void Set(unsigned long row, unsigned long col, unsigned long dim, passivedouble val) { Access(row, col, dim) = val; } |
| 236 | + |
| 237 | + /*! \brief Sets the values for a row and column of the matrix. */ |
| 238 | + void Set(unsigned long row, unsigned long col, std::vector<passivedouble> vals) { |
| 239 | + unsigned long j = 0; |
| 240 | + for (const auto& val : vals) Set(row, col, j++, val); |
| 241 | + } |
| 242 | +}; |
0 commit comments