Skip to content

Commit d13ea2a

Browse files
committed
More rigor in the scalar types
1 parent 29ffd96 commit d13ea2a

File tree

1 file changed

+37
-23
lines changed

1 file changed

+37
-23
lines changed

bindings/Modules/src/SofaPython3/SofaLinearSystem/Binding_LinearSystem.cpp

Lines changed: 37 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -33,27 +33,41 @@ namespace py { using namespace pybind11; }
3333

3434
namespace sofapython3 {
3535

36-
using EigenSparseMatrix = Eigen::SparseMatrix<SReal, Eigen::RowMajor>;
37-
using EigenMatrixMap = Eigen::Map<Eigen::SparseMatrix<SReal, Eigen::RowMajor> >;
38-
using Vector = Eigen::Matrix<SReal,Eigen::Dynamic, 1>;
39-
using EigenVectorMap = Eigen::Map<Vector>;
36+
template<class Real>
37+
using EigenSparseMatrix = Eigen::SparseMatrix<Real, Eigen::RowMajor>;
4038

41-
template<class TBlock>
42-
EigenSparseMatrix toEigen(sofa::linearalgebra::CompressedRowSparseMatrix<TBlock>& matrix)
43-
{
44-
sofa::linearalgebra::CompressedRowSparseMatrix<typename TBlock::Real> filtered;
45-
filtered.copyNonZeros(matrix);
46-
filtered.compress();
47-
return EigenMatrixMap(filtered.rows(), filtered.cols(), filtered.getColsValue().size(),
48-
(EigenMatrixMap::StorageIndex*)filtered.rowBegin.data(), (EigenMatrixMap::StorageIndex*)filtered.colsIndex.data(), filtered.colsValue.data());
49-
}
39+
template<class Real>
40+
using EigenMatrixMap = Eigen::Map<Eigen::SparseMatrix<Real, Eigen::RowMajor> >;
41+
42+
template<class Real>
43+
using Vector = Eigen::Matrix<Real,Eigen::Dynamic, 1>;
5044

51-
template<>
52-
EigenSparseMatrix toEigen<SReal>(sofa::linearalgebra::CompressedRowSparseMatrix<SReal>& matrix)
45+
template<class Real>
46+
using EigenVectorMap = Eigen::Map<Vector<Real>>;
47+
48+
template<class TBlock>
49+
EigenSparseMatrix<typename sofa::linearalgebra::CompressedRowSparseMatrix<TBlock>::Real>
50+
toEigen(sofa::linearalgebra::CompressedRowSparseMatrix<TBlock>& matrix)
5351
{
54-
matrix.compress();
55-
return EigenMatrixMap(matrix.rows(), matrix.cols(), matrix.getColsValue().size(),
56-
(EigenMatrixMap::StorageIndex*)matrix.rowBegin.data(), (EigenMatrixMap::StorageIndex*)matrix.colsIndex.data(), matrix.colsValue.data());
52+
using Real = typename sofa::linearalgebra::CompressedRowSparseMatrix<TBlock>::Real;
53+
if constexpr (std::is_same_v<TBlock, Real>)
54+
{
55+
matrix.compress();
56+
return EigenMatrixMap<Real>(matrix.rows(), matrix.cols(), matrix.getColsValue().size(),
57+
(typename EigenMatrixMap<Real>::StorageIndex*)matrix.rowBegin.data(),
58+
(typename EigenMatrixMap<Real>::StorageIndex*)matrix.colsIndex.data(),
59+
matrix.colsValue.data());
60+
}
61+
else
62+
{
63+
sofa::linearalgebra::CompressedRowSparseMatrix<typename TBlock::Real> filtered;
64+
filtered.copyNonZeros(matrix);
65+
filtered.compress();
66+
return EigenMatrixMap<Real>(filtered.rows(), filtered.cols(), filtered.getColsValue().size(),
67+
(typename EigenMatrixMap<Real>::StorageIndex*)filtered.rowBegin.data(),
68+
(typename EigenMatrixMap<Real>::StorageIndex*)filtered.colsIndex.data(),
69+
filtered.colsValue.data());
70+
}
5771
}
5872

5973
template<class TBlock>
@@ -69,7 +83,7 @@ void bindLinearSystems(py::module &m)
6983
sofa::core::objectmodel::BaseObject,
7084
sofapython3::py_shared_ptr<CRSLinearSystem> > c(m, typeName.c_str(), sofapython3::doc::linearsystem::linearSystemClass);
7185

72-
c.def("A", [](CRSLinearSystem& self) -> EigenSparseMatrix
86+
c.def("A", [](CRSLinearSystem& self) -> EigenSparseMatrix<Real>
7387
{
7488
if (CRS* matrix = self.getSystemMatrix())
7589
{
@@ -78,20 +92,20 @@ void bindLinearSystems(py::module &m)
7892
return {};
7993
}, sofapython3::doc::linearsystem::linearSystem_A);
8094

81-
c.def("b", [](CRSLinearSystem& self) -> Vector
95+
c.def("b", [](CRSLinearSystem& self) -> Vector<Real>
8296
{
8397
if (auto* vector = self.getRHSVector())
8498
{
85-
return EigenVectorMap(vector->ptr(), vector->size());
99+
return EigenVectorMap<Real>(vector->ptr(), vector->size());
86100
}
87101
return {};
88102
}, sofapython3::doc::linearsystem::linearSystem_b);
89103

90-
c.def("x", [](CRSLinearSystem& self) -> Vector
104+
c.def("x", [](CRSLinearSystem& self) -> Vector<Real>
91105
{
92106
if (auto* vector = self.getSolutionVector())
93107
{
94-
return EigenVectorMap(vector->ptr(), vector->size());
108+
return EigenVectorMap<Real>(vector->ptr(), vector->size());
95109
}
96110
return {};
97111
}, sofapython3::doc::linearsystem::linearSystem_x);

0 commit comments

Comments
 (0)