Skip to content

Commit b209f4c

Browse files
committed
3D matrix view
1 parent e14aa3e commit b209f4c

File tree

2 files changed

+81
-0
lines changed

2 files changed

+81
-0
lines changed

Common/include/containers/CPyWrapperMatrixView.hpp

Lines changed: 76 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,8 @@
5454
\
5555
/*! \brief Gets the value for a (row, column) pair. */ \
5656
passivedouble operator()(unsigned long row, unsigned long col) const { return Get(row, col); } \
57+
\
58+
/*! \brief Gets the values for a row of the matrix. */ \
5759
std::vector<passivedouble> operator()(unsigned long row) const { return Get(row); } \
5860
\
5961
/*! \brief Gets the value for a (row, column) pair. */ \
@@ -164,3 +166,77 @@ class CPyWrapperMarkerMatrixView {
164166
/*--- Use the macro to generate the interface. ---*/
165167
PY_WRAPPER_MATRIX_INTERFACE
166168
};
169+
170+
/*!
171+
* \class CPyWrapper3DMatrixView
172+
* \ingroup PySU2
173+
* \brief This class wraps C3DDoubleMatrix for the python wrapper matrix interface.
174+
* It is generaly used to wrap access to solver gradients defined for the entire volume.
175+
*/
176+
class CPyWrapper3DMatrixView {
177+
protected:
178+
static_assert(su2activematrix::IsRowMajor, "");
179+
su2double* data_ = nullptr;
180+
unsigned long rows_ = 0, cols_ = 0, dims_ = 0;
181+
std::string name_;
182+
bool read_only_ = false;
183+
184+
/*--- Define the functions required by the interface macro. ---*/
185+
inline const su2double& Access(unsigned long row, unsigned long col, unsigned long dim) const {
186+
if (row > rows_ || col > cols_ || dim > dims_) SU2_MPI::Error(name_ + " out of bounds", "CPyWrapper3DMatrixView");
187+
return data_[row * (cols_ * dims_) + col * dims_ + dim];
188+
}
189+
inline su2double& Access(unsigned long row, unsigned long col, unsigned long dim) {
190+
if (read_only_) SU2_MPI::Error(name_ + " is read-only", "CPyWrapper3DMatrixView");
191+
const auto& const_me = *this;
192+
return const_cast<su2double&>(const_me.Access(row, col, dim));
193+
}
194+
195+
public:
196+
CPyWrapper3DMatrixView() = default;
197+
198+
/*!
199+
* \brief Construct the view of the matrix.
200+
* \note "name" should be set to the variable name being returned to give better information to users.
201+
* \note "read_only" can be set to true to prevent the data from being modified.
202+
*/
203+
CPyWrapper3DMatrixView(C3DDoubleMatrix& mat, const std::string& name, bool read_only)
204+
: data_(mat.data()),
205+
rows_(mat.length()),
206+
cols_(mat.rows()),
207+
dims_(mat.cols()),
208+
name_(name),
209+
read_only_(read_only) {}
210+
211+
/*! \brief Returns the shape of the matrix. */
212+
std::vector<unsigned long> Shape() const { return {rows_, cols_, dims_}; }
213+
214+
/*! \brief Returns whether the data is read-only [true] or if it can be modified [false]. */
215+
bool IsReadOnly() const { return read_only_; }
216+
217+
/*! \brief Gets the value for a (row, column, dimension) triplet. */
218+
passivedouble operator()(unsigned long row, unsigned long col, unsigned long dim) const { return Get(row, col, dim); }
219+
220+
/*! \brief Gets the values for a row and column of the matrix. */
221+
std::vector<passivedouble> operator()(unsigned long row, unsigned long col) const { return Get(row, col); }
222+
223+
/*! \brief Gets the value for a (row, column, dimension) triplet. */
224+
passivedouble Get(unsigned long row, unsigned long col, unsigned long dim) const {
225+
return SU2_TYPE::GetValue(Access(row, col, dim));
226+
}
227+
228+
/*! \brief Gets the values for a row and column of the matrix. */
229+
std::vector<passivedouble> Get(unsigned long row, unsigned long col) const {
230+
std::vector<passivedouble> vals(dims_);
231+
for (unsigned long j = 0; j < dims_; ++j) vals[j] = Get(row, col, j);
232+
return vals;
233+
}
234+
/*! \brief Sets the value for a (row, column, dimension) triplet. This clears derivative information. */
235+
void Set(unsigned long row, unsigned long col, unsigned long dim, passivedouble val) { Access(row, col, dim) = val; }
236+
237+
/*! \brief Sets the values for a row and column of the matrix. */
238+
void Set(unsigned long row, unsigned long col, std::vector<passivedouble> vals) {
239+
unsigned long j = 0;
240+
for (const auto& val : vals) Set(row, col, j++, val);
241+
}
242+
};

Common/include/containers/container_decorators.hpp

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -165,6 +165,11 @@ class C3DContainerDecorator {
165165
FORCEINLINE StaticContainer get(Int i, Index j = 0) const noexcept {
166166
return m_storage.template get<StaticContainer>(i, j * m_innerSz);
167167
}
168+
169+
/*!
170+
* \brief Raw data access, for Python wrapper.
171+
*/
172+
FORCEINLINE Scalar* data() { return m_storage.data(); }
168173
};
169174

170175
/*!

0 commit comments

Comments
 (0)