Skip to content

Commit 83888b5

Browse files
committed
Allow easy simulation of more than one segment and improve plotting
1 parent d7b7882 commit 83888b5

File tree

1 file changed

+41
-13
lines changed

1 file changed

+41
-13
lines changed

examples/simulate_planar_pcs.py

Lines changed: 41 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
import jax
44

55
jax.config.update("jax_enable_x64", True) # double precision
6-
from diffrax import diffeqsolve, Dopri5, Euler, ODETerm, SaveAt
6+
from diffrax import diffeqsolve, Euler, ODETerm, SaveAt, Tsit5
77
from jax import Array, vmap
88
from jax import numpy as jnp
99
import matplotlib.pyplot as plt
@@ -26,7 +26,7 @@
2626

2727
# set parameters
2828
rho = 1070 * jnp.ones((num_segments,)) # Volumetric density of Dragon Skin 20 [kg/m^3]
29-
D = 1e-5 * jnp.diag(
29+
D = 1e-4 * jnp.diag(
3030
jnp.repeat(
3131
jnp.array([[1e0, 1e3, 1e3]]), num_segments, axis=0
3232
).flatten(),
@@ -44,15 +44,14 @@
4444

4545
# activate all strains (i.e. bending, shear, and axial)
4646
strain_selector = jnp.ones((3 * num_segments,), dtype=bool)
47-
strain_selector = jnp.array([True, False, False])
4847

4948
# define initial configuration
50-
q0 = jnp.array([5 * jnp.pi])
49+
q0 = jnp.repeat(jnp.array([5.0 * jnp.pi, 0.1, 0.2])[None, :], num_segments, axis=0).flatten()
5150
# number of generalized coordinates
5251
n_q = q0.shape[0]
5352

5453
# set simulation parameters
55-
dt = 1e-3 # time step
54+
dt = 1e-4 # time step
5655
ts = jnp.arange(0.0, 2, dt) # time steps
5756
skip_step = 10 # how many time steps to skip in between video frames
5857
video_ts = ts[::skip_step] # time steps for video
@@ -122,16 +121,15 @@ def draw_robot(
122121
# cv2.waitKey()
123122
# cv2.destroyWindow(window_name)
124123

125-
x0 = jnp.zeros((2 * q0.shape[0],)) # initial condition
126-
x0 = x0.at[: q0.shape[0]].set(q0) # set initial configuration
124+
x0 = jnp.concatenate([q0, jnp.zeros_like(q0)]) # initial condition
127125
tau = jnp.zeros_like(q0) # torques
128126

129127
ode_fn = ode_factory(dynamical_matrices_fn, params, tau)
130128
term = ODETerm(ode_fn)
131129

132130
sol = diffeqsolve(
133131
term,
134-
solver=Dopri5(),
132+
solver=Tsit5(),
135133
t0=ts[0],
136134
t1=ts[-1],
137135
dt0=dt,
@@ -150,13 +148,25 @@ def draw_robot(
150148
chi_ee_ts = vmap(forward_kinematics_fn, in_axes=(None, 0, None))(
151149
params, q_ts, jnp.array([jnp.sum(params["l"])])
152150
)
153-
# plot the end-effector position along the trajectory
151+
# plot the configuration vs time
154152
plt.figure()
155-
plt.plot(chi_ee_ts[0, :], chi_ee_ts[1, :])
156-
plt.axis("equal")
153+
for segment_idx in range(num_segments):
154+
plt.plot(
155+
video_ts, q_ts[:, 3 * segment_idx + 0],
156+
label=r"$\kappa_\mathrm{be," + str(segment_idx + 1) + "}$ [rad/m]"
157+
)
158+
plt.plot(
159+
video_ts, q_ts[:, 3 * segment_idx + 1],
160+
label=r"$\sigma_\mathrm{sh," + str(segment_idx + 1) + "}$ [-]"
161+
)
162+
plt.plot(
163+
video_ts, q_ts[:, 3 * segment_idx + 2],
164+
label=r"$\sigma_\mathrm{el," + str(segment_idx + 1) + "}$ [-]"
165+
)
166+
plt.xlabel("Time [s]")
167+
plt.ylabel("Configuration")
168+
plt.legend()
157169
plt.grid(True)
158-
plt.xlabel("End-effector x [m]")
159-
plt.ylabel("End-effector y [m]")
160170
plt.tight_layout()
161171
plt.show()
162172
# plot end-effector position vs time
@@ -170,6 +180,24 @@ def draw_robot(
170180
plt.box(True)
171181
plt.tight_layout()
172182
plt.show()
183+
# plot the end-effector position in the x-y plane as a scatter plot with the time as the color
184+
plt.figure()
185+
plt.scatter(chi_ee_ts[:, 0], chi_ee_ts[:, 1], c=video_ts, cmap="viridis")
186+
plt.axis("equal")
187+
plt.grid(True)
188+
plt.xlabel("End-effector x [m]")
189+
plt.ylabel("End-effector y [m]")
190+
plt.colorbar(label="Time [s]")
191+
plt.tight_layout()
192+
plt.show()
193+
# plt.figure()
194+
# plt.plot(chi_ee_ts[:, 0], chi_ee_ts[:, 1])
195+
# plt.axis("equal")
196+
# plt.grid(True)
197+
# plt.xlabel("End-effector x [m]")
198+
# plt.ylabel("End-effector y [m]")
199+
# plt.tight_layout()
200+
# plt.show()
173201

174202
# plot the energy along the trajectory
175203
kinetic_energy_fn_vmapped = vmap(

0 commit comments

Comments
 (0)