Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@

import jsrm
from jsrm import ode_factory
from jsrm.systems import pneumatic_planar_pcs
from jsrm.systems import pneumatically_actuated_planar_pcs

num_segments = 1

Expand Down Expand Up @@ -48,7 +48,7 @@
strain_selector = jnp.array([True, False, True])[None, :].repeat(num_segments, axis=0).flatten()

B_xi, forward_kinematics_fn, dynamical_matrices_fn, auxiliary_fns = (
pneumatic_planar_pcs.factory(
pneumatically_actuated_planar_pcs.factory(
num_segments, sym_exp_filepath, strain_selector, # simplified_actuation_mapping=True
)
)
Expand Down
250 changes: 250 additions & 0 deletions examples/simulate_tendon_actuated_planar_pcs.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,250 @@
import cv2 # importing cv2
from functools import partial
import jax

jax.config.update("jax_enable_x64", True) # double precision
from diffrax import diffeqsolve, Euler, ODETerm, SaveAt, Tsit5
from jax import Array, vmap
from jax import numpy as jnp
import matplotlib.pyplot as plt
import numpy as onp
from pathlib import Path
from typing import Callable, Dict

import jsrm
from jsrm import ode_factory
from jsrm.systems import tendon_actuated_planar_pcs as planar_pcs

num_segments = 1

# filepath to symbolic expressions
sym_exp_filepath = (
Path(jsrm.__file__).parent
/ "symbolic_expressions"
/ f"planar_pcs_ns-{num_segments}.dill"
)

# set parameters
rho = 1070 * jnp.ones((num_segments,)) # Volumetric density of Dragon Skin 20 [kg/m^3]
params = {
"th0": jnp.array(0.0), # initial orientation angle [rad]
"l": 1e-1 * jnp.ones((num_segments,)),
"r": 2e-2 * jnp.ones((num_segments,)),
"rho": rho,
"g": jnp.array([0.0, 9.81]),
"E": 2e3 * jnp.ones((num_segments,)), # Elastic modulus [Pa]
"G": 1e3 * jnp.ones((num_segments,)), # Shear modulus [Pa]
"d": 2e-2 * jnp.array([[1.0, -1.0]]).repeat(num_segments, axis=0), # distance of tendons from the central axis [m]
}
params["D"] = 1e-3 * jnp.diag(
(jnp.repeat(
jnp.array([[1e0, 1e3, 1e3]]), num_segments, axis=0
) * params["l"][:, None]).flatten()
)

# activate all strains (i.e. bending, shear, and axial)
strain_selector = jnp.ones((3 * num_segments,), dtype=bool)
# actuation selector for the segments
segment_actuation_selector = jnp.ones((num_segments,), dtype=bool)
# segment_actuation_selector = jnp.array([False, True]) # only the last segment is actuated

# define initial configuration
q0 = jnp.repeat(jnp.array([5.0 * jnp.pi, 0.1, 0.2])[None, :], num_segments, axis=0).flatten()
# number of generalized coordinates
n_q = q0.shape[0]

# set simulation parameters
dt = 1e-4 # time step
ts = jnp.arange(0.0, 10.0, dt) # time steps
skip_step = 10 # how many time steps to skip in between video frames
video_ts = ts[::skip_step] # time steps for video

# video settings
video_width, video_height = 700, 700 # img height and width
video_path = Path(__file__).parent / "videos" / f"planar_pcs_ns-{num_segments}.mp4"


def draw_robot(
batched_forward_kinematics_fn: Callable,
params: Dict[str, Array],
q: Array,
width: int,
height: int,
num_points: int = 50,
) -> onp.ndarray:
# plotting in OpenCV
h, w = height, width # img height and width
ppm = h / (2.0 * jnp.sum(params["l"])) # pixel per meter
base_color = (0, 0, 0) # black robot_color in BGR
robot_color = (255, 0, 0) # black robot_color in BGR

# we use for plotting N points along the length of the robot
s_ps = jnp.linspace(0, jnp.sum(params["l"]), num_points)

# poses along the robot of shape (3, N)
chi_ps = batched_forward_kinematics_fn(params, q, s_ps)

img = 255 * onp.ones((w, h, 3), dtype=jnp.uint8) # initialize background to white
curve_origin = onp.array(
[w // 2, 0.1 * h], dtype=onp.int32
) # in x-y pixel coordinates
# draw base
cv2.rectangle(img, (0, h - curve_origin[1]), (w, h), color=base_color, thickness=-1)
# transform robot poses to pixel coordinates
# should be of shape (N, 2)
curve = onp.array((curve_origin + chi_ps[:2, :].T * ppm), dtype=onp.int32)
# invert the v pixel coordinate
curve[:, 1] = h - curve[:, 1]
cv2.polylines(img, [curve], isClosed=False, color=robot_color, thickness=10)

return img


if __name__ == "__main__":
strain_basis, forward_kinematics_fn, dynamical_matrices_fn, auxiliary_fns = (
planar_pcs.factory(num_segments, sym_exp_filepath, strain_selector, segment_actuation_selector=segment_actuation_selector)
)
actuation_mapping_fn = auxiliary_fns["actuation_mapping_fn"]
# jit the functions
dynamical_matrices_fn = jax.jit(partial(dynamical_matrices_fn))
batched_forward_kinematics = vmap(
forward_kinematics_fn, in_axes=(None, None, 0), out_axes=-1
)

# test the actuation mapping function
xi_eq = jnp.array([0.0, 0.0, 1.0])[None].repeat(num_segments, axis=0).flatten()
B_xi = strain_basis
# call the actuation mapping function
A = actuation_mapping_fn(
forward_kinematics_fn,
auxiliary_fns["jacobian_fn"],
params,
B_xi,
xi_eq,
jnp.zeros_like(q0),
)
print("A =\n", A)

x0 = jnp.concatenate([q0, jnp.zeros_like(q0)]) # initial condition
u = jnp.array([1.0, 1.0])[None].repeat(num_segments, axis=0).flatten() # tendon tensions
# u = 2e-1 * jnp.array([2.0, 0.0, 0.0, 1.0])
print("u =\n", u)

ode_fn = ode_factory(dynamical_matrices_fn, params, u)
term = ODETerm(ode_fn)

sol = diffeqsolve(
term,
solver=Tsit5(),
t0=ts[0],
t1=ts[-1],
dt0=dt,
y0=x0,
max_steps=None,
saveat=SaveAt(ts=video_ts),
)

print("sol.ys =\n", sol.ys)
# the evolution of the generalized coordinates
q_ts = sol.ys[:, :n_q]
# the evolution of the generalized velocities
q_d_ts = sol.ys[:, n_q:]

# evaluate the forward kinematics along the trajectory
chi_ee_ts = vmap(forward_kinematics_fn, in_axes=(None, 0, None))(
params, q_ts, jnp.array([jnp.sum(params["l"])])
)
# plot the configuration vs time
plt.figure()
for segment_idx in range(num_segments):
plt.plot(
video_ts, q_ts[:, 3 * segment_idx + 0],
label=r"$\kappa_\mathrm{be," + str(segment_idx + 1) + "}$ [rad/m]"
)
plt.plot(
video_ts, q_ts[:, 3 * segment_idx + 1],
label=r"$\sigma_\mathrm{sh," + str(segment_idx + 1) + "}$ [-]"
)
plt.plot(
video_ts, q_ts[:, 3 * segment_idx + 2],
label=r"$\sigma_\mathrm{ax," + str(segment_idx + 1) + "}$ [-]"
)
plt.xlabel("Time [s]")
plt.ylabel("Configuration")
plt.legend()
plt.grid(True)
plt.tight_layout()
plt.show()
# plot end-effector position vs time
plt.figure()
plt.plot(video_ts, chi_ee_ts[:, 0], label="x")
plt.plot(video_ts, chi_ee_ts[:, 1], label="y")
plt.xlabel("Time [s]")
plt.ylabel("End-effector Position [m]")
plt.legend()
plt.grid(True)
plt.box(True)
plt.tight_layout()
plt.show()
# plot the end-effector position in the x-y plane as a scatter plot with the time as the color
plt.figure()
plt.scatter(chi_ee_ts[:, 0], chi_ee_ts[:, 1], c=video_ts, cmap="viridis")
plt.axis("equal")
plt.grid(True)
plt.xlabel("End-effector x [m]")
plt.ylabel("End-effector y [m]")
plt.colorbar(label="Time [s]")
plt.tight_layout()
plt.show()
# plt.figure()
# plt.plot(chi_ee_ts[:, 0], chi_ee_ts[:, 1])
# plt.axis("equal")
# plt.grid(True)
# plt.xlabel("End-effector x [m]")
# plt.ylabel("End-effector y [m]")
# plt.tight_layout()
# plt.show()

# plot the energy along the trajectory
kinetic_energy_fn_vmapped = vmap(
partial(auxiliary_fns["kinetic_energy_fn"], params)
)
potential_energy_fn_vmapped = vmap(
partial(auxiliary_fns["potential_energy_fn"], params)
)
U_ts = potential_energy_fn_vmapped(q_ts)
T_ts = kinetic_energy_fn_vmapped(q_ts, q_d_ts)
plt.figure()
plt.plot(video_ts, U_ts, label="Potential energy")
plt.plot(video_ts, T_ts, label="Kinetic energy")
plt.xlabel("Time [s]")
plt.ylabel("Energy [J]")
plt.legend()
plt.grid(True)
plt.box(True)
plt.tight_layout()
plt.show()

# create video
fourcc = cv2.VideoWriter_fourcc(*"mp4v")
video_path.parent.mkdir(parents=True, exist_ok=True)
video = cv2.VideoWriter(
str(video_path),
fourcc,
1 / (skip_step * dt), # fps
(video_width, video_height),
)

for time_idx, t in enumerate(video_ts):
x = sol.ys[time_idx]
img = draw_robot(
batched_forward_kinematics,
params,
x[: (x.shape[0] // 2)],
video_width,
video_height,
)
video.write(img)

video.release()
print(f"Video saved at {video_path}")
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ name = "jsrm" # Required
#
# For a discussion on single-sourcing the version, see
# https://packaging.python.org/guides/single-sourcing-package-version/
version = "0.0.14" # Required
version = "0.0.15" # Required

# This is a one-line description or tagline of what your project does. This
# corresponds to the "Summary" metadata field:
Expand Down
30 changes: 25 additions & 5 deletions src/jsrm/symbolic_derivation/planar_pcs.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@ def symbolically_derive_planar_pcs_model(
sp.symbols(f"r1:{num_segments + 1}", nonnegative=True)
) # radius of each segment [m]
g_syms = list(sp.symbols(f"g1:3")) # gravity vector
d = sp.symbols("d", real=True, nonnegative=True) # # distance of the tendon from the neutral axis

# planar strains and their derivatives
xi_syms = list(sp.symbols(f"xi1:{num_dof + 1}", nonzero=True)) # strains
Expand All @@ -59,6 +60,10 @@ def symbolically_derive_planar_pcs_model(
chi_sms = []
# Jacobians (positional + orientation) in each segment as a function of the point coordinate s and its time derivative
J_sms, J_d_sms = [], []
# tendon lengths for each segment as a function of the point coordinate s
L_tend_sms = []
# tendon length jacobians for each segment as a function of the point coordinate s
J_tend_sms = []
# cross-sectional area of each segment
A = sp.zeros(num_segments)
# second area moment of inertia of each segment
Expand All @@ -74,13 +79,14 @@ def symbolically_derive_planar_pcs_model(
# initialize
th_prev = th0
p_prev = sp.Matrix([0, 0])
L_tend = 0 # tendon length
for i in range(num_segments):
# bending strain
kappa = xi[3 * i]
kappa_be = xi[3 * i]
# shear strain
sigma_x = xi[3 * i + 1]
sigma_sh = xi[3 * i + 1]
# axial strain
sigma_y = xi[3 * i + 2]
sigma_ax = xi[3 * i + 2]

# compute the cross-sectional area of the rod
A[i] = sp.pi * r[i] ** 2
Expand All @@ -89,13 +95,13 @@ def symbolically_derive_planar_pcs_model(
I[i] = A[i] ** 2 / (4 * sp.pi)

# planar orientation of robot as a function of the point s
th = th_prev + s * kappa
th = th_prev + s * kappa_be

# absolute rotation of link
R = sp.Matrix([[sp.cos(th), -sp.sin(th)], [sp.sin(th), sp.cos(th)]])

# derivative of Cartesian position as function of the point s
dp_ds = R @ sp.Matrix([sigma_x, sigma_y])
dp_ds = R @ sp.Matrix([sigma_sh, sigma_ax])

# position along the current rod as a function of the point s
p = p_prev + sp.integrate(dp_ds, (s, 0.0, s))
Expand Down Expand Up @@ -146,12 +152,24 @@ def symbolically_derive_planar_pcs_model(
# add potential energy of segment to previous segments
U_g = U_g + U_gi

# simplify derived tendon length
L_tend = L_tend + s * (1 + kappa_be * d) * sp.sqrt(sigma_sh**2 + sigma_ax**2)
L_tend_sms.append(L_tend)
print(f"L_tend of segment {i+1}:\n", L_tend)
# take the derivative of the tendon length with respect to the configuration
J_tend = sp.simplify(sp.Matrix([L_tend]).jacobian(xi))
J_tend_sms.append(J_tend)
print(f"J_tend of segment {i+1}:\n", J_tend)

# update the orientation for the next segment
th_prev = th.subs(s, l[i])

# update the position for the next segment
p_prev = p.subs(s, l[i])

# previous tendon length
L_tend = L_tend.subs(s, l[i])

if simplify_expressions:
# simplify mass matrix
B = sp.simplify(B)
Expand Down Expand Up @@ -197,6 +215,8 @@ def symbolically_derive_planar_pcs_model(
"C": C, # coriolis matrix
"G": G, # gravity vector
"U_g": U_g, # gravitational potential energy
"L_tend_sms": L_tend_sms, # list of tendon lengths for each segment
"J_tend_sms": J_tend_sms, # list of tendon length Jacobians for each segment
},
}

Expand Down
Binary file modified src/jsrm/symbolic_expressions/planar_pcs_ns-1.dill
Binary file not shown.
Binary file modified src/jsrm/symbolic_expressions/planar_pcs_ns-2.dill
Binary file not shown.
4 changes: 3 additions & 1 deletion src/jsrm/systems/planar_pcs.py
Original file line number Diff line number Diff line change
Expand Up @@ -242,6 +242,7 @@ def actuation_mapping_fn(
jacobian_fn: Callable,
params: Dict[str, Array],
B_xi: Array,
xi_eq: Array,
q: Array,
) -> Array:
"""
Expand All @@ -252,6 +253,7 @@ def actuation_mapping_fn(
jacobian_fn: function to compute the Jacobian
params: dictionary with robot parameters
B_xi: strain basis matrix
xi_eq: equilibrium strains as array of shape (n_xi,)
q: configuration of the robot
Returns:
A: actuation matrix of shape (n_xi, n_xi) where n_xi is the number of strains.
Expand Down Expand Up @@ -359,7 +361,7 @@ def dynamical_matrices_fn(
# compute the stiffness matrix
K = stiffness_fn(params, B_xi, formulate_in_strain_space=True)
# compute the actuation matrix
A = actuation_mapping_fn(forward_kinematics_fn, jacobian_fn, params, B_xi, q)
A = actuation_mapping_fn(forward_kinematics_fn, jacobian_fn, params, B_xi, xi_eq, q)

# dissipative matrix from the parameters
D = params.get("D", jnp.zeros((n_xi, n_xi)))
Expand Down
Loading