Skip to content
Merged
Show file tree
Hide file tree
Changes from 29 commits
Commits
Show all changes
35 commits
Select commit Hold shift + click to select a range
ab00835
Add plot backup, full-JAX PCS planar + Gaussian quadrature integratio…
solangegbv May 12, 2025
231c9e3
Change fori_loop to vmap/scan and cond to min in
solangegbv May 13, 2025
19d91ea
Transformed simulate_planar_pcs into a function
solangegbv May 28, 2025
63ed97f
Replacement of at.set by block
solangegbv Jun 4, 2025
6dd50a4
Add quadax dependency, recover original files, remove jit for math_utils
solangegbv Jun 5, 2025
7c17c6a
Changing Coriolis for loop to vmap,
solangegbv Jun 6, 2025
3bb65cf
Corrected type error in planar_pcs_num.py
solangegbv Jun 10, 2025
7bb0d43
Added formulas in SE2, Coriolis corrections
solangegbv Jun 16, 2025
6952264
Correction of the kinetic energy function
solangegbv Jun 17, 2025
81e5895
Fix jit decorators to apply JIT-compilation to
solangegbv Jun 20, 2025
62a7aa4
Creation of a test file for planar_pcs_num.py
solangegbv Jul 7, 2025
d72094a
Roll-back changes to symbolic expressions
mstoelzle Jul 21, 2025
877634a
Bumpy version number and add Solange as an author
mstoelzle Jul 21, 2025
b4beba7
Merge branch 'main' into numerical-derivation
mstoelzle Jul 21, 2025
b8c1fe6
Rename `planar_pcs` system to `planar_pcs_sym`
mstoelzle Jul 21, 2025
5fec620
Fix some type hinting errors
mstoelzle Jul 21, 2025
6838cec
Fix missing changes in last commit
mstoelzle Jul 21, 2025
9baef51
Rename `test_planar_pcs.py` to `test_planar_pcs_sym`
mstoelzle Jul 21, 2025
3eb5391
Fix some bugs
mstoelzle Jul 21, 2025
94c4928
Format systems
mstoelzle Jul 21, 2025
44242a2
Format the `tests` files
mstoelzle Jul 21, 2025
e021a71
Format the `utils` files
mstoelzle Jul 21, 2025
7107cbd
Exclude some test scripts from automated testing if they require gui
mstoelzle Jul 21, 2025
aa35610
Class for Planar PCS
solangegbv Jul 25, 2025
9d81090
3D PCS and documentation updates
solangegbv Jul 30, 2025
fc144e5
Documentation 3D-PCS
solangegbv Jul 30, 2025
df770ce
Merge branch 'main' into numerical-derivation
solangegbv Jul 30, 2025
18cde88
Correction for merging
solangegbv Jul 30, 2025
9739918
Ruff format of files
solangegbv Jul 30, 2025
f34aad4
Add Planar PCS class and documentation correction
solangegbv Jul 31, 2025
b8ea679
Correct Planar PCS documentation
solangegbv Jul 31, 2025
420e77d
Remove alias for planar_pcs_sym import
solangegbv Jul 31, 2025
0a13331
Add comment about operational_space_dynamical_matrices function
mstoelzle Jul 31, 2025
e6c0343
Change version to 0.1.0
solangegbv Jul 31, 2025
d9a996a
Merge branch 'numerical-derivation' of https://github.com/tud-phi/jax…
solangegbv Jul 31, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
283 changes: 283 additions & 0 deletions examples/simulate_pcs.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,283 @@
import jax

from jsrm.systems.pcs import PCS
import jax.numpy as jnp

from typing import Callable
from jax import Array

import numpy as onp

import matplotlib.pyplot as plt
from matplotlib.animation import FuncAnimation
from IPython.display import HTML

from diffrax import Tsit5

from functools import partial
from matplotlib.widgets import Slider

jax.config.update("jax_enable_x64", True) # double precision
jnp.set_printoptions(
threshold=jnp.inf,
linewidth=jnp.inf,
formatter={"float_kind": lambda x: "0" if x == 0 else f"{x:.2e}"},
)


