Skip to content

Commit 14bd079

Browse files
committed
Migrate to scipy.optimize.minimize
1 parent c5bbb3d commit 14bd079

File tree

2 files changed

+14
-28
lines changed

2 files changed

+14
-28
lines changed

examples/demo_planar_hsa_motor2ee_jacobian.py

Lines changed: 13 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
from functools import partial
88
import numpy as onp
99
from pathlib import Path
10+
import scipy as sp
1011
from typing import Callable, Dict, Tuple
1112

1213
import jsrm
@@ -49,7 +50,7 @@ def residual_fn(q: Array, phi: Array) -> Array:
4950
q_d = jnp.zeros_like(q)
5051
_, _, G, K, _, alpha = dynamical_matrices_fn(q, q_d, phi=phi)
5152
res = alpha - G - K
52-
return res
53+
return jnp.square(res).mean()
5354

5455
# jit the residual function
5556
residual_fn = jit(residual_fn)
@@ -61,9 +62,9 @@ def residual_fn(q: Array, phi: Array) -> Array:
6162
print("Compiling jac_residual_fn...")
6263
print(jac_residual_fn(jnp.zeros((3,)), jnp.zeros((2,))))
6364

64-
def phi2q_static_model_jaxopt_fn(phi: Array, q0: Array = jnp.zeros((3, ))) -> Tuple[Array, Dict[str, Array]]:
65+
def phi2q_static_model_fn(phi: Array, q0: Array = jnp.zeros((3, ))) -> Tuple[Array, Dict[str, Array]]:
6566
"""
66-
A static model mapping the motor angles to the planar HSA configuration.
67+
A static model mapping the motor angles to the planar HSA configuration using scipy.optimize.minimize.
6768
Arguments:
6869
phi: motor angles
6970
q0: initial guess for the configuration
@@ -72,30 +73,25 @@ def phi2q_static_model_jaxopt_fn(phi: Array, q0: Array = jnp.zeros((3, ))) -> Tu
7273
aux: dictionary with auxiliary data
7374
"""
7475
# solve the nonlinear least squares problem
75-
lm = LevenbergMarquardt(
76-
residual_fun=partial(residual_fn, phi=phi),
77-
jac_fun=partial(jac_residual_fn, phi=phi),
78-
tol=nlq_tol,
79-
jit=False,
80-
verbose=True
76+
sol = sp.optimize.minimize(
77+
fun=lambda q: residual_fn(q, phi).item(),
78+
x0=q0,
79+
jac=lambda q: jac_residual_fn(q, phi),
80+
# options={"disp": True},
8181
)
82-
sol = lm.run(q0)
82+
print("Optimization converged after", sol.nit, "iterations with residual", sol.fun)
8383

8484
# configuration that minimizes the residual
85-
q = sol.params
86-
87-
# compute the L2 optimality
88-
optimality_error = lm.l2_optimality_error(sol.params)
85+
q = jnp.array(sol.x)
8986

9087
aux = dict(
9188
phi=phi,
9289
q=q,
93-
optimality_error=optimality_error,
90+
residual=sol.fun,
9491
)
9592

9693
return q, aux
9794

98-
9995
def phi2chi_static_model_fn(phi: Array, q0: Array = jnp.zeros((3, ))) -> Tuple[Array, Dict[str, Array]]:
10096
"""
10197
A static model mapping the motor angles to the planar end-effector pose.
@@ -106,7 +102,7 @@ def phi2chi_static_model_fn(phi: Array, q0: Array = jnp.zeros((3, ))) -> Tuple[A
106102
chi: end-effector pose
107103
aux: dictionary with auxiliary data
108104
"""
109-
q, aux = phi2q_static_model_jaxopt_fn(phi, q0=q0)
105+
q, aux = phi2q_static_model_fn(phi, q0=q0)
110106
chi = forward_kinematics_end_effector_fn(q)
111107
aux["chi"] = chi
112108
return chi, aux
@@ -150,17 +146,6 @@ def jac_phi2chi_static_model_fn(phi: Array) -> Tuple[Array, Dict[str, Array]]:
150146
# define initial configuration
151147
q0 = jnp.array([0.0, 0.0, 0.0])
152148

153-
# phi2q_static_model_fn = jit(phi2q_static_model)
154-
# print("Compiling phi2q_static_model_fn...")
155-
# print(phi2q_static_model_fn(jnp.zeros((2,))))
156-
157-
# print("Compiling J_phi2chi_autodiff_fn...")
158-
# J_phi2chi_autodiff_fn = jit(jacfwd(phi2chi_static_model, has_aux=True))
159-
# print(J_phi2chi_autodiff_fn(jnp.zeros((2,))))
160-
# J_phi2chi_fn = jit(jac_phi2chi_static_model)
161-
# print("Compiling J_phi2chi_fn...")
162-
# print(J_phi2chi_fn(jnp.zeros((2,))))
163-
164149
rng = random.key(seed=0)
165150
for i in range(10):
166151
match i:

pyproject.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -134,6 +134,7 @@ examples = [
134134
"jaxopt",
135135
"matplotlib",
136136
"opencv-python",
137+
"scipy",
137138
]
138139
test = [
139140
"codecov",

0 commit comments

Comments
 (0)