1
+ import jax
2
+
3
+ jax .config .update ("jax_enable_x64" , True ) # double precision
4
+ from jax import Array , jacfwd , jacrev , jit , random , vmap
5
+ from jax import numpy as jnp
6
+ from jaxopt import GaussNewton , LevenbergMarquardt
7
+ from functools import partial
8
+ import numpy as onp
9
+ from pathlib import Path
10
+ import scipy as sp
11
+ from typing import Callable , Dict , Tuple
12
+
13
+ import jsrm
14
+ from jsrm .parameters .hsa_params import PARAMS_FPU_CONTROL
15
+ from jsrm .systems import planar_hsa
16
+ from jsrm .utils .numerical_jacobian import approx_derivative
17
+
18
+ num_segments = 1
19
+ num_rods_per_segment = 2
20
+
21
+ # filepath to symbolic expressions
22
+ sym_exp_filepath = (
23
+ Path (jsrm .__file__ ).parent
24
+ / "symbolic_expressions"
25
+ / f"planar_hsa_ns-{ num_segments } _nrs-{ num_rods_per_segment } .dill"
26
+ )
27
+
28
+
29
+ def factory_fn (params : Dict [str , Array ], verbose : bool = False ) -> Tuple [Callable , Callable ]:
30
+ """
31
+ Factory function for the planar HSA.
32
+ Args:
33
+ params: dictionary with robot parameters
34
+ verbose: flag to print additional information
35
+ Returns:
36
+ phi2chi_static_model_fn: function that maps motor angles to the end-effector pose
37
+ jac_phi2chi_static_model_fn: function that computes the Jacobian between the actuation space and the task-space
38
+ """
39
+ (
40
+ forward_kinematics_virtual_backbone_fn ,
41
+ forward_kinematics_end_effector_fn ,
42
+ jacobian_end_effector_fn ,
43
+ inverse_kinematics_end_effector_fn ,
44
+ dynamical_matrices_fn ,
45
+ sys_helpers ,
46
+ ) = planar_hsa .factory (sym_exp_filepath , strain_selector )
47
+ dynamical_matrices_fn = partial (dynamical_matrices_fn , params )
48
+ forward_kinematics_end_effector_fn = jit (partial (forward_kinematics_end_effector_fn , params ))
49
+ jacobian_end_effector_fn = jit (partial (jacobian_end_effector_fn , params ))
50
+
51
+ def residual_fn (q : Array , phi : Array ) -> Array :
52
+ q_d = jnp .zeros_like (q )
53
+ _ , _ , G , K , _ , alpha = dynamical_matrices_fn (q , q_d , phi = phi )
54
+ res = alpha - G - K
55
+ return jnp .square (res ).mean ()
56
+
57
+ # jit the residual function
58
+ residual_fn = jit (residual_fn )
59
+ print ("Compiling residual_fn..." )
60
+ print (residual_fn (jnp .zeros ((3 ,)), jnp .zeros ((2 ,))))
61
+
62
+ # define the Jacobian of the residual function
63
+ jac_residual_fn = jit (jacrev (residual_fn , argnums = 0 ))
64
+ print ("Compiling jac_residual_fn..." )
65
+ print (jac_residual_fn (jnp .zeros ((3 ,)), jnp .zeros ((2 ,))))
66
+
67
+ def phi2q_static_model_fn (phi : Array , q0 : Array = jnp .zeros ((3 , ))) -> Tuple [Array , Dict [str , Array ]]:
68
+ """
69
+ A static model mapping the motor angles to the planar HSA configuration using scipy.optimize.minimize.
70
+ Arguments:
71
+ phi: motor angles
72
+ q0: initial guess for the configuration
73
+ Returns:
74
+ q: planar HSA configuration consisting of (k_be, sigma_sh, sigma_ax)
75
+ aux: dictionary with auxiliary data
76
+ """
77
+ # solve the nonlinear least squares problem
78
+ sol = sp .optimize .minimize (
79
+ fun = lambda q : residual_fn (q , phi ).item (),
80
+ x0 = q0 ,
81
+ jac = lambda q : jac_residual_fn (q , phi ),
82
+ options = {"disp" : True } if verbose else None ,
83
+ )
84
+ if verbose :
85
+ print ("Optimization converged after" , sol .nit , "iterations with residual" , sol .fun )
86
+
87
+ # configuration that minimizes the residual
88
+ q = jnp .array (sol .x )
89
+
90
+ aux = dict (
91
+ phi = phi ,
92
+ q = q ,
93
+ residual = sol .fun ,
94
+ )
95
+
96
+ return q , aux
97
+
98
+ def phi2chi_static_model_fn (phi : Array , q0 : Array = jnp .zeros ((3 , ))) -> Tuple [Array , Dict [str , Array ]]:
99
+ """
100
+ A static model mapping the motor angles to the planar end-effector pose.
101
+ Arguments:
102
+ phi: motor angles
103
+ q0: initial guess for the configuration
104
+ Returns:
105
+ chi: end-effector pose
106
+ aux: dictionary with auxiliary data
107
+ """
108
+ q , aux = phi2q_static_model_fn (phi , q0 = q0 )
109
+ chi = forward_kinematics_end_effector_fn (q )
110
+ aux ["chi" ] = chi
111
+ return chi , aux
112
+
113
+ def jac_phi2chi_static_model_fn (phi : Array , q0 : Array = jnp .zeros ((3 , ))) -> Tuple [Array , Dict [str , Array ]]:
114
+ """
115
+ Compute the Jacobian between the actuation space and the task-space.
116
+ Arguments:
117
+ phi: motor angles
118
+ """
119
+ # evaluate the static model to convert motor angles into a configuration
120
+ q , aux = phi2q_static_model_fn (phi , q0 = q0 )
121
+ # approximate the Jacobian between the actuation and the task-space using finite differences
122
+ J_phi2q = approx_derivative (
123
+ fun = lambda _phi : phi2q_static_model_fn (_phi , q0 = q0 )[0 ],
124
+ x0 = phi ,
125
+ f0 = q ,
126
+ )
127
+
128
+ # evaluate the closed-form, analytical jacobian of the forward kinematics
129
+ J_q2chi = jacobian_end_effector_fn (q )
130
+
131
+ # evaluate the Jacobian between the actuation and the task-space
132
+ J_phi2chi = J_q2chi @ J_phi2q
133
+
134
+ return J_phi2chi , aux
135
+
136
+ return phi2chi_static_model_fn , jac_phi2chi_static_model_fn
137
+
138
+
139
+ if __name__ == "__main__" :
140
+ # activate all strains (i.e. bending, shear, and axial)
141
+ strain_selector = jnp .ones ((3 * num_segments ,), dtype = bool )
142
+ params = PARAMS_FPU_CONTROL
143
+ phi_max = params ["phi_max" ].flatten ()
144
+
145
+ # call the factory function
146
+ phi2chi_static_model_fn , jac_phi2chi_static_model_fn = factory_fn (params )
147
+
148
+ # define initial configuration
149
+ q0 = jnp .array ([0.0 , 0.0 , 0.0 ])
150
+
151
+ rng = random .key (seed = 0 )
152
+ for i in range (10 ):
153
+ match i :
154
+ case 0 :
155
+ phi = jnp .array ([0.0 , 0.0 ])
156
+ case 1 :
157
+ phi = jnp .array ([1.0 , 1.0 ])
158
+ case _:
159
+ rng , subkey = random .split (rng )
160
+ phi = random .uniform (
161
+ subkey ,
162
+ phi_max .shape ,
163
+ minval = 0.0 ,
164
+ maxval = phi_max
165
+ )
166
+
167
+ print ("i" , i )
168
+
169
+ chi , aux = phi2chi_static_model_fn (phi , q0 = q0 )
170
+ print ("phi" , phi , "q" , aux ["q" ], "chi" , chi )
171
+
172
+ J_phi2chi , aux = jac_phi2chi_static_model_fn (phi , q0 = q0 )
173
+ print ("J_phi2chi:\n " , J_phi2chi )
0 commit comments