Skip to content

Commit 1bf71b5

Browse files
Numerical derivation with class (#9)
* Add plot backup, full-JAX PCS planar + Gaussian quadrature integration scheme * Change fori_loop to vmap/scan and cond to min in the planar_pcs_num.py file. Correct a description of function in the planar_pcs.py file. * Transformed simulate_planar_pcs into a function callable from other files for comparison purposes with the users choice parameters : - option of saving or not the results, figures, videos - option of plotting/printing or not the results, figures - option on the type of derivation to use : symbolic, numeric - option on the type of integration and parameter of integration to use : gauss, trapezoid - option on the type of jacobian to use : explicit or autodifferentiation Added the ability to save simulation results in pickle files (.pkl) for later comparison Set up an explicit Jacobian in SE3 to compute B and G - SE(3) Lie algebra operators - convert SE(2) to SE(3) to use operators for the planar case * Replacement of at.set by block Implementation of a Gauss-Kronrad quadrature integration using the quadax library Function documentation * Add quadax dependency, recover original files, remove jit for math_utils * Changing Coriolis for loop to vmap, get rid of jnp.array when possible, benchmark and tests on eps. * Corrected type error in planar_pcs_num.py * Added formulas in SE2, Coriolis corrections Autodiff: for loop calculation corrected Explicit: implementation of explicit calculation using Lie algebra Various tests * Correction of the kinetic energy function kinetic energy depends only on B and does not need to calculate other dynamic matrices Correction of documentations * Fix jit decorators to apply JIT-compilation to last level * Creation of a test file for planar_pcs_num.py Corrected documentation Removal of unnecessary imports Ready to merge * Roll-back changes to symbolic expressions * Bumpy version number and add Solange as an author * Rename `planar_pcs` system to `planar_pcs_sym` * Fix some type hinting errors * Fix missing changes in last commit * Rename `test_planar_pcs.py` to `test_planar_pcs_sym` * Fix some bugs * Format systems * Format the `tests` files * Format the `utils` files * Exclude some test scripts from automated testing if they require gui * Class for Planar PCS * 3D PCS and documentation updates * Documentation 3D-PCS * Correction for merging * Ruff format of files * Add Planar PCS class and documentation correction * Correct Planar PCS documentation * Remove alias for planar_pcs_sym import * Add comment about operational_space_dynamical_matrices function * Change version to 0.1.0 --------- Co-authored-by: Maximilian Stölzle <[email protected]>
1 parent b7944dc commit 1bf71b5

File tree

12 files changed

+4615
-2601
lines changed

12 files changed

+4615
-2601
lines changed

examples/simulate_pcs.py

Lines changed: 280 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,280 @@
1+
import jax
2+
3+
from jsrm.systems.pcs import PCS
4+
import jax.numpy as jnp
5+
6+
from typing import Callable
7+
from jax import Array
8+
9+
import numpy as onp
10+
11+
import matplotlib.pyplot as plt
12+
from matplotlib.animation import FuncAnimation
13+
from IPython.display import HTML
14+
15+
from diffrax import Tsit5
16+
17+
from functools import partial
18+
from matplotlib.widgets import Slider
19+
20+
jax.config.update("jax_enable_x64", True) # double precision
21+
jnp.set_printoptions(
22+
threshold=jnp.inf,
23+
linewidth=jnp.inf,
24+
formatter={"float_kind": lambda x: "0" if x == 0 else f"{x:.2e}"},
25+
)
26+
27+
28+
def draw_robot_curve(
29+
batched_forward_kinematics: Callable,
30+
L_max: float,
31+
q: Array,
32+
num_points: int = 50,
33+
):
34+
s_ps = jnp.linspace(0, L_max, num_points)
35+
g_ps = batched_forward_kinematics(q, s_ps)[:, :3, 3]
36+
37+
curve = onp.array(g_ps, dtype=onp.float64)
38+
return curve # (N, 3)
39+
40+
41+
def animate_robot_matplotlib(
42+
robot: PCS,
43+
t_list: Array, # shape (T,)
44+
q_list: Array, # shape (T, DOF)
45+
num_points: int = 50,
46+
interval: int = 50,
47+
slider: bool = None,
48+
animation: bool = None,
49+
show: bool = True,
50+
):
51+
if slider is None and animation is None:
52+
raise ValueError("Either 'slider' or 'animation' must be set to True.")
53+
if animation and slider:
54+
raise ValueError(
55+
"Cannot use both animation and slider at the same time. Choose one."
56+
)
57+
58+
batched_forward_kinematics = jax.vmap(robot.forward_kinematics, in_axes=(None, 0))
59+
L_max = jnp.sum(robot.L)
60+
61+
width = jnp.linalg.norm(robot.L) * 3
62+
height = width
63+
64+
fig = plt.figure()
65+
ax = fig.add_subplot(111, projection="3d")
66+
ax_slider = fig.add_axes([0.2, 0.05, 0.6, 0.03]) # [left, bottom, width, height]
67+
68+
if animation:
69+
(line,) = ax.plot([], [], [], lw=4, color="blue")
70+
ax.set_xlim(-width / 2, width / 2)
71+
ax.set_ylim(-width / 2, width / 2)
72+
ax.set_zlim(0, height)
73+
title_text = ax.set_title("t = 0.00 s")
74+
75+
def init():
76+
line.set_data([], [])
77+
line.set_3d_properties([])
78+
title_text.set_text("t = 0.00 s")
79+
return line, title_text
80+
81+
def update(frame_idx):
82+
q = q_list[frame_idx]
83+
t = t_list[frame_idx]
84+
curve = draw_robot_curve(batched_forward_kinematics, L_max, q, num_points)
85+
line.set_data(curve[:, 0], curve[:, 1])
86+
line.set_3d_properties(curve[:, 2])
87+
title_text.set_text(f"t = {t:.2f} s")
88+
return line, title_text
89+
90+
ani = FuncAnimation(
91+
fig,
92+
update,
93+
frames=len(q_list),
94+
init_func=init,
95+
blit=False,
96+
interval=interval,
97+
)
98+
99+
if show:
100+
plt.show()
101+
102+
plt.close(fig)
103+
return HTML(ani.to_jshtml())
104+
105+
elif slider:
106+
107+
def update_plot(frame_idx):
108+
ax.cla() # Clear current axes
109+
ax.set_xlim(-width / 2, width / 2)
110+
ax.set_ylim(-width / 2, width / 2)
111+
ax.set_zlim(0, height)
112+
ax.set_xlabel("X [m]")
113+
ax.set_ylabel("Y [m]")
114+
ax.set_zlabel("Z [m]")
115+
ax.set_title(f"t = {t_list[frame_idx]:.2f} s")
116+
q = q_list[frame_idx]
117+
curve = draw_robot_curve(batched_forward_kinematics, L_max, q, num_points)
118+
ax.plot(curve[:, 0], curve[:, 1], curve[:, 2], lw=4, color="blue")
119+
fig.canvas.draw_idle()
120+
121+
# Create slider
122+
slider = Slider(
123+
ax=ax_slider,
124+
label="Frame",
125+
valmin=0,
126+
valmax=len(t_list) - 1,
127+
valinit=0,
128+
valstep=1,
129+
)
130+
slider.on_changed(update_plot)
131+
132+
update_plot(0) # Initial plot
133+
134+
if show:
135+
plt.show()
136+
137+
plt.close(fig)
138+
return HTML(
139+
"Slider animation not implemented in HTML format. Use matplotlib directly to view the slider."
140+
) # Slider cannot be converted to HTML
141+
142+
143+
if __name__ == "__main__":
144+
num_segments = 2
145+
rho = 1070 * jnp.ones(
146+
(num_segments,)
147+
) # Volumetric density of Dragon Skin 20 [kg/m^3]
148+
params = {
149+
"p0": jnp.array(
150+
[0.0, 0.0, 0.0, 0.0, 0.0, 0.0]
151+
), # Initial position and orientation
152+
"l": 1e-1 * jnp.ones((num_segments,)),
153+
"r": 2e-2 * jnp.ones((num_segments,)),
154+
"rho": rho,
155+
"g": jnp.array([0.0, 0.0, -9.81]), # Gravity vector [m/s^2]
156+
"E": 2e3 * jnp.ones((num_segments,)), # Elastic modulus [Pa]
157+
"G": 1e3 * jnp.ones((num_segments,)), # Shear modulus [Pa]
158+
}
159+
params["D"] = 1e-3 * jnp.diag(
160+
(
161+
jnp.repeat(
162+
jnp.array([[1e0, 1e0, 1e0, 1e3, 1e3, 1e3]]), num_segments, axis=0
163+
)
164+
* params["l"][:, None]
165+
).flatten()
166+
)
167+
168+
# ======================================================
169+
# Robot initialization
170+
# ======================================================
171+
robot = PCS(
172+
num_segments=num_segments,
173+
params=params,
174+
order_gauss=5,
175+
)
176+
177+
# =====================================================
178+
# Simulation upon time
179+
# =====================================================
180+
# Initial configuration
181+
q0 = jnp.repeat(
182+
jnp.array([5.0 * jnp.pi, 0.0, 0.0, 0.0, 0.1, 0.2])[None, :],
183+
num_segments,
184+
axis=0,
185+
).flatten()
186+
# Initial velocities
187+
qd0 = jnp.zeros_like(q0)
188+
189+
# Actuation parameters
190+
tau = jnp.zeros_like(q0)
191+
# WARNING: actuation_args need to be a tuple, even if it contains only one element
192+
# so (tau, ) is necessary NOT (tau) or tau
193+
actuation_args = (tau,)
194+
195+
# Simulation time parameters
196+
t0 = 0.0
197+
t1 = 2.0
198+
dt = 1e-4
199+
skip_step = 100 # how many time steps to skip in between video frames
200+
201+
# Solver
202+
solver = Tsit5() # Runge-Kutta 5(4) method
203+
204+
ts, q_ts, q_d_ts = robot.resolve_upon_time(
205+
q0=q0,
206+
qd0=qd0,
207+
actuation_args=actuation_args,
208+
t0=t0,
209+
t1=t1,
210+
dt=dt,
211+
skip_steps=skip_step,
212+
solver=solver,
213+
max_steps=None,
214+
)
215+
216+
# =====================================================
217+
# End-effector position upon time
218+
# =====================================================
219+
forward_kinematics_end_effector = jax.jit(
220+
partial(
221+
robot.forward_kinematics,
222+
s=jnp.sum(robot.L), # end-effector position
223+
)
224+
)
225+
g_ee_ts = jax.vmap(forward_kinematics_end_effector)(q_ts)
226+
227+
plt.figure()
228+
plt.plot(ts, g_ee_ts[:, 0, 3], label="End-effector x [m]")
229+
plt.plot(ts, g_ee_ts[:, 1, 3], label="End-effector y [m]")
230+
plt.plot(ts, g_ee_ts[:, 2, 3], label="End-effector z [m]")
231+
plt.xlabel("Time [s]")
232+
plt.ylabel("End-effector position [m]")
233+
plt.legend()
234+
plt.grid(True)
235+
plt.box(True)
236+
plt.tight_layout()
237+
plt.show()
238+
239+
fig = plt.figure()
240+
ax = fig.add_subplot(111, projection="3d")
241+
p = ax.scatter(
242+
g_ee_ts[:, 0, 3], g_ee_ts[:, 1, 3], g_ee_ts[:, 2, 3], c=ts, cmap="viridis"
243+
)
244+
ax.axis("equal")
245+
ax.set_xlabel("X [m]")
246+
ax.set_ylabel("Y [m]")
247+
ax.set_zlabel("Z [m]")
248+
ax.set_title("End-effector trajectory (3D)")
249+
fig.colorbar(p, ax=ax, label="Time [s]")
250+
plt.show()
251+
252+
# =====================================================
253+
# Energy computation upon time
254+
# =====================================================
255+
U_ts = jax.vmap(jax.jit(partial(robot.potential_energy)))(q_ts)
256+
T_ts = jax.vmap(jax.jit(partial(robot.kinetic_energy)))(q_ts, q_d_ts)
257+
258+
plt.figure()
259+
plt.plot(ts, U_ts, label="Potential Energy")
260+
plt.plot(ts, T_ts, label="Kinetic Energy")
261+
plt.xlabel("Time (s)")
262+
plt.ylabel("Energy (J)")
263+
plt.legend()
264+
plt.title("Energy over Time")
265+
plt.grid(True)
266+
plt.box(True)
267+
plt.tight_layout()
268+
plt.show()
269+
270+
# =====================================================
271+
# Plot the robot configuration upon time
272+
# =====================================================
273+
animate_robot_matplotlib(
274+
robot,
275+
t_list=ts, # shape (T,)
276+
q_list=q_ts, # shape (T, DOF)
277+
num_points=50,
278+
interval=100, # ms
279+
slider=True,
280+
)

0 commit comments

Comments
 (0)