Skip to content

Commit 4ec9302

Browse files
mstoelzleCopilot
andauthored
Add a planar pcs system with pneumatic actuation (#5)
* Add `B_xi` to `stiffness_fn` interface * Start working on pneumatic planar pcs system * Start working on `actuation_mapping_fn` * Already bump version such as that we release a new version when merging into main * Fix bug in `energy_fn` * Allow setting custom `actuation_mapping_fn` for planar pcs systems * Properly implement `actuation_basis` * Continue implementing `actuation_mapping_fn` * Continue working on `compute_actuation_matrix_for_segment` * Fully implement pneumatic actuation model * Start working on `sweep_local_tip_force_to_bending_torque_mapping` * Update `sweep_actuation_mapping` * Fix small bug * Update src/jsrm/systems/pneumatic_planar_pcs.py Co-authored-by: Copilot <[email protected]> --------- Co-authored-by: Copilot <[email protected]>
1 parent e192e5d commit 4ec9302

File tree

6 files changed

+503
-9
lines changed

6 files changed

+503
-9
lines changed

examples/simulate_planar_pcs.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -162,7 +162,7 @@ def draw_robot(
162162
)
163163
plt.plot(
164164
video_ts, q_ts[:, 3 * segment_idx + 2],
165-
label=r"$\sigma_\mathrm{el," + str(segment_idx + 1) + "}$ [-]"
165+
label=r"$\sigma_\mathrm{ax," + str(segment_idx + 1) + "}$ [-]"
166166
)
167167
plt.xlabel("Time [s]")
168168
plt.ylabel("Configuration")
Lines changed: 273 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,273 @@
1+
from functools import partial
2+
import jax
3+
4+
jax.config.update("jax_enable_x64", True) # double precision
5+
from diffrax import diffeqsolve, Euler, ODETerm, SaveAt, Tsit5
6+
from jax import Array, vmap
7+
from jax import numpy as jnp
8+
import matplotlib.pyplot as plt
9+
import numpy as onp
10+
from pathlib import Path
11+
from typing import Callable, Dict
12+
13+
import jsrm
14+
from jsrm import ode_factory
15+
from jsrm.systems import pneumatic_planar_pcs
16+
17+
num_segments = 1
18+
19+
# filepath to symbolic expressions
20+
sym_exp_filepath = (
21+
Path(jsrm.__file__).parent
22+
/ "symbolic_expressions"
23+
/ f"planar_pcs_ns-{num_segments}.dill"
24+
)
25+
26+
# set parameters
27+
rho = 1070 * jnp.ones((num_segments,)) # Volumetric density of Dragon Skin 20 [kg/m^3]
28+
params = {
29+
"th0": jnp.array(0.0), # initial orientation angle [rad]
30+
"l": 1e-1 * jnp.ones((num_segments,)),
31+
"r": 2e-2 * jnp.ones((num_segments,)),
32+
"rho": rho,
33+
"g": jnp.array([0.0, 9.81]),
34+
"E": 2e3 * jnp.ones((num_segments,)), # Elastic modulus [Pa]
35+
"G": 1e3 * jnp.ones((num_segments,)), # Shear modulus [Pa]
36+
"r_cham_in": 5e-3 * jnp.ones((num_segments,)),
37+
"r_cham_out": 2e-2 - 2e-3 * jnp.ones((num_segments,)),
38+
"varphi_cham": jnp.pi/2 * jnp.ones((num_segments,)),
39+
}
40+
params["D"] = 5e-4 * jnp.diag(
41+
(jnp.repeat(
42+
jnp.array([[1e0, 1e3, 1e3]]), num_segments, axis=0
43+
) * params["l"][:, None]).flatten()
44+
)
45+
46+
# activate all strains (i.e. bending, shear, and axial)
47+
# strain_selector = jnp.ones((3 * num_segments,), dtype=bool)
48+
strain_selector = jnp.array([True, False, True])[None, :].repeat(num_segments, axis=0).flatten()
49+
50+
B_xi, forward_kinematics_fn, dynamical_matrices_fn, auxiliary_fns = (
51+
pneumatic_planar_pcs.factory(num_segments, sym_exp_filepath, strain_selector)
52+
)
53+
# jit the functions
54+
dynamical_matrices_fn = jax.jit(dynamical_matrices_fn)
55+
actuation_mapping_fn = partial(
56+
auxiliary_fns["actuation_mapping_fn"],
57+
forward_kinematics_fn,
58+
auxiliary_fns["jacobian_fn"],
59+
)
60+
61+
def sweep_local_tip_force_to_bending_torque_mapping():
62+
def compute_bending_torque(q: Array) -> Array:
63+
# backbone coordinate of the end-effector
64+
s_ee = jnp.sum(params["l"])
65+
# compute the pose of the end-effector
66+
chi_ee = forward_kinematics_fn(params, q, s_ee)
67+
# orientation of the end-effector
68+
th_ee = chi_ee[2]
69+
# compute the jacobian of the end-effector
70+
J_ee = auxiliary_fns["jacobian_fn"](params, q, s_ee)
71+
# local tip force
72+
f_ee_local = jnp.array([0.0, 1.0])
73+
# tip force in inertial frame
74+
f_ee = jnp.array([[jnp.cos(th_ee), -jnp.sin(th_ee)], [jnp.sin(th_ee), jnp.cos(th_ee)]]) @ f_ee_local
75+
# compute the generalized torque
76+
tau_be = J_ee[:2, 0].T @ f_ee
77+
return tau_be
78+
79+
kappa_be_pts = jnp.arange(-2*jnp.pi, 2*jnp.pi, 0.01)
80+
sigma_ax_pts = jnp.zeros_like(kappa_be_pts)
81+
q_pts = jnp.stack([kappa_be_pts, sigma_ax_pts], axis=-1)
82+
83+
tau_be_pts = vmap(compute_bending_torque)(q_pts)
84+
85+
# plot the mapping on the bending strain
86+
fig, ax = plt.subplots(num="planar_pcs_local_tip_force_to_bending_torque_mapping")
87+
plt.title(r"Mapping from $f_\mathrm{ee}$ to $\tau_\mathrm{be}$")
88+
ax.plot(kappa_be_pts, tau_be_pts, linewidth=2.5)
89+
ax.set_xlabel(r"$\kappa_\mathrm{be}$ [rad/m]")
90+
ax.set_ylabel(r"$\tau_\mathrm{be}$ [N m]")
91+
plt.grid(True)
92+
plt.tight_layout()
93+
plt.show()
94+
95+
96+
def sweep_actuation_mapping():
97+
# evaluate the actuation matrix for a straight backbone
98+
q = jnp.zeros((2 * num_segments,))
99+
A = actuation_mapping_fn(params, B_xi, q)
100+
print("Evaluating actuation matrix for straight backbone: A =\n", A)
101+
102+
kappa_be_pts = jnp.linspace(-3*jnp.pi, 3*jnp.pi, 500)
103+
sigma_ax_pts = jnp.zeros_like(kappa_be_pts)
104+
q_pts = jnp.stack([kappa_be_pts, sigma_ax_pts], axis=-1)
105+
A_pts = vmap(actuation_mapping_fn, in_axes=(None, None, 0))(params, B_xi, q_pts)
106+
# mark the points that are not controllable as the u1 and u2 terms share the same sign
107+
non_controllable_selector = A_pts[..., 0, 0] * A_pts[..., 0, 1] >= 0.0
108+
non_controllable_indices = jnp.where(non_controllable_selector)[0]
109+
non_controllable_boundary_indices = jnp.where(non_controllable_selector[:-1] != non_controllable_selector[1:])[0]
110+
# plot the mapping on the bending strain for various bending strains
111+
fig, ax = plt.subplots(num="pneumatic_planar_pcs_actuation_mapping_bending_torque_vs_bending_strain")
112+
plt.title(r"Actuation mapping from $u$ to $\tau_\mathrm{be}$")
113+
# # shade the region where the actuation mapping is negative as we are not able to bend the robot further
114+
# ax.axhspan(A_pts[:, 0, 0:2].min(), 0.0, facecolor='red', alpha=0.2)
115+
for idx in non_controllable_indices:
116+
ax.axvspan(kappa_be_pts[idx], kappa_be_pts[idx+1], facecolor='red', alpha=0.2)
117+
ax.plot(kappa_be_pts, A_pts[:, 0, 0], linewidth=2, label=r"$\frac{\partial \tau_\mathrm{be}}{\partial u_1}$")
118+
ax.plot(kappa_be_pts, A_pts[:, 0, 1], linewidth=2, label=r"$\frac{\partial \tau_\mathrm{ax}}{\partial u_2}$")
119+
ax.set_xlabel(r"$\kappa_\mathrm{be}$ [rad/m]")
120+
ax.set_ylabel(r"$\frac{\partial \tau_\mathrm{be}}{\partial u_1}$")
121+
plt.legend()
122+
plt.grid(True)
123+
plt.tight_layout()
124+
plt.show()
125+
126+
# create grid for bending and axial strains
127+
kappa_be_grid, sigma_ax_grid = jnp.meshgrid(
128+
jnp.linspace(-jnp.pi, jnp.pi, 20),
129+
jnp.linspace(-0.2, 0.2, 20),
130+
)
131+
q_pts = jnp.stack([kappa_be_grid.flatten(), sigma_ax_grid.flatten()], axis=-1)
132+
133+
# evaluate the actuation mapping on the grid
134+
A_pts = vmap(actuation_mapping_fn, in_axes=(None, None, 0))(params, B_xi, q_pts)
135+
# reshape A_pts to match the grid shape
136+
A_grid = A_pts.reshape(kappa_be_grid.shape[:2] + A_pts.shape[-2:])
137+
138+
# plot the mapping on the bending strain
139+
fig, ax = plt.subplots(num="pneumatic_planar_pcs_actuation_mapping_bending_torque_vs_axial_vs_bending_strain")
140+
plt.title(r"Actuation mapping from $u_1$ to $\tau_\mathrm{be}$")
141+
# contourf plot
142+
c = ax.contourf(kappa_be_grid, sigma_ax_grid, A_grid[..., 0, 0], levels=100)
143+
fig.colorbar(c, ax=ax, label=r"$\frac{\partial \tau_\mathrm{be}}{\partial u_1}$")
144+
# contour plot
145+
ax.contour(kappa_be_grid, sigma_ax_grid, A_grid[..., 0, 0], levels=20, colors="k", linewidths=0.5)
146+
ax.set_xlabel(r"$\kappa_\mathrm{be}$ [rad/m]")
147+
ax.set_ylabel(r"$\sigma_\mathrm{ax}$ [-]")
148+
plt.tight_layout()
149+
plt.show()
150+
151+
# plot the mapping on the axial strain
152+
fig, ax = plt.subplots(num="pneumatic_planar_pcs_actuation_mapping_axial_torque_vs_axial_vs_bending_strain")
153+
plt.title(r"Actuation mapping from $u_1$ to $\tau_\mathrm{ax}$")
154+
# contourf plot
155+
c = ax.contourf(kappa_be_grid, sigma_ax_grid, A_grid[..., 1, 0], levels=100)
156+
fig.colorbar(c, ax=ax, label=r"$\frac{\partial \tau_\mathrm{ax}}{\partial u_1}$")
157+
# contour plot
158+
ax.contour(kappa_be_grid, sigma_ax_grid, A_grid[..., 1, 0], levels=20, colors="k", linewidths=0.5)
159+
ax.set_xlabel(r"$\kappa_\mathrm{be}$ [rad/m]")
160+
ax.set_ylabel(r"$\sigma_\mathrm{ax}$ [-]")
161+
plt.tight_layout()
162+
plt.show()
163+
164+
165+
def simulate_robot():
166+
# define initial configuration
167+
q0 = jnp.repeat(jnp.array([-5.0 * jnp.pi, -0.2])[None, :], num_segments, axis=0).flatten()
168+
# number of generalized coordinates
169+
n_q = q0.shape[0]
170+
171+
# set simulation parameters
172+
dt = 1e-3 # time step
173+
sim_dt = 5e-5 # simulation time step
174+
ts = jnp.arange(0.0, 7.0, dt) # time steps
175+
176+
x0 = jnp.concatenate([q0, jnp.zeros_like(q0)]) # initial condition
177+
u = jnp.array([1.2e3, 0e0]) # control inputs (pressures in the right and left chambers)
178+
179+
ode_fn = ode_factory(dynamical_matrices_fn, params, u)
180+
term = ODETerm(ode_fn)
181+
182+
sol = diffeqsolve(
183+
term,
184+
solver=Tsit5(),
185+
t0=ts[0],
186+
t1=ts[-1],
187+
dt0=sim_dt,
188+
y0=x0,
189+
max_steps=None,
190+
saveat=SaveAt(ts=ts),
191+
)
192+
193+
print("sol.ys =\n", sol.ys)
194+
# the evolution of the generalized coordinates
195+
q_ts = sol.ys[:, :n_q]
196+
# the evolution of the generalized velocities
197+
q_d_ts = sol.ys[:, n_q:]
198+
199+
# evaluate the forward kinematics along the trajectory
200+
chi_ee_ts = vmap(forward_kinematics_fn, in_axes=(None, 0, None))(
201+
params, q_ts, jnp.array([jnp.sum(params["l"])])
202+
)
203+
# plot the configuration vs time
204+
plt.figure()
205+
for segment_idx in range(num_segments):
206+
plt.plot(
207+
ts, q_ts[:, 2 * segment_idx + 0],
208+
label=r"$\kappa_\mathrm{be," + str(segment_idx + 1) + "}$ [rad/m]"
209+
)
210+
plt.plot(
211+
ts, q_ts[:, 2 * segment_idx + 1],
212+
label=r"$\sigma_\mathrm{ax," + str(segment_idx + 1) + "}$ [-]"
213+
)
214+
plt.xlabel("Time [s]")
215+
plt.ylabel("Configuration")
216+
plt.legend()
217+
plt.grid(True)
218+
plt.tight_layout()
219+
plt.show()
220+
# plot end-effector position vs time
221+
plt.figure()
222+
plt.plot(ts, chi_ee_ts[:, 0], label="x")
223+
plt.plot(ts, chi_ee_ts[:, 1], label="y")
224+
plt.xlabel("Time [s]")
225+
plt.ylabel("End-effector Position [m]")
226+
plt.legend()
227+
plt.grid(True)
228+
plt.box(True)
229+
plt.tight_layout()
230+
plt.show()
231+
# plot the end-effector position in the x-y plane as a scatter plot with the time as the color
232+
plt.figure()
233+
plt.scatter(chi_ee_ts[:, 0], chi_ee_ts[:, 1], c=ts, cmap="viridis")
234+
plt.axis("equal")
235+
plt.grid(True)
236+
plt.xlabel("End-effector x [m]")
237+
plt.ylabel("End-effector y [m]")
238+
plt.colorbar(label="Time [s]")
239+
plt.tight_layout()
240+
plt.show()
241+
# plt.figure()
242+
# plt.plot(chi_ee_ts[:, 0], chi_ee_ts[:, 1])
243+
# plt.axis("equal")
244+
# plt.grid(True)
245+
# plt.xlabel("End-effector x [m]")
246+
# plt.ylabel("End-effector y [m]")
247+
# plt.tight_layout()
248+
# plt.show()
249+
250+
# plot the energy along the trajectory
251+
kinetic_energy_fn_vmapped = vmap(
252+
partial(auxiliary_fns["kinetic_energy_fn"], params)
253+
)
254+
potential_energy_fn_vmapped = vmap(
255+
partial(auxiliary_fns["potential_energy_fn"], params)
256+
)
257+
U_ts = potential_energy_fn_vmapped(q_ts)
258+
T_ts = kinetic_energy_fn_vmapped(q_ts, q_d_ts)
259+
plt.figure()
260+
plt.plot(ts, U_ts, label="Potential energy")
261+
plt.plot(ts, T_ts, label="Kinetic energy")
262+
plt.xlabel("Time [s]")
263+
plt.ylabel("Energy [J]")
264+
plt.legend()
265+
plt.grid(True)
266+
plt.box(True)
267+
plt.tight_layout()
268+
plt.show()
269+
270+
if __name__ == "__main__":
271+
sweep_local_tip_force_to_bending_torque_mapping()
272+
sweep_actuation_mapping()
273+
simulate_robot()

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@ name = "jsrm" # Required
1717
#
1818
# For a discussion on single-sourcing the version, see
1919
# https://packaging.python.org/guides/single-sourcing-package-version/
20-
version = "0.0.12" # Required
20+
version = "0.0.13" # Required
2121

2222
# This is a one-line description or tagline of what your project does. This
2323
# corresponds to the "Summary" metadata field:

0 commit comments

Comments
 (0)