7
7
from functools import partial
8
8
import numpy as onp
9
9
from pathlib import Path
10
+ import scipy as sp
10
11
from typing import Callable , Dict , Tuple
11
12
12
13
import jsrm
@@ -49,7 +50,7 @@ def residual_fn(q: Array, phi: Array) -> Array:
49
50
q_d = jnp .zeros_like (q )
50
51
_ , _ , G , K , _ , alpha = dynamical_matrices_fn (q , q_d , phi = phi )
51
52
res = alpha - G - K
52
- return res
53
+ return jnp . square ( res ). mean ()
53
54
54
55
# jit the residual function
55
56
residual_fn = jit (residual_fn )
@@ -61,9 +62,9 @@ def residual_fn(q: Array, phi: Array) -> Array:
61
62
print ("Compiling jac_residual_fn..." )
62
63
print (jac_residual_fn (jnp .zeros ((3 ,)), jnp .zeros ((2 ,))))
63
64
64
- def phi2q_static_model_jaxopt_fn (phi : Array , q0 : Array = jnp .zeros ((3 , ))) -> Tuple [Array , Dict [str , Array ]]:
65
+ def phi2q_static_model_fn (phi : Array , q0 : Array = jnp .zeros ((3 , ))) -> Tuple [Array , Dict [str , Array ]]:
65
66
"""
66
- A static model mapping the motor angles to the planar HSA configuration.
67
+ A static model mapping the motor angles to the planar HSA configuration using scipy.optimize.minimize .
67
68
Arguments:
68
69
phi: motor angles
69
70
q0: initial guess for the configuration
@@ -72,30 +73,25 @@ def phi2q_static_model_jaxopt_fn(phi: Array, q0: Array = jnp.zeros((3, ))) -> Tu
72
73
aux: dictionary with auxiliary data
73
74
"""
74
75
# solve the nonlinear least squares problem
75
- lm = LevenbergMarquardt (
76
- residual_fun = partial (residual_fn , phi = phi ),
77
- jac_fun = partial (jac_residual_fn , phi = phi ),
78
- tol = nlq_tol ,
79
- jit = False ,
80
- verbose = True
76
+ sol = sp .optimize .minimize (
77
+ fun = lambda q : residual_fn (q , phi ).item (),
78
+ x0 = q0 ,
79
+ jac = lambda q : jac_residual_fn (q , phi ),
80
+ # options={"disp": True},
81
81
)
82
- sol = lm . run ( q0 )
82
+ print ( "Optimization converged after" , sol . nit , "iterations with residual" , sol . fun )
83
83
84
84
# configuration that minimizes the residual
85
- q = sol .params
86
-
87
- # compute the L2 optimality
88
- optimality_error = lm .l2_optimality_error (sol .params )
85
+ q = jnp .array (sol .x )
89
86
90
87
aux = dict (
91
88
phi = phi ,
92
89
q = q ,
93
- optimality_error = optimality_error ,
90
+ residual = sol . fun ,
94
91
)
95
92
96
93
return q , aux
97
94
98
-
99
95
def phi2chi_static_model_fn (phi : Array , q0 : Array = jnp .zeros ((3 , ))) -> Tuple [Array , Dict [str , Array ]]:
100
96
"""
101
97
A static model mapping the motor angles to the planar end-effector pose.
@@ -106,7 +102,7 @@ def phi2chi_static_model_fn(phi: Array, q0: Array = jnp.zeros((3, ))) -> Tuple[A
106
102
chi: end-effector pose
107
103
aux: dictionary with auxiliary data
108
104
"""
109
- q , aux = phi2q_static_model_jaxopt_fn (phi , q0 = q0 )
105
+ q , aux = phi2q_static_model_fn (phi , q0 = q0 )
110
106
chi = forward_kinematics_end_effector_fn (q )
111
107
aux ["chi" ] = chi
112
108
return chi , aux
@@ -150,17 +146,6 @@ def jac_phi2chi_static_model_fn(phi: Array) -> Tuple[Array, Dict[str, Array]]:
150
146
# define initial configuration
151
147
q0 = jnp .array ([0.0 , 0.0 , 0.0 ])
152
148
153
- # phi2q_static_model_fn = jit(phi2q_static_model)
154
- # print("Compiling phi2q_static_model_fn...")
155
- # print(phi2q_static_model_fn(jnp.zeros((2,))))
156
-
157
- # print("Compiling J_phi2chi_autodiff_fn...")
158
- # J_phi2chi_autodiff_fn = jit(jacfwd(phi2chi_static_model, has_aux=True))
159
- # print(J_phi2chi_autodiff_fn(jnp.zeros((2,))))
160
- # J_phi2chi_fn = jit(jac_phi2chi_static_model)
161
- # print("Compiling J_phi2chi_fn...")
162
- # print(J_phi2chi_fn(jnp.zeros((2,))))
163
-
164
149
rng = random .key (seed = 0 )
165
150
for i in range (10 ):
166
151
match i :
0 commit comments