|
| 1 | +import jax |
| 2 | + |
| 3 | +from jsrm.systems.pcs import PCS |
| 4 | +import jax.numpy as jnp |
| 5 | + |
| 6 | +from typing import Callable |
| 7 | +from jax import Array |
| 8 | + |
| 9 | +import numpy as onp |
| 10 | + |
| 11 | +import matplotlib.pyplot as plt |
| 12 | +from matplotlib.animation import FuncAnimation |
| 13 | +from IPython.display import HTML |
| 14 | + |
| 15 | +from diffrax import Tsit5 |
| 16 | + |
| 17 | +from functools import partial |
| 18 | +from matplotlib.widgets import Slider |
| 19 | + |
| 20 | +jax.config.update("jax_enable_x64", True) # double precision |
| 21 | +jnp.set_printoptions( |
| 22 | + threshold=jnp.inf, |
| 23 | + linewidth=jnp.inf, |
| 24 | + formatter={"float_kind": lambda x: "0" if x == 0 else f"{x:.2e}"}, |
| 25 | +) |
| 26 | + |
| 27 | + |
| 28 | +def draw_robot_curve( |
| 29 | + batched_forward_kinematics: Callable, |
| 30 | + L_max: float, |
| 31 | + q: Array, |
| 32 | + num_points: int = 50, |
| 33 | +): |
| 34 | + s_ps = jnp.linspace(0, L_max, num_points) |
| 35 | + g_ps = batched_forward_kinematics(q, s_ps)[:, :3, 3] |
| 36 | + |
| 37 | + curve = onp.array(g_ps, dtype=onp.float64) |
| 38 | + return curve # (N, 3) |
| 39 | + |
| 40 | + |
| 41 | +def animate_robot_matplotlib( |
| 42 | + robot: PCS, |
| 43 | + t_list: Array, # shape (T,) |
| 44 | + q_list: Array, # shape (T, DOF) |
| 45 | + num_points: int = 50, |
| 46 | + interval: int = 50, |
| 47 | + slider: bool = None, |
| 48 | + animation: bool = None, |
| 49 | + show: bool = True, |
| 50 | +): |
| 51 | + if slider is None and animation is None: |
| 52 | + raise ValueError("Either 'slider' or 'animation' must be set to True.") |
| 53 | + if animation and slider: |
| 54 | + raise ValueError( |
| 55 | + "Cannot use both animation and slider at the same time. Choose one." |
| 56 | + ) |
| 57 | + |
| 58 | + batched_forward_kinematics = jax.vmap(robot.forward_kinematics, in_axes=(None, 0)) |
| 59 | + L_max = jnp.sum(robot.L) |
| 60 | + |
| 61 | + width = jnp.linalg.norm(robot.L) * 3 |
| 62 | + height = width |
| 63 | + |
| 64 | + fig = plt.figure() |
| 65 | + ax = fig.add_subplot(111, projection="3d") |
| 66 | + ax_slider = fig.add_axes([0.2, 0.05, 0.6, 0.03]) # [left, bottom, width, height] |
| 67 | + |
| 68 | + if animation: |
| 69 | + (line,) = ax.plot([], [], [], lw=4, color="blue") |
| 70 | + ax.set_xlim(-width / 2, width / 2) |
| 71 | + ax.set_ylim(-width / 2, width / 2) |
| 72 | + ax.set_zlim(0, height) |
| 73 | + title_text = ax.set_title("t = 0.00 s") |
| 74 | + |
| 75 | + def init(): |
| 76 | + line.set_data([], []) |
| 77 | + line.set_3d_properties([]) |
| 78 | + title_text.set_text("t = 0.00 s") |
| 79 | + return line, title_text |
| 80 | + |
| 81 | + def update(frame_idx): |
| 82 | + q = q_list[frame_idx] |
| 83 | + t = t_list[frame_idx] |
| 84 | + curve = draw_robot_curve(batched_forward_kinematics, L_max, q, num_points) |
| 85 | + line.set_data(curve[:, 0], curve[:, 1]) |
| 86 | + line.set_3d_properties(curve[:, 2]) |
| 87 | + title_text.set_text(f"t = {t:.2f} s") |
| 88 | + return line, title_text |
| 89 | + |
| 90 | + ani = FuncAnimation( |
| 91 | + fig, |
| 92 | + update, |
| 93 | + frames=len(q_list), |
| 94 | + init_func=init, |
| 95 | + blit=False, |
| 96 | + interval=interval, |
| 97 | + ) |
| 98 | + |
| 99 | + if show: |
| 100 | + plt.show() |
| 101 | + |
| 102 | + plt.close(fig) |
| 103 | + return HTML(ani.to_jshtml()) |
| 104 | + |
| 105 | + elif slider: |
| 106 | + |
| 107 | + def update_plot(frame_idx): |
| 108 | + ax.cla() # Clear current axes |
| 109 | + ax.set_xlim(-width / 2, width / 2) |
| 110 | + ax.set_ylim(-width / 2, width / 2) |
| 111 | + ax.set_zlim(0, height) |
| 112 | + ax.set_xlabel("X [m]") |
| 113 | + ax.set_ylabel("Y [m]") |
| 114 | + ax.set_zlabel("Z [m]") |
| 115 | + ax.set_title(f"t = {t_list[frame_idx]:.2f} s") |
| 116 | + q = q_list[frame_idx] |
| 117 | + curve = draw_robot_curve(batched_forward_kinematics, L_max, q, num_points) |
| 118 | + ax.plot(curve[:, 0], curve[:, 1], curve[:, 2], lw=4, color="blue") |
| 119 | + fig.canvas.draw_idle() |
| 120 | + |
| 121 | + # Create slider |
| 122 | + slider = Slider( |
| 123 | + ax=ax_slider, |
| 124 | + label="Frame", |
| 125 | + valmin=0, |
| 126 | + valmax=len(t_list) - 1, |
| 127 | + valinit=0, |
| 128 | + valstep=1, |
| 129 | + ) |
| 130 | + slider.on_changed(update_plot) |
| 131 | + |
| 132 | + update_plot(0) # Initial plot |
| 133 | + |
| 134 | + if show: |
| 135 | + plt.show() |
| 136 | + |
| 137 | + plt.close(fig) |
| 138 | + return HTML( |
| 139 | + "Slider animation not implemented in HTML format. Use matplotlib directly to view the slider." |
| 140 | + ) # Slider cannot be converted to HTML |
| 141 | + |
| 142 | + |
| 143 | +if __name__ == "__main__": |
| 144 | + num_segments = 2 |
| 145 | + rho = 1070 * jnp.ones( |
| 146 | + (num_segments,) |
| 147 | + ) # Volumetric density of Dragon Skin 20 [kg/m^3] |
| 148 | + params = { |
| 149 | + "p0": jnp.array( |
| 150 | + [0.0, 0.0, 0.0, 0.0, 0.0, 0.0] |
| 151 | + ), # Initial position and orientation |
| 152 | + "l": 1e-1 * jnp.ones((num_segments,)), |
| 153 | + "r": 2e-2 * jnp.ones((num_segments,)), |
| 154 | + "rho": rho, |
| 155 | + "g": jnp.array([0.0, 0.0, -9.81]), # Gravity vector [m/s^2] |
| 156 | + "E": 2e3 * jnp.ones((num_segments,)), # Elastic modulus [Pa] |
| 157 | + "G": 1e3 * jnp.ones((num_segments,)), # Shear modulus [Pa] |
| 158 | + } |
| 159 | + params["D"] = 1e-3 * jnp.diag( |
| 160 | + ( |
| 161 | + jnp.repeat( |
| 162 | + jnp.array([[1e0, 1e0, 1e0, 1e3, 1e3, 1e3]]), num_segments, axis=0 |
| 163 | + ) |
| 164 | + * params["l"][:, None] |
| 165 | + ).flatten() |
| 166 | + ) |
| 167 | + |
| 168 | + # ====================================================== |
| 169 | + # Robot initialization |
| 170 | + # ====================================================== |
| 171 | + robot = PCS( |
| 172 | + num_segments=num_segments, |
| 173 | + params=params, |
| 174 | + order_gauss=5, |
| 175 | + ) |
| 176 | + |
| 177 | + # ===================================================== |
| 178 | + # Simulation upon time |
| 179 | + # ===================================================== |
| 180 | + # Initial configuration |
| 181 | + q0 = jnp.repeat( |
| 182 | + jnp.array([5.0 * jnp.pi, 0.0, 0.0, 0.0, 0.1, 0.2])[None, :], |
| 183 | + num_segments, |
| 184 | + axis=0, |
| 185 | + ).flatten() |
| 186 | + # Initial velocities |
| 187 | + qd0 = jnp.zeros_like(q0) |
| 188 | + |
| 189 | + # Actuation parameters |
| 190 | + tau = jnp.zeros_like(q0) |
| 191 | + # WARNING: actuation_args need to be a tuple, even if it contains only one element |
| 192 | + # so (tau, ) is necessary NOT (tau) or tau |
| 193 | + actuation_args = (tau,) |
| 194 | + |
| 195 | + # Simulation time parameters |
| 196 | + t0 = 0.0 |
| 197 | + t1 = 2.0 |
| 198 | + dt = 1e-4 |
| 199 | + skip_step = 100 # how many time steps to skip in between video frames |
| 200 | + |
| 201 | + # Solver |
| 202 | + solver = Tsit5() # Runge-Kutta 5(4) method |
| 203 | + |
| 204 | + ts, q_ts, q_d_ts = robot.resolve_upon_time( |
| 205 | + q0=q0, |
| 206 | + qd0=qd0, |
| 207 | + actuation_args=actuation_args, |
| 208 | + t0=t0, |
| 209 | + t1=t1, |
| 210 | + dt=dt, |
| 211 | + skip_steps=skip_step, |
| 212 | + solver=solver, |
| 213 | + max_steps=None, |
| 214 | + ) |
| 215 | + |
| 216 | + # ===================================================== |
| 217 | + # End-effector position upon time |
| 218 | + # ===================================================== |
| 219 | + forward_kinematics_end_effector = jax.jit( |
| 220 | + partial( |
| 221 | + robot.forward_kinematics, |
| 222 | + s=jnp.sum(robot.L), # end-effector position |
| 223 | + ) |
| 224 | + ) |
| 225 | + g_ee_ts = jax.vmap(forward_kinematics_end_effector)(q_ts) |
| 226 | + |
| 227 | + plt.figure() |
| 228 | + plt.plot(ts, g_ee_ts[:, 0, 3], label="End-effector x [m]") |
| 229 | + plt.plot(ts, g_ee_ts[:, 1, 3], label="End-effector y [m]") |
| 230 | + plt.plot(ts, g_ee_ts[:, 2, 3], label="End-effector z [m]") |
| 231 | + plt.xlabel("Time [s]") |
| 232 | + plt.ylabel("End-effector position [m]") |
| 233 | + plt.legend() |
| 234 | + plt.grid(True) |
| 235 | + plt.box(True) |
| 236 | + plt.tight_layout() |
| 237 | + plt.show() |
| 238 | + |
| 239 | + fig = plt.figure() |
| 240 | + ax = fig.add_subplot(111, projection="3d") |
| 241 | + p = ax.scatter( |
| 242 | + g_ee_ts[:, 0, 3], g_ee_ts[:, 1, 3], g_ee_ts[:, 2, 3], c=ts, cmap="viridis" |
| 243 | + ) |
| 244 | + ax.axis("equal") |
| 245 | + ax.set_xlabel("X [m]") |
| 246 | + ax.set_ylabel("Y [m]") |
| 247 | + ax.set_zlabel("Z [m]") |
| 248 | + ax.set_title("End-effector trajectory (3D)") |
| 249 | + fig.colorbar(p, ax=ax, label="Time [s]") |
| 250 | + plt.show() |
| 251 | + |
| 252 | + # ===================================================== |
| 253 | + # Energy computation upon time |
| 254 | + # ===================================================== |
| 255 | + U_ts = jax.vmap(jax.jit(partial(robot.potential_energy)))(q_ts) |
| 256 | + T_ts = jax.vmap(jax.jit(partial(robot.kinetic_energy)))(q_ts, q_d_ts) |
| 257 | + |
| 258 | + plt.figure() |
| 259 | + plt.plot(ts, U_ts, label="Potential Energy") |
| 260 | + plt.plot(ts, T_ts, label="Kinetic Energy") |
| 261 | + plt.xlabel("Time (s)") |
| 262 | + plt.ylabel("Energy (J)") |
| 263 | + plt.legend() |
| 264 | + plt.title("Energy over Time") |
| 265 | + plt.grid(True) |
| 266 | + plt.box(True) |
| 267 | + plt.tight_layout() |
| 268 | + plt.show() |
| 269 | + |
| 270 | + # ===================================================== |
| 271 | + # Plot the robot configuration upon time |
| 272 | + # ===================================================== |
| 273 | + animate_robot_matplotlib( |
| 274 | + robot, |
| 275 | + t_list=ts, # shape (T,) |
| 276 | + q_list=q_ts, # shape (T, DOF) |
| 277 | + num_points=50, |
| 278 | + interval=100, # ms |
| 279 | + slider=True, |
| 280 | + ) |
0 commit comments