Skip to content

Commit 276a131

Browse files
committed
Compile residual function including its jacobian ourselves
1 parent 6117c26 commit 276a131

File tree

1 file changed

+126
-88
lines changed

1 file changed

+126
-88
lines changed

examples/demo_planar_hsa_motor2ee_jacobian.py

Lines changed: 126 additions & 88 deletions
Original file line numberDiff line numberDiff line change
@@ -23,105 +23,143 @@
2323
/ f"planar_hsa_ns-{num_segments}_nrs-{num_rods_per_segment}.dill"
2424
)
2525

26-
# activate all strains (i.e. bending, shear, and axial)
27-
strain_selector = jnp.ones((3 * num_segments,), dtype=bool)
28-
params = PARAMS_FPU_CONTROL
29-
phi_max = params["phi_max"].flatten()
30-
31-
# define initial configuration
32-
q0 = jnp.array([0.0, 0.0, 0.0])
33-
34-
# increase damping for simulation stability
35-
params["zetab"] = 5 * params["zetab"]
36-
params["zetash"] = 5 * params["zetash"]
37-
params["zetaa"] = 5 * params["zetaa"]
38-
39-
# nonlinear least squares solver settings
40-
nlq_tol = 1e-5 # tolerance for the nonlinear least squares solver
41-
42-
(
43-
forward_kinematics_virtual_backbone_fn,
44-
forward_kinematics_end_effector_fn,
45-
jacobian_end_effector_fn,
46-
inverse_kinematics_end_effector_fn,
47-
dynamical_matrices_fn,
48-
sys_helpers,
49-
) = planar_hsa.factory(sym_exp_filepath, strain_selector)
50-
jitted_dynamical_matrics_fn = jit(partial(dynamical_matrices_fn, params))
51-
52-
53-
def phi2q_static_model(phi: Array, q0: Array = jnp.zeros((3, ))) -> Tuple[Array, Dict[str, Array]]:
26+
27+
def factory_fn(params: Dict[str, Array], nlq_tol: float = 1e-5):
5428
"""
55-
A static model mapping the motor angles to the planar HSA configuration.
56-
Arguments:
57-
phi: motor angles
29+
Factory function for the planar HSA.
30+
Args:
31+
params: dictionary with robot parameters
32+
nlq_tol: tolerance for the nonlinear least squares solver
33+
5834
Returns:
59-
q: planar HSA configuration consisting of (k_be, sigma_sh, sigma_ax)
60-
aux: dictionary with auxiliary data
61-
"""
62-
q_d = jnp.zeros((3,))
6335
64-
def residual_fn(_q: Array) -> Array:
65-
_, _, _G, _K, _, _alpha = jitted_dynamical_matrics_fn(_q, q_d, phi=phi)
66-
res = _alpha - _G - _K
36+
"""
37+
(
38+
forward_kinematics_virtual_backbone_fn,
39+
forward_kinematics_end_effector_fn,
40+
jacobian_end_effector_fn,
41+
inverse_kinematics_end_effector_fn,
42+
dynamical_matrices_fn,
43+
sys_helpers,
44+
) = planar_hsa.factory(sym_exp_filepath, strain_selector)
45+
dynamical_matrices_fn = partial(dynamical_matrices_fn, params)
46+
forward_kinematics_end_effector_fn = jit(partial(forward_kinematics_end_effector_fn, params))
47+
48+
def residual_fn(q: Array, phi: Array) -> Array:
49+
q_d = jnp.zeros_like(q)
50+
_, _, G, K, _, alpha = dynamical_matrices_fn(q, q_d, phi=phi)
51+
res = alpha - G - K
6752
return res
6853

69-
# solve the nonlinear least squares problem
70-
lm = LevenbergMarquardt(residual_fun=residual_fn, tol=nlq_tol, jit=True, unroll=True, verbose=True)
71-
sol = lm.run(q0)
54+
# jit the residual function
55+
residual_fn = jit(residual_fn)
56+
print("Compiling residual_fn...")
57+
print(residual_fn(jnp.zeros((3,)), jnp.zeros((2,))))
58+
59+
# define the Jacobian of the residual function
60+
jac_residual_fn = jit(jacrev(residual_fn, argnums=0))
61+
print("Compiling jac_residual_fn...")
62+
print(jac_residual_fn(jnp.zeros((3,)), jnp.zeros((2,))))
63+
64+
def phi2q_static_model_jaxopt_fn(phi: Array, q0: Array = jnp.zeros((3, ))) -> Tuple[Array, Dict[str, Array]]:
65+
"""
66+
A static model mapping the motor angles to the planar HSA configuration.
67+
Arguments:
68+
phi: motor angles
69+
q0: initial guess for the configuration
70+
Returns:
71+
q: planar HSA configuration consisting of (k_be, sigma_sh, sigma_ax)
72+
aux: dictionary with auxiliary data
73+
"""
74+
# solve the nonlinear least squares problem
75+
lm = LevenbergMarquardt(
76+
residual_fun=partial(residual_fn, phi=phi),
77+
jac_fun=partial(jac_residual_fn, phi=phi),
78+
tol=nlq_tol,
79+
jit=False,
80+
verbose=True
81+
)
82+
sol = lm.run(q0)
83+
84+
# configuration that minimizes the residual
85+
q = sol.params
86+
87+
# compute the L2 optimality
88+
optimality_error = lm.l2_optimality_error(sol.params)
89+
90+
aux = dict(
91+
phi=phi,
92+
q=q,
93+
optimality_error=optimality_error,
94+
)
95+
96+
return q, aux
97+
98+
99+
def phi2chi_static_model_fn(phi: Array, q0: Array = jnp.zeros((3, ))) -> Tuple[Array, Dict[str, Array]]:
100+
"""
101+
A static model mapping the motor angles to the planar end-effector pose.
102+
Arguments:
103+
phi: motor angles
104+
q0: initial guess for the configuration
105+
Returns:
106+
chi: end-effector pose
107+
aux: dictionary with auxiliary data
108+
"""
109+
q, aux = phi2q_static_model_jaxopt_fn(phi, q0=q0)
110+
chi = forward_kinematics_end_effector_fn(q)
111+
aux["chi"] = chi
112+
return chi, aux
113+
114+
def jac_phi2chi_static_model_fn(phi: Array) -> Tuple[Array, Dict[str, Array]]:
115+
"""
116+
Compute the Jacobian between the actuation space and the task-space.
117+
Arguments:
118+
phi: motor angles
119+
"""
120+
# evaluate the static model to convert motor angles into a configuration
121+
q = phi2q_static_model_jaxopt_fn(phi)
122+
# take the Jacobian between actuation and configuration space
123+
J_phi2q, aux = jacfwd(phi2q_static_model_jaxopt_fn, has_aux=True)(phi)
124+
125+
# evaluate the closed-form, analytical jacobian of the forward kinematics
126+
J_q2chi = jacobian_end_effector_fn(params, q)
127+
128+
# evaluate the Jacobian between the actuation and the task-space
129+
J_phi2chi = J_q2chi @ J_phi2q
130+
131+
return J_phi2chi, aux
132+
133+
return phi2chi_static_model_fn
72134

73-
# configuration that minimizes the residual
74-
q = sol.params
75135

76-
# compute the L2 optimality
77-
optimality_error = lm.l2_optimality_error(sol.params)
78-
79-
aux = dict(
80-
phi=phi,
81-
q=q,
82-
optimality_error=optimality_error,
83-
)
84-
85-
return q, aux
86-
87-
def phi2chi_static_model(phi: Array) -> Tuple[Array, Dict[str, Array]]:
88-
"""
89-
A static model mapping the motor angles to the planar end-effector pose.
90-
Arguments:
91-
phi: motor angles
92-
Returns:
93-
chi: end-effector pose
94-
aux: dictionary with auxiliary data
95-
"""
96-
q, aux = phi2q_static_model(phi)
97-
chi = forward_kinematics_end_effector_fn(params, q)
98-
aux["chi"] = chi
99-
return chi, aux
100-
101-
def jac_phi2chi_static_model(phi: Array) -> Tuple[Array, Dict[str, Array]]:
102-
"""
103-
Compute the Jacobian between the actuation space and the task-space.
104-
Arguments:
105-
phi: motor angles
106-
"""
107-
# evaluate the static model to convert motor angles into a configuration
108-
q = phi2q_static_model(phi)
109-
# take the Jacobian between actuation and configuration space
110-
J_phi2q, aux = jacfwd(phi2q_static_model, has_aux=True)(phi)
136+
if __name__ == "__main__":
137+
# activate all strains (i.e. bending, shear, and axial)
138+
strain_selector = jnp.ones((3 * num_segments,), dtype=bool)
139+
params = PARAMS_FPU_CONTROL
140+
phi_max = params["phi_max"].flatten()
111141

112-
# evaluate the closed-form, analytical jacobian of the forward kinematics
113-
J_q2chi = jacobian_end_effector_fn(params, q)
142+
# increase damping for simulation stability
143+
params["zetab"] = 5 * params["zetab"]
144+
params["zetash"] = 5 * params["zetash"]
145+
params["zetaa"] = 5 * params["zetaa"]
114146

115-
# evaluate the Jacobian between the actuation and the task-space
116-
J_phi2chi = J_q2chi @ J_phi2q
147+
# call the factory function
148+
phi2chi_static_model_fn = factory_fn(params)
117149

118-
return J_phi2chi, aux
150+
# define initial configuration
151+
q0 = jnp.array([0.0, 0.0, 0.0])
119152

153+
# phi2q_static_model_fn = jit(phi2q_static_model)
154+
# print("Compiling phi2q_static_model_fn...")
155+
# print(phi2q_static_model_fn(jnp.zeros((2,))))
120156

121-
if __name__ == "__main__":
122-
jitted_phi2q_static_model_fn = jit(phi2q_static_model)
123-
J_phi2chi_autodiff_fn = jit(jacfwd(phi2chi_static_model, has_aux=True))
124-
J_phi2chi_fn = jit(jac_phi2chi_static_model)
157+
# print("Compiling J_phi2chi_autodiff_fn...")
158+
# J_phi2chi_autodiff_fn = jit(jacfwd(phi2chi_static_model, has_aux=True))
159+
# print(J_phi2chi_autodiff_fn(jnp.zeros((2,))))
160+
# J_phi2chi_fn = jit(jac_phi2chi_static_model)
161+
# print("Compiling J_phi2chi_fn...")
162+
# print(J_phi2chi_fn(jnp.zeros((2,))))
125163

126164
rng = random.key(seed=0)
127165
for i in range(10):
@@ -141,7 +179,7 @@ def jac_phi2chi_static_model(phi: Array) -> Tuple[Array, Dict[str, Array]]:
141179

142180
print("i", i, "phi", phi)
143181

144-
q, aux = jitted_phi2q_static_model_fn(phi)
182+
q, aux = phi2chi_static_model_fn(phi, q0=q0)
145183
print("phi", phi, "q", q)
146184

147185
# J_phi2chi_autodiff, aux = J_phi2chi_autodiff_fn(phi)

0 commit comments

Comments
 (0)