1
1
import jax
2
2
3
3
jax .config .update ("jax_enable_x64" , True ) # double precision
4
- from jax import Array , jacrev , jit , random , vmap
4
+ from jax import Array , jacfwd , jacrev , jit , random , vmap
5
5
from jax import numpy as jnp
6
6
from jaxopt import GaussNewton , LevenbergMarquardt
7
7
from functools import partial
8
8
import numpy as onp
9
9
from pathlib import Path
10
- from typing import Callable , Dict
10
+ from typing import Callable , Dict , Tuple
11
11
12
12
import jsrm
13
13
from jsrm .parameters .hsa_params import PARAMS_FPU_CONTROL
48
48
jitted_dynamical_matrics_fn = jit (partial (dynamical_matrices_fn , params ))
49
49
50
50
51
- def phi2q_static_model (phi : Array , q0 : Array = jnp .zeros ((3 , ))) -> Array :
51
+ def phi2q_static_model (phi : Array , q0 : Array = jnp .zeros ((3 , ))) -> Tuple [ Array , Dict [ str , Array ]] :
52
52
"""
53
53
A static model mapping the motor angles to the planar HSA configuration.
54
54
Arguments:
55
- u : motor angles
55
+ phi : motor angles
56
56
Returns:
57
57
q: planar HSA configuration consisting of (k_be, sigma_sh, sigma_ax)
58
+ aux: dictionary with auxiliary data
58
59
"""
59
60
q_d = jnp .zeros ((3 ,))
60
61
@@ -73,59 +74,74 @@ def residual_fn(_q: Array) -> Array:
73
74
# compute the L2 optimality
74
75
optimality_error = lm .l2_optimality_error (sol .params )
75
76
76
- return q
77
+ aux = dict (
78
+ phi = phi ,
79
+ q = q ,
80
+ optimality_error = optimality_error ,
81
+ )
77
82
78
- def u2chi_static_model (u : Array ) -> Array :
83
+ return q , aux
84
+
85
+ def phi2chi_static_model (phi : Array ) -> Tuple [Array , Dict [str , Array ]]:
79
86
"""
80
87
A static model mapping the motor angles to the planar end-effector pose.
81
88
Arguments:
82
- u : motor angles
89
+ phi : motor angles
83
90
Returns:
84
91
chi: end-effector pose
92
+ aux: dictionary with auxiliary data
85
93
"""
86
- q = phi2q_static_model (u )
94
+ q , aux = phi2q_static_model (phi )
87
95
chi = forward_kinematics_end_effector_fn (params , q )
88
- return chi
96
+ aux ["chi" ] = chi
97
+ return chi , aux
89
98
90
- def jac_u2chi_static_model (u : Array ) -> Array :
99
+ def jac_phi2chi_static_model (phi : Array ) -> Tuple [Array , Dict [str , Array ]]:
100
+ """
101
+ Compute the Jacobian between the actuation space and the task-space.
102
+ Arguments:
103
+ phi: motor angles
104
+ """
91
105
# evaluate the static model to convert motor angles into a configuration
92
- q = phi2q_static_model (u )
106
+ q = phi2q_static_model (phi )
93
107
# take the Jacobian between actuation and configuration space
94
- J_u2q = jacrev (phi2q_static_model )( u )
108
+ J_phi2q , aux = jacfwd (phi2q_static_model , has_aux = True )( phi )
95
109
96
110
# evaluate the closed-form, analytical jacobian of the forward kinematics
97
111
J_q2chi = jacobian_end_effector_fn (params , q )
98
112
99
113
# evaluate the Jacobian between the actuation and the task-space
100
- J_u2chi = J_q2chi @ J_u2q
114
+ J_phi2chi = J_q2chi @ J_phi2q
101
115
102
- return J_u2chi
116
+ return J_phi2chi , aux
103
117
104
118
105
119
if __name__ == "__main__" :
106
120
jitted_phi2q_static_model_fn = jit (phi2q_static_model )
107
- J_u2chi_autodiff_fn = jacrev ( u2chi_static_model )
108
- J_u2chi_fn = jac_u2chi_static_model
121
+ J_phi2chi_autodiff_fn = jit ( jacfwd ( phi2chi_static_model , has_aux = True ) )
122
+ J_phi2chi_fn = jit ( jac_phi2chi_static_model )
109
123
110
124
rng = random .key (seed = 0 )
111
125
for i in range (10 ):
112
126
match i :
113
127
case 0 :
114
- u = jnp .array ([0.0 , 0.0 ])
128
+ phi = jnp .array ([0.0 , 0.0 ])
115
129
case 1 :
116
- u = jnp .array ([1.0 , 1.0 ])
130
+ phi = jnp .array ([1.0 , 1.0 ])
117
131
case _:
118
132
rng , subkey = random .split (rng )
119
- u = random .uniform (
133
+ phi = random .uniform (
120
134
subkey ,
121
135
phi_max .shape ,
122
136
minval = 0.0 ,
123
137
maxval = phi_max
124
138
)
125
139
126
- q = jitted_phi2q_static_model_fn (u )
127
- print ("u" , u , "q" , q )
128
- # J_u2chi_autodiff = J_u2chi_autodiff_fn(u)
129
- # J_u2chi = J_u2chi_fn(u)
130
- # print("J_u2chi:\n", J_u2chi, "\nJ_u2chi_autodiff:\n", J_u2chi_autodiff)
131
- # print(J_u2chi.shape)
140
+ print ("i" , i , "phi" , phi )
141
+
142
+ q , aux = jitted_phi2q_static_model_fn (phi )
143
+ print ("phi" , phi , "q" , q )
144
+
145
+ J_u2chi_autodiff , aux = J_phi2chi_autodiff_fn (phi )
146
+ J_u2chi , aux = J_phi2chi_fn (phi )
147
+ print ("J_u2chi:\n " , J_u2chi , "\n J_u2chi_autodiff:\n " , J_u2chi_autodiff )
0 commit comments