def draw_robot_curve(
batched_forward_kinematics_fn: Callable,
L_max: float,
q: Array,
num_points: int = 50,
):
s_ps = jnp.linspace(0, L_max, num_points)
g_ps = batched_forward_kinematics_fn(q, s_ps)[:, :3, 3]

curve = onp.array(g_ps, dtype=onp.float64)
return curve # (N, 3)


def animate_robot_matplotlib(
robot: PCS,
t_list: Array, # shape (T,)
q_list: Array, # shape (T, DOF)
num_points: int = 50,
interval: int = 50,
slider: bool = None,
animation: bool = None,
show: bool = False,
):
if slider is None and animation is None:
raise ValueError("Either 'slider' or 'animation' must be set to True.")
if animation and slider:
raise ValueError(
"Cannot use both animation and slider at the same time. Choose one."
)
batched_forward_kinematics_fn = jax.vmap(
robot.forward_kinematics_fn, in_axes=(None, 0)
)
L_max = jnp.sum(robot.L)

width = jnp.linalg.norm(robot.L) * 1.5
height = width

fig = plt.figure(figsize=(8, 8))
ax = fig.add_subplot(111, projection="3d")
ax_slider = fig.add_axes([0.2, 0.05, 0.6, 0.03]) # [left, bottom, width, height]

if animation:
(line,) = ax.plot([], [], [], lw=4, color="blue")
ax.set_xlim(-width / 2, width / 2)
ax.set_ylim(-width / 2, width / 2)
ax.set_zlim(0, height)
title_text = ax.set_title("t = 0.00 s")

def init():
line.set_data([], [])
line.set_3d_properties([])
title_text.set_text("t = 0.00 s")
return line, title_text

def update(frame_idx):
q = q_list[frame_idx]
t = t_list[frame_idx]
curve = draw_robot_curve(
batched_forward_kinematics_fn, L_max, q, num_points
)
line.set_data(curve[:, 0], curve[:, 1])
line.set_3d_properties(curve[:, 2])
title_text.set_text(f"t = {t:.2f} s")
return line, title_text

ani = FuncAnimation(
fig,
update,
frames=len(q_list),
init_func=init,
blit=False,
interval=interval,
)

if show:
plt.show()

plt.close(fig)
return HTML(ani.to_jshtml())
if slider:

def update_plot(frame_idx):
ax.cla() # Clear current axes
ax.set_xlim(-width / 2, width / 2)
ax.set_ylim(-width / 2, width / 2)
ax.set_zlim(0, height)
ax.set_xlabel("X [m]")
ax.set_ylabel("Y [m]")
ax.set_zlabel("Z [m]")
ax.set_title(f"t = {t_list[frame_idx]:.2f} s")
q = q_list[frame_idx]
curve = draw_robot_curve(
batched_forward_kinematics_fn, L_max, q, num_points
)
ax.plot(curve[:, 0], curve[:, 1], curve[:, 2], lw=4, color="blue")
fig.canvas.draw_idle()

# Create slider
slider = Slider(
ax=ax_slider,
label="Frame",
valmin=0,
valmax=len(t_list) - 1,
valinit=0,
valstep=1,
)
slider.on_changed(update_plot)

update_plot(0) # Initial plot

if show:
plt.show()

plt.close(fig)
return HTML(
"Slider animation not implemented in HTML format. Use matplotlib directly to view the slider."
) # Slider cannot be converted to HTML


if __name__ == "__main__":
num_segments = 2
rho = 1070 * jnp.ones(
(num_segments,)
) # Volumetric density of Dragon Skin 20 [kg/m^3]
params = {
"p0": jnp.array(
[0.0, 0.0, 0.0, 0.0, 0.0, 0.0]
), # 1.0, 1.0, 1.0]), # Initial position and orientation
"l": 1e-1 * jnp.ones((num_segments,)),
"r": 2e-2 * jnp.ones((num_segments,)),
"rho": rho,
"g": jnp.array([0.0, 0.0, 9.81]), # Gravity vector [m/s^2]
"E": 2e3 * jnp.ones((num_segments,)), # Elastic modulus [Pa]
"G": 1e3 * jnp.ones((num_segments,)), # Shear modulus [Pa]
}
params["D"] = 1e-3 * jnp.diag(
(
jnp.repeat(
jnp.array([[1e0, 1e0, 1e0, 1e3, 1e3, 1e3]]), num_segments, axis=0
)
* params["l"][:, None]
).flatten()
)

