Skip to content

Commit 3fc38d5

Browse files
committed
Add aux outputs to functions
1 parent 09754f2 commit 3fc38d5

File tree

2 files changed

+44
-28
lines changed

2 files changed

+44
-28
lines changed

examples/demo_planar_hsa_motor2ee_jacobian.py

Lines changed: 41 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,13 @@
11
import jax
22

33
jax.config.update("jax_enable_x64", True) # double precision
4-
from jax import Array, jacrev, jit, random, vmap
4+
from jax import Array, jacfwd, jacrev, jit, random, vmap
55
from jax import numpy as jnp
66
from jaxopt import GaussNewton, LevenbergMarquardt
77
from functools import partial
88
import numpy as onp
99
from pathlib import Path
10-
from typing import Callable, Dict
10+
from typing import Callable, Dict, Tuple
1111

1212
import jsrm
1313
from jsrm.parameters.hsa_params import PARAMS_FPU_CONTROL
@@ -48,13 +48,14 @@
4848
jitted_dynamical_matrics_fn = jit(partial(dynamical_matrices_fn, params))
4949

5050

51-
def phi2q_static_model(phi: Array, q0: Array = jnp.zeros((3, ))) -> Array:
51+
def phi2q_static_model(phi: Array, q0: Array = jnp.zeros((3, ))) -> Tuple[Array, Dict[str, Array]]:
5252
"""
5353
A static model mapping the motor angles to the planar HSA configuration.
5454
Arguments:
55-
u: motor angles
55+
phi: motor angles
5656
Returns:
5757
q: planar HSA configuration consisting of (k_be, sigma_sh, sigma_ax)
58+
aux: dictionary with auxiliary data
5859
"""
5960
q_d = jnp.zeros((3,))
6061

@@ -73,59 +74,74 @@ def residual_fn(_q: Array) -> Array:
7374
# compute the L2 optimality
7475
optimality_error = lm.l2_optimality_error(sol.params)
7576

76-
return q
77+
aux = dict(
78+
phi=phi,
79+
q=q,
80+
optimality_error=optimality_error,
81+
)
7782

78-
def u2chi_static_model(u: Array) -> Array:
83+
return q, aux
84+
85+
def phi2chi_static_model(phi: Array) -> Tuple[Array, Dict[str, Array]]:
7986
"""
8087
A static model mapping the motor angles to the planar end-effector pose.
8188
Arguments:
82-
u: motor angles
89+
phi: motor angles
8390
Returns:
8491
chi: end-effector pose
92+
aux: dictionary with auxiliary data
8593
"""
86-
q = phi2q_static_model(u)
94+
q, aux = phi2q_static_model(phi)
8795
chi = forward_kinematics_end_effector_fn(params, q)
88-
return chi
96+
aux["chi"] = chi
97+
return chi, aux
8998

90-
def jac_u2chi_static_model(u: Array) -> Array:
99+
def jac_phi2chi_static_model(phi: Array) -> Tuple[Array, Dict[str, Array]]:
100+
"""
101+
Compute the Jacobian between the actuation space and the task-space.
102+
Arguments:
103+
phi: motor angles
104+
"""
91105
# evaluate the static model to convert motor angles into a configuration
92-
q = phi2q_static_model(u)
106+
q = phi2q_static_model(phi)
93107
# take the Jacobian between actuation and configuration space
94-
J_u2q = jacrev(phi2q_static_model)(u)
108+
J_phi2q, aux = jacfwd(phi2q_static_model, has_aux=True)(phi)
95109

96110
# evaluate the closed-form, analytical jacobian of the forward kinematics
97111
J_q2chi = jacobian_end_effector_fn(params, q)
98112

99113
# evaluate the Jacobian between the actuation and the task-space
100-
J_u2chi = J_q2chi @ J_u2q
114+
J_phi2chi = J_q2chi @ J_phi2q
101115

102-
return J_u2chi
116+
return J_phi2chi, aux
103117

104118

105119
if __name__ == "__main__":
106120
jitted_phi2q_static_model_fn = jit(phi2q_static_model)
107-
J_u2chi_autodiff_fn = jacrev(u2chi_static_model)
108-
J_u2chi_fn = jac_u2chi_static_model
121+
J_phi2chi_autodiff_fn = jit(jacfwd(phi2chi_static_model, has_aux=True))
122+
J_phi2chi_fn = jit(jac_phi2chi_static_model)
109123

110124
rng = random.key(seed=0)
111125
for i in range(10):
112126
match i:
113127
case 0:
114-
u = jnp.array([0.0, 0.0])
128+
phi = jnp.array([0.0, 0.0])
115129
case 1:
116-
u = jnp.array([1.0, 1.0])
130+
phi = jnp.array([1.0, 1.0])
117131
case _:
118132
rng, subkey = random.split(rng)
119-
u = random.uniform(
133+
phi = random.uniform(
120134
subkey,
121135
phi_max.shape,
122136
minval=0.0,
123137
maxval=phi_max
124138
)
125139

126-
q = jitted_phi2q_static_model_fn(u)
127-
print("u", u, "q", q)
128-
# J_u2chi_autodiff = J_u2chi_autodiff_fn(u)
129-
# J_u2chi = J_u2chi_fn(u)
130-
# print("J_u2chi:\n", J_u2chi, "\nJ_u2chi_autodiff:\n", J_u2chi_autodiff)
131-
# print(J_u2chi.shape)
140+
print("i", i, "phi", phi)
141+
142+
q, aux = jitted_phi2q_static_model_fn(phi)
143+
print("phi", phi, "q", q)
144+
145+
J_u2chi_autodiff, aux = J_phi2chi_autodiff_fn(phi)
146+
J_u2chi, aux = J_phi2chi_fn(phi)
147+
print("J_u2chi:\n", J_u2chi, "\nJ_u2chi_autodiff:\n", J_u2chi_autodiff)

src/jsrm/systems/planar_hsa.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -717,9 +717,9 @@ def ode_factory(
717717
alpha_fn is a function to compute the actuation vector of shape (n_q). It has the following signature:
718718
alpha_fn(phi) -> tau_q where phi is the twist angle vector of shape (n_phi, )
719719
params: Dictionary with robot parameters
720-
control_fn: Callable that returns the forcing function of the form control_fn(t, x) -> u. If consider_underactuation_model is True,
721-
then u is an array of shape (n_q, ) with the configuration-space torques. If consider_underactuation_model is False,
722-
then u is an array of shape (n_phi, ) with the motor positions / twist angles of the proximal end of the rods.
720+
control_fn: Callable that returns the forcing function of the form control_fn(t, x) -> phi. If consider_underactuation_model is True,
721+
then phi is an array of shape (n_q, ) with the configuration-space torques. If consider_underactuation_model is False,
722+
then phi is an array of shape (n_phi, ) with the motor positions / twist angles of the proximal end of the rods.
723723
consider_underactuation_model: If True, the underactuation model is considered. Otherwise, the fully-actuated
724724
model is considered with the identity matrix as the actuation matrix.
725725
consider_hysteresis: If True, Bouc-Wen is used to model hysteresis. Otherwise, hysteresis will be neglected.

0 commit comments

Comments
 (0)