Skip to content

Commit f34aad4

Browse files
committed
Add Planar PCS class and documentation correction
1 parent 9739918 commit f34aad4

File tree

9 files changed

+2088
-2606
lines changed

9 files changed

+2088
-2606
lines changed

examples/simulate_pcs.py

Lines changed: 15 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -26,13 +26,13 @@
2626

2727

2828
def draw_robot_curve(
29-
batched_forward_kinematics_fn: Callable,
29+
batched_forward_kinematics: Callable,
3030
L_max: float,
3131
q: Array,
3232
num_points: int = 50,
3333
):
3434
s_ps = jnp.linspace(0, L_max, num_points)
35-
g_ps = batched_forward_kinematics_fn(q, s_ps)[:, :3, 3]
35+
g_ps = batched_forward_kinematics(q, s_ps)[:, :3, 3]
3636

3737
curve = onp.array(g_ps, dtype=onp.float64)
3838
return curve # (N, 3)
@@ -46,23 +46,22 @@ def animate_robot_matplotlib(
4646
interval: int = 50,
4747
slider: bool = None,
4848
animation: bool = None,
49-
show: bool = False,
49+
show: bool = True,
5050
):
5151
if slider is None and animation is None:
5252
raise ValueError("Either 'slider' or 'animation' must be set to True.")
5353
if animation and slider:
5454
raise ValueError(
5555
"Cannot use both animation and slider at the same time. Choose one."
5656
)
57-
batched_forward_kinematics_fn = jax.vmap(
58-
robot.forward_kinematics_fn, in_axes=(None, 0)
59-
)
57+
58+
batched_forward_kinematics = jax.vmap(robot.forward_kinematics, in_axes=(None, 0))
6059
L_max = jnp.sum(robot.L)
6160

62-
width = jnp.linalg.norm(robot.L) * 1.5
61+
width = jnp.linalg.norm(robot.L) * 3
6362
height = width
6463

65-
fig = plt.figure(figsize=(8, 8))
64+
fig = plt.figure()
6665
ax = fig.add_subplot(111, projection="3d")
6766
ax_slider = fig.add_axes([0.2, 0.05, 0.6, 0.03]) # [left, bottom, width, height]
6867

@@ -82,9 +81,7 @@ def init():
8281
def update(frame_idx):
8382
q = q_list[frame_idx]
8483
t = t_list[frame_idx]
85-
curve = draw_robot_curve(
86-
batched_forward_kinematics_fn, L_max, q, num_points
87-
)
84+
curve = draw_robot_curve(batched_forward_kinematics, L_max, q, num_points)
8885
line.set_data(curve[:, 0], curve[:, 1])
8986
line.set_3d_properties(curve[:, 2])
9087
title_text.set_text(f"t = {t:.2f} s")
@@ -104,7 +101,8 @@ def update(frame_idx):
104101

105102
plt.close(fig)
106103
return HTML(ani.to_jshtml())
107-
if slider:
104+
105+
elif slider:
108106

109107
def update_plot(frame_idx):
110108
ax.cla() # Clear current axes
@@ -116,9 +114,7 @@ def update_plot(frame_idx):
116114
ax.set_zlabel("Z [m]")
117115
ax.set_title(f"t = {t_list[frame_idx]:.2f} s")
118116
q = q_list[frame_idx]
119-
curve = draw_robot_curve(
120-
batched_forward_kinematics_fn, L_max, q, num_points
121-
)
117+
curve = draw_robot_curve(batched_forward_kinematics, L_max, q, num_points)
122118
ax.plot(curve[:, 0], curve[:, 1], curve[:, 2], lw=4, color="blue")
123119
fig.canvas.draw_idle()
124120

@@ -152,11 +148,11 @@ def update_plot(frame_idx):
152148
params = {
153149
"p0": jnp.array(
154150
[0.0, 0.0, 0.0, 0.0, 0.0, 0.0]
155-
), # 1.0, 1.0, 1.0]), # Initial position and orientation
151+
), # Initial position and orientation
156152
"l": 1e-1 * jnp.ones((num_segments,)),
157153
"r": 2e-2 * jnp.ones((num_segments,)),
158154
"rho": rho,
159-
"g": jnp.array([0.0, 0.0, 9.81]), # Gravity vector [m/s^2]
155+
"g": jnp.array([0.0, 0.0, -9.81]), # Gravity vector [m/s^2]
160156
"E": 2e3 * jnp.ones((num_segments,)), # Elastic modulus [Pa]
161157
"G": 1e3 * jnp.ones((num_segments,)), # Shear modulus [Pa]
162158
}
@@ -213,6 +209,7 @@ def update_plot(frame_idx):
213209
t1=t1,
214210
dt=dt,
215211
skip_steps=skip_step,
212+
solver=solver,
216213
max_steps=None,
217214
)
218215

@@ -221,7 +218,7 @@ def update_plot(frame_idx):
221218
# =====================================================
222219
forward_kinematics_end_effector = jax.jit(
223220
partial(
224-
robot.forward_kinematics_fn,
221+
robot.forward_kinematics,
225222
s=jnp.sum(robot.L), # end-effector position
226223
)
227224
)