# ======================================================
# Robot initialization
# ======================================================
robot = PCS(
num_segments=num_segments,
params=params,
order_gauss=5,
)

# =====================================================
# Simulation upon time
# =====================================================
# Initial configuration
q0 = jnp.repeat(
jnp.array([5.0 * jnp.pi, 0.0, 0.0, 0.0, 0.1, 0.2])[None, :],
num_segments,
axis=0,
).flatten()
# Initial velocities
qd0 = jnp.zeros_like(q0)

# Actuation parameters
tau = jnp.zeros_like(q0)
# WARNING: actuation_args need to be a tuple, even if it contains only one element
# so (tau, ) is necessary NOT (tau) or tau
actuation_args = (tau,)

# Simulation time parameters
t0 = 0.0
t1 = 2.0
dt = 1e-4
skip_step = 100 # how many time steps to skip in between video frames

# Solver
solver = Tsit5() # Runge-Kutta 5(4) method

ts, q_ts, q_d_ts = robot.resolve_upon_time(
q0=q0,
qd0=qd0,
actuation_args=actuation_args,
t0=t0,
t1=t1,
dt=dt,
skip_steps=skip_step,
max_steps=None,
)

# =====================================================
# End-effector position upon time
# =====================================================
forward_kinematics_end_effector = jax.jit(
partial(
robot.forward_kinematics_fn,
s=jnp.sum(robot.L), # end-effector position
)
)
g_ee_ts = jax.vmap(forward_kinematics_end_effector)(q_ts)

plt.figure()
plt.plot(ts, g_ee_ts[:, 0, 3], label="End-effector x [m]")
plt.plot(ts, g_ee_ts[:, 1, 3], label="End-effector y [m]")
plt.plot(ts, g_ee_ts[:, 2, 3], label="End-effector z [m]")
plt.xlabel("Time [s]")
plt.ylabel("End-effector position [m]")
plt.legend()
plt.grid(True)
plt.box(True)
plt.tight_layout()
plt.show()

fig = plt.figure()
ax = fig.add_subplot(111, projection="3d")
p = ax.scatter(
g_ee_ts[:, 0, 3], g_ee_ts[:, 1, 3], g_ee_ts[:, 2, 3], c=ts, cmap="viridis"
)
ax.axis("equal")
ax.set_xlabel("X [m]")
ax.set_ylabel("Y [m]")
ax.set_zlabel("Z [m]")
ax.set_title("End-effector trajectory (3D)")
fig.colorbar(p, ax=ax, label="Time [s]")
plt.show()

# =====================================================
# Energy computation upon time
# =====================================================
U_ts = jax.vmap(jax.jit(partial(robot.potential_energy)))(q_ts)
T_ts = jax.vmap(jax.jit(partial(robot.kinetic_energy)))(q_ts, q_d_ts)

plt.figure()
plt.plot(ts, U_ts, label="Potential Energy")
plt.plot(ts, T_ts, label="Kinetic Energy")
plt.xlabel("Time (s)")
plt.ylabel("Energy (J)")
plt.legend()
plt.title("Energy over Time")
plt.grid(True)
plt.box(True)
plt.tight_layout()
plt.show()

# =====================================================
# Plot the robot configuration upon time
# =====================================================
animate_robot_matplotlib(
robot,
t_list=ts, # shape (T,)
q_list=q_ts, # shape (T, DOF)
num_points=50,
interval=100, # ms
slider=True,
)
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@

