Skip to content

Commit 006c6b9

Browse files
committed
Start working on pneumatic planar pcs system
1 parent ab76aea commit 006c6b9

File tree

2 files changed

+229
-0
lines changed

2 files changed

+229
-0
lines changed
Lines changed: 163 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,163 @@
1+
from functools import partial
2+
import jax
3+
4+
jax.config.update("jax_enable_x64", True) # double precision
5+
from diffrax import diffeqsolve, Euler, ODETerm, SaveAt, Tsit5
6+
from jax import Array, vmap
7+
from jax import numpy as jnp
8+
import matplotlib.pyplot as plt
9+
import numpy as onp
10+
from pathlib import Path
11+
from typing import Callable, Dict
12+
13+
import jsrm
14+
from jsrm import ode_factory
15+
from jsrm.systems import pneumatic_planar_pcs
16+
17+
num_segments = 1
18+
19+
# filepath to symbolic expressions
20+
sym_exp_filepath = (
21+
Path(jsrm.__file__).parent
22+
/ "symbolic_expressions"
23+
/ f"planar_pcs_ns-{num_segments}.dill"
24+
)
25+
26+
# set parameters
27+
rho = 1070 * jnp.ones((num_segments,)) # Volumetric density of Dragon Skin 20 [kg/m^3]
28+
params = {
29+
"th0": jnp.array(0.0), # initial orientation angle [rad]
30+
"l": 1e-1 * jnp.ones((num_segments,)),
31+
"r": 2e-2 * jnp.ones((num_segments,)),
32+
"rho": rho,
33+
"g": jnp.array([0.0, 9.81]),
34+
"E": 2e3 * jnp.ones((num_segments,)), # Elastic modulus [Pa]
35+
"G": 1e3 * jnp.ones((num_segments,)), # Shear modulus [Pa]
36+
"r_cham_in": 5e-3 * jnp.ones((num_segments,)),
37+
"r_cham_out": 2e-2 - 2e-3 * jnp.ones((num_segments,)),
38+
"varphi_cham": jnp.pi/2 * jnp.ones((num_segments,)),
39+
}
40+
params["D"] = 1e-3 * jnp.diag(
41+
(jnp.repeat(
42+
jnp.array([[1e0, 1e3, 1e3]]), num_segments, axis=0
43+
) * params["l"][:, None]).flatten()
44+
)
45+
46+
# activate all strains (i.e. bending, shear, and axial)
47+
# strain_selector = jnp.ones((3 * num_segments,), dtype=bool)
48+
strain_selector = jnp.array([True, False, True])[None, :].repeat(num_segments, axis=0).flatten()
49+
50+
51+
def simulate_robot():
52+
# define initial configuration
53+
q0 = jnp.repeat(jnp.array([5.0 * jnp.pi, 0.2])[None, :], num_segments, axis=0).flatten()
54+
# number of generalized coordinates
55+
n_q = q0.shape[0]
56+
57+
# set simulation parameters
58+
dt = 1e-3 # time step
59+
sim_dt = 5e-5 # simulation time step
60+
ts = jnp.arange(0.0, 2, dt) # time steps
61+
62+
strain_basis, forward_kinematics_fn, dynamical_matrices_fn, auxiliary_fns = (
63+
pneumatic_planar_pcs.factory(sym_exp_filepath, strain_selector)
64+
)
65+
# jit the functions
66+
dynamical_matrices_fn = jax.jit(partial(dynamical_matrices_fn))
67+
68+
x0 = jnp.concatenate([q0, jnp.zeros_like(q0)]) # initial condition
69+
tau = jnp.zeros_like(q0) # torques
70+
71+
ode_fn = ode_factory(dynamical_matrices_fn, params, tau)
72+
term = ODETerm(ode_fn)
73+
74+
sol = diffeqsolve(
75+
term,
76+
solver=Tsit5(),
77+
t0=ts[0],
78+
t1=ts[-1],
79+
dt0=sim_dt,
80+
y0=x0,
81+
max_steps=None,
82+
saveat=SaveAt(ts=ts),
83+
)
84+
85+
print("sol.ys =\n", sol.ys)
86+
# the evolution of the generalized coordinates
87+
q_ts = sol.ys[:, :n_q]
88+
# the evolution of the generalized velocities
89+
q_d_ts = sol.ys[:, n_q:]
90+
91+
# evaluate the forward kinematics along the trajectory
92+
chi_ee_ts = vmap(forward_kinematics_fn, in_axes=(None, 0, None))(
93+
params, q_ts, jnp.array([jnp.sum(params["l"])])
94+
)
95+
# plot the configuration vs time
96+
plt.figure()
97+
for segment_idx in range(num_segments):
98+
plt.plot(
99+
ts, q_ts[:, 3 * segment_idx + 0],
100+
label=r"$\kappa_\mathrm{be," + str(segment_idx + 1) + "}$ [rad/m]"
101+
)
102+
plt.plot(
103+
ts, q_ts[:, 3 * segment_idx + 1],
104+
label=r"$\sigma_\mathrm{ax," + str(segment_idx + 1) + "}$ [-]"
105+
)
106+
plt.xlabel("Time [s]")
107+
plt.ylabel("Configuration")
108+
plt.legend()
109+
plt.grid(True)
110+
plt.tight_layout()
111+
plt.show()
112+
# plot end-effector position vs time
113+
plt.figure()
114+
plt.plot(ts, chi_ee_ts[:, 0], label="x")
115+
plt.plot(ts, chi_ee_ts[:, 1], label="y")
116+
plt.xlabel("Time [s]")
117+
plt.ylabel("End-effector Position [m]")
118+
plt.legend()
119+
plt.grid(True)
120+
plt.box(True)
121+
plt.tight_layout()
122+
plt.show()
123+
# plot the end-effector position in the x-y plane as a scatter plot with the time as the color
124+
plt.figure()
125+
plt.scatter(chi_ee_ts[:, 0], chi_ee_ts[:, 1], c=ts, cmap="viridis")
126+
plt.axis("equal")
127+
plt.grid(True)
128+
plt.xlabel("End-effector x [m]")
129+
plt.ylabel("End-effector y [m]")
130+
plt.colorbar(label="Time [s]")
131+
plt.tight_layout()
132+
plt.show()
133+
# plt.figure()
134+
# plt.plot(chi_ee_ts[:, 0], chi_ee_ts[:, 1])
135+
# plt.axis("equal")
136+
# plt.grid(True)
137+
# plt.xlabel("End-effector x [m]")
138+
# plt.ylabel("End-effector y [m]")
139+
# plt.tight_layout()
140+
# plt.show()
141+
142+
# plot the energy along the trajectory
143+
kinetic_energy_fn_vmapped = vmap(
144+
partial(auxiliary_fns["kinetic_energy_fn"], params)
145+
)
146+
potential_energy_fn_vmapped = vmap(
147+
partial(auxiliary_fns["potential_energy_fn"], params)
148+
)
149+
U_ts = potential_energy_fn_vmapped(q_ts)
150+
T_ts = kinetic_energy_fn_vmapped(q_ts, q_d_ts)
151+
plt.figure()
152+
plt.plot(ts, U_ts, label="Potential energy")
153+
plt.plot(ts, T_ts, label="Kinetic energy")
154+
plt.xlabel("Time [s]")
155+
plt.ylabel("Energy [J]")
156+
plt.legend()
157+
plt.grid(True)
158+
plt.box(True)
159+
plt.tight_layout()
160+
plt.show()
161+
162+
if __name__ == "__main__":
163+
simulate_robot()
Lines changed: 66 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,66 @@
1+
__all__ = ["factory", "stiffness_fn"]
2+
from jax import Array, vmap
3+
import jax.numpy as jnp
4+
from jsrm.math_utils import blk_diag
5+
import numpy as onp
6+
from typing import Dict, Tuple, Union
7+
8+
from .planar_pcs import factory as planar_pcs_factory
9+
10+
def factory(
11+
*args, **kwargs
12+
):
13+
return planar_pcs_factory(
14+
*args, stiffness_fn=stiffness_fn, **kwargs
15+
)
16+
17+
def _compute_stiffness_matrix_for_segment(
18+
l: Array, r: Array, r_cham_in: Array, r_cham_out: Array, varphi_cham: Array, E: Array
19+
):
20+
# cross-sectional area [m²] of the material
21+
A_mat = jnp.pi * r ** 2 + 2 * r_cham_in ** 2 * varphi_cham - 2 * r_cham_out ** 2 * varphi_cham
22+
# second moment of area [m⁴] of the material
23+
Ib_mat = jnp.pi * r ** 4 / 4 + r_cham_in ** 4 * varphi_cham / 2 - r_cham_out ** 4 * varphi_cham / 2
24+
# poisson ratio of the material
25+
nu = 0.0
26+
# shear modulus
27+
G = E / (2 * (1 + nu))
28+
29+
# bending stiffness [Nm²]
30+
Sbe = Ib_mat * E * l
31+
# shear stiffness [N]
32+
Ssh = 4 / 3 * A_mat * G * l
33+
# axial stiffness [N]
34+
Sax = A_mat * E * l
35+
36+
S = jnp.diag(jnp.stack([Sbe, Ssh, Sax], axis=0))
37+
38+
return S
39+
40+
def stiffness_fn(
41+
params: Dict[str, Array],
42+
B_xi: Array,
43+
formulate_in_strain_space: bool = False,
44+
) -> Array:
45+
"""
46+
Compute the stiffness matrix of the system.
47+
Args:
48+
params: Dictionary of robot parameters
49+
B_xi: Strain basis matrix
50+
formulate_in_strain_space: whether to formulate the elastic matrix in the strain space
51+
Returns:
52+
K: elastic matrix of shape (n_q, n_q) if formulate_in_strain_space is False or (n_xi, n_xi) otherwise
53+
"""
54+
# stiffness matrix of shape (num_segments, 3, 3)
55+
S = vmap(
56+
_compute_stiffness_matrix_for_segment
57+
)(
58+
params["l"], params["r"], params["r_cham_in"], params["r_cham_out"], params["varphi_cham"], params["E"]
59+
)
60+
# we define the elastic matrix of shape (n_xi, n_xi) as K(xi) = K @ xi where K is equal to
61+
K = blk_diag(S)
62+
63+
if not formulate_in_strain_space:
64+
K = B_xi.T @ K @ B_xi
65+
66+
return K

0 commit comments

Comments
 (0)