examples/simulate_planar_pcs.py

Lines changed: 266 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,266 @@
1+
import jax
2+
3+
from jsrm.systems.planar_pcs import PlanarPCS
4+
import jax.numpy as jnp
5+
6+
from typing import Callable, Dict
7+
from jax import Array
8+
9+
import numpy as onp
10+
11+
import matplotlib.pyplot as plt
12+
from matplotlib.animation import FuncAnimation
13+
from IPython.display import HTML
14+
15+
from diffrax import Tsit5
16+
17+
from functools import partial
18+
from matplotlib.widgets import Slider
19+
20+
jax.config.update("jax_enable_x64", True) # double precision
21+
jnp.set_printoptions(
22+
threshold=jnp.inf,
23+
linewidth=jnp.inf,
24+
formatter={"float_kind": lambda x: "0" if x == 0 else f"{x:.2e}"},
25+
)
26+
27+
28+
def draw_robot_curve(
29+
batched_forward_kinematics: Callable,
30+
L_max: float,
31+
q: Array,
32+
num_points: int = 50,
33+
):
34+
s_ps = jnp.linspace(0, L_max, num_points)
35+
chi_ps = batched_forward_kinematics(q, s_ps)
36+
37+
curve = onp.array(chi_ps[1:, :], dtype=onp.float32).T
38+
39+
return curve # (N, 2)
40+
41+
42+
def animate_robot_matplotlib(
43+
robot: PlanarPCS,
44+
t_list: Array, # shape (T,)
45+
q_list: Array, # shape (T, DOF)
46+
num_points: int = 50,
47+
interval: int = 50,
48+
slider: bool = None,
49+
animation: bool = None,
50+
show: bool = True,
51+
):
52+
if slider is None and animation is None:
53+
raise ValueError("Either 'slider' or 'animation' must be set to True.")
54+
if animation and slider:
55+
raise ValueError(
56+
"Cannot use both animation and slider at the same time. Choose one."
57+
)
58+
59+
batched_forward_kinematics = jax.vmap(
60+
robot.forward_kinematics, in_axes=(None, 0), out_axes=-1
61+
)
62+
L_max = jnp.sum(robot.L)
63+
64+
width = jnp.linalg.norm(robot.L) * 3
65+
height = width
66+
67+
fig = plt.figure()
68+
ax = fig.add_subplot(111)
69+
ax_slider = fig.add_axes([0.2, 0.05, 0.6, 0.03]) # [left, bottom, width, height]
70+
71+
if animation:
72+
(line,) = ax.plot([], [], lw=4, color="blue")
73+
ax.set_xlim(-width / 2, width / 2)
74+
ax.set_ylim(0, height)
75+
title_text = ax.set_title("t = 0.00 s")
76+
77+
def init():
78+
line.set_data([], [])
79+
title_text.set_text("t = 0.00 s")
80+
return line, title_text
81+
82+
def update(frame_idx):
83+
q = q_list[frame_idx]
84+
t = t_list[frame_idx]
85+
curve = draw_robot_curve(batched_forward_kinematics, L_max, q, num_points)
86+
line.set_data(curve[:, 0], curve[:, 1])
87+
title_text.set_text(f"t = {t:.2f} s")
88+
return line, title_text
89+
90+
ani = FuncAnimation(
91+
fig,
92+
update,
93+
frames=len(q_list),
94+
init_func=init,
95+
blit=False,
96+
interval=interval,
97+
)
98+
99+
if show:
100+
plt.show()
101+
plt.close(fig)
102+
return HTML(ani.to_jshtml())
103+
104+
elif slider:
105+
106+
def update_plot(frame_idx):
107+
ax.cla() # Clear current axes
108+
ax.set_xlim(-width / 2, width / 2)
109+
ax.set_ylim(0, height)
110+
ax.set_xlabel("X [m]")
111+
ax.set_ylabel("Y [m]")
112+
ax.set_title(f"t = {t_list[frame_idx]:.2f} s")
113+
q = q_list[frame_idx]
114+
curve = draw_robot_curve(batched_forward_kinematics, L_max, q, num_points)
115+
ax.plot(curve[:, 0], curve[:, 1], lw=4, color="blue")
116+
fig.canvas.draw_idle()
117+
118+
# Create slider
119+
slider = Slider(
120+
ax=ax_slider,
121+
label="Frame",
122+
valmin=0,
123+
valmax=len(t_list) - 1,
124+
valinit=0,
125+
valstep=1,
126+
)
127+
slider.on_changed(update_plot)
128+
129+
update_plot(0) # Initial plot
130+
131+
if show:
132+
plt.show()
133+
134+
plt.close(fig)
135+
return HTML(
136+
"Slider animation not implemented in HTML format. Use matplotlib directly to view the slider."
137+
) # Slider cannot be converted to HTML
138+
139+
140+
if __name__ == "__main__":
141+
num_segments = 2
142+
rho = 1070 * jnp.ones(
143+
(num_segments,)
144+
) # Volumetric density of Dragon Skin 20 [kg/m^3]
145+
params = {
146+
"th0": jnp.array(0.0), # initial orientation angle [rad]
147+
"l": 1e-1 * jnp.ones((num_segments,)),
148+
"r": 2e-2 * jnp.ones((num_segments,)),
149+
"rho": rho,
150+
"g": jnp.array([0.0, -9.81]),
151+
"E": 2e3 * jnp.ones((num_segments,)), # Elastic modulus [Pa]
152+
"G": 1e3 * jnp.ones((num_segments,)), # Shear modulus [Pa]
153+
}
154+
params["D"] = 1e-3 * jnp.diag(
155+
(
156+
jnp.repeat(jnp.array([[1e0, 1e3, 1e3]]), num_segments, axis=0)
157+
* params["l"][:, None]
158+
).flatten()
159+
)
160+
161+
# ======================================================
162+
# Robot initialization
163+
# ======================================================
164+
robot = PlanarPCS(
165+
num_segments=num_segments,
166+
params=params,
167+
order_gauss=5,
168+
)
169+
170+
# =====================================================
171+
# Simulation upon time
172+
# =====================================================
173+
# Initial configuration
174+
q0 = jnp.repeat(
175+
jnp.array([5.0 * jnp.pi, 0.1, 0.2])[None, :], num_segments, axis=0
176+
).flatten()
177+
# Initial velocities
178+
qd0 = jnp.zeros_like(q0)
179+
180+
# Actuation parameters
181+
tau = jnp.zeros_like(q0)
182+
# WARNING: actuation_args need to be a tuple, even if it contains only one element
183+
actuation_args = (tau,)
184+
185+
# Simulation time parameters
186+
t0 = 0.0
187+
t1 = 2.0
188+
dt = 1e-4
189+
skip_step = 100 # how many time steps to skip in between video frames
190+
191+
# Solver
192+
solver = Tsit5() # Runge-Kutta 5(4) method
193+
194+
ts, q_ts, q_d_ts = robot.resolve_upon_time(
195+
q0=q0,
196+
qd0=qd0,
197+
actuation_args=actuation_args,
198+
t0=t0,
199+
t1=t1,
200+
dt=dt,
201+
skip_steps=skip_step,
202+
solver=solver,
203+
max_steps=None,
204+
)
205+
206+
# =====================================================
207+
# End-effector position upon time
208+
# =====================================================
209+
forward_kinematics_end_effector = jax.jit(
210+
partial(
211+
robot.forward_kinematics,
212+
s=jnp.sum(robot.L), # end-effector position
213+
)
214+
)
215+
chi_ee_ts = jax.vmap(forward_kinematics_end_effector)(q_ts)
216+
217+
plt.figure()
218+
plt.plot(ts, chi_ee_ts[:, 1], label="End-effector x [m]")
219+
plt.plot(ts, chi_ee_ts[:, 2], label="End-effector y [m]")
220+
plt.xlabel("Time [s]")
221+
plt.ylabel("End-effector position [m]")
222+
plt.legend()
223+
plt.grid(True)
224+
plt.box(True)
225+
plt.tight_layout()
226+
plt.show()
227+
228+
plt.figure()
229+
plt.scatter(chi_ee_ts[:, 1], chi_ee_ts[:, 2], c=ts, cmap="viridis")
230+
plt.axis("equal")
231+
plt.grid(True)
232+
plt.xlabel("End-effector x [m]")
233+
plt.ylabel("End-effector y [m]")
234+
plt.colorbar(label="Time [s]")
235+
plt.tight_layout()
236+
plt.show()
237+
238+
# =====================================================
239+
# Energy computation upon time
240+
# =====================================================
241+
U_ts = jax.vmap(jax.jit(partial(robot.potential_energy)))(q_ts)
242+
T_ts = jax.vmap(jax.jit(partial(robot.kinetic_energy)))(q_ts, q_d_ts)
243+
244+
plt.figure()
245+
plt.plot(ts, U_ts, label="Potential Energy")
246+
plt.plot(ts, T_ts, label="Kinetic Energy")
247+
plt.xlabel("Time (s)")
248+
plt.ylabel("Energy (J)")
249+
plt.legend()
250+
plt.title("Energy over Time")
251+
plt.grid(True)
252+
plt.box(True)
253+
plt.tight_layout()
254+
plt.show()
255+
256+
# =====================================================
257+
# Plot the robot configuration upon time
258+
# =====================================================
259+
animate_robot_matplotlib(
260+
robot=robot,
261+
t_list=ts, # shape (T,)
262+
q_list=q_ts, # shape (T, DOF)
263+
num_points=50,
264+
interval=100, # ms
265+
slider=True,
266+
)

pyproject.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -107,6 +107,7 @@ dependencies = [ # Optional
107107
"jax",
108108
"numpy",
109109
"quadax",
110+
"equinox",
110111
"peppercorn",
111112
"sympy>=1.11"
112113
]

0 commit comments

Comments
 (0)