import jsrm
from jsrm import ode_factory
from jsrm.systems import planar_pcs
from jsrm.systems import planar_pcs_sym

num_segments = 1

Expand All @@ -36,16 +36,19 @@
"G": 1e3 * jnp.ones((num_segments,)), # Shear modulus [Pa]
}
params["D"] = 1e-3 * jnp.diag(
(jnp.repeat(
jnp.array([[1e0, 1e3, 1e3]]), num_segments, axis=0
) * params["l"][:, None]).flatten()
(
jnp.repeat(jnp.array([[1e0, 1e3, 1e3]]), num_segments, axis=0)
* params["l"][:, None]
).flatten()
)

# activate all strains (i.e. bending, shear, and axial)
strain_selector = jnp.ones((3 * num_segments,), dtype=bool)

# define initial configuration
q0 = jnp.repeat(jnp.array([5.0 * jnp.pi, 0.1, 0.2])[None, :], num_segments, axis=0).flatten()
q0 = jnp.repeat(
jnp.array([5.0 * jnp.pi, 0.1, 0.2])[None, :], num_segments, axis=0
).flatten()
# number of generalized coordinates
n_q = q0.shape[0]

Expand Down Expand Up @@ -98,7 +101,7 @@ def draw_robot(

if __name__ == "__main__":
strain_basis, forward_kinematics_fn, dynamical_matrices_fn, auxiliary_fns = (
planar_pcs.factory(sym_exp_filepath, strain_selector)
planar_pcs_sym.factory(sym_exp_filepath, strain_selector)
)
# jit the functions
dynamical_matrices_fn = jax.jit(partial(dynamical_matrices_fn))
Expand Down Expand Up @@ -128,6 +131,8 @@ def draw_robot(
ode_fn = ode_factory(dynamical_matrices_fn, params, tau)
# jit the ODE function
ode_fn = jax.jit(ode_fn)
# jit the ODE function
ode_fn = jax.jit(ode_fn)
term = ODETerm(ode_fn)

sol = diffeqsolve(
Expand All @@ -145,30 +150,33 @@ def draw_robot(
# the evolution of the generalized coordinates
q_ts = sol.ys[:, :n_q]
# the evolution of the generalized velocities
q_d_ts = sol.ys[:, n_q:]
q_d_ts = sol.ys[:, n_q:]

s_max = jnp.array([jnp.sum(params["l"])])

forward_kinematics_fn_end_effector = partial(forward_kinematics_fn, params, s=s_max)
forward_kinematics_fn_end_effector = jax.jit(forward_kinematics_fn_end_effector)
forward_kinematics_fn_end_effector = vmap(forward_kinematics_fn_end_effector)

# evaluate the forward kinematics along the trajectory
chi_ee_ts = forward_kinematics_fn_end_effector(q_ts)
# plot the configuration vs time
plt.figure()
for segment_idx in range(num_segments):
plt.plot(
video_ts, q_ts[:, 3 * segment_idx + 0],
label=r"$\kappa_\mathrm{be," + str(segment_idx + 1) + "}$ [rad/m]"
video_ts,
q_ts[:, 3 * segment_idx + 0],
label=r"$\kappa_\mathrm{be," + str(segment_idx + 1) + "}$ [rad/m]",
)
plt.plot(
video_ts, q_ts[:, 3 * segment_idx + 1],
label=r"$\sigma_\mathrm{sh," + str(segment_idx + 1) + "}$ [-]"
video_ts,
q_ts[:, 3 * segment_idx + 1],
label=r"$\sigma_\mathrm{sh," + str(segment_idx + 1) + "}$ [-]",
)
plt.plot(
video_ts, q_ts[:, 3 * segment_idx + 2],
label=r"$\sigma_\mathrm{ax," + str(segment_idx + 1) + "}$ [-]"
video_ts,
q_ts[:, 3 * segment_idx + 2],
label=r"$\sigma_\mathrm{ax," + str(segment_idx + 1) + "}$ [-]",
)
plt.xlabel("Time [s]")
plt.ylabel("Configuration")
Expand Down
Loading
Loading