Skip to content

Commit 9d81090

Browse files
committed
3D PCS and documentation updates
1 parent aa35610 commit 9d81090

File tree

10 files changed

+2293
-2695
lines changed

10 files changed

+2293
-2695
lines changed

examples/simulate_pcs.py

Lines changed: 283 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,283 @@
1+
import jax
2+
3+
from jsrm.systems.pcs import PCS
4+
import jax.numpy as jnp
5+
6+
from typing import Callable
7+
from jax import Array
8+
9+
import numpy as onp
10+
11+
import matplotlib.pyplot as plt
12+
from matplotlib.animation import FuncAnimation
13+
from IPython.display import HTML
14+
15+
from diffrax import Tsit5
16+
17+
from functools import partial
18+
from matplotlib.widgets import Slider
19+
20+
jax.config.update("jax_enable_x64", True) # double precision
21+
jnp.set_printoptions(
22+
threshold=jnp.inf,
23+
linewidth=jnp.inf,
24+
formatter={"float_kind": lambda x: "0" if x == 0 else f"{x:.2e}"},
25+
)
26+
27+
28+
def draw_robot_curve(
29+
batched_forward_kinematics_fn: Callable,
30+
L_max: float,
31+
q: Array,
32+
num_points: int = 50,
33+
):
34+
s_ps = jnp.linspace(0, L_max, num_points)
35+
g_ps = batched_forward_kinematics_fn(q, s_ps)[:, :3, 3]
36+
37+
curve = onp.array(g_ps, dtype=onp.float64)
38+
return curve # (N, 3)
39+
40+
41+
def animate_robot_matplotlib(
42+
robot: PCS,
43+
t_list: Array, # shape (T,)
44+
q_list: Array, # shape (T, DOF)
45+
num_points: int = 50,
46+
interval: int = 50,
47+
slider: bool = None,
48+
animation: bool = None,
49+
show: bool = False,
50+
):
51+
if slider is None and animation is None:
52+
raise ValueError("Either 'slider' or 'animation' must be set to True.")
53+
if animation and slider:
54+
raise ValueError(
55+
"Cannot use both animation and slider at the same time. Choose one."
56+
)
57+
batched_forward_kinematics_fn = jax.vmap(
58+
robot.forward_kinematics_fn, in_axes=(None, 0)
59+
)
60+
L_max = jnp.sum(robot.L)
61+
62+
width = jnp.linalg.norm(robot.L) * 1.5
63+
height = width
64+
65+
fig = plt.figure(figsize=(8, 8))
66+
ax = fig.add_subplot(111, projection="3d")
67+
ax_slider = fig.add_axes([0.2, 0.05, 0.6, 0.03]) # [left, bottom, width, height]
68+
69+
if animation:
70+
(line,) = ax.plot([], [], [], lw=4, color="blue")
71+
ax.set_xlim(-width / 2, width / 2)
72+
ax.set_ylim(-width / 2, width / 2)
73+
ax.set_zlim(0, height)
74+
title_text = ax.set_title("t = 0.00 s")
75+
76+
def init():
77+
line.set_data([], [])
78+
line.set_3d_properties([])
79+
title_text.set_text("t = 0.00 s")
80+
return line, title_text
81+
82+
def update(frame_idx):
83+
q = q_list[frame_idx]
84+
t = t_list[frame_idx]
85+
curve = draw_robot_curve(
86+
batched_forward_kinematics_fn, L_max, q, num_points
87+
)
88+
line.set_data(curve[:, 0], curve[:, 1])
89+
line.set_3d_properties(curve[:, 2])
90+
title_text.set_text(f"t = {t:.2f} s")
91+
return line, title_text
92+
93+
ani = FuncAnimation(
94+
fig,
95+
update,
96+
frames=len(q_list),
97+
init_func=init,
98+
blit=False,
99+
interval=interval,
100+
)
101+
102+
if show:
103+
plt.show()
104+
105+
plt.close(fig)
106+
return HTML(ani.to_jshtml())
107+
if slider:
108+
109+
def update_plot(frame_idx):
110+
ax.cla() # Clear current axes
111+
ax.set_xlim(-width / 2, width / 2)
112+
ax.set_ylim(-width / 2, width / 2)
113+
ax.set_zlim(0, height)
114+
ax.set_xlabel("X [m]")
115+
ax.set_ylabel("Y [m]")
116+
ax.set_zlabel("Z [m]")
117+
ax.set_title(f"t = {t_list[frame_idx]:.2f} s")
118+
q = q_list[frame_idx]
119+
curve = draw_robot_curve(
120+
batched_forward_kinematics_fn, L_max, q, num_points
121+
)
122+
ax.plot(curve[:, 0], curve[:, 1], curve[:, 2], lw=4, color="blue")
123+
fig.canvas.draw_idle()
124+
125+
# Create slider
126+
slider = Slider(
127+
ax=ax_slider,
128+
label="Frame",
129+
valmin=0,
130+
valmax=len(t_list) - 1,
131+
valinit=0,
132+
valstep=1,
133+
)
134+
slider.on_changed(update_plot)
135+
136+
update_plot(0) # Initial plot
137+
138+
if show:
139+
plt.show()
140+
141+
plt.close(fig)
142+
return HTML(
143+
"Slider animation not implemented in HTML format. Use matplotlib directly to view the slider."
144+
) # Slider cannot be converted to HTML
145+
146+
147+
if __name__ == "__main__":
148+
num_segments = 2
149+
rho = 1070 * jnp.ones(
150+
(num_segments,)
151+
) # Volumetric density of Dragon Skin 20 [kg/m^3]
152+
params = {
153+
"p0": jnp.array(
154+
[0.0, 0.0, 0.0, 0.0, 0.0, 0.0]
155+
), # 1.0, 1.0, 1.0]), # Initial position and orientation
156+
"l": 1e-1 * jnp.ones((num_segments,)),
157+
"r": 2e-2 * jnp.ones((num_segments,)),
158+
"rho": rho,
159+
"g": jnp.array([0.0, 0.0, 9.81]), # Gravity vector [m/s^2]
160+
"E": 2e3 * jnp.ones((num_segments,)), # Elastic modulus [Pa]
161+
"G": 1e3 * jnp.ones((num_segments,)), # Shear modulus [Pa]
162+
}
163+
params["D"] = 1e-3 * jnp.diag(
164+
(
165+
jnp.repeat(
166+
jnp.array([[1e0, 1e0, 1e0, 1e3, 1e3, 1e3]]), num_segments, axis=0
167+
)
168+
* params["l"][:, None]
169+
).flatten()
170+
)
171+
172+
# ======================================================
173+
# Robot initialization
174+
# ======================================================
175+
robot = PCS(
176+
num_segments=num_segments,
177+
params=params,
178+
order_gauss=5,
179+
)
180+
181+
# =====================================================
182+
# Simulation upon time
183+
# =====================================================
184+
# Initial configuration
185+
q0 = jnp.repeat(
186+
jnp.array([5.0 * jnp.pi, 0.0, 0.0, 0.0, 0.1, 0.2])[None, :],
187+
num_segments,
188+
axis=0,
189+
).flatten()
190+
# Initial velocities
191+
qd0 = jnp.zeros_like(q0)
192+
193+
# Actuation parameters
194+
tau = jnp.zeros_like(q0)
195+
# WARNING: actuation_args need to be a tuple, even if it contains only one element
196+
# so (tau, ) is necessary NOT (tau) or tau
197+
actuation_args = (tau,)
198+
199+
# Simulation time parameters
200+
t0 = 0.0
201+
t1 = 2.0
202+
dt = 1e-4
203+
skip_step = 100 # how many time steps to skip in between video frames
204+
205+
# Solver
206+
solver = Tsit5() # Runge-Kutta 5(4) method
207+
208+
ts, q_ts, q_d_ts = robot.resolve_upon_time(
209+
q0=q0,
210+
qd0=qd0,
211+
actuation_args=actuation_args,
212+
t0=t0,
213+
t1=t1,
214+
dt=dt,
215+
skip_steps=skip_step,
216+
max_steps=None,
217+
)
218+
219+
# =====================================================
220+
# End-effector position upon time
221+
# =====================================================
222+
forward_kinematics_end_effector = jax.jit(
223+
partial(
224+
robot.forward_kinematics_fn,
225+
s=jnp.sum(robot.L), # end-effector position
226+
)
227+
)
228+
g_ee_ts = jax.vmap(forward_kinematics_end_effector)(q_ts)
229+
230+
plt.figure()
231+
plt.plot(ts, g_ee_ts[:, 0, 3], label="End-effector x [m]")
232+
plt.plot(ts, g_ee_ts[:, 1, 3], label="End-effector y [m]")
233+
plt.plot(ts, g_ee_ts[:, 2, 3], label="End-effector z [m]")
234+
plt.xlabel("Time [s]")
235+
plt.ylabel("End-effector position [m]")
236+
plt.legend()
237+
plt.grid(True)
238+
plt.box(True)
239+
plt.tight_layout()
240+
plt.show()
241+
242+
fig = plt.figure()
243+
ax = fig.add_subplot(111, projection="3d")
244+
p = ax.scatter(
245+
g_ee_ts[:, 0, 3], g_ee_ts[:, 1, 3], g_ee_ts[:, 2, 3], c=ts, cmap="viridis"
246+
)
247+
ax.axis("equal")
248+
ax.set_xlabel("X [m]")
249+
ax.set_ylabel("Y [m]")
250+
ax.set_zlabel("Z [m]")
251+
ax.set_title("End-effector trajectory (3D)")
252+
fig.colorbar(p, ax=ax, label="Time [s]")
253+
plt.show()
254+
255+
# =====================================================
256+
# Energy computation upon time
257+
# =====================================================
258+
U_ts = jax.vmap(jax.jit(partial(robot.potential_energy)))(q_ts)
259+
T_ts = jax.vmap(jax.jit(partial(robot.kinetic_energy)))(q_ts, q_d_ts)
260+
261+
plt.figure()
262+
plt.plot(ts, U_ts, label="Potential Energy")
263+
plt.plot(ts, T_ts, label="Kinetic Energy")
264+
plt.xlabel("Time (s)")
265+
plt.ylabel("Energy (J)")
266+
plt.legend()
267+
plt.title("Energy over Time")
268+
plt.grid(True)
269+
plt.box(True)
270+
plt.tight_layout()
271+
plt.show()
272+
273+
# =====================================================
274+
# Plot the robot configuration upon time
275+
# =====================================================
276+
animate_robot_matplotlib(
277+
robot,
278+
t_list=ts, # shape (T,)
279+
q_list=q_ts, # shape (T, DOF)
280+
num_points=50,
281+
interval=100, # ms
282+
slider=True,
283+
)

0 commit comments

Comments
 (0)