Skip to content

Commit aa35610

Browse files
committed
Class for Planar PCS
1 parent 7107cbd commit aa35610

File tree

5 files changed

+1921
-9
lines changed

5 files changed

+1921
-9
lines changed
Lines changed: 232 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,232 @@
1+
import jax
2+
3+
from jsrm.systems.planar_pcs import PlanarPCSNum
4+
import jax.numpy as jnp
5+
6+
from typing import Callable, Dict
7+
from jax import Array
8+
9+
import numpy as onp
10+
11+
import matplotlib.pyplot as plt
12+
from matplotlib.animation import FuncAnimation
13+
from IPython.display import HTML
14+
15+
from diffrax import Tsit5
16+
17+
from functools import partial
18+
19+
jax.config.update("jax_enable_x64", True) # double precision
20+
jnp.set_printoptions(
21+
threshold =jnp.inf,
22+
linewidth =jnp.inf,
23+
formatter ={'float_kind': lambda x: '0' if x==0 else f'{x:.2e}'}
24+
)
25+
26+
def draw_robot_curve_class(
27+
batched_forward_kinematics_fn: Callable,
28+
params: Dict[str, Array],
29+
q: Array,
30+
width: int,
31+
height: int,
32+
num_points: int = 50,
33+
):
34+
h, w = height, width
35+
ppm = h / (2.0 * jnp.sum(params["l"]))
36+
s_ps = jnp.linspace(0, jnp.sum(params["l"]), num_points)
37+
chi_ps = batched_forward_kinematics_fn(q, s_ps)
38+
39+
# Position du robot dans les coordonnées pixel
40+
curve_origin = onp.array([w // 2, 0.1 * h])
41+
curve = onp.array((curve_origin[:, None] + chi_ps[1:, :] * ppm), dtype=onp.float32).T
42+
curve[:, 1] = h - curve[:, 1]
43+
44+
return curve # (N, 2)
45+
46+
def plot_robot_matplotlib(
47+
batched_forward_kinematics_fn: Callable,
48+
params: Dict[str, Array],
49+
q: Array,
50+
width: int = 500,
51+
height: int = 500,
52+
num_points: int = 50,
53+
show: bool = False,
54+
):
55+
fig, ax = plt.subplots()
56+
ax.set_xlim(0, width)
57+
ax.set_ylim(0, height)
58+
ax.invert_yaxis()
59+
line, = ax.plot([], [], lw=4, color="blue")
60+
curve = draw_robot_curve_class(batched_forward_kinematics_fn, params, q, width, height, num_points)
61+
line.set_data(curve[:, 0], curve[:, 1])
62+
63+
if show:
64+
plt.show(fig)
65+
66+
return fig
67+
68+
def animate_robot_matplotlib(
69+
batched_forward_kinematics_fn: Callable,
70+
params: Dict[str, Array],
71+
t_list: Array, # shape (T,)
72+
q_list: Array, # shape (T, DOF)
73+
width: int = 500,
74+
height: int = 500,
75+
num_points: int = 50,
76+
interval: int = 50,
77+
boolshow: bool = True,
78+
):
79+
fig, ax = plt.subplots()
80+
ax.set_xlim(0, width)
81+
ax.set_ylim(0, height)
82+
ax.invert_yaxis()
83+
line, = ax.plot([], [], lw=4, color="blue")
84+
title_text = ax.set_title("t = 0.00 s")
85+
86+
def init():
87+
line.set_data([], [])
88+
title_text.set_text("t = 0.00 s")
89+
return line, title_text
90+
91+
def update(frame_idx):
92+
q = q_list[frame_idx]
93+
t = t_list[frame_idx]
94+
curve = draw_robot_curve_class(batched_forward_kinematics_fn, params, q, width, height, num_points)
95+
line.set_data(curve[:, 0], curve[:, 1])
96+
title_text.set_text(f"t = {t:.2f} s")
97+
return line, title_text
98+
99+
ani = FuncAnimation(
100+
fig,
101+
update,
102+
frames=len(q_list),
103+
init_func=init,
104+
blit=False,
105+
interval=interval)
106+
107+
if boolshow:
108+
plt.show()
109+
plt.close(fig)
110+
return HTML(ani.to_jshtml())
111+
112+
if __name__ == "__main__":
113+
num_segments = 2
114+
rho = 1070 * jnp.ones((num_segments,)) # Volumetric density of Dragon Skin 20 [kg/m^3]
115+
params = {
116+
"th0": jnp.array(0.0), # initial orientation angle [rad]
117+
"l": 1e-1 * jnp.ones((num_segments,)),
118+
"r": 2e-2 * jnp.ones((num_segments,)),
119+
"rho": rho,
120+
"g": jnp.array([0.0, 9.81]),
121+
"E": 2e3 * jnp.ones((num_segments,)), # Elastic modulus [Pa]
122+
"G": 1e3 * jnp.ones((num_segments,)), # Shear modulus [Pa]
123+
}
124+
params["D"] = 1e-3 * jnp.diag(
125+
(jnp.repeat(
126+
jnp.array([[1e0, 1e3, 1e3]]), num_segments, axis=0
127+
) * params["l"][:, None]).flatten()
128+
)
129+
130+
# ======================================================
131+
# Robot initialization
132+
# ======================================================
133+
robot = PlanarPCSNum(
134+
num_segments=num_segments,
135+
params=params,
136+
order_gauss=5,
137+
)
138+
139+
# =====================================================
140+
# Simulation upon time
141+
# =====================================================
142+
# Initial configuration
143+
q0 = jnp.repeat(jnp.array([5.0 * jnp.pi, 0.1, 0.2])[None, :], num_segments, axis=0).flatten()
144+
# Initial velocities
145+
qd0 = jnp.zeros_like(q0)
146+
147+
# Actuation parameters
148+
tau = jnp.zeros_like(q0)
149+
# WARNING: actuation_args need to be a tuple, even if it contains only one element
150+
actuation_args = (tau,)
151+
152+
# Simulation time parameters
153+
t0 = 0.0
154+
t1 = 2.0
155+
dt = 1e-4
156+
skip_step = 100 # how many time steps to skip in between video frames
157+
158+
# Solver
159+
solver = Tsit5() # Runge-Kutta 5(4) method
160+
161+
ts, q_ts, q_d_ts = robot.resolve_upon_time(
162+
q0=q0,
163+
qd0=qd0,
164+
actuation_args=actuation_args,
165+
t0=t0,
166+
t1=t1,
167+
dt=dt,
168+
skip_steps=skip_step,
169+
max_steps=None
170+
)
171+
172+
# =====================================================
173+
# End-effector position upon time
174+
# =====================================================
175+
forward_kinematics_end_effector = jax.jit(partial(
176+
robot.forward_kinematics_fn,
177+
s=jnp.sum(robot.l) # end-effector position
178+
))
179+
chi_ee_ts = jax.vmap(forward_kinematics_end_effector)(q_ts)
180+
181+
plt.figure()
182+
plt.plot(ts, chi_ee_ts[:, 1], label="End-effector x [m]")
183+
plt.plot(ts, chi_ee_ts[:, 2], label="End-effector y [m]")
184+
plt.xlabel("Time [s]")
185+
plt.ylabel("End-effector position [m]")
186+
plt.legend()
187+
plt.grid(True)
188+
plt.box(True)
189+
plt.tight_layout()
190+
plt.show()
191+
192+
plt.figure()
193+
plt.scatter(chi_ee_ts[:, 1], chi_ee_ts[:, 2], c=ts, cmap="viridis")
194+
plt.axis("equal")
195+
plt.grid(True)
196+
plt.xlabel("End-effector x [m]")
197+
plt.ylabel("End-effector y [m]")
198+
plt.colorbar(label="Time [s]")
199+
plt.tight_layout()
200+
plt.show()
201+
202+
# =====================================================
203+
# Energy computation upon time
204+
# =====================================================
205+
U_ts = jax.vmap(jax.jit(partial(robot.potential_energy)))(q_ts)
206+
T_ts = jax.vmap(jax.jit(partial(robot.kinetic_energy)))(q_ts, q_d_ts)
207+
208+
plt.figure()
209+
plt.plot(ts, U_ts, label="Potential Energy")
210+
plt.plot(ts, T_ts, label="Kinetic Energy")
211+
plt.xlabel("Time (s)")
212+
plt.ylabel("Energy (J)")
213+
plt.legend()
214+
plt.title("Energy over Time")
215+
plt.grid(True)
216+
plt.box(True)
217+
plt.tight_layout()
218+
plt.show()
219+
220+
# =====================================================
221+
# Plot the robot configuration upon time
222+
# =====================================================
223+
animate_robot_matplotlib(
224+
batched_forward_kinematics_fn=jax.vmap(robot.forward_kinematics_fn, in_axes=(None, 0), out_axes=-1),
225+
params=params,
226+
t_list=ts, # shape (T,)
227+
q_list=q_ts, # shape (T, DOF)
228+
width=700,
229+
height=700,
230+
num_points=50,
231+
interval=100, #ms
232+
)

0 commit comments

Comments
 (0)