Skip to content

Commit 366f324

Browse files
[SofaImplicitField] Add getHessian and getGradient to the python binding (#5655)
* Add binding for hessian & gradient. * Add getHessian, getGradient and getValue with an unified API to Binding_ScalarField I unified the interface so it is conformant with the one from sofa expect for getHessian. * Update applications/plugins/SofaImplicitField/python/src/Binding_ScalarField.cpp
1 parent 47cf52a commit 366f324

File tree

2 files changed

+93
-12
lines changed

2 files changed

+93
-12
lines changed
Lines changed: 29 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
import Sofa
22
from SofaImplicitField import ScalarField
3+
from SofaTypes.SofaTypes import Vec3d, Mat3x3
34
import numpy
45

56
class Sphere(ScalarField):
@@ -9,20 +10,41 @@ def __init__(self, *args, **kwargs):
910
self.addData("center", type="Vec3d",value=kwargs.get("center", [0.0,0.0,0.0]), default=[0.0,0.0,0.0], help="center of the sphere", group="Geometry")
1011
self.addData("radius", type="double",value=kwargs.get("radius", 1.0), default=1, help="radius of the sphere", group="Geometry")
1112

12-
def getValue(self, x, y, z):
13+
def getValue(self, position):
14+
x,y,z = position
1315
return numpy.sqrt( numpy.sum((self.center.value - numpy.array([x,y,z]))**2) ) - self.radius.value
1416

17+
class SphereWithCustomHessianAndGradient(Sphere):
18+
def __init__(self, *args, **kwargs):
19+
Sphere.__init__(self, *args, **kwargs)#
20+
21+
def getGradient(self, position):
22+
return Vec3d(3.0,2.0,1.0)
23+
24+
def getHessian(self, position):
25+
return Mat3x3([[1,1,1],[2,1,1],[3,1,1]])
26+
1527
class FieldController(Sofa.Core.Controller):
1628
def __init__(self, *args, **kwargs):
1729
Sofa.Core.Controller.__init__(self, *args, **kwargs)
1830
self.field = kwargs.get("target")
1931

2032
def onAnimateEndEvent(self, event):
21-
print("Animation end event")
22-
print("Field value at 0,0,0 is: ", self.field.getValue(0.0,0.0,0.0) )
23-
print("Field value at 1,0,0 is: ", self.field.getValue(1.0,0.0,0.0) )
24-
print("Field value at 2,0,0 is: ", self.field.getValue(2.0,0.0,0.0) )
33+
print("Animation end event, ")
34+
print("Field value at 0,0,0 is: ", self.field.getValue(Vec3d(0.0,0.0,0.0)) )
35+
print("Field value at 1,0,0 is: ", self.field.getValue(Vec3d(1.0,0.0,0.0)) )
36+
print("Field value at 2,0,0 is: ", self.field.getValue(Vec3d(2.0,0.0,0.0)) )
37+
38+
print("Gradient value at 0,0,0 is: ", type(self.field.getGradient(Vec3d(0.0,0.0,0.0))))
39+
print("Hessian value at 0,0,0 is: ", type(self.field.getHessian(Vec3d(0.0,0.0,0.0))))
2540

2641
def createScene(root):
27-
root.addObject(Sphere("field"))
28-
root.addObject(FieldController(target=root.field))
42+
"""In this scene we create two scalar field of spherical shape, the two are implemented using
43+
python. The first one is overriding only the getValue, the hessian and gradient is thus computed using
44+
finite difference in the c++ code. The second field is overriding the hessian and gradient function
45+
"""
46+
root.addObject(Sphere("field1"))
47+
root.addObject(FieldController(name="controller1", target=root.field1))
48+
49+
root.addObject(SphereWithCustomHessianAndGradient("field2"))
50+
root.addObject(FieldController(name="controller2", target=root.field2))

applications/plugins/SofaImplicitField/python/src/Binding_ScalarField.cpp

Lines changed: 64 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -33,22 +33,65 @@ namespace sofaimplicitfield {
3333
using namespace sofapython3;
3434
using sofa::component::geometry::ScalarField;
3535
using sofa::core::objectmodel::BaseObject;
36-
using sofa::type::Vec3d;
36+
using sofa::type::Vec3;
37+
using sofa::type::Mat3x3;
3738

3839
class ScalarField_Trampoline : public ScalarField {
3940
public:
4041
SOFA_CLASS(ScalarField_Trampoline, ScalarField);
4142

42-
double getValue(Vec3d& pos, int& domain) override{
43+
// Override this function so that it returns the actual python class name instead of
44+
// "ScalarField_Trampoline" which correspond to this utility class.
45+
std::string getClassName() const override
46+
{
47+
PythonEnvironment::gil acquire;
48+
49+
// Get the actual class name from python.
50+
return py::str(py::cast(this).get_type().attr("__name__"));
51+
}
52+
53+
double getValue(Vec3& pos, int& domain) override
54+
{
55+
SOFA_UNUSED(domain);
56+
PythonEnvironment::gil acquire;
57+
58+
PYBIND11_OVERLOAD_PURE(double, ScalarField, getValue, pos);
59+
}
60+
61+
Vec3 getGradient(Vec3& pos, int& domain) override
62+
{
4363
SOFA_UNUSED(domain);
4464
PythonEnvironment::gil acquire;
4565

46-
PYBIND11_OVERLOAD_PURE(double, ScalarField, getValue, pos.x(), pos.y(), pos.z());
66+
PYBIND11_OVERLOAD(Vec3, ScalarField, getGradient, pos);
67+
}
68+
69+
void getHessian(Vec3 &pos, Mat3x3& h) override
70+
{
71+
/// The implementation is a bit more complex compared to getGradient. This is because we change de signature between the c++ API and the python one.
72+
PythonEnvironment::gil acquire;
73+
74+
// Search if there is a python override,
75+
pybind11::function override = pybind11::get_override(static_cast<const ScalarField*>(this),"getHessian");
76+
if(!override){
77+
return ScalarField::getHessian(pos, h);
78+
}
79+
// as there is one override, we call it, passing the "pos" argument and storing the return of the
80+
// value in the "o" variable.
81+
auto o = override(pos);
82+
83+
// then we check that the function correctly returned a Mat3x3 object and copy its value
84+
// in case there is no Mat3x3 returned values, rise an error
85+
if(py::isinstance<Mat3x3>(o))
86+
h = py::cast<Mat3x3>(o);
87+
else
88+
throw py::type_error("The function getHessian must return a Mat3x3");
89+
return;
4790
}
4891
};
4992

5093
void moduleAddScalarField(py::module &m) {
51-
py::class_<ScalarField, BaseObject, ScalarField_Trampoline,
94+
py::class_<ScalarField, ScalarField_Trampoline, BaseObject,
5295
py_shared_ptr<ScalarField>> f(m, "ScalarField", py::dynamic_attr(), "");
5396

5497
f.def(py::init([](py::args &args, py::kwargs &kwargs) {
@@ -74,7 +117,23 @@ void moduleAddScalarField(py::module &m) {
74117
return ff;
75118
}));
76119

77-
m.def("getValue", &ScalarField_Trampoline::getValue);
120+
f.def("getValue", [](ScalarField* self, Vec3 pos){
121+
int domain=-1;
122+
// This shouldn't be self->ScalarField::getValue because it is a pure function
123+
// so there is not ScalarField::getValue emitted.
124+
return self->getValue(pos, domain);
125+
});
126+
127+
f.def("getGradient", [](ScalarField* self, Vec3 pos){
128+
int domain=-1;
129+
return self->ScalarField::getGradient(pos, domain);
130+
});
131+
132+
f.def("getHessian", [](ScalarField* self, Vec3 pos){
133+
Mat3x3 result;
134+
self->getHessian(pos, result);
135+
return result;
136+
});
78137
}
79138

80139
}

0 commit comments

Comments
 (0)