1
+ import jax
2
+
3
+ jax .config .update ("jax_enable_x64" , True ) # double precision
4
+ from jax import Array , jacrev , jit , random , vmap
5
+ from jax import numpy as jnp
6
+ from jaxopt import GaussNewton , LevenbergMarquardt
7
+ from functools import partial
8
+ import numpy as onp
9
+ from pathlib import Path
10
+ from typing import Callable , Dict
11
+
12
+ import jsrm
13
+ from jsrm .parameters .hsa_params import PARAMS_FPU_CONTROL
14
+ from jsrm .systems import planar_hsa
15
+
16
+ num_segments = 1
17
+ num_rods_per_segment = 2
18
+
19
+ # filepath to symbolic expressions
20
+ sym_exp_filepath = (
21
+ Path (jsrm .__file__ ).parent
22
+ / "symbolic_expressions"
23
+ / f"planar_hsa_ns-{ num_segments } _nrs-{ num_rods_per_segment } .dill"
24
+ )
25
+
26
+ # activate all strains (i.e. bending, shear, and axial)
27
+ strain_selector = jnp .ones ((3 * num_segments ,), dtype = bool )
28
+ params = PARAMS_FPU_CONTROL
29
+ phi_max = params ["phi_max" ].flatten ()
30
+
31
+ # define initial configuration
32
+ q0 = jnp .array ([jnp .pi , 0.0 , 0.0 ])
33
+
34
+ # increase damping for simulation stability
35
+ params ["zetab" ] = 5 * params ["zetab" ]
36
+ params ["zetash" ] = 5 * params ["zetash" ]
37
+ params ["zetaa" ] = 5 * params ["zetaa" ]
38
+
39
+
40
+ (
41
+ forward_kinematics_virtual_backbone_fn ,
42
+ forward_kinematics_end_effector_fn ,
43
+ jacobian_end_effector_fn ,
44
+ inverse_kinematics_end_effector_fn ,
45
+ dynamical_matrices_fn ,
46
+ sys_helpers ,
47
+ ) = planar_hsa .factory (sym_exp_filepath , strain_selector )
48
+ jitted_dynamical_matrics_fn = jit (partial (dynamical_matrices_fn , params ))
49
+
50
+
51
+ def phi2q_static_model (phi : Array , q0 : Array = jnp .zeros ((3 , ))) -> Array :
52
+ """
53
+ A static model mapping the motor angles to the planar HSA configuration.
54
+ Arguments:
55
+ u: motor angles
56
+ Returns:
57
+ q: planar HSA configuration consisting of (k_be, sigma_sh, sigma_ax)
58
+ """
59
+ q_d = jnp .zeros ((3 ,))
60
+
61
+ def residual_fn (_q : Array ) -> Array :
62
+ _ , _ , _G , _K , _ , _alpha = jitted_dynamical_matrics_fn (_q , q_d , phi = phi )
63
+ res = _alpha - _G - _K
64
+ return res
65
+
66
+ # solve the nonlinear least squares problem
67
+ lm = LevenbergMarquardt (residual_fun = residual_fn , jit = True , verbose = True )
68
+ sol = lm .run (q0 )
69
+
70
+ # configuration that minimizes the residual
71
+ q = sol .params
72
+
73
+ # compute the L2 optimality
74
+ optimality_error = lm .l2_optimality_error (sol .params )
75
+
76
+ return q
77
+
78
+ def u2chi_static_model (u : Array ) -> Array :
79
+ """
80
+ A static model mapping the motor angles to the planar end-effector pose.
81
+ Arguments:
82
+ u: motor angles
83
+ Returns:
84
+ chi: end-effector pose
85
+ """
86
+ q = phi2q_static_model (u )
87
+ chi = forward_kinematics_end_effector_fn (params , q )
88
+ return chi
89
+
90
+ def jac_u2chi_static_model (u : Array ) -> Array :
91
+ # evaluate the static model to convert motor angles into a configuration
92
+ q = phi2q_static_model (u )
93
+ # take the Jacobian between actuation and configuration space
94
+ J_u2q = jacrev (phi2q_static_model )(u )
95
+
96
+ # evaluate the closed-form, analytical jacobian of the forward kinematics
97
+ J_q2chi = jacobian_end_effector_fn (params , q )
98
+
99
+ # evaluate the Jacobian between the actuation and the task-space
100
+ J_u2chi = J_q2chi @ J_u2q
101
+
102
+ return J_u2chi
103
+
104
+
105
+ if __name__ == "__main__" :
106
+ jitted_phi2q_static_model_fn = jit (phi2q_static_model )
107
+ J_u2chi_autodiff_fn = jacrev (u2chi_static_model )
108
+ J_u2chi_fn = jac_u2chi_static_model
109
+
110
+ rng = random .key (seed = 0 )
111
+ for i in range (10 ):
112
+ match i :
113
+ case 0 :
114
+ u = jnp .array ([0.0 , 0.0 ])
115
+ case 1 :
116
+ u = jnp .array ([1.0 , 1.0 ])
117
+ case _:
118
+ rng , subkey = random .split (rng )
119
+ u = random .uniform (
120
+ subkey ,
121
+ phi_max .shape ,
122
+ minval = 0.0 ,
123
+ maxval = phi_max
124
+ )
125
+
126
+ q = jitted_phi2q_static_model_fn (u )
127
+ print ("u" , u , "q" , q )
128
+ # J_u2chi_autodiff = J_u2chi_autodiff_fn(u)
129
+ # J_u2chi = J_u2chi_fn(u)
130
+ # print("J_u2chi:\n", J_u2chi, "\nJ_u2chi_autodiff:\n", J_u2chi_autodiff)
131
+ # print(J_u2chi.shape)
0 commit comments